mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Cherry-Pick][Speculate Decoding][Engine] Cherry-pick #7166, #7349, #7402, #7445 to release/online/20260415 (#7447)
* [Speculate Decoding] Fix step_idx semantics in limit_thinking and set_stop_value kernels (#7166) - speculate_limit_thinking_content_length: update current_base_step to step_idx+1 (step_idx now records history count before current round); remove incorrect step_idx decrement on accept_num truncation; mark step_idx param as const. - speculate_set_stop_value_multi_seqs: fix can_stop gate to use step_idx_now+accept_num>=min_token_limit; fix skip check and pre_ids_idx formula (remove stale -accept_num offset); use <= condition so accept_idx maps directly to the accepted token that ends the stop sequence; fix accept_tokens index (remove -1). - Update unit tests for speculate_set_stop_value_multi_seqs kernel. * [Speculate Decoding] Fix bug of reasoning_phase_token_constraint kernel (#7349) Co-authored-by: guanshihui] <guanshihui@baidu.com> * [Speculate Decoding] Fix reasoning_phase_token_constraint call args in SpeculativeSampler (#7402) * [Interrupt reasoning] Add interrupt_requests control command support --------- Co-authored-by: guanshihui] <guanshihui@baidu.com>
This commit is contained in:
@@ -38,7 +38,7 @@
|
||||
// - In MTP mode, accept_num must be 1 in verify kernel
|
||||
//
|
||||
// Transition condition (x = 1 -> x = 2):
|
||||
// - step_idx >= 3
|
||||
// - step_idx >= 4
|
||||
// - pre_ids[-4:] exactly match:
|
||||
// "\n</think>\n\n"
|
||||
//
|
||||
@@ -83,10 +83,10 @@ __global__ void update_reasoning_status_kernel(
|
||||
int64_t cur_step = step_idx[tid];
|
||||
const int64_t* pre_ids_now =
|
||||
token_ids_all + tid * max_seq_len + prompt_lens[tid];
|
||||
int64_t t0 = (cur_step >= 0) ? pre_ids_now[cur_step] : -1;
|
||||
int64_t t1 = (cur_step >= 1) ? pre_ids_now[cur_step - 1] : -1;
|
||||
int64_t t2 = (cur_step >= 2) ? pre_ids_now[cur_step - 2] : -1;
|
||||
int64_t t3 = (cur_step >= 3) ? pre_ids_now[cur_step - 3] : -1;
|
||||
int64_t t0 = (cur_step >= 1) ? pre_ids_now[cur_step - 1] : -1;
|
||||
int64_t t1 = (cur_step >= 2) ? pre_ids_now[cur_step - 2] : -1;
|
||||
int64_t t2 = (cur_step >= 3) ? pre_ids_now[cur_step - 3] : -1;
|
||||
int64_t t3 = (cur_step >= 4) ? pre_ids_now[cur_step - 4] : -1;
|
||||
|
||||
int32_t new_status = status;
|
||||
|
||||
@@ -104,7 +104,7 @@ __global__ void update_reasoning_status_kernel(
|
||||
// x = 1 -> x = 2 (include think_end_id)
|
||||
// or x = 1 -> x = 3 (not include think_end_id)
|
||||
// Here must be serial judge
|
||||
if (new_status == 1 && cur_step >= 3) {
|
||||
if (new_status == 1 && cur_step >= 4) {
|
||||
if (t3 == line_break_id && t2 == think_end_id && t1 == line_break_id &&
|
||||
t0 == line_break_id) {
|
||||
new_status = 2;
|
||||
|
||||
@@ -34,7 +34,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
int64_t* next_tokens, // [bs, tokens_per_step]
|
||||
const int* max_think_lens, // [bs]
|
||||
int* max_reply_lens, // [bs]
|
||||
int64_t* step_idx, // [bs]
|
||||
const int64_t* step_idx, // [bs]
|
||||
const int64_t* eos_token_ids, // [eos_len]
|
||||
int* limit_status, // [bs]
|
||||
int* accept_num, // [bs]
|
||||
@@ -68,7 +68,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
int new_accept_num = original_accept_num;
|
||||
|
||||
// 本 step 的 token offset 对应的绝对 step
|
||||
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
|
||||
const int64_t current_base_step = step_idx[bid] + 1;
|
||||
|
||||
for (int token_offset = 0; token_offset < original_accept_num;
|
||||
token_offset++) {
|
||||
@@ -100,8 +100,8 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
// inject_token_ids[0])
|
||||
if (status == 0 &&
|
||||
(current_step - 1) ==
|
||||
max_think_len) { // current_step - 1 是因为 speculate_verify 里
|
||||
// step_idx + 1 了
|
||||
max_think_len) { // current_step - 1 : 已输出 current_step-1
|
||||
// 个thinking token
|
||||
status = (inject_len > 0) ? 1 : done_status;
|
||||
}
|
||||
} else if (max_think_len == 0) {
|
||||
@@ -181,13 +181,6 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 step_idx / accept_num(被截断的 token 需要回退
|
||||
// step_idx)
|
||||
const int discarded_tokens = original_accept_num - new_accept_num;
|
||||
if (discarded_tokens > 0) {
|
||||
step_idx[bid] -= discarded_tokens;
|
||||
}
|
||||
|
||||
accept_num[bid] = new_accept_num;
|
||||
limit_status[bid] = status;
|
||||
max_reply_lens[bid] = max_reply_len;
|
||||
@@ -221,7 +214,7 @@ void SpeculateLimitThinkingContentLength(
|
||||
const_cast<int64_t*>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
const_cast<int*>(max_reply_lens.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
step_idx.data<int64_t>(),
|
||||
eos_token_ids.data<int64_t>(),
|
||||
const_cast<int*>(limit_status.data<int>()),
|
||||
const_cast<int*>(accept_num.data<int>()),
|
||||
|
||||
@@ -51,60 +51,65 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
const int64_t step_idx_now = step_idx[bid];
|
||||
const int64_t min_token_limit = min_tokens[bid];
|
||||
|
||||
const bool can_stop = (step_idx_now >= min_token_limit);
|
||||
const bool can_stop = (step_idx_now + accept_num >= min_token_limit);
|
||||
if (!can_stop) return;
|
||||
if (!stop_flags[bid]) {
|
||||
int accept_idx = 0;
|
||||
/*
|
||||
accept_idx 表示 stop_seq 最后 token 在 accept_tokens 中的位置 (0-based)
|
||||
accept_idx = -1 表示 stop_seq 最后 token 在 pre_ids 的末尾
|
||||
(pre_ids[step_idx_now - 1]),即上一轮延迟匹配的最后一个 token。
|
||||
为防止在 stop_seqs 后面追加 eos 越界,跳过 accept_tokens[accept_num-1]
|
||||
(当前轮最后一个 token),该 token 延迟到下一轮匹配。
|
||||
循环范围:accept_num > 0 时为 [-1, accept_num-2];
|
||||
accept_num = 0 时为 [-1](仅检查 pre_ids 末尾)。
|
||||
*/
|
||||
int accept_idx = -1;
|
||||
bool is_end = false;
|
||||
// 遍历起始位置
|
||||
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
|
||||
|
||||
// 统一检测:accept_idx = -1 对应上一轮延迟的最后 token 在 pre_ids 末尾
|
||||
// 完整匹配 stop_seqs 的情况;accept_idx >= 0 对应当前轮 accept_tokens
|
||||
// 中的匹配。两者共享同一套从后向前匹配逻辑。
|
||||
int loop_end = (accept_num > 0) ? accept_num - 2 : -1;
|
||||
for (; accept_idx <= loop_end && !is_end; accept_idx++) {
|
||||
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf("num %d < stop_seq_len %d\n",
|
||||
step_idx_now - accept_num + accept_idx + 1,
|
||||
step_idx_now + accept_idx + 1,
|
||||
stop_seq_len);
|
||||
#endif
|
||||
continue;
|
||||
}
|
||||
// 遍历一个 stop_seqs
|
||||
// 从后向前匹配 stop_seq 的每个 token
|
||||
for (int i = stop_seq_len - 1; i >= 0; --i) {
|
||||
int64_t cur_token_idx = -1;
|
||||
|
||||
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
|
||||
if (stop_seq_len - 1 - i < accept_idx) {
|
||||
int offset = stop_seq_len - 1 - i;
|
||||
int accept_tokens_idx = accept_idx - offset;
|
||||
|
||||
if (accept_tokens_idx >= 0) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf(
|
||||
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
|
||||
"accept_token_idx: "
|
||||
"%d\n",
|
||||
"accept_token_idx: %d\n",
|
||||
bid,
|
||||
tid,
|
||||
accept_idx,
|
||||
accept_idx - (stop_seq_len - 1 - i) - 1);
|
||||
accept_tokens_idx);
|
||||
#endif
|
||||
cur_token_idx =
|
||||
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1];
|
||||
cur_token_idx = accept_tokens_now[accept_tokens_idx];
|
||||
} else {
|
||||
int pre_ids_idx = step_idx_now + accept_tokens_idx;
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf(
|
||||
"PreIds bid:%d. tid:%d, step_idx_now:%ld. "
|
||||
"accept_idx:%d. "
|
||||
"pre_id_idx: %ld\n",
|
||||
"accept_idx:%d. pre_id_idx: %d\n",
|
||||
bid,
|
||||
tid,
|
||||
step_idx_now,
|
||||
accept_idx,
|
||||
step_idx_now - accept_num + accept_idx -
|
||||
(stop_seq_len - 1 - i));
|
||||
pre_ids_idx);
|
||||
#endif
|
||||
int pre_ids_idx =
|
||||
step_idx_now + accept_idx - (stop_seq_len - 1 - i);
|
||||
// EC3
|
||||
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
|
||||
// 导致异常结束
|
||||
if (pre_ids_idx <= 0) {
|
||||
break;
|
||||
}
|
||||
if (pre_ids_idx < 0) break;
|
||||
cur_token_idx = pre_ids_now[pre_ids_idx];
|
||||
}
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
@@ -126,12 +131,11 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
}
|
||||
if (is_end) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf("bid:%d end with accept_idx %d", bid, accept_idx);
|
||||
printf("bid:%d end with accept_idx %d\n", bid, accept_idx);
|
||||
#endif
|
||||
|
||||
accept_nums[bid] = accept_idx;
|
||||
accept_tokens_now[accept_idx - 1] = end_ids[0];
|
||||
// stop_flags[bid] = true;
|
||||
// accept_idx 在循环退出时已递增,指向 stop_seq 最后 token 的下一个位置
|
||||
accept_nums[bid] = accept_idx + 1;
|
||||
accept_tokens_now[accept_idx] = end_ids[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
||||
int64_t *token_ids_all_now =
|
||||
&token_ids_all[batch_id * max_model_len + prompt_len];
|
||||
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
|
||||
int64_t base = cur_step_idx - output_len + 1;
|
||||
int64_t base = cur_step_idx - output_len;
|
||||
for (int i = 0; i < output_len; i++) {
|
||||
token_ids_all_now[base + i] = output_ids[i];
|
||||
}
|
||||
|
||||
@@ -996,7 +996,8 @@ class SpeculativeSampler(nn.Layer):
|
||||
if self.enf_gen_phase_tag:
|
||||
reasoning_phase_token_constraint(
|
||||
logits,
|
||||
sampling_metadata.pre_token_ids,
|
||||
token_ids_all,
|
||||
prompt_lens,
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
|
||||
@@ -105,6 +105,15 @@ class InternalAdapter:
|
||||
with self.response_lock:
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||
|
||||
elif task["cmd"] == "interrupt_requests":
|
||||
self.engine.resource_manager.add_abort_req_ids(task["req_ids"])
|
||||
result = {
|
||||
"task_id": task_id_str,
|
||||
"result": {"success": True, "interrupted_req_ids": task["req_ids"]},
|
||||
}
|
||||
with self.response_lock:
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}")
|
||||
|
||||
|
||||
@@ -37,22 +37,24 @@ class TestReasoningPhaseTokenConstraint(unittest.TestCase):
|
||||
# token_ids_all
|
||||
#
|
||||
# batch 0:
|
||||
# ... \n <think_end> \n \n → status 1 -> 2
|
||||
# step_idx=4, pre_ids_now[0..3]
|
||||
# pattern: \n <think_end> \n \n → status 1 -> 2
|
||||
# t3=pre_ids_now[0]=\n, t2=pre_ids_now[1]=<think_end>, t1=pre_ids_now[2]=\n, t0=pre_ids_now[3]=\n
|
||||
#
|
||||
# batch 1:
|
||||
# contains think_end, but pattern not complete → status 0 -> 1
|
||||
# contains think_end at pre_ids_now[2], but pattern not complete → status 0 -> 1
|
||||
# ------------------------
|
||||
token_ids_all = np.zeros((self.bs, self.max_seq_len), dtype=np.int64)
|
||||
self.prompt_lens = paddle.zeros([self.bs, 1], dtype="int64")
|
||||
|
||||
# batch 0
|
||||
token_ids_all[0, 1] = self.line_break_id
|
||||
token_ids_all[0, 2] = self.think_end_id
|
||||
# batch 0: pattern \n <think_end> \n \n at pre_ids_now[0..3]
|
||||
token_ids_all[0, 0] = self.line_break_id
|
||||
token_ids_all[0, 1] = self.think_end_id
|
||||
token_ids_all[0, 2] = self.line_break_id
|
||||
token_ids_all[0, 3] = self.line_break_id
|
||||
token_ids_all[0, 4] = self.line_break_id
|
||||
|
||||
# batch 1
|
||||
token_ids_all[1, 3] = self.think_end_id
|
||||
# batch 1: think_end at pre_ids_now[2]
|
||||
token_ids_all[1, 2] = self.think_end_id
|
||||
|
||||
self.token_ids_all = paddle.to_tensor(token_ids_all, dtype="int64")
|
||||
self.prompt_lens = paddle.zeros([self.bs, 1], dtype="int64")
|
||||
@@ -167,11 +169,13 @@ class TestReasoningPhaseTokenConstraint(unittest.TestCase):
|
||||
|
||||
# ------------------------
|
||||
# setup: only think_end appears
|
||||
# step_idx=4, pre_ids_now[0..3]
|
||||
# think_end at pre_ids_now[2] (cur_step - 2 = 4 - 2 = 2)
|
||||
# ------------------------
|
||||
token_ids_all = np.zeros((self.bs, self.max_seq_len), dtype=np.int64)
|
||||
|
||||
# batch 0: think_end at cur_step - 1
|
||||
token_ids_all[0, 3] = self.think_end_id
|
||||
# batch 0: think_end at pre_ids_now[2]
|
||||
token_ids_all[0, 2] = self.think_end_id
|
||||
|
||||
# batch 1: no think_end
|
||||
token_ids_all[1, :] = 0
|
||||
@@ -424,13 +428,15 @@ class TestReasoningPhaseTokenConstraint(unittest.TestCase):
|
||||
|
||||
# ------------------------
|
||||
# token_ids_all: force 1 -> 2 pattern
|
||||
# step_idx=4, pre_ids_now[0..3]
|
||||
# pattern: t3=pre_ids_now[0]=\n, t2=pre_ids_now[1]=<think_end>, t1=pre_ids_now[2]=\n, t0=pre_ids_now[3]=\n
|
||||
# ------------------------
|
||||
token_ids_all = np.zeros((bs, max_seq_len), dtype=np.int64)
|
||||
for i in range(bs):
|
||||
token_ids_all[i, 1] = line_break_id
|
||||
token_ids_all[i, 2] = think_end_id
|
||||
token_ids_all[i, 0] = line_break_id
|
||||
token_ids_all[i, 1] = think_end_id
|
||||
token_ids_all[i, 2] = line_break_id
|
||||
token_ids_all[i, 3] = line_break_id
|
||||
token_ids_all[i, 4] = line_break_id
|
||||
|
||||
token_ids_all = paddle.to_tensor(token_ids_all, dtype="int64")
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return paddle_inputs
|
||||
|
||||
|
||||
def run_kernel(paddle_inputs, inputs):
|
||||
def run_kernel(paddle_inputs):
|
||||
"""Call the CUDA kernel."""
|
||||
speculate_set_stop_value_multi_seqs(
|
||||
paddle_inputs["accept_tokens"],
|
||||
@@ -137,7 +137,18 @@ def gen_inputs(
|
||||
|
||||
|
||||
def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Python reference — must match CUDA kernel logic exactly."""
|
||||
"""Python reference — must match CUDA kernel logic exactly.
|
||||
|
||||
token_ids_all 布局 (新 step_idx 语义):
|
||||
pre_ids_now[k] = 第 k 个 output token (k >= 0, 0-indexed)
|
||||
最后一个 output token 在 pre_ids_now[step_idx - 1]
|
||||
step_idx = 历史已生成的 token 数量
|
||||
|
||||
核心设计:
|
||||
1. accept_idx 从 -1 开始,-1 表示检查 pre_ids 末尾(上一轮延迟的情况)
|
||||
2. 主循环检查 accept_idx <= accept_num - 2
|
||||
3. 匹配成功时: 保留 stop_seq 所有 token,在其后追加 eos
|
||||
"""
|
||||
accept_tokens = inputs["accept_tokens"].copy()
|
||||
accept_num = inputs["accept_num"].copy()
|
||||
stop_flags = inputs["stop_flags"].copy()
|
||||
@@ -166,27 +177,36 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
|
||||
step_idx_now = int(step_idx[bid])
|
||||
min_token_limit = int(min_tokens[bid])
|
||||
|
||||
can_stop = step_idx_now >= min_token_limit
|
||||
can_stop = step_idx_now + an >= min_token_limit
|
||||
if not can_stop:
|
||||
continue
|
||||
if stop_flags[bid]:
|
||||
continue
|
||||
|
||||
accept_idx = 0
|
||||
# CUDA kernel: accept_idx 从 -1 开始,检查 pre_ids 末尾
|
||||
accept_idx = -1
|
||||
is_end = False
|
||||
while accept_idx <= an - 1 and not is_end:
|
||||
|
||||
# loop_end = accept_num > 0 ? accept_num - 2 : -1
|
||||
loop_end = an - 2 if an > 0 else -1
|
||||
while accept_idx <= loop_end and not is_end:
|
||||
if step_idx_now + accept_idx + 1 < stop_seq_len:
|
||||
accept_idx += 1
|
||||
continue
|
||||
|
||||
# Check one stop_seq match
|
||||
# 从后向前匹配 stop_seq 的每个 token
|
||||
for i in range(stop_seq_len - 1, -1, -1):
|
||||
offset = stop_seq_len - 1 - i
|
||||
accept_tokens_idx = accept_idx - offset
|
||||
cur_token_idx = -1
|
||||
if stop_seq_len - 1 - i < accept_idx:
|
||||
cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]
|
||||
|
||||
if accept_tokens_idx >= 0:
|
||||
cur_token_idx = accept_tokens_now[accept_tokens_idx]
|
||||
else:
|
||||
pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i)
|
||||
if pre_ids_idx <= 0:
|
||||
# 新语义: pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||
# pre_ids_now[0] 是第 1 个 output token
|
||||
pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||
if pre_ids_idx < 0:
|
||||
break
|
||||
cur_token_idx = pre_ids_now[pre_ids_idx]
|
||||
|
||||
@@ -199,9 +219,10 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
|
||||
accept_idx += 1
|
||||
|
||||
if is_end:
|
||||
accept_num[bid] = accept_idx
|
||||
accept_tokens[bid, accept_idx - 1] = end_ids[0]
|
||||
# stop_flags[bid] = True # kernel no longer sets stop_flags
|
||||
# accept_idx 已递增,指向 stop_seq 最后 token 的下一个位置
|
||||
# 保留 stop_seq 所有 token,在其后追加 eos
|
||||
accept_num[bid] = accept_idx + 1
|
||||
accept_tokens[bid, accept_idx] = end_ids[0]
|
||||
|
||||
return {
|
||||
"accept_tokens": accept_tokens,
|
||||
@@ -239,7 +260,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
|
||||
def _run_and_get(self, inputs):
|
||||
paddle_inputs = to_paddle_inputs(inputs)
|
||||
run_kernel(paddle_inputs, inputs)
|
||||
run_kernel(paddle_inputs)
|
||||
return get_outputs(paddle_inputs)
|
||||
|
||||
def _check_all_outputs(self, inputs, outputs):
|
||||
@@ -264,7 +285,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
self._run_full_test(test_cfg)
|
||||
|
||||
def test_match_in_accept_tokens_only(self):
|
||||
"""Stop seq found entirely within accept_tokens."""
|
||||
"""Stop seq found entirely within accept_tokens. Eos appended after stop_seq last token."""
|
||||
inputs = gen_inputs(real_bsz=1, accept_tokens_len=5, stop_seqs_bs=1, stop_seqs_max_len=3, seed=10)
|
||||
# Place stop seq [A, B, C] at accept_tokens positions [0,1,2]
|
||||
inputs["accept_num"][:] = 4
|
||||
@@ -276,9 +297,13 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# stop_seq [10, 20, 30] matches at accept_idx=2 (window ends at accept_tokens[2]=30)
|
||||
# After loop increment, accept_idx=3, accept_num=4, eos appended at accept_tokens[3]
|
||||
self.assertEqual(outputs["accept_num"][0], 4)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq
|
||||
|
||||
def test_match_spanning_pre_ids_and_accept(self):
|
||||
"""Stop seq spans token_ids_all (pre_ids) and accept_tokens."""
|
||||
"""Stop seq spans token_ids_all (pre_ids) and accept_tokens. Eos appended after stop_seq last token."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
@@ -290,12 +315,15 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 6
|
||||
inputs["accept_num"][:] = 3
|
||||
# Kernel matching at accept_idx=2 (3rd token, 0-indexed):
|
||||
# i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1]=accept_tokens[1]
|
||||
# i=1: stop_seq_len-1-i=1 < accept_idx(2) -> accept_tokens[2-1-1]=accept_tokens[0]
|
||||
# i=0: stop_seq_len-1-i=2 >= accept_idx(2) -> pre_ids[step_idx+2-(3-1-0)]=pre_ids[6]
|
||||
# So stop_seq should be [pre_ids[6], accept_tokens[0], accept_tokens[1]]
|
||||
inputs["token_ids_all"][0, 6] = 99
|
||||
# stop_seq = [99, 11, 22] (len=3)
|
||||
# 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||
# pre_ids_now[k] = 第 k 个 output token (k >= 0)
|
||||
# step_idx = 6 表示有 6 个历史 output token,在 pre_ids_now[0..5]
|
||||
# At accept_idx=1 (window ends at accept_tokens[1]=22):
|
||||
# i=2: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=22 vs stop_seq[2]=22 ✓
|
||||
# i=1: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=11 vs stop_seq[1]=11 ✓
|
||||
# i=0: offset=2, accept_tokens_idx=-1 -> pre_ids_idx=6+(-1)=5 -> pre_ids[5]=99 vs stop_seq[0]=99 ✓
|
||||
inputs["token_ids_all"][0, 5] = 99 # pre_ids_now[5] = 第 6 个 output token (0-indexed)
|
||||
inputs["accept_tokens"][0, :3] = [11, 22, 33]
|
||||
inputs["stop_seqs"][0, 0, :3] = [99, 11, 22]
|
||||
inputs["stop_seqs_len"][0, 0] = 3
|
||||
@@ -303,12 +331,14 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# Match at accept_idx=2, loop increments to 3
|
||||
# Match at accept_idx=1, loop increments to 2 -> accept_num=3, eos at accept_tokens[2]
|
||||
self.assertEqual(outputs["accept_num"][0], 3)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 2], -1)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq
|
||||
|
||||
def test_match_in_pre_ids_only(self):
|
||||
"""Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0."""
|
||||
def test_match_in_pre_ids_only_not_detected(self):
|
||||
"""Stop seq ending purely in pre_ids history but NOT at the end position.
|
||||
The kernel only detects stop_seq at the very end of pre_ids via accept_idx=-1 check.
|
||||
Stop seq placed earlier in pre_ids should not be detected."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
@@ -320,15 +350,13 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 8
|
||||
inputs["accept_num"][:] = 3
|
||||
# pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70
|
||||
# stop_seq = [50, 60, 70], all 3 tokens are in pre_ids
|
||||
# For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check
|
||||
# i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70
|
||||
# i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60
|
||||
# i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50
|
||||
inputs["token_ids_all"][0, 6] = 50
|
||||
inputs["token_ids_all"][0, 7] = 60
|
||||
inputs["token_ids_all"][0, 8] = 70
|
||||
# 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0)
|
||||
# step_idx = 8 表示有 8 个历史 output token,在 pre_ids_now[0..7]
|
||||
# accept_idx=-1 会检查 pre_ids_now[7] 开始的 stop_seq
|
||||
# 把 stop_seq 放在 pre_ids_now[2,3,4] - 不会被检测到
|
||||
inputs["token_ids_all"][0, 2] = 50
|
||||
inputs["token_ids_all"][0, 3] = 60
|
||||
inputs["token_ids_all"][0, 4] = 70
|
||||
inputs["accept_tokens"][0, :3] = [1, 2, 3]
|
||||
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
|
||||
inputs["stop_seqs_len"][0, 0] = 3
|
||||
@@ -336,7 +364,8 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
self.assertEqual(outputs["accept_num"][0], 1)
|
||||
# No match: stop_seq is in pre_ids but not at the end, accept_num unchanged
|
||||
self.assertEqual(outputs["accept_num"][0], 3)
|
||||
|
||||
def test_already_stopped(self):
|
||||
"""Kernel skips sequences with stop_flags=True."""
|
||||
@@ -351,7 +380,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
np.testing.assert_array_equal(outputs["accept_num"], inputs["accept_num"])
|
||||
|
||||
def test_min_tokens_blocks_stop(self):
|
||||
"""Kernel skips stop check when step_idx < min_tokens."""
|
||||
"""Kernel skips stop check when step_idx + accept_num < min_tokens."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
@@ -363,20 +392,24 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 8
|
||||
inputs["accept_num"][:] = 3
|
||||
# Same setup that would match (like test_match_in_pre_ids_only)
|
||||
inputs["token_ids_all"][0, 6] = 50
|
||||
inputs["token_ids_all"][0, 7] = 60
|
||||
inputs["token_ids_all"][0, 8] = 70
|
||||
# Place stop_seq in pre_ids at end position (would be detected by accept_idx=-1)
|
||||
# pre_ids_now[0..7] = 8 个历史 output token
|
||||
# accept_idx=-1 检查 pre_ids_now[5,6,7] 对应 stop_seq[0,1,2]
|
||||
inputs["token_ids_all"][0, 5] = 50
|
||||
inputs["token_ids_all"][0, 6] = 60
|
||||
inputs["token_ids_all"][0, 7] = 70
|
||||
inputs["accept_tokens"][0, :3] = [1, 2, 3]
|
||||
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
|
||||
inputs["stop_seqs_len"][0, 0] = 3
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 100 # step_idx=8 < 100, should NOT stop
|
||||
inputs["min_tokens"][:] = 100 # step_idx+accept_num=11 < 100, should NOT stop
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# min_tokens prevents stop, accept_num unchanged
|
||||
self.assertEqual(outputs["accept_num"][0], 3)
|
||||
|
||||
def test_min_tokens_allows_stop(self):
|
||||
"""Kernel allows stop when step_idx >= min_tokens."""
|
||||
"""Kernel allows stop when step_idx + accept_num >= min_tokens."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
@@ -388,15 +421,17 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 8
|
||||
inputs["accept_num"][:] = 3
|
||||
# Put stop_seq entirely in pre_ids (same pattern as test_match_in_pre_ids_only)
|
||||
inputs["token_ids_all"][0, 6] = 50
|
||||
inputs["token_ids_all"][0, 7] = 60
|
||||
inputs["token_ids_all"][0, 8] = 70
|
||||
inputs["accept_tokens"][0, :3] = [1, 2, 3]
|
||||
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
|
||||
inputs["stop_seqs_len"][0, 0] = 3
|
||||
# stop_seq [X, 50] spans pre_ids and accept_tokens[0].
|
||||
# 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||
# At accept_idx=0 (window ends at accept_tokens[0]=50):
|
||||
# i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=50 vs stop_seq[1]=50 ✓
|
||||
# i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=8+(-1)=7 -> pre_ids[7]
|
||||
pre_val = int(inputs["token_ids_all"][0, 7]) # pre_ids_now[7]
|
||||
inputs["accept_tokens"][0, :3] = [50, 60, 70]
|
||||
inputs["stop_seqs"][0, 0, :2] = [pre_val, 50]
|
||||
inputs["stop_seqs_len"][0, 0] = 2
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 5 # step_idx=8 >= 5, should stop
|
||||
inputs["min_tokens"][:] = 5 # step_idx+accept_num=11 >= 5, should stop
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
|
||||
@@ -413,20 +448,24 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 8
|
||||
inputs["accept_num"][:] = 3
|
||||
# accept_tokens: stop_seq[20,30] matches at accept_idx=2:
|
||||
# i=1: accept_tokens[2-0-1]=accept_tokens[1]=30 vs stop_seq[1]=30 OK
|
||||
# i=0: accept_tokens[2-1-1]=accept_tokens[0]=20 vs stop_seq[0]=20 OK
|
||||
# accept_tokens: [20, 30, 40]
|
||||
# Second stop seq [20, 30] matches at accept_idx=1 (window ends at accept_tokens[1]=30):
|
||||
# i=1: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=30 vs stop_seq[1]=30 ✓
|
||||
# i=0: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=20 vs stop_seq[0]=20 ✓
|
||||
inputs["accept_tokens"][0, :3] = [20, 30, 40]
|
||||
# First stop seq doesn't match
|
||||
inputs["stop_seqs"][0, 0, :3] = [99, 98, 97]
|
||||
inputs["stop_seqs_len"][0, 0] = 3
|
||||
# Second stop seq matches
|
||||
# Second stop seq [20, 30] matches
|
||||
inputs["stop_seqs"][0, 1, :2] = [20, 30]
|
||||
inputs["stop_seqs_len"][0, 1] = 2
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# Match at accept_idx=1 -> accept_num=3, eos at accept_tokens[2]
|
||||
self.assertEqual(outputs["accept_num"][0], 3)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq
|
||||
|
||||
def test_nonzero_prompt_lens(self):
|
||||
"""Verify prompt_lens offset is applied correctly."""
|
||||
@@ -444,19 +483,104 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["accept_num"][:] = 2
|
||||
inputs["accept_tokens"][0, :2] = [55, 66]
|
||||
# pre_ids_now starts at token_ids_all[0, prompt_len:]
|
||||
# stop_seq = [X, 55] where X = token_ids_all[0, prompt_len + step_idx]
|
||||
# For accept_idx=0: pre_ids_idx = step_idx + 0 - (2-1-0) = 5-1 = 4
|
||||
# -> pre_ids_now[4] = token_ids_all[0, prompt_len + 4]
|
||||
# For accept_idx=1 (second token is accept_tokens[0,0]=55):
|
||||
# i=1: accept_tokens_now[1-(2-1-1)-1] = accept_tokens_now[0] = 55
|
||||
# i=0: pre_ids_idx = step_idx + 1 - (2-1-0) = 5+1-1 = 5 -> pre_ids_now[5]
|
||||
target_val = int(inputs["token_ids_all"][0, prompt_len + 5])
|
||||
# pre_ids_now[k] = 第 k 个 output token (k >= 0)
|
||||
# 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||
# stop_seq = [X, 55] where X = pre_ids_now[5 + (-1)] = pre_ids_now[4]
|
||||
# At accept_idx=0 (window ends at accept_tokens[0]=55):
|
||||
# i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=55 vs stop_seq[1]=55 ✓
|
||||
# i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=5+(-1)=4 -> pre_ids[4]=token_ids_all[0, prompt_len+4]
|
||||
target_val = int(inputs["token_ids_all"][0, prompt_len + 4])
|
||||
inputs["stop_seqs"][0, 0, :2] = [target_val, 55]
|
||||
inputs["stop_seqs_len"][0, 0] = 2
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# Match at accept_idx=0 -> accept_num=2, eos at accept_tokens[1]
|
||||
self.assertEqual(outputs["accept_num"][0], 2)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 1], -1) # eos appended after stop_seq
|
||||
|
||||
def test_single_token_stop_seq_preserved(self):
|
||||
"""Single token stop_seq (like <|im_end|>) with eos appended after it."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
max_model_len=32,
|
||||
stop_seqs_bs=1,
|
||||
stop_seqs_max_len=1,
|
||||
seed=90,
|
||||
)
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 10
|
||||
inputs["accept_num"][:] = 4
|
||||
# accept_tokens: [a, b, <|im_end|>, d] where <|im_end|> has token id 999
|
||||
inputs["accept_tokens"][0, :4] = [100, 200, 999, 300]
|
||||
# stop_seq = [<|im_end|>] (single token)
|
||||
inputs["stop_seqs"][0, 0, 0] = 999
|
||||
inputs["stop_seqs_len"][0, 0] = 1
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# Match at accept_idx=2 (window ends at accept_tokens[2]=999)
|
||||
# After loop increment, accept_idx=3, accept_num=4, eos at accept_tokens[3]
|
||||
self.assertEqual(outputs["accept_num"][0], 4)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq
|
||||
|
||||
def test_stop_seq_at_last_position_not_detected(self):
|
||||
"""Stop seq at the last position of accept_tokens is NOT detected (deferred to next round)."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
max_model_len=32,
|
||||
stop_seqs_bs=1,
|
||||
stop_seqs_max_len=1,
|
||||
seed=100,
|
||||
)
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 10
|
||||
inputs["accept_num"][:] = 4
|
||||
# stop_seq [999] is at accept_tokens[3] (last valid position)
|
||||
# Since we only check up to accept_num - 2 = 2, this won't be detected
|
||||
inputs["accept_tokens"][0, :4] = [100, 200, 300, 999]
|
||||
inputs["stop_seqs"][0, 0, 0] = 999
|
||||
inputs["stop_seqs_len"][0, 0] = 1
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# No match because accept_idx only goes up to 2, and 999 is at position 3
|
||||
# accept_num unchanged
|
||||
self.assertEqual(outputs["accept_num"][0], 4)
|
||||
|
||||
def test_stop_seq_detected_from_previous_round(self):
|
||||
"""Stop seq at the end of pre_ids (from previous round) is detected via accept_idx=-1."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
max_model_len=32,
|
||||
stop_seqs_bs=1,
|
||||
stop_seqs_max_len=1,
|
||||
seed=110,
|
||||
)
|
||||
inputs["prompt_lens"][:] = 0
|
||||
# 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0)
|
||||
# step_idx = 10 表示有 10 个历史 output token,在 pre_ids_now[0..9]
|
||||
# accept_idx=-1 检查 pre_ids_now[9] (最后一个历史 token)
|
||||
inputs["step_idx"][:] = 10
|
||||
inputs["token_ids_all"][0, 9] = 999 # pre_ids_now[9] = 第 10 个 output token (0-indexed)
|
||||
inputs["accept_num"][:] = 3
|
||||
inputs["accept_tokens"][0, :3] = [100, 200, 300]
|
||||
inputs["stop_seqs"][0, 0, 0] = 999
|
||||
inputs["stop_seqs_len"][0, 0] = 1
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# stop_seq [999] was in pre_ids at end, accept_idx=-1 matches
|
||||
# After loop increment, accept_idx=0, accept_num=1, eos at accept_tokens[0]
|
||||
self.assertEqual(outputs["accept_num"][0], 1)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 0], -1) # replaced with eos
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -261,7 +261,9 @@ def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Write history to token_ids_all (forward loop, mirrors kernel step 5)
|
||||
if output_len > 0:
|
||||
base_addr = int(prompt_lens[batch_id])
|
||||
base = cur_step_idx - output_len + 1
|
||||
# 新语义: step_idx 入口 = 历史数量,处理后 cur_step_idx = 历史 + output_len
|
||||
# 第一个 output token 写入位置 = cur_step_idx - output_len
|
||||
base = cur_step_idx - output_len
|
||||
for i in range(output_len):
|
||||
write_idx = base_addr + base + i
|
||||
if 0 <= write_idx < max_model_len:
|
||||
|
||||
Reference in New Issue
Block a user