mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculate Decoding] Fix step_idx semantics in limit_thinking and set_stop_value kernels (#7166)
- speculate_limit_thinking_content_length: update current_base_step to step_idx+1 (step_idx now records history count before current round); remove incorrect step_idx decrement on accept_num truncation; mark step_idx param as const. - speculate_set_stop_value_multi_seqs: fix can_stop gate to use step_idx_now+accept_num>=min_token_limit; fix skip check and pre_ids_idx formula (remove stale -accept_num offset); use <= condition so accept_idx maps directly to the accepted token that ends the stop sequence; fix accept_tokens index (remove -1). - Update unit tests for speculate_set_stop_value_multi_seqs kernel.
This commit is contained in:
@@ -34,7 +34,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
int64_t* next_tokens, // [bs, tokens_per_step]
|
||||
const int* max_think_lens, // [bs]
|
||||
int* max_reply_lens, // [bs]
|
||||
int64_t* step_idx, // [bs]
|
||||
const int64_t* step_idx, // [bs]
|
||||
const int64_t* eos_token_ids, // [eos_len]
|
||||
int* limit_status, // [bs]
|
||||
int* accept_num, // [bs]
|
||||
@@ -68,7 +68,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
int new_accept_num = original_accept_num;
|
||||
|
||||
// 本 step 的 token offset 对应的绝对 step
|
||||
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
|
||||
const int64_t current_base_step = step_idx[bid] + 1;
|
||||
|
||||
for (int token_offset = 0; token_offset < original_accept_num;
|
||||
token_offset++) {
|
||||
@@ -100,8 +100,8 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
// inject_token_ids[0])
|
||||
if (status == 0 &&
|
||||
(current_step - 1) ==
|
||||
max_think_len) { // current_step - 1 是因为 speculate_verify 里
|
||||
// step_idx + 1 了
|
||||
max_think_len) { // current_step - 1 : 已输出 current_step-1
|
||||
// 个thinking token
|
||||
status = (inject_len > 0) ? 1 : done_status;
|
||||
}
|
||||
} else if (max_think_len == 0) {
|
||||
@@ -181,13 +181,6 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 step_idx / accept_num(被截断的 token 需要回退
|
||||
// step_idx)
|
||||
const int discarded_tokens = original_accept_num - new_accept_num;
|
||||
if (discarded_tokens > 0) {
|
||||
step_idx[bid] -= discarded_tokens;
|
||||
}
|
||||
|
||||
accept_num[bid] = new_accept_num;
|
||||
limit_status[bid] = status;
|
||||
max_reply_lens[bid] = max_reply_len;
|
||||
@@ -221,7 +214,7 @@ void SpeculateLimitThinkingContentLength(
|
||||
const_cast<int64_t*>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
const_cast<int*>(max_reply_lens.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
step_idx.data<int64_t>(),
|
||||
eos_token_ids.data<int64_t>(),
|
||||
const_cast<int*>(limit_status.data<int>()),
|
||||
const_cast<int*>(accept_num.data<int>()),
|
||||
|
||||
@@ -51,60 +51,65 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
const int64_t step_idx_now = step_idx[bid];
|
||||
const int64_t min_token_limit = min_tokens[bid];
|
||||
|
||||
const bool can_stop = (step_idx_now >= min_token_limit);
|
||||
const bool can_stop = (step_idx_now + accept_num >= min_token_limit);
|
||||
if (!can_stop) return;
|
||||
if (!stop_flags[bid]) {
|
||||
int accept_idx = 0;
|
||||
/*
|
||||
accept_idx 表示 stop_seq 最后 token 在 accept_tokens 中的位置 (0-based)
|
||||
accept_idx = -1 表示 stop_seq 最后 token 在 pre_ids 的末尾
|
||||
(pre_ids[step_idx_now - 1]),即上一轮延迟匹配的最后一个 token。
|
||||
为防止在 stop_seqs 后面追加 eos 越界,跳过 accept_tokens[accept_num-1]
|
||||
(当前轮最后一个 token),该 token 延迟到下一轮匹配。
|
||||
循环范围:accept_num > 0 时为 [-1, accept_num-2];
|
||||
accept_num = 0 时为 [-1](仅检查 pre_ids 末尾)。
|
||||
*/
|
||||
int accept_idx = -1;
|
||||
bool is_end = false;
|
||||
// 遍历起始位置
|
||||
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
|
||||
|
||||
// 统一检测:accept_idx = -1 对应上一轮延迟的最后 token 在 pre_ids 末尾
|
||||
// 完整匹配 stop_seqs 的情况;accept_idx >= 0 对应当前轮 accept_tokens
|
||||
// 中的匹配。两者共享同一套从后向前匹配逻辑。
|
||||
int loop_end = (accept_num > 0) ? accept_num - 2 : -1;
|
||||
for (; accept_idx <= loop_end && !is_end; accept_idx++) {
|
||||
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf("num %d < stop_seq_len %d\n",
|
||||
step_idx_now - accept_num + accept_idx + 1,
|
||||
step_idx_now + accept_idx + 1,
|
||||
stop_seq_len);
|
||||
#endif
|
||||
continue;
|
||||
}
|
||||
// 遍历一个 stop_seqs
|
||||
// 从后向前匹配 stop_seq 的每个 token
|
||||
for (int i = stop_seq_len - 1; i >= 0; --i) {
|
||||
int64_t cur_token_idx = -1;
|
||||
|
||||
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
|
||||
if (stop_seq_len - 1 - i < accept_idx) {
|
||||
int offset = stop_seq_len - 1 - i;
|
||||
int accept_tokens_idx = accept_idx - offset;
|
||||
|
||||
if (accept_tokens_idx >= 0) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf(
|
||||
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
|
||||
"accept_token_idx: "
|
||||
"%d\n",
|
||||
"accept_token_idx: %d\n",
|
||||
bid,
|
||||
tid,
|
||||
accept_idx,
|
||||
accept_idx - (stop_seq_len - 1 - i) - 1);
|
||||
accept_tokens_idx);
|
||||
#endif
|
||||
cur_token_idx =
|
||||
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1];
|
||||
cur_token_idx = accept_tokens_now[accept_tokens_idx];
|
||||
} else {
|
||||
int pre_ids_idx = step_idx_now + accept_tokens_idx;
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf(
|
||||
"PreIds bid:%d. tid:%d, step_idx_now:%ld. "
|
||||
"accept_idx:%d. "
|
||||
"pre_id_idx: %ld\n",
|
||||
"accept_idx:%d. pre_id_idx: %d\n",
|
||||
bid,
|
||||
tid,
|
||||
step_idx_now,
|
||||
accept_idx,
|
||||
step_idx_now - accept_num + accept_idx -
|
||||
(stop_seq_len - 1 - i));
|
||||
pre_ids_idx);
|
||||
#endif
|
||||
int pre_ids_idx =
|
||||
step_idx_now + accept_idx - (stop_seq_len - 1 - i);
|
||||
// EC3
|
||||
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
|
||||
// 导致异常结束
|
||||
if (pre_ids_idx <= 0) {
|
||||
break;
|
||||
}
|
||||
if (pre_ids_idx < 0) break;
|
||||
cur_token_idx = pre_ids_now[pre_ids_idx];
|
||||
}
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
@@ -126,12 +131,11 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
}
|
||||
if (is_end) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf("bid:%d end with accept_idx %d", bid, accept_idx);
|
||||
printf("bid:%d end with accept_idx %d\n", bid, accept_idx);
|
||||
#endif
|
||||
|
||||
accept_nums[bid] = accept_idx;
|
||||
accept_tokens_now[accept_idx - 1] = end_ids[0];
|
||||
// stop_flags[bid] = true;
|
||||
// accept_idx 在循环退出时已递增,指向 stop_seq 最后 token 的下一个位置
|
||||
accept_nums[bid] = accept_idx + 1;
|
||||
accept_tokens_now[accept_idx] = end_ids[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
||||
int64_t *token_ids_all_now =
|
||||
&token_ids_all[batch_id * max_model_len + prompt_len];
|
||||
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
|
||||
int64_t base = cur_step_idx - output_len + 1;
|
||||
int64_t base = cur_step_idx - output_len;
|
||||
for (int i = 0; i < output_len; i++) {
|
||||
token_ids_all_now[base + i] = output_ids[i];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user