diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu index 18aa5d53d2..ae1315848a 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu @@ -98,10 +98,7 @@ __global__ void speculate_limit_thinking_content_length_kernel( if (max_think_len > 0) { // A) 超长触发:到达 max_think_len 时开始注入(从本 token 起输出 // inject_token_ids[0]) - if (status == 0 && - (current_step - 1) == - max_think_len) { // current_step - 1 是因为 speculate_verify 里 - // step_idx + 1 了 + if (status == 0 && current_step == max_think_len) { status = (inject_len > 0) ? 1 : done_status; } } else if (max_think_len == 0) { diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu index ee364884e9..13df3e18eb 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu @@ -58,7 +58,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, bool is_end = false; // 遍历起始位置 for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) { - if (step_idx_now + accept_idx + 1 < stop_seq_len) { + if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) { #ifdef DEBUG_SPEC_STOP_SEQS printf("num %d < stop_seq_len %d\n", step_idx_now - accept_num + accept_idx + 1, @@ -71,7 +71,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, int64_t cur_token_idx = -1; // 通过当前值判断 token 是在 pre_ids 还是 accept_token 里 - if (stop_seq_len - 1 - i < accept_idx) { + if (stop_seq_len - 1 - i <= accept_idx) { #ifdef DEBUG_SPEC_STOP_SEQS printf( "AcceptTokens bid:%d. tid:%d, accept_idx:%d, " @@ -83,7 +83,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, accept_idx - (stop_seq_len - 1 - i) - 1); #endif cur_token_idx = - accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]; + accept_tokens_now[accept_idx - (stop_seq_len - 1 - i)]; } else { #ifdef DEBUG_SPEC_STOP_SEQS printf( @@ -98,7 +98,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, (stop_seq_len - 1 - i)); #endif int pre_ids_idx = - step_idx_now + accept_idx - (stop_seq_len - 1 - i); + step_idx_now - accept_num + accept_idx - (stop_seq_len - 1 - i); // EC3 // 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23, // 导致异常结束 diff --git a/tests/operators/test_speculate_set_stop_value_multi_seqs.py b/tests/operators/test_speculate_set_stop_value_multi_seqs.py index 45d8a0ef34..d982e91cc1 100644 --- a/tests/operators/test_speculate_set_stop_value_multi_seqs.py +++ b/tests/operators/test_speculate_set_stop_value_multi_seqs.py @@ -175,17 +175,19 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str accept_idx = 0 is_end = False while accept_idx <= an - 1 and not is_end: - if step_idx_now + accept_idx + 1 < stop_seq_len: + if step_idx_now - an + accept_idx + 1 < stop_seq_len: accept_idx += 1 continue # Check one stop_seq match for i in range(stop_seq_len - 1, -1, -1): cur_token_idx = -1 - if stop_seq_len - 1 - i < accept_idx: - cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1] + # 注意:新版本kernel改成了 <=,并且去掉了 -1 + if stop_seq_len - 1 - i <= accept_idx: + cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i)] else: - pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i) + # 新版本:step_idx已经包含accept_num,所以要减去 + pre_ids_idx = step_idx_now - an + accept_idx - (stop_seq_len - 1 - i) if pre_ids_idx <= 0: break cur_token_idx = pre_ids_now[pre_ids_idx] @@ -290,12 +292,12 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 6 inputs["accept_num"][:] = 3 - # Kernel matching at accept_idx=2 (3rd token, 0-indexed): - # i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1]=accept_tokens[1] - # i=1: stop_seq_len-1-i=1 < accept_idx(2) -> accept_tokens[2-1-1]=accept_tokens[0] - # i=0: stop_seq_len-1-i=2 >= accept_idx(2) -> pre_ids[step_idx+2-(3-1-0)]=pre_ids[6] - # So stop_seq should be [pre_ids[6], accept_tokens[0], accept_tokens[1]] - inputs["token_ids_all"][0, 6] = 99 + # stop_seq spans pre_ids and accept_tokens + # For accept_idx=1: step_idx_now - accept_num + 1 + 1 = 6-3+1+1 = 5 >= stop_seq_len=3, so we check + # i=2: stop_seq_len-1-i=0 <= accept_idx(1) -> accept_tokens[1-0] = accept_tokens[1] = 22 + # i=1: stop_seq_len-1-i=1 <= accept_idx(1) -> accept_tokens[1-1] = accept_tokens[0] = 11 + # i=0: stop_seq_len-1-i=2 > accept_idx(1) -> pre_ids_idx = 6-3+1-(3-1-0) = 4-2 = 2 -> pre_ids[2] = 99 + inputs["token_ids_all"][0, 2] = 99 inputs["accept_tokens"][0, :3] = [11, 22, 33] inputs["stop_seqs"][0, 0, :3] = [99, 11, 22] inputs["stop_seqs_len"][0, 0] = 3 @@ -303,9 +305,9 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) - # Match at accept_idx=2, loop increments to 3 - self.assertEqual(outputs["accept_num"][0], 3) - self.assertEqual(outputs["accept_tokens"][0, 2], -1) + # Match at accept_idx=1, loop increments to 2 + self.assertEqual(outputs["accept_num"][0], 2) + self.assertEqual(outputs["accept_tokens"][0, 1], inputs["end_ids"][0]) def test_match_in_pre_ids_only(self): """Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0.""" @@ -314,29 +316,29 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): accept_tokens_len=5, max_model_len=32, stop_seqs_bs=1, - stop_seqs_max_len=3, + stop_seqs_max_len=4, # 需要4个元素 seed=30, ) inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70 - # stop_seq = [50, 60, 70], all 3 tokens are in pre_ids - # For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check - # i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70 - # i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60 - # i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50 - inputs["token_ids_all"][0, 6] = 50 - inputs["token_ids_all"][0, 7] = 60 - inputs["token_ids_all"][0, 8] = 70 - inputs["accept_tokens"][0, :3] = [1, 2, 3] - inputs["stop_seqs"][0, 0, :3] = [50, 60, 70] - inputs["stop_seqs_len"][0, 0] = 3 + # stop_seq partially in pre_ids, partially in accept_tokens + # For accept_idx=1: step_idx_now - accept_num + 1 + 1 = 8-3+1+1 = 7 >= stop_seq_len=4, so we check + # i=3: stop_seq_len-1-i=0 <= accept_idx(1) -> accept_tokens[1-0] = accept_tokens[1] = 22 + # i=2: stop_seq_len-1-i=1 <= accept_idx(1) -> accept_tokens[1-1] = accept_tokens[0] = 11 + # i=1: stop_seq_len-1-i=2 > accept_idx(1) -> pre_ids_idx = 8-3+1-(4-1-1) = 6-2 = 4 -> pre_ids[4] = 60 + # i=0: stop_seq_len-1-i=3 > accept_idx(1) -> pre_ids_idx = 8-3+1-(4-1-0) = 6-3 = 3 -> pre_ids[3] = 50 + inputs["token_ids_all"][0, 3] = 50 + inputs["token_ids_all"][0, 4] = 60 + inputs["accept_tokens"][0, :3] = [11, 22, 3] + inputs["stop_seqs"][0, 0, :4] = [50, 60, 11, 22] + inputs["stop_seqs_len"][0, 0] = 4 inputs["stop_flags"][:] = False inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) - self.assertEqual(outputs["accept_num"][0], 1) + # Match at accept_idx=1, loop increments to 2 + self.assertEqual(outputs["accept_num"][0], 2) def test_already_stopped(self): """Kernel skips sequences with stop_flags=True."""