[Speculative Decoding] Fix speculate stop_seqs and fix accept_num in eos branch (#6825)

This commit is contained in:
freeliuzc
2026-03-13 14:48:24 +08:00
committed by GitHub
parent 8906e09e0f
commit 12f412448b
4 changed files with 433 additions and 232 deletions
@@ -58,7 +58,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
bool is_end = false;
// 遍历起始位置
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) {
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf("num %d < stop_seq_len %d\n",
step_idx_now - accept_num + accept_idx + 1,
@@ -98,7 +98,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
(stop_seq_len - 1 - i));
#endif
int pre_ids_idx =
step_idx_now - accept_num + accept_idx - (stop_seq_len - 1 - i);
step_idx_now + accept_idx - (stop_seq_len - 1 - i);
// EC3
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
// 导致异常结束
@@ -131,7 +131,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
accept_nums[bid] = accept_idx;
accept_tokens_now[accept_idx - 1] = end_ids[0];
stop_flags[bid] = true;
// stop_flags[bid] = true;
}
}
}
@@ -116,8 +116,8 @@ struct VerifyContext {
}
// Emit the final token at position `pos` in Phase 2.
// Same EOS/max_dec_len logic, but does NOT increment output_len_now
// (Phase 2's token is already counted in the initial output_len_now=1).
// Same EOS/max_dec_len logic. Increments output_len_now since
// Phase 2 produces one additional token.
__device__ __forceinline__ void emit_final_token(int pos, int64_t token) {
cur_step_idx++;
bool is_eos = is_in_end(token, end_tokens, end_length);
@@ -126,6 +126,7 @@ struct VerifyContext {
token = end_tokens[0];
}
step_output_ids[bid * max_step_tokens + pos] = token;
output_len_now++;
}
// TOPP-only: verify_window bulk-accept fallback.
@@ -278,7 +279,7 @@ __global__ void verify_draft_tokens(
ctx.step_input_ids_now = step_input_ids + bid * max_step_tokens;
ctx.step_output_ids = step_output_ids;
ctx.cur_step_idx = step_idx[bid];
ctx.output_len_now = 1;
ctx.output_len_now = 0;
ctx.stopped = false;
// ======== Phase 1: Verify draft tokens ========