diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index d81cf9989a..e2acb43f03 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -880,7 +880,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& stop_flags, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& is_paused, - const paddle::Tensor& mask_rollback, const paddle::Tensor& token_ids_all, const paddle::Tensor& prompt_lens, const paddle::Tensor& step_idx, @@ -1146,10 +1145,8 @@ std::vector UpdateAttnMaskOffsets( const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& attn_mask_offsets_full, - const paddle::Tensor& attn_mask_offsets_decoder, const paddle::Tensor& is_block_step, - const paddle::Tensor& decode_states, - const paddle::Tensor& mask_rollback); + const paddle::Tensor& decode_states); std::vector FusedNeoxRopeEmbedding( const paddle::Tensor& qkv, diff --git a/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu b/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu index 07865b8835..8bc73bfffc 100644 --- a/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu +++ b/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu @@ -41,7 +41,6 @@ template __global__ void unified_update_model_status_kernel(int *seq_lens_encoder, int *seq_lens_decoder, bool *has_running_seqs, - int *mask_rollback, int64_t *step_input_ids, int64_t *step_output_ids, int *step_output_len, @@ -100,16 +99,12 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder, cur_seq_len_encoder = 0; } else if (cur_seq_len_decoder > 0) { cur_seq_len_decoder += output_len; - mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len; - } else { - mask_rollback[batch_id] = 0; } if (cur_stop_flag) { // It should clear seq_lens_decoder in next step for save_output stop_flag_int = 1; stop_flags[batch_id] = true; - mask_rollback[batch_id] = 0; } // 4. Update model status @@ -174,7 +169,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &stop_flags, const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &is_paused, - const paddle::Tensor &mask_rollback, const paddle::Tensor &token_ids_all, const paddle::Tensor &prompt_lens, const paddle::Tensor &step_idx, @@ -203,7 +197,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder, const_cast(seq_lens_encoder.data()), const_cast(seq_lens_decoder.data()), const_cast(has_running_seqs_gpu.data()), - const_cast(mask_rollback.data()), const_cast(step_input_ids.data()), const_cast(step_output_ids.data()), const_cast(step_output_len.data()), @@ -237,7 +230,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status) "stop_flags", "seq_lens_this_time", "is_paused", - "mask_rollback", "token_ids_all", "prompt_lens", "step_idx", @@ -251,7 +243,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status) "step_output_len_out", "stop_flags_out", "seq_lens_this_time_out", - "mask_rollback_out", "token_ids_all_out", "step_idx_out"}) .SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"}, @@ -262,7 +253,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status) {"step_output_len", "step_output_len_out"}, {"stop_flags", "stop_flags_out"}, {"seq_lens_this_time", "seq_lens_this_time_out"}, - {"mask_rollback", "mask_rollback_out"}, {"token_ids_all", "token_ids_all_out"}, {"step_idx", "step_idx_out"}}) .SetKernelFn(PD_KERNEL(UnifiedUpdateModelStatus)); diff --git a/custom_ops/gpu_ops/speculate_decoding/verify_draft_tokens.cu b/custom_ops/gpu_ops/speculate_decoding/verify_draft_tokens.cu index 5eb58e8ee2..eda2143c27 100644 --- a/custom_ops/gpu_ops/speculate_decoding/verify_draft_tokens.cu +++ b/custom_ops/gpu_ops/speculate_decoding/verify_draft_tokens.cu @@ -129,52 +129,11 @@ struct VerifyContext { output_len_now++; } - // TOPP-only: verify_window bulk-accept fallback. - // - // When draft token is NOT in top-p set but IS the top-2 token, - // check verify_window consecutive positions for top-1 match. - // If all match, bulk-accept from position i through ii. - // - // Returns the new loop position (i) after handling. - // Sets *rejected=true if fallback was not triggered (caller should break). - __device__ __forceinline__ int try_verify_window_fallback( - int i, - bool *rejected, - const int64_t *verify_tokens_now, - int seq_len_this_time, - int max_candidate_len, - int verify_window) { - int ii = i; - if (max_candidate_len >= 2 && - verify_tokens_now[ii * max_candidate_len + 1] == - step_input_ids_now[ii + 1]) { - // top-2 matches — scan verify_window consecutive top-1 matches - int j = 0; - ii += 1; - for (; j < verify_window && ii < seq_len_this_time - 1; j++, ii++) { - if (verify_tokens_now[ii * max_candidate_len] != - step_input_ids_now[ii + 1]) { - break; - } - } - if (j >= verify_window) { - // Bulk accept all tokens from i to ii - for (; i < ii; i++) { - if (emit_token(i, step_input_ids_now[i + 1])) return i; - } - return i; // continue outer loop from position ii - } - } - // Fallback not triggered or insufficient window — reject - *rejected = true; - return i; - } + // ============================================================ + // Phase 2 helpers — sample token for rejected/last position + // ============================================================ }; -// ============================================================ -// Phase 2 helpers — sample token for rejected/last position -// ============================================================ - __device__ inline int64_t topp_sampling_kernel(const int64_t *candidate_ids, const float *candidate_scores, curandState_t *curand_states, @@ -193,7 +152,6 @@ __device__ inline int64_t topp_sampling_kernel(const int64_t *candidate_ids, } return candidate_ids[0]; } - // ============================================================ // Main verification kernel // ============================================================ @@ -201,8 +159,8 @@ __device__ inline int64_t topp_sampling_kernel(const int64_t *candidate_ids, // Input parameter groups by strategy: // - target_tokens: GREEDY=argmax, TARGET_MATCH=sampled, TOPP=unused // (None) -// - candidate_ids/scores: TOPP=full candidate set, GREEDY/TARGET_MATCH=unused -// (None) +// - candidate_ids/scores: TOPP=full candidate set, +// GREEDY/TARGET_MATCH=unused (None) // - candidate_lens: TOPP=actual length per position, // GREEDY/TARGET_MATCH=unused (None) // @@ -250,6 +208,9 @@ __global__ void verify_draft_tokens( // Initialize step_output_len to 0 for ALL slots if (bid < max_bsz) { step_output_len[bid] = 0; + for (int i = 0; i < max_step_tokens; i++) { + step_output_ids[bid * max_step_tokens + i] = -1; + } } else { return; } @@ -307,17 +268,6 @@ __global__ void verify_draft_tokens( accepted = verify_one_topp(candidate_ids_now + i * max_candidate_len, ctx.step_input_ids_now[i + 1], actual_cand_len); - if (!accepted) { - bool rejected = false; - i = ctx.try_verify_window_fallback(i, - &rejected, - candidate_ids_now, - seq_lens_this_time[bid], - max_candidate_len, - verify_window); - if (ctx.stopped || rejected) goto phase1_done; - continue; // bulk accept succeeded, continue from new i - } break; } case 1: // GREEDY @@ -333,7 +283,6 @@ __global__ void verify_draft_tokens( break; // reject } } -phase1_done: // ======== Phase 2: Output token for rejected/last position ======== if (!ctx.stopped) { diff --git a/custom_ops/gpu_ops/update_attn_mask_offsets.cu b/custom_ops/gpu_ops/update_attn_mask_offsets.cu index f928f7d1e4..d59556407c 100644 --- a/custom_ops/gpu_ops/update_attn_mask_offsets.cu +++ b/custom_ops/gpu_ops/update_attn_mask_offsets.cu @@ -21,10 +21,8 @@ __global__ void update_attn_mask_offsets_kernel( const int* seq_lens_decoder, const int* cu_seqlens_q, const int* attn_mask_offsets_full, - int* attn_mask_offsets_decoder, const bool* is_block_step, int* decode_states, - int* mask_rollback, const int real_bsz, const int max_model_len, const int decode_states_len) { @@ -55,15 +53,10 @@ __global__ void update_attn_mask_offsets_kernel( } } } else if (seq_len_decoder > 0) { - // Status: decoder -- normal or chunk_prefill - // TODO: support speculative decoding. - attn_mask_offsets_decoder[bid] -= mask_rollback[bid]; - mask_rollback[bid] = 0; for (int i = 0; i < seq_len_this_time; i++) { attn_mask_offsets[(query_start_id + i) * 2 + 1] = - attn_mask_offsets_decoder[bid] + 1 + i; + seq_len_decoder + 1 + i; } - attn_mask_offsets_decoder[bid] += seq_len_this_time; // Speculative decoding in text_generation for (int i = 0; i < decode_states_len; i++) { @@ -85,10 +78,8 @@ std::vector UpdateAttnMaskOffsets( const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& attn_mask_offsets_full, - const paddle::Tensor& attn_mask_offsets_decoder, const paddle::Tensor& is_block_step, - const paddle::Tensor& decode_states, - const paddle::Tensor& mask_rollback) { + const paddle::Tensor& decode_states) { int max_model_len = attn_mask_offsets_full.shape()[1]; int real_bsz = seq_lens_this_time.shape()[0]; int batch_seq_lens = ids_remove_padding.shape()[0]; @@ -112,10 +103,8 @@ std::vector UpdateAttnMaskOffsets( seq_lens_decoder.data(), cu_seqlens_q.data(), attn_mask_offsets_full.data(), - const_cast(attn_mask_offsets_decoder.data()), is_block_step.data(), const_cast(decode_states.data()), - const_cast(mask_rollback.data()), real_bsz, max_model_len, decode_states_len); @@ -130,11 +119,8 @@ PD_BUILD_STATIC_OP(update_attn_mask_offsets) "seq_lens_decoder", "cu_seqlens_q", "attn_mask_offsets_full", - "attn_mask_offsets_decoder", "is_block_step", - "decode_states", - "mask_rollback"}) - .Outputs({"attn_mask_offsets", "decode_states_out", "mask_rollback_out"}) - .SetInplaceMap({{"decode_states", "decode_states_out"}, - {"mask_rollback", "mask_rollback_out"}}) + "decode_states"}) + .Outputs({"attn_mask_offsets", "decode_states_out"}) + .SetInplaceMap({{"decode_states", "decode_states_out"}}) .SetKernelFn(PD_KERNEL(UpdateAttnMaskOffsets)); diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 2002f47550..c2cab49913 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -515,7 +515,6 @@ def post_process_specualate( model_output.stop_flags, # stop_flags (read-write) model_output.seq_lens_this_time, # seq_lens_this_time model_output.is_block_step, # is_paused - model_output.mask_rollback, # mask_rollback model_output.token_ids_all, # token_ids_all model_output.prompt_lens, # prompt_lens model_output.step_idx, # step_idx (read-write) diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 5868d0bff0..e82c9e9e5b 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -510,9 +510,12 @@ class MTPProposer(Proposer): inputs["attention_mask_offset"][prefill_start_index:prefill_end_index], dtype="int32" ) ) - self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = ( - inputs["attention_mask_offset"][prefill_end_index - 1] + 1 - ) + # GPU don't need it anymore + # NOTE: XPU backend needs decoder attention mask offset; GPU backend does not use it + if current_platform.is_xpu(): + self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = ( + inputs["attention_mask_offset"][prefill_end_index - 1] + 1 + ) if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generates first token @@ -895,10 +898,8 @@ class MTPProposer(Proposer): self.model_inputs["seq_lens_decoder"], cu_seqlens_q, self.model_inputs["attn_mask_offsets_full"], - self.model_inputs["attn_mask_offsets_decoder"], self.model_inputs["is_block_step"], self.model_inputs["decode_states"], - self.model_inputs["mask_rollback"], ) self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False) diff --git a/tests/operators/test_unified_update_model_status.py b/tests/operators/test_unified_update_model_status.py index 03498f52fc..933acd7bd5 100644 --- a/tests/operators/test_unified_update_model_status.py +++ b/tests/operators/test_unified_update_model_status.py @@ -27,9 +27,8 @@ Kernel semantics (from unified_update_model_status.cu): replace non-EOS end token with end_tokens[0], set cur_stop_flag=true. 2. seq_lens update (always executed, even when EOS is hit): encoder > 0 → decoder += encoder, encoder = 0 - decoder > 0 → decoder += output_len, mask_rollback = seq_lens_this_time - output_len - else → mask_rollback = 0 - 3. If cur_stop_flag (EOS hit): stop_flags=true, mask_rollback=0. + decoder > 0 → decoder += output_len + 3. If cur_stop_flag (EOS hit): stop_flags=true. 4. Write back: seq_lens_encoder, seq_lens_decoder, step_output_len, step_idx. 5. Write history to token_ids_all at [prompt_len + base + i] (forward loop). 6. Set step_input_ids[0] = last output token. @@ -81,7 +80,6 @@ def run_kernel(paddle_inputs: Dict[str, Any]): paddle_inputs["stop_flags"], paddle_inputs["seq_lens_this_time"], paddle_inputs["is_paused"], - paddle_inputs["mask_rollback"], paddle_inputs["token_ids_all"], paddle_inputs["prompt_lens"], paddle_inputs["step_idx"], @@ -90,7 +88,7 @@ def run_kernel(paddle_inputs: Dict[str, Any]): ) -# All 11 in-place output keys (from SetInplaceMap in .cu) +# All 10 in-place output keys (from SetInplaceMap in .cu) OUTPUT_KEYS = [ "seq_lens_encoder", "seq_lens_decoder", @@ -100,7 +98,6 @@ OUTPUT_KEYS = [ "step_output_len", "stop_flags", "seq_lens_this_time", - "mask_rollback", "token_ids_all", "step_idx", ] @@ -145,7 +142,6 @@ def gen_inputs( # seq_lens_this_time[batch_id] which is only sized real_bsz stop_flags[real_bsz:] = True is_paused = np.zeros(max_bsz, dtype=bool) - mask_rollback = np.zeros(max_bsz, dtype=np.int32) prompt_lens = rng.integers(10, 50, size=max_bsz, dtype=np.int64) token_ids_all = rng.integers(0, 1000, size=(max_bsz, max_model_len), dtype=np.int64) step_idx = rng.integers(0, 50, size=max_bsz, dtype=np.int64) @@ -168,7 +164,6 @@ def gen_inputs( "stop_flags": stop_flags, "seq_lens_this_time": seq_lens_this_time, "is_paused": is_paused, - "mask_rollback": mask_rollback, "token_ids_all": token_ids_all, "prompt_lens": prompt_lens, "step_idx": step_idx, @@ -198,7 +193,6 @@ def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]: step_output_len = inputs["step_output_len"].copy() stop_flags = inputs["stop_flags"].copy() seq_lens_this_time = inputs["seq_lens_this_time"].copy() - mask_rollback = inputs["mask_rollback"].copy() token_ids_all = inputs["token_ids_all"].copy() step_idx = inputs["step_idx"].copy() step_input_ids = inputs["step_input_ids"].copy() @@ -256,14 +250,10 @@ def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]: cur_seq_len_encoder = 0 elif cur_seq_len_decoder > 0: cur_seq_len_decoder += output_len - mask_rollback[batch_id] = int(seq_lens_this_time[batch_id]) - output_len - else: - mask_rollback[batch_id] = 0 if cur_stop_flag: stop_count += 1 stop_flags[batch_id] = True - mask_rollback[batch_id] = 0 # Write back scalar state seq_lens_encoder[batch_id] = cur_seq_len_encoder @@ -308,7 +298,6 @@ def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]: "step_output_len": step_output_len, "stop_flags": stop_flags, "seq_lens_this_time": seq_lens_this_time, - "mask_rollback": mask_rollback, "token_ids_all": token_ids_all, "step_idx": step_idx, } @@ -495,19 +484,6 @@ class TestUnifiedUpdateModelStatus(unittest.TestCase): outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) - def test_mask_rollback(self): - """mask_rollback = seq_lens_this_time - output_len for running decode slots.""" - inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42) - inputs["stop_flags"][: inputs["real_bsz"]] = False - inputs["is_paused"][:] = False - inputs["seq_lens_encoder"][:] = 0 # All decode slots - inputs["seq_lens_this_time"][:] = [6, 4, 8, 3] - inputs["step_output_len"][:] = [3, 2, 5, 1, 0, 0, 0, 0] - inputs["end_tokens"][:] = [9990, 9991, 9992, 9993] - inputs["max_dec_len"][:] = 10000 - outputs = self._run_and_get(inputs) - self._check_all_outputs(inputs, outputs) - if __name__ == "__main__": unittest.main() diff --git a/tests/operators/test_update_attn_mask.py b/tests/operators/test_update_attn_mask.py index 60c17f3aa2..b320e1cc5e 100644 --- a/tests/operators/test_update_attn_mask.py +++ b/tests/operators/test_update_attn_mask.py @@ -31,10 +31,8 @@ def py_update_attn_mask_offsets_op( seq_lens_decoder, cu_seqlens_q, attn_mask_offsets_full, - attn_mask_offsets_decoder, is_block_step, decode_states, - mask_rollback, ): """ Python-side reference op that mirrors the CUDA kernel you provided (latest version). @@ -42,10 +40,8 @@ def py_update_attn_mask_offsets_op( - seq_lens_*: 1D numpy int32 arrays (len == bsz) - cu_seqlens_q: 1D numpy int32 prefix sums (len == bsz) - attn_mask_offsets_full: numpy array shape (bsz, max_model_len) - - attn_mask_offsets_decoder: 1D numpy int32 (bsz,) - is_block_step: 1D bool array (bsz,) - decode_states: numpy int32 array shape (bsz, decode_states_len) - - mask_rollback: 1D numpy int32 (bsz,) or shape (bsz,1) Returns: attn_mask_offsets_ref (1D int32 length batch_seq_lens * 2), decode_states_ref (bsz x decode_states_len int32) @@ -57,9 +53,7 @@ def py_update_attn_mask_offsets_op( cu_seqlens_q = np.array(cu_seqlens_q, dtype=np.int32).reshape(-1) is_block_step = np.array(is_block_step, dtype=bool).reshape(-1) attn_mask_offsets_full = np.array(attn_mask_offsets_full, dtype=np.int32) - attn_mask_offsets_decoder = np.array(attn_mask_offsets_decoder, dtype=np.int32).reshape(-1) decode_states = np.array(decode_states, dtype=np.int32).copy() - mask_rollback = np.array(mask_rollback, dtype=np.int32).reshape(-1) bsz = int(seq_lens_this_time.shape[0]) total_seq = int(np.sum(seq_lens_this_time)) @@ -100,16 +94,8 @@ def py_update_attn_mask_offsets_op( # decoder path (seq_len_decoder > 0) if seq_len_dec > 0: - # subtract mask rollback - rollback = int(mask_rollback[bid]) if bid < mask_rollback.shape[0] else 0 - attn_mask_offsets_decoder[bid] = int(attn_mask_offsets_decoder[bid]) - rollback - start = int(attn_mask_offsets_decoder[bid]) - for i in range(seq_len_this): - attn_mask_offsets[(query_start + i) * 2 + 1] = start + 1 + i - - # advance decoder offset - attn_mask_offsets_decoder[bid] = int(attn_mask_offsets_decoder[bid]) + seq_len_this + attn_mask_offsets[(query_start + i) * 2 + 1] = seq_len_dec + 1 + i # speculative decoding: if seq_len_this > 1 then set decode_states_now[i] accordingly for i in range(decode_states_len): @@ -152,16 +138,11 @@ class UpdateAttnMaskOffsetsTestCase(unittest.TestCase): # attn_mask_offsets_full: shape (bsz, max_model_len) attn_mask_offsets_full = np.arange(bsz * max_model_len, dtype=np.int32).reshape(bsz, max_model_len) - # attn_mask_offsets_decoder initial (use seq_lens_decoder as seed for deterministic test) - attn_mask_offsets_decoder = np.array(seq_lens_decoder, dtype=np.int32).copy() - # decode_states initial decode_states = np.full((bsz, decode_states_len), -1, dtype=np.int32) if vision_generate: decode_states[:, 0] = 2 # make first element 2 to trigger vision phase - mask_rollback = np.zeros((bsz,), dtype=np.int32) - # ids_remove_padding: length = total_seq (only length used by op) ids_remove_padding = paddle.randint(low=0, high=10, shape=[total_seq], dtype="int32") decode_states_tensor = paddle.to_tensor(decode_states, dtype="int32") @@ -173,10 +154,8 @@ class UpdateAttnMaskOffsetsTestCase(unittest.TestCase): paddle.to_tensor(seq_lens_decoder, dtype="int32"), paddle.to_tensor(cu_seqlens_q, dtype="int32"), paddle.to_tensor(attn_mask_offsets_full, dtype="int32"), - paddle.to_tensor(attn_mask_offsets_decoder, dtype="int32"), paddle.to_tensor(np.array(is_block_step, dtype=bool).reshape(-1), dtype="bool"), decode_states_tensor, - paddle.to_tensor(mask_rollback, dtype="int32"), ) # op returns [attn_mask_offsets, decode_states_out] per your PD_BUILD_STATIC_OP outputs @@ -200,10 +179,8 @@ class UpdateAttnMaskOffsetsTestCase(unittest.TestCase): seq_lens_decoder=seq_lens_decoder, cu_seqlens_q=cu_seqlens_q, attn_mask_offsets_full=attn_mask_offsets_full, - attn_mask_offsets_decoder=attn_mask_offsets_decoder.copy(), is_block_step=np.array(is_block_step, dtype=bool).reshape(-1), decode_states=decode_states.copy(), - mask_rollback=mask_rollback, ) # optionally print debug if env var set @@ -266,7 +243,7 @@ class UpdateAttnMaskOffsetsTestCase(unittest.TestCase): ) def test_decoder_case(self): - # decoder path: should write attn_mask_offsets_decoder - rollback + 1 .. +seq_len_this_time-1 + # decoder path: should write seq_len_decoder + 1 .. + seq_len_this_time - 1 self._call_and_compare( seq_lens_this_time=[2], seq_lens_encoder=[0], diff --git a/tests/operators/test_verify_draft_tokens.py b/tests/operators/test_verify_draft_tokens.py index 2ad7942fc8..8841d65c8c 100644 --- a/tests/operators/test_verify_draft_tokens.py +++ b/tests/operators/test_verify_draft_tokens.py @@ -218,6 +218,10 @@ class _VerifyContext: self.output_len_now += 1 +# NOTE: try_verify_window_fallback was removed from the CUDA kernel. +# TOPP strategy now rejects immediately when draft token is not in candidate set. + + def verify_draft_tokens_ref( step_output_ids, step_output_len, @@ -251,6 +255,8 @@ def verify_draft_tokens_ref( dev_curand_states = [random.Random(0).random() for _ in range(max_step_tokens)] step_output_ids_flat = step_output_ids.reshape(-1) + # Kernel initializes step_output_ids to -1 for all slots + step_output_ids_flat[:] = -1 step_input_ids_flat = step_input_ids.reshape(-1) candidate_ids_flat = candidate_ids.reshape(-1) if candidate_ids is not None else None candidate_scores_flat = candidate_scores.reshape(-1) if candidate_scores is not None else None @@ -302,29 +308,6 @@ def verify_draft_tokens_ref( step_input_ids_now[i + 1], actual_cand_len, ) - if not accepted: - # verify_window fallback - ii = i - if ( - max_candidate_len >= 2 - and candidate_ids_now[ii * max_candidate_len + 1] == step_input_ids_now[ii + 1] - ): - j, ii = 0, ii + 1 - while j < verify_window and ii < seq_lens_this_time[bid] - 1: - if candidate_ids_now[ii * max_candidate_len] != step_input_ids_now[ii + 1]: - break - j += 1 - ii += 1 - if j >= verify_window: - for k in range(i, ii): - if ctx.emit_token(k, step_input_ids_now[k + 1]): - i = k - break - if ctx.stopped: - break - i = ii - continue - break elif verify_strategy in (1, 2): # GREEDY / TARGET_MATCH accepted = target_tokens_now[i] == step_input_ids_now[i + 1] @@ -703,65 +686,6 @@ class TestVerifyDraftTokens(unittest.TestCase): with self.assertRaises(ValueError): VerifyStrategy.from_string("invalid") - def test_topp_verify_window_fallback(self): - """Test TOPP verify_window fallback: top-2 match + consecutive top-1 matches.""" - real_bsz, max_draft_tokens, max_candidate_len, verify_window = 1, 8, 4, 2 - - inputs = gen_verify_draft_tokens_inputs( - real_bsz=real_bsz, - max_draft_tokens=max_draft_tokens, - verify_strategy=VerifyStrategy.TOPP.value, - max_candidate_len=max_candidate_len, - verify_window=verify_window, - seed=42, - ) - - # Rebuild arrays for full seq_lens_this_time - new_slt = max_draft_tokens + 1 - inputs["seq_lens_this_time"] = np.array([new_slt], dtype=np.int32) - inputs["cu_seqlens_q_output"] = np.array([0], dtype=np.int32) - - rng = np.random.default_rng(42) - sum_seq = new_slt - inputs["candidate_ids"] = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64) - inputs["candidate_scores"] = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32) - inputs["candidate_scores"] /= inputs["candidate_scores"].sum(axis=1, keepdims=True) - inputs["candidate_lens"] = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32) - - # Draft tokens - draft_tokens = [100, 200, 300, 400, 500, 600, 700] - for i, token in enumerate(draft_tokens): - inputs["step_input_ids"][0, i + 1] = token - - # Position 0: draft NOT in candidates, but top-2 matches draft - inputs["candidate_ids"][0] = [999, 100, 998, 997] - # Positions 1,2: top-1 matches next draft tokens - inputs["candidate_ids"][1] = [200, 888, 777, 666] - inputs["candidate_ids"][2] = [300, 555, 444, 333] - inputs["candidate_lens"][:3] = 4 - inputs["is_block_step"] = np.zeros(real_bsz, dtype=bool) - - self._run_and_compare(inputs, label="verify_window_fallback") - - def test_topp_verify_window_no_fallback(self): - """Test TOPP when verify_window fallback does NOT trigger.""" - inputs = gen_verify_draft_tokens_inputs( - real_bsz=1, - max_draft_tokens=5, - verify_strategy=VerifyStrategy.TOPP.value, - max_candidate_len=4, - verify_window=2, - seed=42, - ) - - inputs["step_input_ids"][0, 1:] = [999, 998, 997, 996] - inputs["candidate_ids"][:] = 0 - inputs["candidate_ids"][0] = [1, 2, 3, 4] - inputs["candidate_lens"][0] = 4 - inputs["seq_lens_this_time"][0] = 5 - - self._run_and_compare(inputs, label="verify_window_no_fallback") - if __name__ == "__main__": unittest.main() diff --git a/tests/spec_decode/test_mtp_proposer.py b/tests/spec_decode/test_mtp_proposer.py index fe61862770..179dccfe26 100644 --- a/tests/spec_decode/test_mtp_proposer.py +++ b/tests/spec_decode/test_mtp_proposer.py @@ -294,7 +294,6 @@ class TestMTPProposer(unittest.TestCase): proposer.enable_mm = True request1.multimodal_inputs = {"attention_mask_offset": [0, 1, 2, 3, 4]} proposer.model_inputs["attn_mask_offsets_full"] = paddle.zeros([2, 2048], dtype="int32") - proposer.model_inputs["attn_mask_offsets_decoder"] = paddle.zeros([2, 1], dtype="int32") proposer.insert_tasks_v1([request1], 1) @patch("fastdeploy.spec_decode.mtp.get_model_loader") @@ -612,7 +611,6 @@ class TestMTPProposer(unittest.TestCase): ) self.assertIn("attn_mask_offsets", proposer.model_inputs) self.assertIn("attn_mask_offsets_full", proposer.model_inputs) - self.assertIn("attn_mask_offsets_decoder", proposer.model_inputs) @patch("fastdeploy.spec_decode.mtp.get_model_loader") @patch("fastdeploy.spec_decode.mtp.get_attention_backend")