[Speculate Decoding] Fix bug of reasoning_phase_token_constraint kernel (#7349)

Co-authored-by: guanshihui] <guanshihui@baidu.com>
This commit is contained in:
lonelygsh
2026-04-14 20:57:11 +08:00
committed by GitHub
parent 7b0baced17
commit e0a1653b26
2 changed files with 25 additions and 19 deletions
@@ -38,7 +38,7 @@
// - In MTP mode, accept_num must be 1 in verify kernel
//
// Transition condition (x = 1 -> x = 2):
// - step_idx >= 3
// - step_idx >= 4
// - pre_ids[-4:] exactly match:
// "\n</think>\n\n"
//
@@ -83,10 +83,10 @@ __global__ void update_reasoning_status_kernel(
int64_t cur_step = step_idx[tid];
const int64_t* pre_ids_now =
token_ids_all + tid * max_seq_len + prompt_lens[tid];
int64_t t0 = (cur_step >= 0) ? pre_ids_now[cur_step] : -1;
int64_t t1 = (cur_step >= 1) ? pre_ids_now[cur_step - 1] : -1;
int64_t t2 = (cur_step >= 2) ? pre_ids_now[cur_step - 2] : -1;
int64_t t3 = (cur_step >= 3) ? pre_ids_now[cur_step - 3] : -1;
int64_t t0 = (cur_step >= 1) ? pre_ids_now[cur_step - 1] : -1;
int64_t t1 = (cur_step >= 2) ? pre_ids_now[cur_step - 2] : -1;
int64_t t2 = (cur_step >= 3) ? pre_ids_now[cur_step - 3] : -1;
int64_t t3 = (cur_step >= 4) ? pre_ids_now[cur_step - 4] : -1;
int32_t new_status = status;
@@ -104,7 +104,7 @@ __global__ void update_reasoning_status_kernel(
// x = 1 -> x = 2 (include think_end_id)
// or x = 1 -> x = 3 (not include think_end_id)
// Here must be serial judge
if (new_status == 1 && cur_step >= 3) {
if (new_status == 1 && cur_step >= 4) {
if (t3 == line_break_id && t2 == think_end_id && t1 == line_break_id &&
t0 == line_break_id) {
new_status = 2;