diff --git a/.claude/skills/cuda-kernel-unittest.md b/.claude/skills/cuda-kernel-unittest.md new file mode 100644 index 0000000000..600708f578 --- /dev/null +++ b/.claude/skills/cuda-kernel-unittest.md @@ -0,0 +1,174 @@ +# Skill: CUDA Kernel Unit Test + +Write unit tests for PaddlePaddle CUDA custom ops following a modular 4-layer architecture. + +## Trigger + +When the user asks to write/create/add unit tests for a CUDA kernel (`.cu` file in `custom_ops/`). + +## Steps + +1. **Read the CUDA kernel source** to understand: input/output tensors, dtypes, shapes, which tensors are CPU vs GPU, scalar attrs, in-place semantics. +2. **Write the test file** in `tests/operators/test_.py` following the structure below. + +## Test File Structure + +```python +import unittest +from typing import Any, Dict +import numpy as np +import paddle + +# --- Import ops (bypass fastdeploy.__init__) --- +try: + import sys, os + _fd_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + if _fd_root not in sys.path: + sys.path.insert(0, _fd_root) + from fastdeploy.import_ops import import_custom_ops + _package = "fastdeploy.model_executor.ops.gpu" + import_custom_ops(_package, ".fastdeploy_ops", globals()) +except ImportError as e: + print(f"Import error: {e}") + raise + +CUDA_PLACE = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() +CPU_PLACE = paddle.CPUPlace() + + +# ============================================================ +# Layer 1: Helpers — tensor creation / kernel invocation / output extraction +# ============================================================ + +def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Convert numpy dict → paddle tensors. CPU tensors must be explicitly handled.""" + paddle_inputs = {} + for k, v in inputs.items(): + if isinstance(v, (int, bool, float, str)): + paddle_inputs[k] = v + elif k in ("",): # <-- tensors the kernel expects on CPU + paddle_inputs[k] = paddle.to_tensor(v, place=CPU_PLACE) + elif v is not None: + paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE) + else: + paddle_inputs[k] = None + return paddle_inputs + +def run_kernel(paddle_inputs, inputs): + """Call the CUDA kernel with paddle tensors + scalar attrs.""" + kernel_name( + paddle_inputs["tensor_a"], + # ... all tensor args ... + inputs["scalar_attr"], # scalar attrs from raw dict + ) + +def get_outputs(paddle_inputs) -> Dict[str, np.ndarray]: + """Extract ALL in-place-modified tensors back to numpy.""" + keys = ["tensor_a", "tensor_b", ...] + return {k: paddle_inputs[k].numpy() for k in keys} + + +# ============================================================ +# Layer 2: Input generation +# ============================================================ + +def gen__inputs(real_bsz=8, ..., seed=42) -> Dict[str, Any]: + """Generate randomized test inputs. Returns dict with both numpy arrays and scalar configs.""" + rng = np.random.default_rng(seed) + # ... generate all numpy arrays with correct dtypes/shapes ... + return { "tensor_a": ..., "scalar_attr": ..., "real_bsz": real_bsz, ... } + + +# ============================================================ +# Layer 3: Reference implementation (pure Python/NumPy) +# ============================================================ + +def reference_(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Python reference — must match CUDA kernel logic exactly.""" + # Deep-copy all mutable arrays + tensor_a = inputs["tensor_a"].copy() + # ... replicate kernel logic ... + return {"tensor_a": tensor_a, ...} + + +# ============================================================ +# Layer 4a: TEST_CONFIGS — all pure-parameter test scenarios +# ============================================================ + +TEST_CONFIGS = [ + # Each config is a dict of gen__inputs kwargs + a "name" key. + # Pure parameter variations go here — do NOT create separate test methods for them. + # + # --- basic coverage --- + {"name": "small_batch", "real_bsz": 1, "seed": 42, ...}, + {"name": "large_batch", "real_bsz": 64, "seed": 42, ...}, + # --- mode / strategy variants --- + {"name": "mode_a", "real_bsz": 8, "mode": "a", "seed": 42, ...}, + {"name": "mode_b", "real_bsz": 8, "mode": "b", "seed": 42, ...}, + # --- flags --- + {"name": "reject_all", "real_bsz": 8, "reject_all": True, "seed": 42, ...}, + {"name": "accept_all", "real_bsz": 8, "accept_all": True, "seed": 42, ...}, + # --- edge cases --- + {"name": "min_batch", "real_bsz": 1, "max_tokens": 1, "seed": 42, ...}, +] + + +# ============================================================ +# Layer 4b: Test suite +# ============================================================ + +class Test(unittest.TestCase): + + # ------ shared helpers ------ + + def _run_and_get(self, inputs): + paddle_inputs = to_paddle_inputs(inputs) + run_kernel(paddle_inputs, inputs) + return get_outputs(paddle_inputs) + + def _check_all_outputs(self, inputs, outputs): + """Compare ALL output tensors against reference + sanity checks.""" + ref = reference_(inputs) + all_keys = ["tensor_a", "tensor_b", ...] + for key in all_keys: + np.testing.assert_array_equal( + outputs[key], ref[key], err_msg=f"{key} mismatch" + ) + # Add domain-specific sanity checks here + + def _run_full_test(self, config): + inputs = gen__inputs(**config) + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + return outputs + + # ------ test cases ------ + + def test_configs(self): + """Run all TEST_CONFIGS via subTest (one subTest per config).""" + for cfg in TEST_CONFIGS: + with self.subTest(name=cfg["name"]): + test_cfg = {k: v for k, v in cfg.items() if k != "name"} + self._run_full_test(test_cfg) + + # Only keep separate test methods for scenarios that need tensor overrides: + def test_special_scenario(self): + """Scenarios that need manual tensor setup beyond gen_inputs params.""" + inputs = gen__inputs(real_bsz=2, seed=42) + inputs["some_tensor"][0, 2] = special_value # override specific tensor + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + +if __name__ == "__main__": + unittest.main() +``` + +## Key Rules + +1. **CPU vs GPU tensors**: Read the CUDA kernel `.cu` file carefully. If a tensor is `copy_to(place, false)` inside the host function, it's a CPU tensor input — must use `CPU_PLACE` in `to_paddle_inputs`. +2. **`_check_all_outputs` checks ALL tensors**: Every in-place-modified output tensor must be compared against reference. Never scatter `assertEqual`/`assertTrue` across individual test methods — all checks go through `_check_all_outputs`. +3. **Stochastic kernels**: If the kernel uses `curand` (e.g., top-p sampling), compare only deterministic positions. Skip the last sampled token in `compare_results`. Note: `curand_states` in reference should be sized to `max_step_tokens` (position count), not `bsz` (batch count). +4. **TEST_CONFIGS for pure-parameter scenarios**: Any test that only differs by `gen_inputs` parameters belongs in `TEST_CONFIGS`, not a separate `test_*` method. Only create separate methods when you need to **override specific tensor values** after generation. +5. **Test cases are thin**: Each `test_*` method should be 3-15 lines. It either calls `_run_full_test(config)` or does `gen → override → _run_and_get → _check_all_outputs`. +6. **No `fastdeploy.__init__`**: Import ops via `import_custom_ops` directly to avoid heavy dependency chain. +7. **Padding slots**: Kernel may have `max_bsz > real_bsz`. Reference impl must handle padding slots the same way as the kernel (typically no-op or stop_count++). diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 56e0d9cbee..2b2cfe90f5 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -800,6 +800,29 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens, const paddle::Tensor& end_ids, const paddle::Tensor& min_tokens); +void VerifyDraftTokens(const paddle::Tensor& step_output_ids, + const paddle::Tensor& step_output_len, + const paddle::Tensor& step_input_ids, + const paddle::optional& target_tokens, + const paddle::optional& candidate_ids, + const paddle::optional& candidate_scores, + const paddle::optional& candidate_lens, + const paddle::Tensor& topp, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& end_tokens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& reasoning_status, + const paddle::Tensor& max_dec_len, + const paddle::Tensor& step_idx, + int max_seq_len, + int verify_window, + int verify_strategy, + bool reject_all, + bool accept_all); + void SpeculateVerify(const paddle::Tensor& sampled_token_ids, const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, @@ -837,6 +860,25 @@ void SpeculateUpdate(const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& is_block_step, const paddle::Tensor& mask_rollback); +void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& has_running_seqs, + const paddle::Tensor& step_input_ids, + const paddle::Tensor& adaptive_step_input_len, + const paddle::Tensor& step_output_ids, + const paddle::Tensor& step_output_len, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& is_paused, + const paddle::Tensor& mask_rollback, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& end_tokens, + const paddle::Tensor& max_dec_len, + const bool is_naive_mode, + const bool prefill_one_step_stop); + void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor& token_ids_all, const paddle::Tensor& prompt_lens, const paddle::Tensor& accept_tokens, @@ -1675,11 +1717,18 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("speculate_set_stop_value_multi_seqs", &SpecGetStopFlagsMultiSeqs, "speculate_set_stop_value_multi_seqs function"); - m.def("speculate_verify", &SpeculateVerify, "speculate_verify function"); + m.def("verify_draft_tokens", + &VerifyDraftTokens, + "verify_draft_tokens function"); + m.def("speculate_update", &SpeculateUpdate, "Speculate Update Kernel"); + m.def("unified_update_model_status", + &UnifiedUpdateModelStatus, + "unified_update_model_status function"); + m.def("speculate_set_value_by_flags_and_idx", &SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function"); diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_update.cu b/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_update.cu index f217879362..2255de39ef 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_update.cu +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_update.cu @@ -78,6 +78,7 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens, } // multi_end + // TODO(liuzichang): Don't check eos in future if (is_in_end(token_this_time, end_ids, end_ids_len) || prefill_one_step_stop) { stop_flags[tid] = true; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu index 37dabeae07..76a84f30d4 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu @@ -100,6 +100,7 @@ void SpeculateGetLogits(const paddle::Tensor& draft_logits, const_cast(&cu_batch_token_offset.data()[1]), real_bsz, cu_stream); + cudaFree(temp_storage1); void* temp_storage2 = nullptr; size_t temp_storage_bytes2 = 0; @@ -118,6 +119,7 @@ void SpeculateGetLogits(const paddle::Tensor& draft_logits, const_cast(&cu_next_token_offset.data()[1]), real_bsz, cu_stream); + cudaFree(temp_storage2); constexpr int PackSize = VEC_16B / sizeof(float); dim3 grid_dim(real_bsz); @@ -184,7 +186,7 @@ void SpeculateInsertFirstToken(const paddle::Tensor& token_ids, template __global__ void speculate_get_target_logits_kernel( - float* target_logtis, + float* target_logits, const float* logits, const int* cu_batch_token_offset, const int* ori_cu_batch_token_offset, @@ -197,18 +199,18 @@ __global__ void speculate_get_target_logits_kernel( const int bid = blockIdx.x; const int tid = threadIdx.x; if (bid < real_bsz) { - auto* target_logtis_now = - target_logtis + cu_batch_token_offset[bid] * vocab_size; + auto* target_logits_now = + target_logits + cu_batch_token_offset[bid] * vocab_size; auto* logits_now = logits + ori_cu_batch_token_offset[bid] * vocab_size; for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) { if (seq_lens_encoder[bid] > 0) { Load(&logits_now[i], &src_vec); - Store(src_vec, &target_logtis_now[i]); + Store(src_vec, &target_logits_now[i]); } else { for (int j = 0; j < accept_num[bid]; j++) { Load(&logits_now[j * vocab_size + i], &src_vec); Store(src_vec, - &target_logtis_now[j * vocab_size + i]); + &target_logits_now[j * vocab_size + i]); } } } diff --git a/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu b/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu new file mode 100644 index 0000000000..0a16c1cd17 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu @@ -0,0 +1,298 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +/** + * @file unified_update_model_status.cu + * @brief Unified kernel for updating model status after token generation. + * + * Launched as a single block of 1024 threads (max_bsz <= 1024). + */ + +/** + * @brief Check if token is an end token. + */ +__device__ __forceinline__ bool is_end_token(int64_t token, + const int64_t *end_tokens, + int num_end_tokens) { +#pragma unroll 4 + for (int i = 0; i < num_end_tokens; i++) { + if (token == end_tokens[i]) return true; + } + return false; +} + +/** + * @brief Main unified update kernel. + */ +template +__global__ void unified_update_model_status_kernel(int *seq_lens_encoder, + int *seq_lens_decoder, + bool *has_running_seqs, + int *mask_rollback, + int64_t *step_input_ids, + int *adaptive_step_input_len, + int64_t *step_output_ids, + int *step_output_len, + bool *stop_flags, + int *seq_lens_this_time, + const bool *is_paused, + int64_t *token_ids_all, + const int64_t *prompt_lens, + int64_t *step_idx, + const int64_t *end_tokens, + const int64_t *max_dec_len, + int real_bsz, + int max_bsz, + int max_step_tokens, + int max_model_len, + int num_end_tokens, + bool is_naive_mode, + bool prefill_one_step_stop) { + const int batch_id = blockIdx.x * BLOCK_SIZE + threadIdx.x; + const bool is_valid_slot = batch_id < max_bsz; + int stop_flag_int = 0; + + if (is_valid_slot) { + // Read state + int cur_seq_len_encoder = seq_lens_encoder[batch_id]; + int cur_seq_len_decoder = seq_lens_decoder[batch_id]; + bool cur_stop_flag = stop_flags[batch_id]; + int output_len = 0; + int64_t cur_step_idx = step_idx[batch_id]; + bool cur_is_paused = is_paused[batch_id]; + + bool is_running = !cur_stop_flag && !cur_is_paused; + + // Compute output length + if (is_running) { + if (is_naive_mode) { + output_len = 1; + } else { + output_len = step_output_len[batch_id]; + } + } + + // EOS detection + if (is_running && output_len > 0) { + bool hit_stop = false; + int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens]; + + for (int i = 0; i < output_len; i++) { + cur_step_idx++; + int64_t token = output_ids[i]; + bool is_eos = is_end_token(token, end_tokens, num_end_tokens); + bool max_len_hit = (cur_step_idx >= max_dec_len[batch_id]); + + if (is_eos || max_len_hit) { + if (!is_eos) output_ids[i] = end_tokens[0]; + output_len = i + 1; + cur_stop_flag = true; + hit_stop = true; + break; + } + } + + if (!hit_stop && prefill_one_step_stop && cur_seq_len_encoder > 0) { + cur_stop_flag = true; + } + } + + // Update state and write back + if (is_running) { + if (cur_stop_flag) { + stop_flag_int = 1; + if (output_len == 0) cur_seq_len_decoder = 0; + stop_flags[batch_id] = true; + mask_rollback[batch_id] = 0; + } else if (cur_seq_len_encoder == 0) { + cur_seq_len_decoder += output_len; + mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len; + } else { + mask_rollback[batch_id] = 0; + } + + if (cur_seq_len_encoder > 0) { + cur_seq_len_decoder += cur_seq_len_encoder; + cur_seq_len_encoder = 0; + } + + seq_lens_encoder[batch_id] = cur_seq_len_encoder; + seq_lens_decoder[batch_id] = cur_seq_len_decoder; + step_output_len[batch_id] = output_len; + step_idx[batch_id] = cur_step_idx; + + // Write history to token_ids_all + if (cur_step_idx > 0 && output_len > 0) { + // Bounds check: highest write index is prompt_lens + cur_step_idx + if (prompt_lens[batch_id] + cur_step_idx < max_model_len) { + int64_t *token_ids_all_now = + &token_ids_all[batch_id * max_model_len + prompt_lens[batch_id]]; + int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens]; + for (int i = 0; i < output_len; i++) { + token_ids_all_now[cur_step_idx - i] = + output_ids[output_len - 1 - i]; + } + } + } + + // Setup next input + if (output_len > 0) { + step_input_ids[batch_id * max_step_tokens] = + step_output_ids[batch_id * max_step_tokens + output_len - 1]; + } + + if (is_naive_mode) { + seq_lens_this_time[batch_id] = cur_stop_flag ? 0 : 1; + } + } else if (batch_id >= real_bsz) { + // Padding slot: just count as stopped, don't modify state + stop_flag_int = 1; + } else { + // Stopped or paused slot (batch_id < real_bsz) + stop_flag_int = 1; + stop_flags[batch_id] = true; + seq_lens_decoder[batch_id] = 0; + seq_lens_this_time[batch_id] = 0; + step_output_len[batch_id] = 0; + } + } + + // Simple block-level reduction using shared memory + __syncthreads(); + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // printf("stop_flag_now_int %d \n", stop_flag_int); + int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_int); + + if (threadIdx.x == 0) { + // printf("stop_sum %d \n", stop_sum); + has_running_seqs[0] = stop_sum < max_bsz; + } +} + +// Host interface +void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &has_running_seqs, + const paddle::Tensor &step_input_ids, + const paddle::Tensor &adaptive_step_input_len, + const paddle::Tensor &step_output_ids, + const paddle::Tensor &step_output_len, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &is_paused, + const paddle::Tensor &mask_rollback, + const paddle::Tensor &token_ids_all, + const paddle::Tensor &prompt_lens, + const paddle::Tensor &step_idx, + const paddle::Tensor &end_tokens, + const paddle::Tensor &max_dec_len, + const bool is_naive_mode, + const bool prefill_one_step_stop) { + const int real_bsz = seq_lens_this_time.shape()[0]; + const int max_bsz = stop_flags.shape()[0]; + PADDLE_ENFORCE_LE( + max_bsz, + 1024, + phi::errors::InvalidArgument( + "unified_update_model_status: max_bsz (%d) must be <= 1024 " + "(single-block launch limit).", + max_bsz)); + const int max_step_tokens = step_input_ids.shape()[1]; + const int max_model_len = token_ids_all.shape()[1]; + const int num_end_tokens = end_tokens.shape()[0]; + + constexpr int BlockSize = 1024; + + // has_running_seqs is CPU tensor, need to copy to GPU first + auto has_running_seqs_gpu = + has_running_seqs.copy_to(seq_lens_this_time.place(), false); + unified_update_model_status_kernel + <<<1, BlockSize, 0, seq_lens_this_time.stream()>>>( + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(has_running_seqs_gpu.data()), + const_cast(mask_rollback.data()), + const_cast(step_input_ids.data()), + const_cast(adaptive_step_input_len.data()), + const_cast(step_output_ids.data()), + const_cast(step_output_len.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(is_paused.data()), + const_cast(token_ids_all.data()), + prompt_lens.data(), + const_cast(step_idx.data()), + end_tokens.data(), + max_dec_len.data(), + real_bsz, + max_bsz, + max_step_tokens, + max_model_len, + num_end_tokens, + is_naive_mode, + prefill_one_step_stop); + // Copy result back to CPU + auto has_running_seqs_cpu = + has_running_seqs_gpu.copy_to(has_running_seqs.place(), false); + bool *out_data = const_cast(has_running_seqs.data()); + out_data[0] = has_running_seqs_cpu.data()[0]; +} + +PD_BUILD_STATIC_OP(unified_update_model_status) + .Inputs({"seq_lens_encoder", + "seq_lens_decoder", + "has_running_seqs", + "step_input_ids", + "adaptive_step_input_len", + "step_output_ids", + "step_output_len", + "stop_flags", + "seq_lens_this_time", + "is_paused", + "mask_rollback", + "token_ids_all", + "prompt_lens", + "step_idx", + "end_tokens", + "max_dec_len"}) + .Attrs({"is_naive_mode: bool", "prefill_one_step_stop: bool"}) + .Outputs({"seq_lens_encoder_out", + "seq_lens_decoder_out", + "has_running_seqs_out", + "step_input_ids_out", + "adaptive_step_input_len_out", + "step_output_ids_out", + "step_output_len_out", + "stop_flags_out", + "seq_lens_this_time_out", + "mask_rollback_out", + "token_ids_all_out", + "step_idx_out"}) + .SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"has_running_seqs", "has_running_seqs_out"}, + {"step_input_ids", "step_input_ids_out"}, + {"adaptive_step_input_len", "adaptive_step_input_len_out"}, + {"step_output_ids", "step_output_ids_out"}, + {"step_output_len", "step_output_len_out"}, + {"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"mask_rollback", "mask_rollback_out"}, + {"token_ids_all", "token_ids_all_out"}, + {"step_idx", "step_idx_out"}}) + .SetKernelFn(PD_KERNEL(UnifiedUpdateModelStatus)); diff --git a/custom_ops/gpu_ops/speculate_decoding/verify_draft_tokens.cu b/custom_ops/gpu_ops/speculate_decoding/verify_draft_tokens.cu new file mode 100644 index 0000000000..9f51c8ad08 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/verify_draft_tokens.cu @@ -0,0 +1,525 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Verification kernel — outputs step_output_ids + step_output_len, +// and performs EOS / max_dec_len detection (read-only on step_idx). +// step_idx is NOT modified here; all state updates (including step_idx) +// are handled by unified_update_model_status. +// +// Verification strategies: +// 0 = TOPP : draft token in top-p candidate set (+ verify_window +// fallback) 1 = GREEDY : draft token == top-1 token (strict argmax +// match) 2 = TARGET_MATCH : draft token == target model's sampled token + +#include +#include "helper.h" // NOLINT + +// ============================================================ +// Persistent curand state — allocated once, reused across calls. +// Only needed for TOPP strategy (Phase 2 stochastic sampling). +// ============================================================ +static curandState_t *dev_curand_states = nullptr; +static int allocated_bsz = 0; +static uint64_t seed = 0; +static uint64_t offset = 0; + +__global__ void setup_seed_kernel(curandState_t *state, + const uint64_t seed, + const uint64_t offset, + const int bs, + const bool need_batch_random) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { + if (need_batch_random) { + curand_init(seed, i, offset, &state[i]); + } else { + curand_init(seed, 0, offset, &state[i]); + } + } +} + +// ============================================================ +// Phase 1 helpers — single-step draft token verification +// ============================================================ + +// Check if draft_token appears in the candidate set +__device__ inline bool is_in(const int64_t *candidates, + const int64_t draft, + const int candidate_len) { + for (int i = 0; i < candidate_len; i++) { + if (draft == candidates[i]) return true; + } + return false; +} + +// TOPP: draft in top-p filtered candidate set +__device__ inline bool verify_one_topp(const int64_t *verify_tokens_row, + int64_t draft_token, + int actual_cand_len) { + return is_in(verify_tokens_row, draft_token, actual_cand_len); +} + +// GREEDY / TARGET_MATCH: exact single-token match +__device__ inline bool verify_one_match(int64_t target_token, + int64_t draft_token) { + return target_token == draft_token; +} + +// ============================================================ +// VerifyContext — per-batch mutable state + accept helpers. +// Eliminates repeated EOS/max_dec_len check and output write +// patterns across Phase 1 and Phase 2. +// ============================================================ +struct VerifyContext { + // Immutable per-batch (set once at kernel entry) + int bid; + int max_step_tokens; + int end_length; + const int64_t *end_tokens; + const int64_t *max_dec_len; + const int64_t *step_input_ids_now; + int64_t *step_output_ids; + + // Mutable per-batch state + int64_t cur_step_idx; + int output_len_now; + bool stopped; + + // Emit a token at position `pos` to output in Phase 1. + // Performs: step_idx check, EOS detection, token replacement, output write. + // Returns true if this sequence should stop (EOS or max_dec_len hit). + __device__ __forceinline__ bool emit_token(int pos, int64_t token) { + cur_step_idx++; + bool is_eos = is_in_end(token, end_tokens, end_length); + bool max_len_hit = (cur_step_idx >= max_dec_len[bid]); + if ((is_eos || max_len_hit) && !is_eos) { + token = end_tokens[0]; + } + step_output_ids[bid * max_step_tokens + pos] = token; + output_len_now++; + if (is_eos || max_len_hit) { + stopped = true; + return true; + } + return false; + } + + // Emit the final token at position `pos` in Phase 2. + // Same EOS/max_dec_len logic, but does NOT increment output_len_now + // (Phase 2's token is already counted in the initial output_len_now=1). + __device__ __forceinline__ void emit_final_token(int pos, int64_t token) { + cur_step_idx++; + bool is_eos = is_in_end(token, end_tokens, end_length); + bool max_len_hit = (cur_step_idx >= max_dec_len[bid]); + if ((is_eos || max_len_hit) && !is_eos) { + token = end_tokens[0]; + } + step_output_ids[bid * max_step_tokens + pos] = token; + } + + // TOPP-only: verify_window bulk-accept fallback. + // + // When draft token is NOT in top-p set but IS the top-2 token, + // check verify_window consecutive positions for top-1 match. + // If all match, bulk-accept from position i through ii. + // + // Returns the new loop position (i) after handling. + // Sets *rejected=true if fallback was not triggered (caller should break). + __device__ __forceinline__ int try_verify_window_fallback( + int i, + bool *rejected, + const int64_t *verify_tokens_now, + int seq_len_this_time, + int max_candidate_len, + int verify_window) { + int ii = i; + if (max_candidate_len >= 2 && + verify_tokens_now[ii * max_candidate_len + 1] == + step_input_ids_now[ii + 1]) { + // top-2 matches — scan verify_window consecutive top-1 matches + int j = 0; + ii += 1; + for (; j < verify_window && ii < seq_len_this_time - 1; j++, ii++) { + if (verify_tokens_now[ii * max_candidate_len] != + step_input_ids_now[ii + 1]) { + break; + } + } + if (j >= verify_window) { + // Bulk accept all tokens from i to ii + for (; i < ii; i++) { + if (emit_token(i, step_input_ids_now[i + 1])) return i; + } + return i; // continue outer loop from position ii + } + } + // Fallback not triggered or insufficient window — reject + *rejected = true; + return i; + } +}; + +// ============================================================ +// Phase 2 helpers — sample token for rejected/last position +// ============================================================ + +__device__ inline int64_t topp_sampling_kernel(const int64_t *candidate_ids, + const float *candidate_scores, + curandState_t *curand_states, + const int candidate_len, + const float topp) { + // Use bid (blockIdx.x-based) index, not threadIdx.x — curand_states is + // allocated with size bsz, and each batch element uses one thread. + const int bid = blockIdx.x * blockDim.x + threadIdx.x; + float sum_scores = 0.0f; + float rand_top_p = curand_uniform(curand_states + bid) * topp; + for (int i = 0; i < candidate_len; i++) { + sum_scores += candidate_scores[i]; + if (rand_top_p <= sum_scores) { + return candidate_ids[i]; + } + } + return candidate_ids[0]; +} + +// ============================================================ +// Main verification kernel +// ============================================================ +// +// Input parameter groups by strategy: +// - target_tokens: GREEDY=argmax, TARGET_MATCH=sampled, TOPP=unused +// (None) +// - candidate_ids/scores: TOPP=full candidate set, GREEDY/TARGET_MATCH=unused +// (None) +// - candidate_lens: TOPP=actual length per position, +// GREEDY/TARGET_MATCH=unused (None) +// +// All parameters may be empty tensors for strategies that don't use them. +// +__global__ void verify_draft_tokens( + // Core I/O + int64_t *step_output_ids, + int *step_output_len, + const int64_t *step_input_ids, // draft tokens + // Target model outputs (strategy-dependent interpretation) + const int64_t + *target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused + // Candidate set for TOPP/GREEDY (TARGET_MATCH: unused) + const int64_t *candidate_ids, + const float *candidate_scores, + const int *candidate_lens, + // Sampling params + curandState_t *curand_states, // nullptr for GREEDY/TARGET_MATCH + const float *topp, + // Metadata + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_this_time, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection (read-only) + const int64_t *max_dec_len, + const int64_t *step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all) { + const int bid = threadIdx.x; + + // Initialize step_output_len to 0 for ALL slots + if (bid < max_bsz) { + step_output_len[bid] = 0; + } else { + return; + } + + if (bid >= real_bsz || is_block_step[bid] || stop_flags[bid]) return; + + const int start_token_id = cu_seqlens_q_output[bid]; + // Pointers are strategy-dependent (may be nullptr for unused params) + auto *candidate_ids_now = + candidate_ids ? candidate_ids + start_token_id * max_candidate_len + : nullptr; + auto *candidate_scores_now = + candidate_scores ? candidate_scores + start_token_id * max_candidate_len + : nullptr; + auto *candidate_lens_now = + candidate_lens ? candidate_lens + start_token_id : nullptr; + auto *target_tokens_now = + target_tokens ? target_tokens + start_token_id : nullptr; + + // Initialize per-batch verification context + VerifyContext ctx; + ctx.bid = bid; + ctx.max_step_tokens = max_step_tokens; + ctx.end_length = end_length; + ctx.end_tokens = end_tokens; + ctx.max_dec_len = max_dec_len; + ctx.step_input_ids_now = step_input_ids + bid * max_step_tokens; + ctx.step_output_ids = step_output_ids; + ctx.cur_step_idx = step_idx[bid]; + ctx.output_len_now = 1; + ctx.stopped = false; + + // ======== Phase 1: Verify draft tokens ======== + int i = 0; + for (; i < seq_lens_this_time[bid] - 1; i++) { + // Early exit conditions: reject-all, prefill, reasoning + if (reject_all || seq_lens_encoder[bid] != 0 || + reasoning_status[bid] == 1) { + break; + } + + // Accept-all override (debug/warmup) + if (accept_all) { + if (ctx.emit_token(i, ctx.step_input_ids_now[i + 1])) break; + continue; + } + + // Strategy dispatch + bool accepted = false; + switch (verify_strategy) { + case 0: { // TOPP + auto actual_cand_len = candidate_lens_now[i] > max_candidate_len + ? max_candidate_len + : candidate_lens_now[i]; + accepted = verify_one_topp(candidate_ids_now + i * max_candidate_len, + ctx.step_input_ids_now[i + 1], + actual_cand_len); + if (!accepted) { + bool rejected = false; + i = ctx.try_verify_window_fallback(i, + &rejected, + candidate_ids_now, + seq_lens_this_time[bid], + max_candidate_len, + verify_window); + if (ctx.stopped || rejected) goto phase1_done; + continue; // bulk accept succeeded, continue from new i + } + break; + } + case 1: // GREEDY + case 2: // TARGET_MATCH + accepted = verify_one_match(target_tokens_now[i], + ctx.step_input_ids_now[i + 1]); + break; + } + + if (accepted) { + if (ctx.emit_token(i, ctx.step_input_ids_now[i + 1])) break; + } else { + break; // reject + } + } +phase1_done: + + // ======== Phase 2: Output token for rejected/last position ======== + if (!ctx.stopped) { + int64_t output_token; + switch (verify_strategy) { + case 0: { // TOPP — stochastic sampling from candidate set + auto actual_cand_len = candidate_lens_now[i] > max_candidate_len + ? max_candidate_len + : candidate_lens_now[i]; + output_token = + topp_sampling_kernel(candidate_ids_now + i * max_candidate_len, + candidate_scores_now + i * max_candidate_len, + curand_states, + actual_cand_len, + topp[bid]); + break; + } + case 1: // GREEDY — deterministic argmax from target_tokens + case 2: // TARGET_MATCH — target model's sampled token + output_token = target_tokens_now[i]; + break; + } + ctx.emit_final_token(i, output_token); + } + step_output_len[bid] = ctx.output_len_now; +} + +// ============================================================ +// Host function +// ============================================================ +void VerifyDraftTokens( + // Core I/O + const paddle::Tensor &step_output_ids, + const paddle::Tensor &step_output_len, + const paddle::Tensor &step_input_ids, + // Target model outputs (optional, required for TARGET_MATCH) + const paddle::optional &target_tokens, + // Candidate set (optional, required for TOPP/GREEDY) + const paddle::optional &candidate_ids, + const paddle::optional &candidate_scores, + const paddle::optional &candidate_lens, + // Sampling params + const paddle::Tensor &topp, + // Metadata + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &end_tokens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &cu_seqlens_q_output, + const paddle::Tensor &reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection + const paddle::Tensor &max_dec_len, + const paddle::Tensor &step_idx, + int max_seq_len, + int verify_window, + int verify_strategy, + bool reject_all, + bool accept_all) { + auto bsz = step_output_ids.shape()[0]; + auto real_bsz = seq_lens_this_time.shape()[0]; + auto max_step_tokens = step_input_ids.shape()[1]; + auto end_length = end_tokens.shape()[0]; + // max_candidate_len: 1 if candidate_ids not provided, else from shape + int max_candidate_len = candidate_ids ? candidate_ids->shape()[1] : 1; + + constexpr int BlockSize = 1024; + PADDLE_ENFORCE_LE(bsz, + BlockSize, + phi::errors::InvalidArgument( + "verify_draft_tokens: bsz (%d) exceeds BlockSize (%d). " + "Increase BlockSize or reduce max_num_seqs.", + bsz, + BlockSize)); + auto stream = step_output_ids.stream(); + + // curand state: only needed for TOPP(0) strategy (stochastic sampling) + curandState_t *curand_ptr = nullptr; + if (verify_strategy == + 0 /* TOPP only - GREEDY and TARGET_MATCH use deterministic output */) { + if (dev_curand_states == nullptr || bsz > allocated_bsz) { + if (dev_curand_states) cudaFree(dev_curand_states); + cudaMalloc(&dev_curand_states, sizeof(curandState_t) * bsz); + allocated_bsz = bsz; + } + setup_seed_kernel<<<1, BlockSize, 0, stream>>>( + dev_curand_states, seed, offset, bsz, true); + seed++; + offset++; + curand_ptr = dev_curand_states; + } + + // Get data pointers (nullptr if optional not provided) + const int64_t *target_tokens_ptr = + target_tokens ? target_tokens->data() : nullptr; + const int64_t *candidate_ids_ptr = + candidate_ids ? candidate_ids->data() : nullptr; + const float *candidate_scores_ptr = + candidate_scores ? candidate_scores->data() : nullptr; + const int *candidate_lens_ptr = + candidate_lens ? candidate_lens->data() : nullptr; + + // Validate parameters based on verify_strategy. + // Note: empty_input_forward may lead to empty optional tensors — only + // validate when bsz > 0 (i.e. there are active sequences). + if (bsz > 0) { + if (verify_strategy == 0 /* TOPP */) { + if (!candidate_ids_ptr || !candidate_scores_ptr || !candidate_lens_ptr) { + PD_THROW( + "verify_strategy=TOPP (0) requires candidate_ids, " + "candidate_scores, candidate_lens"); + } + } else if (verify_strategy == 1 /* GREEDY */) { + if (!target_tokens_ptr) { + PD_THROW("verify_strategy=GREEDY (1) requires target_tokens (argmax)"); + } + } else if (verify_strategy == 2 /* TARGET_MATCH */) { + if (!target_tokens_ptr) { + PD_THROW( + "verify_strategy=TARGET_MATCH (2) requires target_tokens " + "(sampled)"); + } + } + } + + verify_draft_tokens<<<1, BlockSize, 0, stream>>>( + // Core I/O + const_cast(step_output_ids.data()), + const_cast(step_output_len.data()), + step_input_ids.data(), + // Target model outputs + target_tokens_ptr, + // Candidate set + candidate_ids_ptr, + candidate_scores_ptr, + candidate_lens_ptr, + // Sampling params + curand_ptr, + topp.data(), + // Metadata + stop_flags.data(), + seq_lens_encoder.data(), + seq_lens_this_time.data(), + end_tokens.data(), + is_block_step.data(), + cu_seqlens_q_output.data(), + reasoning_status.data(), + // max_dec_len / step_idx + max_dec_len.data(), + step_idx.data(), + // Dimensions and config + bsz, // max_bsz + real_bsz, // real_bsz + max_step_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + verify_strategy, + reject_all, + accept_all); +} + +PD_BUILD_STATIC_OP(verify_draft_tokens) + .Inputs({"step_output_ids", + "step_output_len", + "step_input_ids", + paddle::Optional("target_tokens"), + paddle::Optional("candidate_ids"), + paddle::Optional("candidate_scores"), + paddle::Optional("candidate_lens"), + "topp", + "stop_flags", + "seq_lens_encoder", + "seq_lens_this_time", + "end_tokens", + "is_block_step", + "cu_seqlens_q_output", + "reasoning_status", + "max_dec_len", + "step_idx"}) + .Outputs({"step_output_ids_out", "step_output_len_out"}) + .Attrs({"max_seq_len: int", + "verify_window: int", + "verify_strategy: int", + "reject_all: bool", + "accept_all: bool"}) + .SetInplaceMap({{"step_output_ids", "step_output_ids_out"}, + {"step_output_len", "step_output_len_out"}}) + .SetKernelFn(PD_KERNEL(VerifyDraftTokens)); diff --git a/docs/features/speculative_decoding.md b/docs/features/speculative_decoding.md index 6a14d57dbb..2ffca5cc58 100644 --- a/docs/features/speculative_decoding.md +++ b/docs/features/speculative_decoding.md @@ -10,7 +10,9 @@ This project implements an efficient **Speculative Decoding** inference framewor ### Supported -- **Ngram** +- **Naive**: Normal decoding mode that uses the speculative decoding code path without generating draft tokens, useful for testing the speculative decoding framework + +- **Ngram**: N-gram matching based speculative decoding - **Suffix Decoding** @@ -54,12 +56,41 @@ This project implements an efficient **Speculative Decoding** inference framewor ## 🔧 Configuration Parameters -- `method`: The speculative decoding strategy, currently supports `["mtp", "ngram", "suffix"]`. +### Basic Parameters + +- `method`: The speculative decoding strategy, supports `["mtp", "ngram", "naive", "suffix"]`. + - `naive`: Normal decoding mode using speculative decoding code path without generating draft tokens + - `ngram`: N-gram matching based speculative decoding + - `mtp`: Multi-Token Prediction + - `suffix`: Suffix decoding based speculative decoding - `num_speculative_tokens`: Number of speculative tokens to generate; max is 5, currently MTP supports only 1. +- `num_model_steps`: MTP model steps, must satisfy `num_speculative_tokens >= num_model_steps` - `model`: Path to the MTP draft model when using the `"mtp"` method. - `quantization`: Quantization method of the MTP model (e.g., WINT4). - Max `batch_size`: 256 +### Verification Strategy (verify_strategy) + +Controls how draft tokens are verified: +- `topp` (default): Top-P sampling verification, draft token must be in top-p candidate set +- `greedy`: Greedy verification, draft token must equal target model's argmax output +- `target_match`: Target match verification, draft token must equal target model's sampled output + +```bash +--speculative-config '{"method": "mtp", "verify_strategy": "greedy", "num_speculative_tokens": 1, "model": "${path_to_mtp_model}"}' +``` + +### Accept Policy (accept_policy) + +Controls draft token acceptance behavior: +- `normal` (default): Normal verification flow +- `accept_all`: Accept all draft tokens (for debugging) +- `reject_all`: Reject all draft tokens (for debugging) + +```bash +--speculative-config '{"method": "mtp", "accept_policy": "accept_all", "num_speculative_tokens": 1}' +``` + --- ## 🚀 Using Multi-Token Prediction (MTP) @@ -161,7 +192,7 @@ python -m fastdeploy.entrypoints.openai.api_server \ --model ${path_to_main_model} \ --tensor-parallel-size 4 \ --config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \ - --speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}' + --speculative-config '{"method": "ngram", "num_speculative_tokens": 1}' ``` @@ -196,3 +227,17 @@ self.suffix_decoding_max_spec_factor: float = 1.0 # The probability threshold for speculated tokens. self.suffix_decoding_min_token_prob: float = 0.1 ``` +--- + +## 📝 Using Naive Mode (Normal Decoding) + +Naive mode uses the speculative decoding code path without generating draft tokens, useful for testing the correctness of the speculative decoding framework or establishing performance baselines. + +```bash +python -m fastdeploy.entrypoints.openai.api_server \ + --model ${path_to_main_model} \ + --tensor-parallel-size 4 \ + --speculative-config '{"method": "naive", "num_speculative_tokens": 1}' +``` + +**Note**: In Naive mode, `num_speculative_tokens` will be forced to 0. diff --git a/docs/zh/features/speculative_decoding.md b/docs/zh/features/speculative_decoding.md index c1401b0beb..3392f0634d 100644 --- a/docs/zh/features/speculative_decoding.md +++ b/docs/zh/features/speculative_decoding.md @@ -6,7 +6,9 @@ ## ✅ 投机解码方法支持 ### ✅ 支持列表 -- **Ngram** +- **Naive**: 普通解码模式,走投机解码代码路径但不生成草稿Token,用于测试投机解码框架的正确性 + +- **Ngram**: 基于n-gram匹配的投机解码方法 - **后缀解码** @@ -38,12 +40,39 @@ - **高效 DraftModel/MTP 框架**:开发多个融合 Cuda Kernel,统一完成模型类方法的前后处理,相比传统的循环、切片方法,性能高效且易维护 ## 🔧 参数说明 -- `method`: 解码策略,可选值为 `"mtp"` 、 `"ngram"` 或 `"suffix"` + +### 基础参数 +- `method`: 解码策略,可选值为 `"mtp"`、`"ngram"`、`"naive"` 或 `"suffix"` + - `naive`: 普通解码模式,走投机解码代码路径但不生成草稿Token + - `ngram`: 基于n-gram匹配的投机解码 + - `mtp`: 多Token预测(Multi-Token Prediction) + - `suffix`: 基于后缀解码的投机解码 - `num_speculative_tokens`: 每轮预测的 Token 数,最大支持 5(当前 MTP 仅支持 1) +- `num_model_steps`: MTP 模型步数,需满足 `num_speculative_tokens >= num_model_steps` - `model`: 若选择 MTP,则需指定 MTP 模型路径 - `quantization`: 模型量化方式,推荐使用 `wint8` - `batch_size`: 当前支持最大值为 256 +### 验证策略参数 (verify_strategy) +控制草稿Token的验证方式: +- `topp` (默认): Top-P采样验证,草稿Token需在Top-P候选集中 +- `greedy`: 贪婪验证,草稿Token需等于目标模型的argmax输出 +- `target_match`: 目标匹配验证,草稿Token需等于目标模型的采样输出 + +```bash +--speculative-config '{"method": "mtp", "verify_strategy": "greedy", "num_speculative_tokens": 1, "model": "${path_to_mtp_model}"}' +``` + +### 接受策略参数 (accept_policy) +控制草稿Token的接受行为: +- `normal` (默认): 正常验证流程 +- `accept_all`: 接受所有草稿Token(调试用) +- `reject_all`: 拒绝所有草稿Token(调试用) + +```bash +--speculative-config '{"method": "mtp", "accept_policy": "accept_all", "num_speculative_tokens": 1}' +``` + ## 🚀 使用 Multi-Token-Prediction(MTP) 解码 详见论文:[DeepSeek-V3](https://arxiv.org/pdf/2412.19437) ### TP 并行部署 @@ -133,9 +162,21 @@ python -m fastdeploy.entrypoints.openai.api_server \ --model ${path_to_main_model} \ --tensor-parallel-size 4 \ --config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \ - --speculative-config '{"method": "ngram", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}' + --speculative-config '{"method": "ngram", "num_speculative_tokens": 1}' ``` +## 📝 使用 Naive 模式(普通解码) +Naive 模式走投机解码代码路径但不生成草稿 Token,用于测试投机解码框架的正确性或对比性能基线。 + +``` +python -m fastdeploy.entrypoints.openai.api_server \ + --model ${path_to_main_model} \ + --tensor-parallel-size 4 \ + --speculative-config '{"method": "naive"}' +``` + +**注意**: Naive 模式下 `num_speculative_tokens` 会被强制设置为 0。 + ## 🌲 使用后缀解码 (Suffix Decoding) 后缀解码是一种无模型推理框架,通过在 CPU 上使用高效后缀树进行快速草稿 Token 预测,加速重复性推理任务(如代理工作流程、编码等),消除 GPU 开销。 @@ -149,7 +190,7 @@ python -m fastdeploy.entrypoints.openai.api_server \ --model ${path_to_main_model} \ --tensor-parallel-size 4 \ --config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \ - --speculative-config '{"method": "mtp", "num_speculative_tokens": 4, "suffix_decoding_max_tree_depth": 64, "suffix_decoding_max_cached_requests": 10000, "suffix_decoding_max_spec_factor": 1.0, "suffix_decoding_min_token_prob": 0.1}' + --speculative-config '{"method": "suffix", "num_speculative_tokens": 4, "suffix_decoding_max_tree_depth": 64, "suffix_decoding_max_cached_requests": 10000, "suffix_decoding_max_spec_factor": 1.0, "suffix_decoding_min_token_prob": 0.1}' ``` 参数描述 diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 4ebfd4584b..264f385563 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -33,6 +33,7 @@ from fastdeploy import envs from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase from fastdeploy.platforms import current_platform from fastdeploy.scheduler import SchedulerConfig +from fastdeploy.spec_decode import SpecMethod from fastdeploy.transformer_utils.config import get_pooling_config from fastdeploy.utils import ( ceil_div, @@ -672,6 +673,10 @@ class ParallelConfig: else: self.pd_disaggregation_mode = "None" + # Prefill node one step stop (PD disaggregation specific) + # When enabled, prefill node stops after one decoding step + self.prefill_one_step_stop: bool = os.getenv("PREFILL_NODE_ONE_STEP_STOP", "0") == "1" + # disable_sequence_parallel_moe: qkv_linear + attn + out_linear + allreduce # use_sequence_parallel_moe: allgather + qkv_linear + attn + all2all + out_linear self.use_sequence_parallel_moe = ( @@ -719,69 +724,118 @@ class SpeculativeConfig: Configuration for speculative decoding. """ + # Class-level default values for all config options + _DEFAULTS = { + "method": None, + "mtp_strategy": "default", + "num_speculative_tokens": 1, + "num_model_steps": 1, + "max_candidate_len": 5, + "verify_window": 2, + "max_ngram_size": 5, + "min_ngram_size": 2, + # Suffix Decoding + "suffix_decoding_max_tree_depth": 64, + "suffix_decoding_max_cached_requests": -1, + "suffix_decoding_max_spec_factor": 1.0, + "suffix_decoding_min_token_prob": 0.1, + "model": None, + "quantization": None, + "num_gpu_block_expand_ratio": 1.0, + "model_type": "main", + "sharing_model": None, + "benchmark_mode": False, + "enf_gen_phase_tag": False, + "enable_draft_logprob": False, + "verify_strategy": "topp", + "accept_policy": "normal", + } + + # Environment variable to config mapping for backward compatibility + # Format: env_var: (config_key, value_when_set) + _ENV_OVERRIDES = { + "SPECULATE_VERIFY_USE_TOPK": ("verify_strategy", "greedy"), + "SPECULATE_VERIFY_USE_TARGET_SAMPLING": ("verify_strategy", "target_match"), + } + def __init__( self, args, ): - self.method_list = ["ngram_match", "mtp", "suffix"] + # Valid value lists (not defaults, but valid options) + self.method_list = ["ngram", "mtp", "naive", "suffix"] self.mtp_strategy_list = ["default", "with_ngram"] - # speculative method, choose in [None, "ngram_match", "mtp", "hybrid_mtp_ngram"] - self.method: Optional[str] = None - # mtp strategy in mtp-method - self.mtp_strategy = "default" - # the max length of speculative tokens - self.num_speculative_tokens: int = 1 - # the model runner step of draft model/mtp... - self.num_model_steps: int = 1 - # the max length of candidate tokens for speculative method - self.max_candidate_len: int = 5 - # the max length of verify window for speculative method - self.verify_window: int = 2 - # ngram match - self.max_ngram_size: int = 5 - self.min_ngram_size: int = 2 - # Suffix Decoding - # The maximum length of token sequences cached in suffix trees. - self.suffix_decoding_max_tree_depth: int = 64 - # The limits of requests that can be stored in the cache. - self.suffix_decoding_max_cached_requests: int = -1 - # The factor of matched length, calculated as num_draft_tokens = suffix_max_spec_factor * matched_length - self.suffix_decoding_max_spec_factor: float = 1.0 - # The probability threshold for speculated tokens. - self.suffix_decoding_min_token_prob: float = 0.1 - # model for mtp/eagle/draft_model - self.model: Optional[str] = None - # quantization of model - self.quantization: Optional[Dict[str, Any]] = None - # allocate more blocks to prevent mtp from finishing the block earlier than the main model - # Fixed now - self.num_gpu_block_expand_ratio: Optional[float] = 1 - # To distinguish the main model and draft model(mtp/eagle/draftmodel) - # ["main", "mtp"] - self.model_type: Optional[str] = "main" - # TODO(liuzichang): To reduce memory usage, MTP shares the main model's lm_head and embedding layers. - # A trick method is currently used to enable this sharing. - # This will be replaced with a more standardized solution in the future. - self.sharing_model = None - # During benchmarking, we need to enforce that the number of accepted tokens is 1. - # This means no tokens from MTP are accepted. - # This ensures that the specified simulation acceptance rate is not affected. - self.benchmark_mode: bool = False - # Enable token constraint enforcement in generation phase - # When enabled, enforces specific tokens after the reasoning phase boundary pattern - self.enf_gen_phase_tag: bool = False + # Initialize from defaults + self._init_from_defaults() + # Apply user-provided arguments (highest priority) + self._apply_user_args(args) + + # Read model config (overrides defaults but not user args) + self.read_model_config() + self._apply_model_config() + + # Apply environment variable overrides (backward compatibility) + self._apply_env_overrides(args) + + # Initialize computed fields self.num_extra_cache_layer = 0 - self.enable_draft_logprob: bool = False + # Convert and validate all parameters + self._convert_and_validate() + def _init_from_defaults(self): + """Initialize all config options from class defaults.""" + for key, value in self._DEFAULTS.items(): + setattr(self, key, value) + + def _apply_user_args(self, args: Dict[str, Any]): + """Apply user-provided arguments.""" + if args is None: + return for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) - self.read_model_config() - self.reset() + def _apply_model_config(self): + """Apply configuration from model config file.""" + if not self.enabled_speculative_decoding(): + return + if self.model is None: + return + + # Model config can override certain defaults + # Currently no automatic overrides, but can be extended here + pass + + def _apply_env_overrides(self, user_args: Dict[str, Any]): + """ + Apply environment variable overrides for backward compatibility. + Only applies if user hasn't explicitly set the corresponding config. + """ + for env_var, (config_key, env_value) in self._ENV_OVERRIDES.items(): + if os.environ.get(env_var, "0") == "1": + # Only apply if user didn't explicitly set this config + if user_args is None or config_key not in user_args: + setattr(self, config_key, env_value) + + def _convert_and_validate(self): + """ + Convert string configs to enums and validate all parameters. + """ + # Convert method from string to SpecMethod enum + if self.method is not None: + from fastdeploy.spec_decode import SpecMethod + + self.method = SpecMethod.from_string(self.method) + + # Set method-specific computed values + if self.method == SpecMethod.MTP: + self.num_extra_cache_layer = 1 + + # Run validation (includes dependency validation) + self.check_legality_parameters() def read_model_config(self): """ @@ -799,24 +853,6 @@ class SpeculativeConfig: if os.path.exists(self.config_path): self.model_config = json.load(open(self.config_path, "r", encoding="utf-8")) - def reset(self): - """ - Reset configuration. - """ - - def reset_value(cls, value_name, key=None, default=None): - if key is not None and key in cls.model_config: - setattr(cls, value_name, cls.model_config[key]) - elif getattr(cls, value_name, None) is None: - setattr(cls, value_name, default) - - if not self.enabled_speculative_decoding(): - return - - # NOTE(liuzichang): We will support multi-layer in future - if self.method in ["mtp"]: - self.num_extra_cache_layer = 1 - def enabled_speculative_decoding(self): """ Check if speculative decoding is enabled. @@ -846,18 +882,21 @@ class SpeculativeConfig: ) -> None: """Check the legality of parameters passed in from the command line""" if self.method is not None: - assert ( - self.method in self.method_list - ), f"speculative method only support {self.method_list} now, but get {self.method}." + from fastdeploy.spec_decode import SpecMethod - assert ( - self.num_speculative_tokens >= 1 and self.num_speculative_tokens <= 5 - ), f"num_speculative_tokens only support in range[1, 5], but get {self.num_speculative_tokens}." - assert ( - self.num_model_steps >= 1 and self.num_model_steps <= 5 - ), f"num_model_steps only support in range[1, 5], but get {self.num_model_steps}." + assert self.method in [ + m.value for m in SpecMethod + ], f"speculative method only support {[m.value for m in SpecMethod]} now, but get {self.method}." - if self.method in ["mtp", "hybrid_mtp_ngram"]: + if self.method != SpecMethod.NAIVE: + assert ( + self.num_speculative_tokens >= 1 and self.num_speculative_tokens <= 5 + ), f"num_speculative_tokens only support in range[1, 5], but get {self.num_speculative_tokens}." + assert ( + self.num_model_steps >= 1 and self.num_model_steps <= 5 + ), f"num_model_steps only support in range[1, 5], but get {self.num_model_steps}." + + if self.method == SpecMethod.MTP: if self.num_speculative_tokens < self.num_model_steps: logger.warning( f"Get num_model_steps > num_speculative_tokens. Reset num_speculative_tokens to {self.num_model_steps}" @@ -868,6 +907,79 @@ class SpeculativeConfig: self.mtp_strategy in self.mtp_strategy_list ), f"mtp_strategy_list only support {self.mtp_strategy_list}, but get {self.mtp_strategy}" + # Validate verify strategy and accept policy + # Support case-insensitive input for better user experience + from fastdeploy.spec_decode import VerifyStrategy + + if not isinstance(self.verify_strategy, VerifyStrategy): + # Handle both string and int inputs + if isinstance(self.verify_strategy, int): + # If it's already an int (enum value), convert directly + self.verify_strategy = VerifyStrategy(self.verify_strategy) + else: + # Assume it's a string + self.verify_strategy = VerifyStrategy.from_string(self.verify_strategy) + + # Support case-insensitive accept_policy + valid_accept_policies = ["normal", "accept_all", "reject_all"] + accept_policy_lower = self.accept_policy.lower() + assert ( + accept_policy_lower in valid_accept_policies + ), f"accept_policy must be one of {valid_accept_policies} (case-insensitive), but got '{self.accept_policy}'." + self.accept_policy = accept_policy_lower + + # Validate parameter dependencies after basic validation + self._validate_dependencies() + + def _validate_dependencies(self) -> None: + """ + Validate parameter dependencies across different speculative methods. + Called by check_legality_parameters after basic validation. + """ + if not self.enabled_speculative_decoding(): + return + + from fastdeploy.spec_decode import SpecMethod + + # Define parameter constraints for each speculative method + # Each constraint is a tuple: (dependent_param, operator, expected_relation) + constraints = { + SpecMethod.MTP: [ + { + "check": lambda: self.num_speculative_tokens >= self.num_model_steps, + "message": f"MTP requires num_speculative_tokens >= num_model_steps, " + f"but got {self.num_speculative_tokens} < {self.num_model_steps}", + "auto_fix": lambda: setattr(self, "num_speculative_tokens", self.num_model_steps), + } + ], + SpecMethod.NGRAM: [ + { + "check": lambda: self.max_ngram_size >= self.min_ngram_size, + "message": f"NGRAM requires max_ngram_size >= min_ngram_size, " + f"but got {self.max_ngram_size} < {self.min_ngram_size}", + "auto_fix": None, # Cannot auto-fix, user must adjust + } + ], + SpecMethod.NAIVE: [ + { + "check": lambda: self.num_speculative_tokens == 0, + "message": f"NAIVE mode requires num_speculative_tokens == 0, " + f"but got {self.num_speculative_tokens}. Resetting to 0.", + "auto_fix": lambda: setattr(self, "num_speculative_tokens", 0), + } + ], + } + + if self.method in constraints: + method_constraints = constraints[self.method] + for constraint in method_constraints: + if not constraint["check"](): + if constraint["auto_fix"] is not None: + logger.warning(constraint["message"] + " Applying auto-fix.") + constraint["auto_fix"]() + else: + raise ValueError(constraint["message"]) + def __str__(self) -> str: return self.to_json_string() @@ -1710,7 +1822,10 @@ class FDConfig: # Initialize cuda graph capture list max_capture_shape = self.scheduler_config.max_num_seqs - if self.speculative_config is not None and self.speculative_config.method in ["mtp", "suffix"]: + if self.speculative_config is not None and self.speculative_config.method in [ + SpecMethod.MTP, + SpecMethod.SUFFIX, + ]: max_capture_shape = self.scheduler_config.max_num_seqs * ( self.speculative_config.num_speculative_tokens + 1 ) @@ -1738,7 +1853,7 @@ class FDConfig: max_capture_shape_prefill=max_capture_shape_prefill, dec_token_per_query_per_step=dec_token_per_query_per_step, ) - if self.speculative_config is not None and self.speculative_config.method in ["mtp", "suffix"]: + if self.speculative_config is not None and self.speculative_config.method is not None: real_bsz_to_captured_size = {} for capture_size in self.graph_opt_config.cudagraph_capture_sizes: dummy_batch_size = int(capture_size / (self.speculative_config.num_speculative_tokens + 1)) @@ -1941,7 +2056,7 @@ class FDConfig: ) # adjust speculative config - if self.speculative_config is not None and self.speculative_config.method == "mtp": + if self.speculative_config is not None and self.speculative_config.method == SpecMethod.MTP: if self.scheduler_config.splitwise_role == "prefill": self.speculative_config.num_speculative_tokens = 1 self.speculative_config.num_model_steps = 1 diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 06ed0fda2c..5a060a1cad 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -65,6 +65,7 @@ from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.plugins.token_processor import load_token_processor_plugins from fastdeploy.router.utils import check_service_health +from fastdeploy.spec_decode import SpecMethod from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.trace.constants import LoggingEventName @@ -575,7 +576,10 @@ class EngineService: req_out.metrics.decode_preallocate_req_time = cur_req.metrics.decode_preallocate_req_time cur_req.metrics = req_out.metrics cur_req.metrics.decode_inference_start_time = time.time() - if self.cfg.speculative_config.method in ["mtp"] and self.cfg.scheduler_config.splitwise_role == "decode": + if ( + self.cfg.speculative_config.method == SpecMethod.MTP + and self.cfg.scheduler_config.splitwise_role == "decode" + ): cur_req.draft_token_ids = copy.deepcopy(req_out.outputs.draft_token_ids) if req_out.error_code != 200: diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index dfcb4406c8..9205f1c05c 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -46,6 +46,7 @@ from fastdeploy.inter_communicator import IPCSignal from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.multimodal.hasher import MultimodalHasher from fastdeploy.platforms import current_platform +from fastdeploy.spec_decode import SpecMethod from fastdeploy.trace.constants import LoggingEventName from fastdeploy.trace.trace_logger import print as trace_print from fastdeploy.utils import download_from_bos, init_bos_client, llm_logger @@ -1359,7 +1360,7 @@ class ResourceManagerV1(ResourceManager): request.output_token_ids.append(request_output.outputs.token_ids[0]) request.num_cached_tokens = request_output.num_cached_tokens if ( - self.config.speculative_config.method in ["mtp"] + self.config.speculative_config.method == SpecMethod.MTP and self.config.scheduler_config.splitwise_role == "decode" ): request.draft_token_ids = copy.deepcopy(request_output.outputs.draft_token_ids) diff --git a/fastdeploy/metrics/metrics.py b/fastdeploy/metrics/metrics.py index 8c15e26318..028c4273a2 100644 --- a/fastdeploy/metrics/metrics.py +++ b/fastdeploy/metrics/metrics.py @@ -36,6 +36,7 @@ from fastdeploy.metrics.prometheus_multiprocess_setup import ( setup_multiprocess_prometheus, ) from fastdeploy.metrics.stats import ZMQMetricsStats +from fastdeploy.spec_decode import SpecMethod class SimpleCollector(Collector): @@ -668,7 +669,7 @@ class MetricsManager: "kwargs": {}, }, } - if speculative_method == "mtp": + if speculative_method == SpecMethod.MTP: self.SPECULATIVE_METRICS["spec_decode_efficiency"] = { "type": Gauge, "name": "fastdeploy:spec_decode_efficiency", diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index cbb2aea883..b9a15c9fef 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -159,7 +159,7 @@ class CudaGraphPiecewiseBackend: real_shape = ids_remove_padding.shape[0] if self.speculative_decoding and all(self.real_bsz_to_captured_size.values()): seq_lens_this_time: paddle.Tensor = kwargs["forward_meta"].seq_lens_this_time - num_running_requests = seq_lens_this_time.flatten().nonzero(as_tuple=False)[-1].item() + 1 + num_running_requests = int((seq_lens_this_time.flatten() > 0).sum().item()) real_shape = self.real_bsz_to_captured_size[num_running_requests] exist_prefill = kwargs["forward_meta"].exist_prefill # Static split graph mode: use Static + CUDAGraph for prefill/mixed phase diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index ddebd15d41..96c6905dfd 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -45,6 +45,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( ) from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id from fastdeploy.platforms import current_platform +from fastdeploy.spec_decode import SpecMethod @dataclass @@ -143,10 +144,10 @@ class AppendAttentionBackend(AttentionBackend): if fd_config.speculative_config.model_type != "main": self.rope_3d = False self.causal: bool = getattr(fd_config.model_config, "causal", True) - self.speculative_method: str = fd_config.speculative_config.method + self.speculative_method = fd_config.speculative_config.method self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" - self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) + self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) self.kv_num_heads: int = kv_num_heads self.num_heads: int = num_heads diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 751c37cac8..8fa2aa6cdb 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -68,6 +68,8 @@ else: merge_prefill_decode_output = None +from fastdeploy.spec_decode import SpecMethod + FLASH_ATTN_VERSION = None @@ -255,7 +257,7 @@ class FlashAttentionBackend(AttentionBackend): self.use_speculate = self.speculative_method is not None self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" - self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) + self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 6dbcf3eed6..472e75e14b 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.platforms import current_platform +from fastdeploy.spec_decode import SpecMethod if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output @@ -106,7 +107,7 @@ class FlashMaskAttentionBackend(AttentionBackend): self.use_speculate = self.speculative_method is not None self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" - self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) + self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 08ed2d8c06..77156872b8 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -60,6 +60,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionMetadata, ) from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id +from fastdeploy.spec_decode import SpecMethod @triton.jit() @@ -257,11 +258,11 @@ class MLAAttentionBackend(AttentionBackend): ) self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) self.causal: bool = getattr(fd_config.model_config, "causal", True) - self.speculative_method: str = fd_config.speculative_config.method + self.speculative_method = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" - self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) + self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) self.num_heads: int = num_heads self.head_dim: int = fd_config.model_config.head_dim diff --git a/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py index cd6cd5eeec..82938c8736 100644 --- a/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py @@ -221,7 +221,7 @@ class HPUAttentionBackend(AttentionBackend_HPU): self.rope_theta = 10000.0 if llm_config.model_config.rope_theta is None else llm_config.model_config.rope_theta self.rope_3d = getattr(llm_config.model_config, "rope_3d", False) self.causal = getattr(llm_config.model_config, "causal", True) - self.speculative_method: str = llm_config.speculative_config.method + self.speculative_method = llm_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None self.speculate_max_draft_token_num: int = llm_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = llm_config.speculative_config.model_type == "mtp" diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py index b7a5c48a42..74fc27f67b 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py @@ -34,6 +34,7 @@ from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_i from fastdeploy.model_executor.ops.gpu import cache_kv_with_rope from fastdeploy.model_executor.ops.gpu import merge_qkv as merge_qkv_cu from fastdeploy.model_executor.ops.gpu import split_qkv as split_qkv_cu +from fastdeploy.spec_decode import SpecMethod @dataclass @@ -102,11 +103,11 @@ class FlashAttentionBackend(AttentionBackend): ) self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) self.causal: bool = getattr(fd_config.model_config, "causal", True) - self.speculative_method: str = fd_config.speculative_config.method + self.speculative_method = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" - self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) + self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) self.encoder_block_shape_q: int = encoder_block_shape_q self.decoder_block_shape_q: int = decoder_block_shape_q diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py index 2f2be8cb0b..c905086f9f 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py @@ -28,6 +28,7 @@ from fastdeploy.model_executor.ops.gpu import ( get_block_shape_and_split_kv_block, prefill_mla_write_cache, ) +from fastdeploy.spec_decode import SpecMethod if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -106,11 +107,11 @@ class MetaxMLAAttentionBackend(AttentionBackend): ) self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) self.causal: bool = getattr(fd_config.model_config, "causal", True) - self.speculative_method: str = fd_config.speculative_config.method + self.speculative_method = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" - self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) + self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) self.kv_num_heads: int = kv_num_heads self.num_heads: int = num_heads diff --git a/fastdeploy/model_executor/layers/backends/xpu/attention.py b/fastdeploy/model_executor/layers/backends/xpu/attention.py index f4868a5d59..2223cad6df 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/attention.py +++ b/fastdeploy/model_executor/layers/backends/xpu/attention.py @@ -38,6 +38,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionMetadata, ) from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id +from fastdeploy.spec_decode import SpecMethod @dataclass @@ -92,7 +93,7 @@ class XPUAttentionBackend(AttentionBackend): ) self.causal: bool = getattr(fd_config.model_config, "causal", True) self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" - self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) + self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) self.kv_num_heads: int = kv_num_heads self.num_heads: int = num_heads diff --git a/fastdeploy/model_executor/layers/sample/logprobs.py b/fastdeploy/model_executor/layers/sample/logprobs.py index affaf10346..42eb86d5d8 100644 --- a/fastdeploy/model_executor/layers/sample/logprobs.py +++ b/fastdeploy/model_executor/layers/sample/logprobs.py @@ -14,11 +14,15 @@ # limitations under the License. """ +from typing import Callable, List, Optional, Tuple + import paddle +import paddle.nn.functional as F import triton import triton.language as tl from fastdeploy.platforms import current_platform +from fastdeploy.worker.output import LogprobsTensors @triton.jit @@ -80,3 +84,134 @@ def batched_count_greater_than(x: paddle.Tensor, y: paddle.Tensor) -> paddle.Ten out = (x >= y).sum(-1) return out + + +def gather_logprobs( + logprobs: paddle.Tensor, + num_logprobs: int, + token_ids: paddle.Tensor, +) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + + Args: + logprobs: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to retain per token + token_ids: prompt tokens (if prompt logprobs) or sampled tokens + (if sampled logprobs); 1D token ID tensor with (num tokens) elements. + Must be int64. + + Returns: + LogprobsTensors with top-k indices, top-k logprobs, and token ranks. + """ + assert token_ids.dtype == paddle.int64 + token_ids = token_ids.unsqueeze(1) + logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) + token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) + + token_ranks = batched_count_greater_than(logprobs, token_logprobs) + + if num_logprobs >= 1: + topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1) + indices = paddle.concat([token_ids, topk_indices], axis=1) + top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1) + else: + indices = token_ids + top_logprobs = token_logprobs + + return LogprobsTensors(indices, top_logprobs, token_ranks) + + +def build_output_logprobs( + logits: paddle.Tensor, + sampling_metadata, + share_inputs: List[paddle.Tensor], + is_naive: bool = False, + logprobs_mode: str = "default", + compute_logprobs_fn: Optional[Callable] = None, +) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]: + """ + Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes. + + This is a standalone function (not tied to any sampler) so that both + naive and speculative decoding paths can share the same logprob logic. + + For NAIVE mode: logits are already per-token, no extraction needed. + For speculative mode: extracts target logits for accepted token positions. + + Args: + logits: Model output logits. + sampling_metadata: Sampling parameters and metadata. + share_inputs: Shared input tensors. + is_naive: True for NAIVE mode (single token per request). + logprobs_mode: One of "raw_logprobs", "raw_logits", or "default". + compute_logprobs_fn: Callable for computing logprobs with temperature + scaling and top_p normalization. Used when logprobs_mode == "raw_logprobs". + + Returns: + tuple: (logprobs_tensors, cu_batch_token_offset) + """ + num_logprobs = sampling_metadata.max_num_logprobs + logprobs_tensors = None + cu_batch_token_offset = None + + if num_logprobs is None: + return logprobs_tensors, cu_batch_token_offset + + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + + if is_naive: + # NAIVE mode: one token per request, logits are already correct + output_logits = logits + token_ids = share_inputs["accept_tokens"][:real_bsz, 0] + else: + # Speculative mode: extract target logits for accepted positions + from fastdeploy.model_executor.layers.sample.ops import ( + speculate_get_target_logits, + ) + + batch_token_num = paddle.where( + share_inputs["seq_lens_encoder"][:real_bsz] != 0, + paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]), + share_inputs["seq_lens_this_time"], + ).flatten() + + share_inputs["batch_token_num"] = batch_token_num + + ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( + "int32" + ) + cu_batch_token_offset = paddle.concat( + [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])] + ).astype("int32") + share_inputs["cu_batch_token_offset"] = cu_batch_token_offset + + output_logits = paddle.empty( + [share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], + dtype=logits.dtype, + ) + speculate_get_target_logits( + output_logits, + logits, + cu_batch_token_offset, + ori_cu_batch_token_offset, + share_inputs["seq_lens_this_time"], + share_inputs["seq_lens_encoder"], + share_inputs["accept_num"], + ) + + idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32") + mask = idx < share_inputs["accept_num"].unsqueeze(1) + token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask) + + # Compute logprobs with temperature scaling and top_p normalization + if logprobs_mode == "raw_logprobs": + raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata) + elif logprobs_mode == "raw_logits": + raw_logprobs = output_logits.clone() + else: + raw_logprobs = F.log_softmax(output_logits, axis=-1) + + logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) + + return logprobs_tensors, cu_batch_token_offset diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index b4452598b4..772917e914 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -32,19 +32,22 @@ from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase from fastdeploy.model_executor.layers.sample.early_stopper import ( get_early_stopper_cls_from_stragegy, ) -from fastdeploy.model_executor.layers.sample.logprobs import batched_count_greater_than +from fastdeploy.model_executor.layers.sample.logprobs import ( + batched_count_greater_than, + build_output_logprobs, +) from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.ops import ( apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, min_p_sampling, reasoning_phase_token_constraint, - speculate_get_target_logits, speculate_insert_first_token, top_k_top_p_sampling, ) from fastdeploy.platforms import current_platform from fastdeploy.reasoning import ReasoningParser +from fastdeploy.spec_decode import SpecMethod, VerifyStrategy from fastdeploy.worker.output import LogprobsTensors, SamplerOutput @@ -638,6 +641,18 @@ class SpeculativeSampler(nn.Layer): self.line_break_id = fd_config.model_config.line_break_id self.enf_gen_phase_tag = fd_config.speculative_config.enf_gen_phase_tag + # Verify strategy derived from config (replaces env vars in CUDA kernel) + spec_config = fd_config.speculative_config + # Verify strategy enum: VerifyStrategy.TOPP/GREEDY/TARGET_MATCH + # Use .value (0/1/2) when passing to CUDA kernel + self.spec_method = spec_config.method + self.verify_strategy = spec_config.verify_strategy + self.prefill_one_step_stop = fd_config.parallel_config.prefill_one_step_stop + + # Accept policy from config (can be overridden by function parameters) + self.config_accept_all = spec_config.accept_policy == "accept_all" + self.config_reject_all = spec_config.accept_policy == "reject_all" + def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" pass @@ -750,6 +765,157 @@ class SpeculativeSampler(nn.Layer): return LogprobsTensors(indices, top_logprobs, token_ranks) + def _verify_and_sample( + self, + logits: paddle.Tensor, + probs: paddle.Tensor, + sampling_metadata: SamplingMetadata, + max_model_len: int, + share_inputs: List[paddle.Tensor], + accept_all_drafts: bool = False, + reject_all_drafts: bool = False, + ) -> SamplerOutput: + """ + Verify draft tokens against target model output and produce final samples. + + This is the core speculative decoding logic that compares draft tokens + with target model predictions to determine acceptance/rejection. + + Args: + logits: Target model raw logits + probs: Target model softmax output + sampling_metadata: Sampling parameters and metadata + max_model_len: Maximum model sequence length + share_inputs: Shared input tensors including draft_tokens, accept_tokens, etc. + accept_all_drafts: Force accept all draft tokens (debug mode) + reject_all_drafts: Force reject all draft tokens (debug mode) + + Returns: + SamplerOutput with accepted tokens and metadata + """ + from fastdeploy.model_executor.ops.gpu import ( + top_p_candidates, + verify_draft_tokens, + ) + + # Prepare strategy-specific tensors + # TARGET_MATCH: needs target_tokens=sampled, candidates=None + # GREEDY: needs target_tokens=argmax, candidates=None + # TOPP: needs target_tokens=None, candidates=full top_p set + target_tokens, candidate_ids, candidate_scores, candidate_lens = None, None, None, None + + if self.verify_strategy == VerifyStrategy.TARGET_MATCH: + # Only TARGET_MATCH needs stochastic sampling + top_p, top_k, topp_seed = padding_sampling_params( + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.seed, + share_inputs["seq_lens_this_time"], + share_inputs["seq_lens_encoder"], + ) + _, target_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed) + elif self.verify_strategy == VerifyStrategy.GREEDY: + # GREEDY: deterministic argmax in target_tokens, no candidates needed + target_tokens = paddle.argmax(probs, axis=-1) + elif self.verify_strategy == VerifyStrategy.TOPP: # TOPP + # TOPP: needs full candidate set, target_tokens unused + candidate_scores, candidate_ids, candidate_lens = top_p_candidates( + probs, + sampling_metadata.top_p, + share_inputs["batch_id_per_token_output"], + self.speculative_max_candidate_len, + max_model_len, + ) + else: + raise ValueError(f"Unknown verify strategy: {self.verify_strategy}") + + # Accept policy: config default OR function parameter (OR logic) + final_accept_all = self.config_accept_all or accept_all_drafts + final_reject_all = self.config_reject_all or reject_all_drafts or self.speculative_benchmark_mode + + verify_draft_tokens( + # Core I/O + share_inputs["accept_tokens"], # step_output_ids + share_inputs["accept_num"], # step_output_len + share_inputs["draft_tokens"], # step_input_ids + # Target model outputs + target_tokens, + # Candidate set (strategy-dependent usage) + candidate_ids, + candidate_scores, + candidate_lens, + # Sampling params + sampling_metadata.top_p, + # Metadata + share_inputs["stop_flags"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_this_time"], + sampling_metadata.eos_token_ids, + share_inputs["is_block_step"], + share_inputs["cu_seqlens_q_output"], + share_inputs["reasoning_status"], + # max_dec_len / step_idx for EOS/max-len detection, only read + share_inputs["max_dec_len"], + share_inputs["step_idx"], + # Config + max_model_len, + self.speculative_verify_window, + self.verify_strategy.value, + final_reject_all, + final_accept_all, + ) + + return SamplerOutput( + sampled_token_ids=share_inputs["accept_tokens"], + logprobs_tensors=None, + token_num_per_batch=share_inputs["accept_num"], + logits=logits, + ) + + def _normal_sample( + self, + logits: paddle.Tensor, + probs: paddle.Tensor, + sampling_metadata: SamplingMetadata, + share_inputs: List[paddle.Tensor], + ) -> SamplerOutput: + """ + Normal sampling without draft token verification. + + Used by NAIVE mode: directly samples from target model output + and writes results to share_inputs["accept_tokens"]/["accept_num"]. + + Args: + probs: Target model softmax output + logits: Target model output logits + sampling_metadata: Sampling parameters and metadata + share_inputs: Shared input tensors + + Returns: + SamplerOutput with sampled tokens (no logprobs; logprobs are computed in forward_cuda) + """ + # Apply min_p sampling if configured + probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list) + + # Sample tokens + _, next_tokens = top_k_top_p_sampling( + probs, + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.top_k_list, + topp_seed=sampling_metadata.seed, + ) + + # For NAIVE mode: write directly to accept_tokens/accept_num + share_inputs["accept_tokens"][: next_tokens.shape[0], 0] = next_tokens.squeeze(-1) + + return SamplerOutput( + sampled_token_ids=share_inputs["accept_tokens"], + logprobs_tensors=None, + token_num_per_batch=share_inputs["accept_num"], + logits=logits, + ) + def forward_cuda( self, logits: paddle.Tensor, @@ -758,10 +924,26 @@ class SpeculativeSampler(nn.Layer): share_inputs: List[paddle.Tensor], accept_all_drafts: bool = False, reject_all_drafts: bool = False, - ) -> paddle.Tensor: - """ """ + ) -> SamplerOutput: + """ + Forward pass for speculative sampling. - from fastdeploy.model_executor.ops.gpu import speculate_verify, top_p_candidates + Routes between: + - NAIVE mode: Normal sampling without draft verification + - MTP/Ngram mode: Draft token verification + sampling + + Args: + logits: Target model output logits + sampling_metadata: Sampling parameters and metadata + max_model_len: Maximum model sequence length + share_inputs: Shared input tensors + accept_all_drafts: Force accept all draft tokens (debug mode) + reject_all_drafts: Force reject all draft tokens (debug mode) + + Returns: + SamplerOutput with sampled/accepted tokens + """ + # Apply speculative penalty scores (shared path) if sampling_metadata.token_ids_all is not None: token_ids_all = sampling_metadata.token_ids_all @@ -788,11 +970,11 @@ class SpeculativeSampler(nn.Layer): max_model_len, ) + # Apply reasoning phase constraint if enabled if self.enf_gen_phase_tag: reasoning_phase_token_constraint( logits, - token_ids_all, - prompt_lens, + sampling_metadata.pre_token_ids, share_inputs["stop_flags"], share_inputs["seq_lens_this_time"], share_inputs["seq_lens_encoder"], @@ -808,102 +990,34 @@ class SpeculativeSampler(nn.Layer): probs = F.softmax(logits) - top_p, top_k, topp_seed = padding_sampling_params( - sampling_metadata.top_p, - sampling_metadata.top_k, - sampling_metadata.seed, - share_inputs["seq_lens_this_time"], - share_inputs["seq_lens_encoder"], - ) - _, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed) - - verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( - probs, - sampling_metadata.top_p, - share_inputs["batch_id_per_token_output"], - self.speculative_max_candidate_len, - max_model_len, - ) - - speculate_verify( - sampled_token_ids, - share_inputs["accept_tokens"], - share_inputs["accept_num"], - share_inputs["step_idx"], - share_inputs["stop_flags"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs[ - "draft_tokens" - ], # Both input and output, need to write the last 1 token accepted to position 0. - share_inputs["seq_lens_this_time"], - verify_tokens, - verify_scores, - share_inputs["max_dec_len"], - sampling_metadata.eos_token_ids, - share_inputs["is_block_step"], - share_inputs["cu_seqlens_q_output"], - actual_candidate_len, - share_inputs["actual_draft_token_num"], - sampling_metadata.top_p, - share_inputs["reasoning_status"], - max_model_len, - self.speculative_verify_window, - True, # enable_topp - (self.speculative_benchmark_mode or reject_all_drafts), - accept_all_drafts, - ) - - num_logprobs = sampling_metadata.max_num_logprobs - batch_token_num = None - if num_logprobs is not None: - real_bsz = share_inputs["seq_lens_this_time"].shape[0] - batch_token_num = paddle.where( - share_inputs["seq_lens_encoder"][:real_bsz] != 0, - paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]), - share_inputs["seq_lens_this_time"], - ).flatten() - share_inputs["batch_token_num"] = batch_token_num - ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( - "int32" - ) - cu_batch_token_offset = paddle.concat( - [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])] - ).astype("int32") - share_inputs["cu_batch_token_offset"] = cu_batch_token_offset - target_logits = paddle.empty( - [share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype - ) - speculate_get_target_logits( - target_logits, + # Route based on spec_method + is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE + if is_naive: + sampler_output = self._normal_sample(logits, probs, sampling_metadata, share_inputs) + else: + sampler_output = self._verify_and_sample( logits, - cu_batch_token_offset, - ori_cu_batch_token_offset, - share_inputs["seq_lens_this_time"], - share_inputs["seq_lens_encoder"], - share_inputs["accept_num"], + probs, + sampling_metadata, + max_model_len, + share_inputs, + accept_all_drafts, + reject_all_drafts, ) - if self.logprobs_mode == "raw_logprobs": - raw_logprobs = self.compute_logprobs(target_logits, sampling_metadata) - elif self.logprobs_mode == "raw_logits": - raw_logprobs = target_logits.clone() - - logprobs_tensors = None - if num_logprobs is not None: - token_ids = share_inputs["accept_tokens"] - idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32") - mask = idx < share_inputs["accept_num"].unsqueeze(1) - token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask) - logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) - - sampler_output = SamplerOutput( - sampled_token_ids=share_inputs["accept_tokens"], - logprobs_tensors=logprobs_tensors, - token_num_per_batch=share_inputs["accept_num"], - cu_batch_token_offset=share_inputs["cu_batch_token_offset"], - logits=logits, - ) + # Build logprobs via unified path (outside of sampling logic) + if sampling_metadata.max_num_logprobs is not None: + logprobs_tensors, cu_batch_token_offset = build_output_logprobs( + logits, + sampling_metadata, + share_inputs, + is_naive=is_naive, + logprobs_mode=self.logprobs_mode, + compute_logprobs_fn=self.compute_logprobs, + ) + sampler_output.logprobs_tensors = logprobs_tensors + if cu_batch_token_offset is not None: + sampler_output.cu_batch_token_offset = cu_batch_token_offset return sampler_output def forward_xpu( diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 2056db648a..3bc788bbdb 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -66,14 +66,13 @@ elif current_platform.is_maca(): speculate_save_output, speculate_save_output_topk, speculate_set_stop_value_multi_seqs, - speculate_set_value_by_flags_and_idx, speculate_step_paddle, speculate_step_reschedule, speculate_step_system_cache, - speculate_update, step_paddle, step_reschedule, step_system_cache, + unified_update_model_status, update_inputs, update_inputs_v1, ) @@ -88,11 +87,10 @@ else: speculate_pre_process, speculate_save_output, speculate_save_output_topk, - speculate_set_value_by_flags_and_idx, speculate_step_paddle, speculate_step_system_cache, - speculate_update, speculate_set_stop_value_multi_seqs, + unified_update_model_status, step_paddle, step_system_cache, update_inputs, @@ -425,6 +423,8 @@ def post_process_specualate( think_end_id: int = -1, splitwise_role_is_decode: bool = False, enable_entropy: bool = False, + is_naive_mode: bool = False, + prefill_one_step_stop: bool = False, ): if think_end_id > 0: speculate_limit_thinking_content_length( @@ -457,18 +457,30 @@ def post_process_specualate( if enable_entropy: speculate_calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature) - speculate_update( - model_output.seq_lens_encoder, - model_output.seq_lens_decoder, - model_output.not_need_stop, - model_output.draft_tokens, - model_output.actual_draft_token_num, - model_output.accept_tokens, - model_output.accept_num, - model_output.stop_flags, - model_output.seq_lens_this_time, - model_output.is_block_step, - model_output.mask_rollback, + # Unified state update: merges speculate_update + speculate_set_value_by_flags_and_idx + # into a single kernel launch. For MTP/ngram paths, verify_draft_tokens has already + # handled EOS/max_dec_len detection (replacing tokens + updating step_idx), so + # unified_update_model_status acts as a no-op for those checks. For naive mode + # (which skips verify), this kernel handles EOS/max_dec_len detection. + unified_update_model_status( + model_output.seq_lens_encoder, # seq_lens_encoder + model_output.seq_lens_decoder, # seq_lens_decoder + model_output.not_need_stop, # has_running_seqs + model_output.draft_tokens, # step_input_ids + model_output.actual_draft_token_num, # adaptive_step_input_len + model_output.accept_tokens, # step_output_ids (read-write) + model_output.accept_num, # step_output_len (read-write) + model_output.stop_flags, # stop_flags (read-write) + model_output.seq_lens_this_time, # seq_lens_this_time + model_output.is_block_step, # is_paused + model_output.mask_rollback, # mask_rollback + model_output.token_ids_all, # token_ids_all + model_output.prompt_lens, # prompt_lens + model_output.step_idx, # step_idx (read-write) + model_output.eos_token_id, # end_tokens + model_output.max_dec_len, # max_dec_len + is_naive_mode, # is_naive_mode + prefill_one_step_stop, # prefill_one_step_stop ) if not skip_save_output: @@ -522,20 +534,6 @@ def post_process_specualate( save_each_rank, ) - # Update token_ids_all through accept tokens - - speculate_set_value_by_flags_and_idx( - model_output.token_ids_all, - model_output.prompt_lens, - model_output.accept_tokens, - model_output.accept_num, - model_output.stop_flags, - model_output.seq_lens_this_time, - model_output.seq_lens_encoder, - model_output.seq_lens_decoder, - model_output.step_idx, - ) - def post_process( sampler_or_pooler_output: Union[SamplerOutput, PoolerOutput], @@ -550,6 +548,8 @@ def post_process( think_end_id: int = -1, splitwise_role_is_decode: bool = False, enable_entropy: bool = False, + is_naive_mode: bool = False, + prefill_one_step_stop: bool = False, ) -> None: """Post-processing steps after completing a single token generation.""" @@ -575,6 +575,8 @@ def post_process( think_end_id, splitwise_role_is_decode, enable_entropy, + is_naive_mode, + prefill_one_step_stop, ) else: post_process_normal( diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 89b767f9ed..f02cd67b4a 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -42,6 +42,7 @@ from fastdeploy.engine.request import ( from fastdeploy.inter_communicator import ZmqIpcServer from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.platforms import current_platform +from fastdeploy.spec_decode import SpecMethod from fastdeploy.trace.constants import LoggingEventName from fastdeploy.trace.trace_logger import print as trace_print from fastdeploy.utils import llm_logger, spec_logger @@ -584,7 +585,7 @@ class TokenProcessor: f" average accept len: {self.number_of_output_tokens / self.total_step}" ) - if self.cfg.speculative_config.method in ["mtp"]: + if self.cfg.speculative_config.method == SpecMethod.MTP: single_head_acceptance_rates = [] for i in range(1, self.cfg.speculative_config.num_speculative_tokens + 1): if self.accept_token_num_per_head[i - 1] != 0: @@ -1028,12 +1029,12 @@ class TokenProcessor: main_process_metrics.spec_decode_num_accepted_tokens_total.set(self.num_accepted_tokens) main_process_metrics.spec_decode_num_emitted_tokens_total.set(self.num_emitted_tokens) - if self.cfg.speculative_config.method in ["ngram"]: + if self.cfg.speculative_config.method == SpecMethod.NGRAM: main_process_metrics.spec_decode_draft_acceptance_rate.set( self.num_accepted_tokens / self.num_emitted_tokens ) - if self.cfg.speculative_config.method in ["mtp"]: + if self.cfg.speculative_config.method == SpecMethod.MTP: num_draft_tokens = len(real_accept_num) * self.cfg.speculative_config.num_speculative_tokens self.num_draft_tokens += num_draft_tokens diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index c62a514da4..cfc0939c9a 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -69,6 +69,7 @@ from fastdeploy.model_executor.utils import ( process_final_after_loading, ) from fastdeploy.rl.rollout_config import RolloutModelConfig +from fastdeploy.spec_decode import SpecMethod class RolloutModel(nn.Layer): @@ -707,7 +708,7 @@ class Glm4MoeForCausalLMRL(Glm4MoeForCausalLM, BaseRLModel): self.speculative_decoding = fd_config.speculative_config.method is not None self.speculative_method = fd_config.speculative_config.method - if self.speculative_decoding and self.speculative_method == "mtp": + if self.speculative_decoding and self.speculative_method == SpecMethod.MTP: fd_config.parallel_config.tp_group = None fd_config.parallel_config.ep_group = None self.mtp_fd_config = copy.deepcopy(fd_config) @@ -745,7 +746,7 @@ class Glm4MoeForCausalLMRL(Glm4MoeForCausalLM, BaseRLModel): """state_dict""" main_state_dict = super().state_dict() state_dict = {k: v for k, v in main_state_dict.items() if not k.startswith("mtp_layers")} - if self.speculative_decoding and self.speculative_method == "mtp": + if self.speculative_decoding and self.speculative_method == SpecMethod.MTP: mtp_state_dict = self.mtp_layers.state_dict() state_dict.update(mtp_state_dict) return state_dict @@ -805,7 +806,7 @@ class Glm4MoeForCausalLMRL(Glm4MoeForCausalLM, BaseRLModel): self._complete_missing_mappings() # extra for mtp - if self.speculative_decoding and self.speculative_method == "mtp": + if self.speculative_decoding and self.speculative_method == SpecMethod.MTP: mtp_infer_to_train_mapping = self.mtp_layers.get_name_mappings_to_training(trainer_degree) self.infer_to_train_mapping.update(mtp_infer_to_train_mapping) diff --git a/fastdeploy/spec_decode/__init__.py b/fastdeploy/spec_decode/__init__.py index 0b675b826a..456ca92a03 100644 --- a/fastdeploy/spec_decode/__init__.py +++ b/fastdeploy/spec_decode/__init__.py @@ -14,26 +14,8 @@ """ speculative decoding module """ -from fastdeploy.platforms import current_platform from .base import Proposer -from .mtp import MTPProposer +from .types import SpecMethod, VerifyStrategy -# XPU is not support ngram proposer now -if not current_platform.is_xpu(): - from .ngram import NgramProposer -__all__ = ["Proposer", "MTPProposer", "NgramProposer"] - -# Suffix proposer requires arctic_inference -try: - from .suffix import SuffixProposer - - _suffix_proposer_available = True -except ImportError: - _suffix_proposer_available = False - SuffixProposer = None - -if _suffix_proposer_available: - __all__ = ["Proposer", "MTPProposer", "NgramProposer", "SuffixProposer"] -else: - __all__ = ["Proposer", "MTPProposer", "NgramProposer"] +__all__ = ["Proposer", "SpecMethod", "VerifyStrategy"] diff --git a/fastdeploy/spec_decode/base.py b/fastdeploy/spec_decode/base.py index 6499b35827..fa50eae462 100644 --- a/fastdeploy/spec_decode/base.py +++ b/fastdeploy/spec_decode/base.py @@ -16,14 +16,16 @@ from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any +from typing import TYPE_CHECKING, Any import paddle.distributed as dist from fastdeploy import envs -from fastdeploy.config import FDConfig from fastdeploy.utils import spec_logger +if TYPE_CHECKING: + from fastdeploy.config import FDConfig + class Proposer(ABC): """ @@ -33,7 +35,7 @@ class Proposer(ABC): the speculative decoding framework """ - def __init__(self, fd_config: FDConfig): + def __init__(self, fd_config: "FDConfig"): """ Init Speculative proposer """ diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 7cc0c621ba..5461ac2833 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -16,14 +16,13 @@ import os import time -from typing import List +from typing import TYPE_CHECKING, List import numpy as np import paddle from paddleformers.utils.log import logger from fastdeploy import envs -from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request, RequestType from fastdeploy.inter_communicator import IPCSignal from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -81,6 +80,9 @@ from fastdeploy.worker.input_batch import ( from .base import Proposer +if TYPE_CHECKING: + from fastdeploy.config import FDConfig + class MTPProposer(Proposer): """ @@ -89,7 +91,7 @@ class MTPProposer(Proposer): def __init__( self, - fd_config: FDConfig, + fd_config: "FDConfig", main_model: ModelForCasualLM, local_rank: int, device_id: int, # physical device id @@ -724,7 +726,7 @@ class MTPProposer(Proposer): self.target_model_inputs["is_block_step"], self.target_model_inputs["draft_tokens"], self.num_model_steps, - self.speculative_method in ["eagle", "mtp"], + True, self.role == "prefill", use_v1_cache_scheduler, ) diff --git a/fastdeploy/spec_decode/ngram.py b/fastdeploy/spec_decode/ngram.py index 1a766da14e..b64e8fb579 100644 --- a/fastdeploy/spec_decode/ngram.py +++ b/fastdeploy/spec_decode/ngram.py @@ -14,13 +14,17 @@ # limitations under the License. """ +from typing import TYPE_CHECKING + import paddle -from fastdeploy.config import FDConfig from fastdeploy.model_executor.ops.gpu import ngram_match from .base import Proposer +if TYPE_CHECKING: + from fastdeploy.config import FDConfig + class NgramProposer(Proposer): """ @@ -29,7 +33,7 @@ class NgramProposer(Proposer): Matching corresponding tokens in input and output as draft tokens. """ - def __init__(self, fd_config: FDConfig): + def __init__(self, fd_config: "FDConfig"): super().__init__(fd_config) self.max_ngram_size = self.speculative_config.max_ngram_size self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() diff --git a/fastdeploy/spec_decode/suffix.py b/fastdeploy/spec_decode/suffix.py index 3f2f4586d8..f4d1495524 100644 --- a/fastdeploy/spec_decode/suffix.py +++ b/fastdeploy/spec_decode/suffix.py @@ -14,13 +14,17 @@ # limitations under the License. """ +from typing import TYPE_CHECKING + import numpy as np -from fastdeploy.config import FDConfig from fastdeploy.utils import spec_logger from .base import Proposer +if TYPE_CHECKING: + from fastdeploy.config import FDConfig + try: from arctic_inference.suffix_decoding import SuffixDecodingCache except ImportError: @@ -34,7 +38,7 @@ class SuffixProposer(Proposer): Uses SuffixDecodingCache to generate draft tokens based on suffix tree matching. """ - def __init__(self, fd_config: FDConfig): + def __init__(self, fd_config: "FDConfig"): super().__init__(fd_config) if SuffixDecodingCache is None: diff --git a/fastdeploy/spec_decode/types.py b/fastdeploy/spec_decode/types.py new file mode 100644 index 0000000000..3473d810bb --- /dev/null +++ b/fastdeploy/spec_decode/types.py @@ -0,0 +1,151 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from enum import Enum +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from fastdeploy.spec_decode.base import Proposer + + +class VerifyStrategy(int, Enum): + """Draft token verification strategy enum. + + Used in verify_draft_tokens kernel to control how draft tokens are verified + and how bonus/correction tokens are sampled. + + Values match the kernel's internal constants: + 0 = TOPP: draft in top-p candidate set, stochastic sampling for bonus + 1 = GREEDY: draft == argmax, deterministic argmax for bonus + 2 = TARGET_MATCH: draft == target sampled token, use target sample + """ + + TOPP = 0 + GREEDY = 1 + TARGET_MATCH = 2 + + @classmethod + def from_string(cls, value: str) -> "VerifyStrategy": + """Create VerifyStrategy from string with validation (case-insensitive). + + Args: + value: Strategy name (e.g., "topp", "GREEDY", "Target_Match") + + Returns: + VerifyStrategy enum value + + Raises: + ValueError: If the strategy name is not recognized + TypeError: If value is not a string + """ + if not isinstance(value, str): + raise TypeError( + f"Expected string input for VerifyStrategy.from_string(), " + f"but got {type(value).__name__}: {value}. " + f"If you have an int value, use VerifyStrategy(value) directly." + ) + try: + return cls[value.upper()] + except KeyError: + valid_names = [s.name for s in cls] + raise ValueError( + f"Invalid verify strategy '{value}'. " f"Must be one of: {valid_names} (case-insensitive)" + ) + + +class SpecMethod(str, Enum): + """Speculative decoding method enum. + + Value is the config string passed via --speculative-config '{"method": "mtp"}'. + """ + + NAIVE = "naive" + MTP = "mtp" + NGRAM = "ngram" + SUFFIX = "suffix" + + def create_proposer(self, fd_config, **kwargs) -> Optional["Proposer"]: + """Factory method: create the appropriate Proposer for this method. + + Args: + fd_config: FDConfig instance. + **kwargs: Method-specific args forwarded to the Proposer constructor. + MTP requires: main_model, local_rank, device_id, share_inputs. + + Returns: + Proposer instance, or None for NAIVE. + """ + if self == SpecMethod.NAIVE: + return None + elif self == SpecMethod.MTP: + from fastdeploy.spec_decode.mtp import MTPProposer + + return MTPProposer( + fd_config, + kwargs["main_model"], + kwargs["local_rank"], + kwargs["device_id"], + kwargs["share_inputs"], + ) + elif self == SpecMethod.NGRAM: + from fastdeploy.spec_decode.ngram import NgramProposer + + return NgramProposer(fd_config) + elif self == SpecMethod.SUFFIX: + from fastdeploy.spec_decode.suffix import SuffixProposer + + return SuffixProposer(fd_config) + + @property + def needs_proposer(self) -> bool: + """Whether this method requires a proposer model.""" + return self != SpecMethod.NAIVE + + @property + def needs_kv_cache(self) -> bool: + """Whether the proposer needs its own KV cache layer.""" + return self == SpecMethod.MTP + + @classmethod + def from_string(cls, value: str) -> "SpecMethod": + """Create SpecMethod from string with validation (case-insensitive). + + Args: + value: Method name (e.g., "mtp", "NGRAM", "Naive") + + Returns: + SpecMethod enum value + + Raises: + ValueError: If the method name is not recognized + TypeError: If value is not a string + """ + if not isinstance(value, str): + raise TypeError( + f"Expected string input for SpecMethod.from_string(), " + f"but got {type(value).__name__}: {value}. " + f"If you have an enum value, use SpecMethod(value) directly." + ) + # Backward-compatible aliases + ALIASES = {"ngram_match": "ngram"} + normalized = ALIASES.get(value.lower(), value.lower()) + try: + return cls(normalized) + except ValueError: + valid_names = [m.value for m in cls] + raise ValueError( + f"Invalid speculative method '{value}'. " f"Must be one of: {valid_names} (case-insensitive)" + ) diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 20e4f3aece..42941ab77f 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -45,6 +45,7 @@ from fastdeploy.model_executor.pre_and_post_process import ( pre_process, rebuild_padding, ) +from fastdeploy.spec_decode import SpecMethod from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput @@ -118,9 +119,9 @@ class GCUModelRunner(ModelRunnerBase): """ Init speculative proposer """ - if self.speculative_method == "ngram": + if self.speculative_method == SpecMethod.NGRAM: raise NotImplementedError("NgramProposer is not support by GCUModelRunner.") - elif self.speculative_method == "mtp": + elif self.speculative_method == SpecMethod.MTP: raise NotImplementedError("MTPProposer is not support by GCUModelRunner.") else: self.proposer = None @@ -290,7 +291,7 @@ class GCUModelRunner(ModelRunnerBase): self.share_inputs["not_need_stop"][0] = True - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.insert_prefill_inputs(req_dicts) self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer @@ -740,7 +741,7 @@ class GCUModelRunner(ModelRunnerBase): batch_size=batch_size, expected_decode_len=expected_decode_len, ) - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.dummy_prefill_inputs( num_tokens=num_tokens, batch_size=batch_size, @@ -840,7 +841,7 @@ class GCUModelRunner(ModelRunnerBase): ) if self.speculative_decoding: - if self.speculative_method == "mtp": + if self.speculative_method == SpecMethod.MTP: self.proposer.run(full_hidden_states=model_output) else: self.proposer.run(share_inputs=self.share_inputs) @@ -1061,7 +1062,7 @@ class GCUModelRunner(ModelRunnerBase): accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), ) - if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": + if self.speculative_config.method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill": skip_save_output = True else: skip_save_output = False @@ -1077,7 +1078,7 @@ class GCUModelRunner(ModelRunnerBase): # 6. Speculative decode if self.speculative_decoding: - if self.speculative_method == "mtp": + if self.speculative_method == SpecMethod.MTP: self.proposer.run(full_hidden_states=model_output) else: self.proposer.run(share_inputs=self.share_inputs) @@ -1120,7 +1121,7 @@ class GCUModelRunner(ModelRunnerBase): # 3. gc self.clear_cache() - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.clear_dummy_input() # paddle.device.cuda.synchronize() @@ -1151,7 +1152,7 @@ class GCUModelRunner(ModelRunnerBase): } ) - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.update_block_num(num_gpu_blocks) def cal_theortical_kvcache(self): @@ -1180,7 +1181,7 @@ class GCUModelRunner(ModelRunnerBase): hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads num_layers = ( self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio - if self.speculative_method in ["mtp"] + if self.speculative_method == SpecMethod.MTP else self.model_config.num_hidden_layers ) required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 98f3ed827f..75cfae8ab6 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -53,6 +53,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.platforms import current_platform +from fastdeploy.spec_decode import SpecMethod from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode if current_platform.is_iluvatar(): @@ -78,17 +79,6 @@ else: unset_data_ipc, ) -from fastdeploy.model_executor.pre_and_post_process import ( - async_set_value, - post_process, - pre_process, - rebuild_padding, - save_output_normal, -) - -if not (current_platform.is_dcu() or current_platform.is_iluvatar()): - from fastdeploy.spec_decode import MTPProposer, NgramProposer, SuffixProposer - import zmq from fastdeploy import envs @@ -100,6 +90,13 @@ from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling +from fastdeploy.model_executor.pre_and_post_process import ( + async_set_value, + post_process, + pre_process, + rebuild_padding, + save_output_normal, +) from fastdeploy.output.pooler import PoolerOutput from fastdeploy.worker.model_runner_base import ( DistributedOut, @@ -124,8 +121,8 @@ class GPUModelRunner(ModelRunnerBase): self.rank = rank self.local_rank = local_rank self.device_id = device_id - self.speculative_method = self.fd_config.speculative_config.method - self.speculative_decoding = self.speculative_method is not None + self.spec_method = self.fd_config.speculative_config.method + self.speculative_decoding = self.spec_method is not None self.enable_logprob = fd_config.model_config.enable_logprob self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling" @@ -296,7 +293,9 @@ class GPUModelRunner(ModelRunnerBase): """ check whether decode stage exist """ - return (self.share_inputs["seq_lens_decoder"] > 0).any().cpu().numpy().item() + seq_lens_decoder = self.share_inputs["seq_lens_decoder"] + stop_flags = self.share_inputs["stop_flags"].squeeze(1) + return ((seq_lens_decoder > 0) & ~stop_flags).any().cpu().numpy().item() def _resolve_current_launch_token_num( self, cached_token_num: int, token_num_event, is_dummy_or_profile_run: bool @@ -428,21 +427,19 @@ class GPUModelRunner(ModelRunnerBase): """ Init speculative proposer """ - if self.speculative_method == "ngram": - self.proposer = NgramProposer(self.fd_config) - elif self.speculative_method == "mtp": - self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"] - self.proposer = MTPProposer( - self.fd_config, - self.get_model(), - self.local_rank, - self.device_id, - self.share_inputs, - ) - elif self.speculative_method == "suffix": - self.proposer = SuffixProposer(self.fd_config) - else: + if self.spec_method is None: self.proposer = None + return + # MTP-specific: swap seq_lens_this_time to the buffer tensor + if self.spec_method == SpecMethod.MTP: + self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"] + self.proposer = self.spec_method.create_proposer( + self.fd_config, + main_model=self.get_model(), + local_rank=self.local_rank, + device_id=self.device_id, + share_inputs=self.share_inputs, + ) def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],]: """ @@ -868,7 +865,7 @@ class GPUModelRunner(ModelRunnerBase): self.prompt_logprobs_reqs[request.request_id] = request self.forward_batch_reqs_list[idx] = request - if self.speculative_decoding and self.speculative_method == "suffix" and self.proposer is not None: + if self.speculative_decoding and self.spec_method == SpecMethod.SUFFIX and self.proposer is not None: if isinstance(request.prompt_token_ids, np.ndarray): prompt_token_ids = request.prompt_token_ids.tolist() else: @@ -984,7 +981,7 @@ class GPUModelRunner(ModelRunnerBase): self._process_mm_features(req_dicts) self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests] - if self.speculative_method in ["mtp"]: + if self.spec_method == SpecMethod.MTP: self.proposer.insert_tasks_v1(req_dicts, num_running_requests) def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): @@ -1228,7 +1225,7 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs.condense() reorder_split_prefill_and_decode(input_batch=self.share_inputs) if self.speculative_decoding: - if self.speculative_method == "mtp": + if self.spec_method == SpecMethod.MTP: self.proposer.reorder_inputs() def load_model(self) -> None: @@ -1249,7 +1246,7 @@ class GPUModelRunner(ModelRunnerBase): if self.fd_config.load_config.dynamic_load_weight: from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager - if self.fd_config.speculative_config.method == "mtp": + if self.spec_method == SpecMethod.MTP: self.dynamic_weight_manager = DynamicWeightManager( self.fd_config, [self.model, self.proposer.model], self.local_rank ) @@ -1745,15 +1742,19 @@ class GPUModelRunner(ModelRunnerBase): think_end_id=self.model_config.think_end_id, splitwise_role_is_decode=self.scheduler_config.splitwise_role == "decode", enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0, + is_naive_mode=(self.speculative_decoding and self.proposer is None), + prefill_one_step_stop=self.parallel_config.prefill_one_step_stop, ) self.exist_prefill_flag = False if self.speculative_decoding: - if self.speculative_method == "mtp": + if self.spec_method == SpecMethod.MTP: self.proposer.run( full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph, is_dummy_run=True, ) + elif self.spec_method == SpecMethod.NAIVE: + pass else: self.proposer.prepare_dummy_speculative_drafts(share_inputs=self.share_inputs, batch_size=batch_size) return sampler_output @@ -1789,7 +1790,7 @@ class GPUModelRunner(ModelRunnerBase): max_dec_len_list=max_dec_len_list, block_num=block_num, ) - if self.speculative_method in ["mtp"]: + if self.spec_method == SpecMethod.MTP: self.proposer.dummy_prefill_inputs( num_tokens=num_tokens, batch_size=batch_size, @@ -1803,7 +1804,6 @@ class GPUModelRunner(ModelRunnerBase): self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph self.padding_cudagraph_inputs() - # 3. Run model if self.enable_mm: model_output = self.model( self.forward_meta.ids_remove_padding, @@ -1877,7 +1877,7 @@ class GPUModelRunner(ModelRunnerBase): logger.info( f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" ) - elif self.speculative_decoding: + elif self.speculative_decoding and self.spec_method == SpecMethod.MTP: # Capture Target Model without bsz 1 for capture_size in sorted(capture_sizes, reverse=True): expected_decode_len = self.speculative_config.num_speculative_tokens * 2 + 1 @@ -2116,7 +2116,6 @@ class GPUModelRunner(ModelRunnerBase): # 2. Padding inputs for cuda graph self.padding_cudagraph_inputs() - # 3. Execute model if self.enable_mm: model_output = self.model( @@ -2330,7 +2329,7 @@ class GPUModelRunner(ModelRunnerBase): enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False), ) - if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": + if self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill": skip_save_output = True else: skip_save_output = False @@ -2348,19 +2347,25 @@ class GPUModelRunner(ModelRunnerBase): think_end_id=self.model_config.think_end_id, splitwise_role_is_decode=self.scheduler_config.splitwise_role == "decode", enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0, + is_naive_mode=(self.speculative_decoding and self.proposer is None), + prefill_one_step_stop=self.parallel_config.prefill_one_step_stop, ) if self.guided_backend is not None and sampler_output is not None: self.sampler.post_process(sampler_output.sampled_token_ids) - # 6. Speculative decode - if self.speculative_decoding: - if self.speculative_method == "mtp": + # 6. Speculative decode -- proposer run (method="naive" has proposer=None, skip) + # For naive mode: seq_lens_this_time is already reset to 1 inside + # unified_update_model_status kernel. For MTP/Ngram, the proposer + # will overwrite it with (draft_count + 1) below. + + if self.speculative_decoding and self.proposer is not None: + if self.spec_method == SpecMethod.MTP: self.proposer.run( full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph ) - elif self.speculative_method == "suffix": - self.proposer.run(share_inputs=self.share_inputs) + elif self.spec_method == SpecMethod.NAIVE: + pass else: self.proposer.run(share_inputs=self.share_inputs) @@ -2483,7 +2488,7 @@ class GPUModelRunner(ModelRunnerBase): # TODO(gongshaotian): Optimize the management logic of kvcache self.num_gpu_blocks = self.cache_config.total_block_num self.initialize_kv_cache(profile=True) - if self.speculative_method in ["mtp"]: + if self.spec_method == SpecMethod.MTP: self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True) # 1. Profile with multimodal encoder & encoder cache @@ -2499,7 +2504,7 @@ class GPUModelRunner(ModelRunnerBase): ) # 3. gc - if self.speculative_method in ["mtp"]: + if self.spec_method == SpecMethod.MTP: self.proposer.clear_mtp_cache(profile=True) self.clear_cache(profile=True) @@ -2530,7 +2535,7 @@ class GPUModelRunner(ModelRunnerBase): } ) - if self.speculative_method in ["mtp"]: + if self.spec_method == SpecMethod.MTP: self.proposer.update_mtp_block_num(num_gpu_blocks) def cal_theortical_kvcache(self): @@ -2561,7 +2566,7 @@ class GPUModelRunner(ModelRunnerBase): # NOTE(liuzichang): Implement multi-layer MTP architecture in the future num_layers = ( self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio - if self.speculative_method in ["mtp"] + if self.spec_method == SpecMethod.MTP else self.model_config.num_hidden_layers ) @@ -2620,7 +2625,7 @@ class GPUModelRunner(ModelRunnerBase): self.dynamic_weight_manager.clear_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle ) - if self.speculative_method in ["mtp"]: + if self.spec_method == SpecMethod.MTP: self.proposer.clear_mtp_cache() self.clear_cache() paddle.device.cuda.empty_cache() @@ -2646,7 +2651,7 @@ class GPUModelRunner(ModelRunnerBase): # Reset share_inputs self.share_inputs.reset_share_inputs() - if self.speculative_method in ["mtp"]: + if self.spec_method == SpecMethod.MTP: self.proposer.model_inputs.reset_model_inputs() self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) self.initialize_kv_cache() diff --git a/fastdeploy/worker/hpu_model_runner.py b/fastdeploy/worker/hpu_model_runner.py index d6c5108f40..7237214b00 100644 --- a/fastdeploy/worker/hpu_model_runner.py +++ b/fastdeploy/worker/hpu_model_runner.py @@ -46,6 +46,7 @@ from fastdeploy.model_executor.ops.intel_hpu import ( step_paddle, update_inputs_v3, ) +from fastdeploy.spec_decode import SpecMethod from fastdeploy.utils import get_logger from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput @@ -501,9 +502,9 @@ class HPUModelRunner(ModelRunnerBase): """ Init speculative proposer """ - # if self.speculative_method == "ngram": + # if self.speculative_method == SpecMethod.NGRAM: # self.proposer = NgramProposer(self.fd_config) - # elif self.speculative_method == "mtp": + # elif self.speculative_method == SpecMethod.MTP: # self.proposer = MTPProposer(self.fd_config, self.get_model(), # self.local_rank, self.device_id, # self.share_inputs) @@ -767,7 +768,7 @@ class HPUModelRunner(ModelRunnerBase): self.share_inputs["not_need_stop"][0] = True - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.insert_prefill_inputs(req_dicts, num_running_requests) def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): @@ -1195,7 +1196,7 @@ class HPUModelRunner(ModelRunnerBase): self._dummy_prefill_inputs( num_tokens=num_tokens, batch_size=batch_size, expected_decode_len=expected_decode_len ) - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: raise NotImplementedError("speculative sampling is not supported on Intel HPU.") while True: @@ -1682,7 +1683,7 @@ class HPUModelRunner(ModelRunnerBase): accept_num=self.share_inputs["accept_num"] if self.speculative_decoding else None, ) - # if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": + # if self.speculative_config.method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill": # skip_save_output = True # else: # skip_save_output = False @@ -1702,7 +1703,7 @@ class HPUModelRunner(ModelRunnerBase): # 6. Speculative decode if self.speculative_decoding: - if self.speculative_method == "mtp": + if self.speculative_method == SpecMethod.MTP: self.proposer.run(full_hidden_states=hiddden_states) else: self.proposer.run(share_inputs=self.share_inputs) @@ -1757,7 +1758,7 @@ class HPUModelRunner(ModelRunnerBase): # 3. gc self.clear_cache() - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.clear_dummy_input() def update_share_input_block_num(self, num_gpu_blocks: int) -> None: @@ -1785,7 +1786,7 @@ class HPUModelRunner(ModelRunnerBase): self.parallel_config.do_profile = False - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.update_block_num(num_gpu_blocks) def cal_theortical_kvcache(self): @@ -1816,7 +1817,7 @@ class HPUModelRunner(ModelRunnerBase): # NOTE(liuzichang): Implement multi-layer MTP architecture in the future num_layers = ( self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio - if self.speculative_method in ["mtp"] + if self.speculative_method == SpecMethod.MTP else self.model_config.num_hidden_layers ) required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index b265b5b525..e24b58f9f2 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -76,7 +76,7 @@ from fastdeploy.model_executor.pre_and_post_process import ( save_output_normal, ) from fastdeploy.output.pooler import PoolerOutput -from fastdeploy.spec_decode import MTPProposer, NgramProposer +from fastdeploy.spec_decode import SpecMethod from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode from fastdeploy.worker.model_runner_base import ( DistributedOut, @@ -349,19 +349,16 @@ class MetaxModelRunner(ModelRunnerBase): """ Init speculative proposer """ - if self.speculative_method == "ngram": - self.proposer = NgramProposer(self.fd_config) - elif self.speculative_method == "mtp": + if self.speculative_method == SpecMethod.MTP: self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"] - self.proposer = MTPProposer( - self.fd_config, - self.get_model(), - self.local_rank, - self.device_id, - self.share_inputs, - ) - else: - self.proposer = None + + self.proposer = self.speculative_method.create_proposer( + self.fd_config, + main_model=self.get_model(), + local_rank=self.local_rank, + device_id=self.device_id, + share_inputs=self.share_inputs, + ) def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],]: """ @@ -885,7 +882,7 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["not_need_stop"][0] = True self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests] - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.insert_tasks_v1(req_dicts, num_running_requests) def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): @@ -969,7 +966,7 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests] - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.insert_prefill_inputs(req_dicts, num_running_requests) def get_input_length_list( @@ -1205,7 +1202,7 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs.condense() reorder_split_prefill_and_decode(input_batch=self.share_inputs) if self.speculative_decoding: - if self.speculative_method == "mtp": + if self.speculative_method == SpecMethod.MTP: self.proposer.reorder_inputs() def load_model(self) -> None: @@ -1686,7 +1683,7 @@ class MetaxModelRunner(ModelRunnerBase): ) self.exist_prefill_flag = False if self.speculative_decoding: - if self.speculative_method == "mtp": + if self.speculative_method == SpecMethod.MTP: self.proposer.run( full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph, @@ -1728,7 +1725,7 @@ class MetaxModelRunner(ModelRunnerBase): max_dec_len_list=max_dec_len_list, block_num=block_num, ) - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.dummy_prefill_inputs( num_tokens=num_tokens, batch_size=batch_size, @@ -1816,7 +1813,7 @@ class MetaxModelRunner(ModelRunnerBase): logger.info( f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" ) - elif self.speculative_decoding and self.speculative_method == "mtp": + elif self.speculative_decoding and self.speculative_method == SpecMethod.MTP: # Capture Target Model without bsz 1 for capture_size in sorted(capture_sizes, reverse=True): self._dummy_run( @@ -2233,7 +2230,7 @@ class MetaxModelRunner(ModelRunnerBase): enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False), ) - if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": + if self.speculative_config.method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill": skip_save_output = True else: skip_save_output = False @@ -2258,7 +2255,7 @@ class MetaxModelRunner(ModelRunnerBase): # 6. Speculative decode if self.speculative_decoding: - if self.speculative_method == "mtp": + if self.speculative_method == SpecMethod.MTP: self.proposer.run( full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph ) @@ -2386,7 +2383,7 @@ class MetaxModelRunner(ModelRunnerBase): # TODO(gongshaotian): Optimize the management logic of kvcache self.num_gpu_blocks = self.cache_config.total_block_num self.initialize_kv_cache(profile=True) - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True) # 1. Profile with multimodal encoder & encoder cache @@ -2402,7 +2399,7 @@ class MetaxModelRunner(ModelRunnerBase): ) # 3. gc - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.clear_mtp_cache(profile=True) self.clear_cache(profile=True) @@ -2433,7 +2430,7 @@ class MetaxModelRunner(ModelRunnerBase): } ) - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.update_mtp_block_num(num_gpu_blocks) def cal_theortical_kvcache(self): @@ -2464,7 +2461,7 @@ class MetaxModelRunner(ModelRunnerBase): # NOTE(liuzichang): Implement multi-layer MTP architecture in the future num_layers = ( self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio - if self.speculative_method in ["mtp"] + if self.speculative_method == SpecMethod.MTP else self.model_config.num_hidden_layers ) @@ -2517,7 +2514,7 @@ class MetaxModelRunner(ModelRunnerBase): self.dynamic_weight_manager.clear_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle ) - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.clear_mtp_cache() self.clear_cache() paddle.device.empty_cache() @@ -2540,7 +2537,7 @@ class MetaxModelRunner(ModelRunnerBase): self.dynamic_weight_manager.update_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle ) - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) self.initialize_kv_cache() # Recapture CUDAGraph diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index f827c5dc00..9de01c93fd 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -62,7 +62,7 @@ from fastdeploy.model_executor.xpu_pre_and_post_process import ( xpu_pre_process, xpu_process_output, ) -from fastdeploy.spec_decode import MTPProposer +from fastdeploy.spec_decode import SpecMethod from fastdeploy.utils import get_logger from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput @@ -727,7 +727,7 @@ class XPUModelRunner(ModelRunnerBase): if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.insert_tasks_v1(req_dicts, num_running_requests) def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): @@ -876,7 +876,7 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["not_need_stop"][0] = True - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request( request, "temp_scaled_logprobs", False ) @@ -1433,7 +1433,7 @@ class XPUModelRunner(ModelRunnerBase): block_num=block_num, ) - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.dummy_prefill_inputs( num_tokens=num_tokens, batch_size=batch_size, @@ -1450,17 +1450,16 @@ class XPUModelRunner(ModelRunnerBase): """ Init speculative proposer """ - if self.speculative_method == "ngram": + if self.speculative_method == SpecMethod.NGRAM: # xpu not support ngram proposer now - # self.proposer = NgramProposer(self.fd_config) self.proposer = None - elif self.speculative_method == "mtp": - self.proposer = MTPProposer( + elif self.speculative_method == SpecMethod.MTP: + self.proposer = self.speculative_method.create_proposer( self.fd_config, - self.get_model(), - self.local_rank, - self.device_id, - self.share_inputs, + main_model=self.get_model(), + local_rank=self.local_rank, + device_id=self.device_id, + share_inputs=self.share_inputs, ) else: self.proposer = None @@ -1631,7 +1630,7 @@ class XPUModelRunner(ModelRunnerBase): ) skip_save_output = is_dummy_run or ( - self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill" + self.speculative_config.method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill" ) if self.speculative_decoding: @@ -1657,7 +1656,7 @@ class XPUModelRunner(ModelRunnerBase): ) # 6. Draft model propose - if self.speculative_method == "mtp": + if self.speculative_method == SpecMethod.MTP: self.proposer.run(full_hidden_states=model_output) # 7. Updata 'infer_seed' and step_paddle() @@ -1711,7 +1710,7 @@ class XPUModelRunner(ModelRunnerBase): """Execute a forward pass with dummy inputs to profile the memory usage of the model""" self.num_gpu_blocks = self.cache_config.total_block_num - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True) self.initialize_kv_cache(profile=True) @@ -1733,7 +1732,7 @@ class XPUModelRunner(ModelRunnerBase): self.num_gpu_blocks = num_gpu_blocks # Reset block table and kv cache with global block num - if self.speculative_method in ["mtp"]: + if self.speculative_method == SpecMethod.MTP: self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) self.initialize_kv_cache() diff --git a/tests/layers/test_sampler.py b/tests/layers/test_sampler.py index e3d6d426ed..d3cd8716b9 100644 --- a/tests/layers/test_sampler.py +++ b/tests/layers/test_sampler.py @@ -58,6 +58,11 @@ def _disable_triton_cuda_path(monkeypatch): monkeypatch.setattr( "fastdeploy.model_executor.layers.sample.sampler.batched_count_greater_than", lambda x, y: (x >= y).sum(-1) ) + # Also patch batched_count_greater_than in logprobs module itself, because + # build_output_logprobs -> gather_logprobs calls it from logprobs scope. + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.logprobs.batched_count_greater_than", lambda x, y: (x >= y).sum(-1) + ) @pytest.fixture @@ -296,8 +301,15 @@ def test_speculative_sampler_basic(monkeypatch): fd_config = types.SimpleNamespace( model_config=types.SimpleNamespace(logprobs_mode="raw_logits", think_end_id=1, line_break_id=2), speculative_config=types.SimpleNamespace( - verify_window=2, max_candidate_len=4, benchmark_mode=False, enf_gen_phase_tag=False + method="ngram", + verify_window=2, + max_candidate_len=4, + benchmark_mode=False, + enf_gen_phase_tag=False, + verify_strategy="topp", + accept_policy="normal", ), + parallel_config=types.SimpleNamespace(prefill_one_step_stop=False), ) monkeypatch.setattr("fastdeploy.model_executor.layers.sample.sampler.current_platform.is_cuda", lambda: True) monkeypatch.setattr("fastdeploy.model_executor.layers.sample.sampler.current_platform.is_xpu", lambda: False) diff --git a/tests/layers/test_speculative_sampler.py b/tests/layers/test_speculative_sampler.py index e0fe92aa2f..7abd94286d 100644 --- a/tests/layers/test_speculative_sampler.py +++ b/tests/layers/test_speculative_sampler.py @@ -17,6 +17,7 @@ from unittest.mock import Mock import paddle +import pytest from fastdeploy.config import ( CacheConfig, @@ -34,6 +35,22 @@ from fastdeploy.model_executor.layers.sample.sampler import ( ) +@pytest.fixture(autouse=True) +def _ensure_triton_fallback(monkeypatch): + """Ensure batched_count_greater_than uses the non-triton fallback. + + When test_sampler.py runs before this file in the same pytest session, + it installs a triton stub (triton.jit = lambda fn: fn) at module level, + which permanently turns count_greater_kernel into a plain function that + is not subscriptable. Monkeypatching the logprobs module to use the + non-triton fallback avoids the resulting TypeError. + """ + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.logprobs.batched_count_greater_than", + lambda x, y: (x >= y).sum(-1), + ) + + def _create_fake_logits(batch_size: int, vocab_size: int) -> paddle.Tensor: fake_logits = paddle.rand(shape=[batch_size, vocab_size], dtype="float32") return fake_logits @@ -80,12 +97,12 @@ def _create_default_sampling_metadata( return fake_sampling_metadata -def _create_fd_config(max_model_len): +def _create_fd_config(max_model_len, method=None): model_config: Mock = Mock() model_config.max_model_len = max_model_len model_config.architectures = ["test_model"] model_config.mm_max_tokens_per_item = None - speculative_config = SpeculativeConfig({}) + speculative_config = SpeculativeConfig({"method": method} if method else {}) graph_opt_config = GraphOptimizationConfig({}) scheduler_config = SchedulerConfig({}) parallel_config = ParallelConfig({}) @@ -169,7 +186,8 @@ def test_speculative_sampler(): max_model_len = 1024 max_draft_token_num = 1 - fd_config = _create_fd_config(max_model_len) + # Use ngram method for speculative decoding + fd_config = _create_fd_config(max_model_len, method="ngram") sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len) logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size) share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) @@ -186,7 +204,8 @@ def test_speculative_sampler_logprobs(): max_model_len = 1024 max_draft_token_num = 1 - fd_config = _create_fd_config(max_model_len) + # Use ngram method for speculative decoding + fd_config = _create_fd_config(max_model_len, method="ngram") share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0) sampling_metadata.share_inputs = share_inputs diff --git a/tests/operators/test_unified_update_model_status.py b/tests/operators/test_unified_update_model_status.py new file mode 100644 index 0000000000..566544bdc4 --- /dev/null +++ b/tests/operators/test_unified_update_model_status.py @@ -0,0 +1,574 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for unified_update_model_status kernel. + +Kernel semantics (from unified_update_model_status.cu): + - Launched as <<<1, 1024>>>, one thread per batch slot (max_bsz <= 1024). + - real_bsz = seq_lens_this_time.shape[0], max_bsz = stop_flags.shape[0]. + - has_running_seqs is a CPU tensor (copied to GPU, kernel writes, copied back). + - Padding slots (batch_id >= real_bsz): only counted as stopped, NO state modified. + - Stopped/paused real slots: set stop_flags=true, seq_lens_decoder=0, + seq_lens_this_time=0, step_output_len=0. + - Running slots: EOS detection → state update → token_ids_all write → next input setup. +""" + +import unittest +from typing import Any, Dict + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.gpu import unified_update_model_status + +CUDA_PLACE = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() +CPU_PLACE = paddle.CPUPlace() + + +# ============================================================ +# Layer 1: Helpers — tensor creation / kernel invocation / output extraction +# ============================================================ + + +def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Convert numpy dict → paddle tensors. has_running_seqs goes to CPU.""" + paddle_inputs = {} + for k, v in inputs.items(): + if isinstance(v, (int, bool, float, str)): + paddle_inputs[k] = v + elif k == "has_running_seqs": + # Kernel host function: has_running_seqs.copy_to(GPU) → kernel → copy_to(CPU) + paddle_inputs[k] = paddle.to_tensor(v, place=CPU_PLACE) + elif v is not None: + paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE) + else: + paddle_inputs[k] = None + return paddle_inputs + + +def run_kernel(paddle_inputs: Dict[str, Any], inputs: Dict[str, Any]): + """Call unified_update_model_status kernel.""" + unified_update_model_status( + paddle_inputs["seq_lens_encoder"], + paddle_inputs["seq_lens_decoder"], + paddle_inputs["has_running_seqs"], + paddle_inputs["step_input_ids"], + paddle_inputs["adaptive_step_input_len"], + paddle_inputs["step_output_ids"], + paddle_inputs["step_output_len"], + paddle_inputs["stop_flags"], + paddle_inputs["seq_lens_this_time"], + paddle_inputs["is_paused"], + paddle_inputs["mask_rollback"], + paddle_inputs["token_ids_all"], + paddle_inputs["prompt_lens"], + paddle_inputs["step_idx"], + paddle_inputs["end_tokens"], + paddle_inputs["max_dec_len"], + inputs["is_naive_mode"], + inputs["prefill_one_step_stop"], + ) + + +# All 12 in-place output keys (from SetInplaceMap in .cu) +OUTPUT_KEYS = [ + "seq_lens_encoder", + "seq_lens_decoder", + "has_running_seqs", + "step_input_ids", + "step_output_ids", + "step_output_len", + "stop_flags", + "seq_lens_this_time", + "mask_rollback", + "token_ids_all", + "step_idx", + # adaptive_step_input_len is in InplaceMap but kernel never writes it +] + + +def get_outputs(paddle_inputs: Dict[str, Any]) -> Dict[str, np.ndarray]: + """Extract ALL in-place-modified tensors back to numpy.""" + return {k: paddle_inputs[k].numpy() for k in OUTPUT_KEYS} + + +# ============================================================ +# Layer 2: Input generation +# ============================================================ + + +def gen_inputs( + real_bsz: int = 8, + max_step_tokens: int = 16, + max_model_len: int = 256, + seed: int = 42, + is_naive_mode: bool = False, + prefill_one_step_stop: bool = False, +) -> Dict[str, Any]: + """Generate randomized test inputs for unified_update_model_status kernel. + + Shapes follow the kernel contract: + - real_bsz = seq_lens_this_time.shape[0] + - max_bsz = stop_flags.shape[0] (= real_bsz + padding) + - is_paused.shape[0] = max_bsz + """ + rng = np.random.default_rng(seed) + max_bsz = real_bsz + 4 # padding slots + + # Per-slot arrays (size=max_bsz) + seq_lens_encoder = rng.integers(0, 5, size=max_bsz, dtype=np.int32) + seq_lens_decoder = rng.integers(10, 100, size=max_bsz, dtype=np.int32) + step_input_ids = rng.integers(0, 1000, size=(max_bsz, max_step_tokens), dtype=np.int64) + adaptive_step_input_len = rng.integers(1, max_step_tokens + 1, size=max_bsz, dtype=np.int32) + step_output_ids = rng.integers(0, 1000, size=(max_bsz, max_step_tokens), dtype=np.int64) + step_output_len = rng.integers(1, max_step_tokens + 1, size=max_bsz, dtype=np.int32) + stop_flags = np.zeros(max_bsz, dtype=bool) + # Randomly stop a few real slots + stop_flags[rng.choice(real_bsz, size=min(2, real_bsz), replace=False)] = True + # Padding slots (batch_id >= real_bsz) must be stopped — kernel accesses + # seq_lens_this_time[batch_id] which is only sized real_bsz + stop_flags[real_bsz:] = True + is_paused = np.zeros(max_bsz, dtype=bool) + mask_rollback = np.zeros(max_bsz, dtype=np.int32) + prompt_lens = rng.integers(10, 50, size=max_bsz, dtype=np.int64) + token_ids_all = rng.integers(0, 1000, size=(max_bsz, max_model_len), dtype=np.int64) + step_idx = rng.integers(0, 50, size=max_bsz, dtype=np.int64) + max_dec_len = rng.integers(100, 200, size=max_bsz, dtype=np.int64) + + # Per-real-batch arrays (size=real_bsz) + seq_lens_this_time = rng.integers(1, max_step_tokens + 1, size=real_bsz, dtype=np.int32) + + # Scalar / small tensors + has_running_seqs = np.array([True], dtype=bool) + end_tokens = rng.integers(1, 1000, size=4, dtype=np.int64) + + return { + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_decoder": seq_lens_decoder, + "has_running_seqs": has_running_seqs, + "step_input_ids": step_input_ids, + "adaptive_step_input_len": adaptive_step_input_len, + "step_output_ids": step_output_ids, + "step_output_len": step_output_len, + "stop_flags": stop_flags, + "seq_lens_this_time": seq_lens_this_time, + "is_paused": is_paused, + "mask_rollback": mask_rollback, + "token_ids_all": token_ids_all, + "prompt_lens": prompt_lens, + "step_idx": step_idx, + "end_tokens": end_tokens, + "max_dec_len": max_dec_len, + # Scalar configs + "real_bsz": real_bsz, + "max_bsz": max_bsz, + "max_step_tokens": max_step_tokens, + "max_model_len": max_model_len, + "is_naive_mode": is_naive_mode, + "prefill_one_step_stop": prefill_one_step_stop, + } + + +# ============================================================ +# Layer 3: Reference implementation (1:1 with CUDA kernel) +# ============================================================ + + +def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Python reference of unified_update_model_status_kernel. + + Line references are to unified_update_model_status.cu. + """ + # Deep-copy all mutable in-place tensors + seq_lens_encoder = inputs["seq_lens_encoder"].copy() + seq_lens_decoder = inputs["seq_lens_decoder"].copy() + step_output_len = inputs["step_output_len"].copy() + stop_flags = inputs["stop_flags"].copy() + seq_lens_this_time = inputs["seq_lens_this_time"].copy() + mask_rollback = inputs["mask_rollback"].copy() + token_ids_all = inputs["token_ids_all"].copy() + step_idx = inputs["step_idx"].copy() + step_input_ids = inputs["step_input_ids"].copy() + step_output_ids = inputs["step_output_ids"].copy() + + # Read-only inputs + real_bsz = inputs["real_bsz"] + max_bsz = inputs["max_bsz"] + max_model_len = inputs["max_model_len"] + is_naive_mode = inputs["is_naive_mode"] + prefill_one_step_stop = inputs["prefill_one_step_stop"] + end_tokens = inputs["end_tokens"] + num_end_tokens = len(end_tokens) + max_dec_len = inputs["max_dec_len"] + prompt_lens = inputs["prompt_lens"] + is_paused = inputs["is_paused"] + + # Block-level stop count for has_running_seqs reduction (line 175) + stop_count = 0 + + for batch_id in range(max_bsz): + # --- line 68-75: Read state --- + cur_seq_len_encoder = int(seq_lens_encoder[batch_id]) + cur_seq_len_decoder = int(seq_lens_decoder[batch_id]) + cur_stop_flag = bool(stop_flags[batch_id]) + output_len = 0 + cur_step_idx = int(step_idx[batch_id]) + cur_is_paused = bool(is_paused[batch_id]) + + # line 77 + is_running = not cur_stop_flag and not cur_is_paused + + # --- line 80-86: Compute output length --- + if is_running: + output_len = 1 if is_naive_mode else int(step_output_len[batch_id]) + + # --- line 89-110: EOS detection --- + if is_running and output_len > 0: + hit_stop = False + for i in range(output_len): + cur_step_idx += 1 # line 94 + token = int(step_output_ids[batch_id, i]) # line 95 + is_eos = any(token == end_tokens[j] for j in range(num_end_tokens)) # line 96 + max_len_hit = cur_step_idx >= int(max_dec_len[batch_id]) # line 97 + + if is_eos or max_len_hit: # line 99 + if not is_eos: + step_output_ids[batch_id, i] = end_tokens[0] # line 100 + output_len = i + 1 # line 101 + cur_stop_flag = True # line 102 + hit_stop = True # line 103 + break # line 104 + + # line 108-110 + if not hit_stop and prefill_one_step_stop and cur_seq_len_encoder > 0: + cur_stop_flag = True + + # --- line 114-166: Update state and write back --- + if is_running: + if cur_stop_flag: + # line 115-119 + stop_count += 1 + if output_len == 0: + cur_seq_len_decoder = 0 # line 117 + stop_flags[batch_id] = True # line 118 + mask_rollback[batch_id] = 0 # line 119 + elif cur_seq_len_encoder == 0: + # line 120-122 + cur_seq_len_decoder += output_len # line 121 + mask_rollback[batch_id] = int(seq_lens_this_time[batch_id]) - output_len # line 122 + else: + # line 123-124 (encoder > 0, not stopped) + mask_rollback[batch_id] = 0 + + # line 127-130: Fold encoder into decoder + if cur_seq_len_encoder > 0: + cur_seq_len_decoder += cur_seq_len_encoder # line 128 + cur_seq_len_encoder = 0 # line 129 + + # line 132-135: Write back scalar state + seq_lens_encoder[batch_id] = cur_seq_len_encoder + seq_lens_decoder[batch_id] = cur_seq_len_decoder + step_output_len[batch_id] = output_len + step_idx[batch_id] = cur_step_idx + + # line 138-145: Write history to token_ids_all + if cur_step_idx > 0 and output_len > 0: + base = int(prompt_lens[batch_id]) + for i in range(output_len): + # token_ids_all_now[cur_step_idx - i] = output_ids[output_len - 1 - i] + write_idx = base + cur_step_idx - i + if 0 <= write_idx < max_model_len: + token_ids_all[batch_id, write_idx] = step_output_ids[batch_id, output_len - 1 - i] + + # line 148-151: Setup next step_input_ids + if output_len > 0: + step_input_ids[batch_id, 0] = step_output_ids[batch_id, output_len - 1] + + # line 153-155: naive_mode → seq_lens_this_time + if is_naive_mode: + seq_lens_this_time[batch_id] = 0 if cur_stop_flag else 1 + + elif batch_id >= real_bsz: + # line 156-158: Padding slot — only count, don't modify state + stop_count += 1 + else: + # line 159-166: Stopped or paused real slot + stop_count += 1 + stop_flags[batch_id] = True # line 162 + seq_lens_decoder[batch_id] = 0 # line 163 + seq_lens_this_time[batch_id] = 0 # line 164 + step_output_len[batch_id] = 0 # line 165 + + # line 177-179: has_running_seqs = stop_sum < max_bsz + has_running_seqs = np.array([stop_count < max_bsz], dtype=bool) + + return { + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_decoder": seq_lens_decoder, + "has_running_seqs": has_running_seqs, + "step_input_ids": step_input_ids, + "step_output_ids": step_output_ids, + "step_output_len": step_output_len, + "stop_flags": stop_flags, + "seq_lens_this_time": seq_lens_this_time, + "mask_rollback": mask_rollback, + "token_ids_all": token_ids_all, + "step_idx": step_idx, + } + + +# ============================================================ +# Layer 4a: TEST_CONFIGS +# ============================================================ + +TEST_CONFIGS = [ + # --- basic mode coverage --- + { + "name": "mtp_mode", + "real_bsz": 8, + "max_step_tokens": 16, + "max_model_len": 256, + "seed": 42, + "is_naive_mode": False, + }, + { + "name": "naive_mode", + "real_bsz": 8, + "max_step_tokens": 16, + "max_model_len": 256, + "seed": 42, + "is_naive_mode": True, + }, + # --- batch size --- + { + "name": "small_batch", + "real_bsz": 1, + "max_step_tokens": 8, + "max_model_len": 128, + "seed": 42, + "is_naive_mode": False, + }, + { + "name": "large_batch", + "real_bsz": 32, + "max_step_tokens": 16, + "max_model_len": 512, + "seed": 42, + "is_naive_mode": False, + }, + # --- prefill_one_step_stop --- + { + "name": "prefill_one_step_stop", + "real_bsz": 8, + "max_step_tokens": 8, + "max_model_len": 128, + "seed": 42, + "is_naive_mode": False, + "prefill_one_step_stop": True, + }, + # --- different seeds for randomized coverage --- + { + "name": "seed_100", + "real_bsz": 8, + "max_step_tokens": 16, + "max_model_len": 256, + "seed": 100, + "is_naive_mode": False, + }, + { + "name": "seed_200_naive", + "real_bsz": 8, + "max_step_tokens": 16, + "max_model_len": 256, + "seed": 200, + "is_naive_mode": True, + }, +] + + +# ============================================================ +# Layer 4b: Test suite +# ============================================================ + + +class TestUnifiedUpdateModelStatus(unittest.TestCase): + + def setUp(self): + if not paddle.is_compiled_with_cuda(): + self.skipTest("Requires CUDA") + + # ------ shared helpers ------ + + def _run_and_get(self, inputs: Dict[str, Any]) -> Dict[str, np.ndarray]: + paddle_inputs = to_paddle_inputs(inputs) + run_kernel(paddle_inputs, inputs) + return get_outputs(paddle_inputs) + + def _check_all_outputs(self, inputs: Dict[str, Any], outputs: Dict[str, np.ndarray]): + """Compare ALL output tensors against reference + sanity checks.""" + ref = reference_impl(inputs) + for key in OUTPUT_KEYS: + if not np.array_equal(outputs[key], ref[key]): + diff_mask = outputs[key] != ref[key] + diff_indices = np.argwhere(diff_mask) + for idx in diff_indices[:10]: # print first 10 mismatches + idx_tuple = tuple(idx) + print( + f" [{key}] mismatch at {idx_tuple}: " + f"gpu={outputs[key][idx_tuple]} ref={ref[key][idx_tuple]}" + ) + if key == "token_ids_all": + bid = idx_tuple[0] + print( + f" batch_id={bid}, prompt_lens={inputs['prompt_lens'][bid]}, " + f"step_idx(input)={inputs['step_idx'][bid]}, " + f"step_idx(gpu)={outputs['step_idx'][bid]}, " + f"step_idx(ref)={ref['step_idx'][bid]}, " + f"step_output_len(gpu)={outputs['step_output_len'][bid]}, " + f"step_output_len(ref)={ref['step_output_len'][bid]}, " + f"stop_flags(input)={inputs['stop_flags'][bid]}, " + f"is_paused={inputs['is_paused'][bid]}, " + f"seq_lens_encoder={inputs['seq_lens_encoder'][bid]}" + ) + np.testing.assert_array_equal(outputs[key], ref[key], err_msg=f"{key} mismatch") + + # Sanity: running slots must have encoder zeroed + for i in range(inputs["real_bsz"]): + if not inputs["stop_flags"][i] and not inputs["is_paused"][i]: + self.assertEqual(outputs["seq_lens_encoder"][i], 0, f"Running slot {i} should have encoder=0") + self.assertTrue(np.all(outputs["seq_lens_decoder"] >= 0), "negative seq_lens_decoder") + self.assertTrue(np.all(outputs["step_output_len"] >= 0), "negative step_output_len") + self.assertTrue(np.all(outputs["step_idx"] >= 0), "negative step_idx") + + def _run_full_test(self, config: Dict[str, Any]) -> Dict[str, np.ndarray]: + inputs = gen_inputs(**config) + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + return outputs + + # ------ test cases ------ + + def test_configs(self): + """Run all TEST_CONFIGS via subTest.""" + for cfg in TEST_CONFIGS: + with self.subTest(name=cfg["name"]): + test_cfg = {k: v for k, v in cfg.items() if k != "name"} + self._run_full_test(test_cfg) + + def test_eos_detection(self): + """EOS token at position 2 should truncate output_len to 3.""" + inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42) + eos_token = int(inputs["end_tokens"][0]) + inputs["step_output_ids"][0, 2] = eos_token + inputs["step_output_len"][:] = [5, 3, 0, 0, 0, 0] + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_max_dec_len_stop(self): + """step_idx near max_dec_len should trigger stop and replace with end_tokens[0].""" + # Use large max_model_len to avoid token_ids_all overflow: + # kernel doesn't bounds-check prompt_lens + step_idx < max_model_len + inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=512, seed=42) + inputs["step_idx"][:] = [95, 50, 0, 0, 0, 0] + inputs["max_dec_len"][:] = 100 + inputs["step_output_len"][:] = [10, 5, 0, 0, 0, 0] + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_paused_slots(self): + """Paused slots should be treated as stopped/paused (decoder=0, output_len=0).""" + inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42) + inputs["is_paused"][:] = [True, True, False, False, False, False, False, False] + inputs["stop_flags"][: inputs["real_bsz"]] = False + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_all_stopped(self): + """All slots stopped → has_running_seqs should be False.""" + inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42) + inputs["stop_flags"][:] = True + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_encoder_to_decoder(self): + """Encoder length should fold into decoder: decoder += encoder, encoder → 0.""" + inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42) + inputs["seq_lens_encoder"][:] = [10, 0, 0, 0, 0, 0] + inputs["seq_lens_decoder"][:] = [20, 30, 0, 0, 0, 0] + inputs["step_output_len"][:] = [5, 3, 0, 0, 0, 0] + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_token_ids_all_writing(self): + """token_ids_all should be written at prompt_lens + step_idx positions.""" + inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42) + inputs["step_idx"][:] = [10, 20, 0, 0, 0, 0] + inputs["prompt_lens"][:] = [5, 5, 0, 0, 0, 0] + inputs["step_output_len"][:] = [3, 2, 0, 0, 0, 0] + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + inputs["seq_lens_encoder"][:] = 0 + # Use end_tokens that won't collide with output_ids + inputs["end_tokens"][:] = [9990, 9991, 9992, 9993] + inputs["max_dec_len"][:] = 10000 + inputs["step_output_ids"][0, :3] = [100, 200, 300] + inputs["step_output_ids"][1, :2] = [400, 500] + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_zero_output_len(self): + """Running slot with output_len=0 in MTP mode: output_len stays 0.""" + inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42) + inputs["step_output_len"][:] = [0, 5, 0, 0, 0, 0] + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_prefill_one_step_stop_with_encoder(self): + """prefill_one_step_stop + encoder>0 should stop even without EOS.""" + inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42, prefill_one_step_stop=True) + inputs["seq_lens_encoder"][:] = [5, 0, 0, 0, 0, 0, 0, 0] + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + # Ensure no accidental EOS hit + inputs["end_tokens"][:] = [9990, 9991, 9992, 9993] + inputs["max_dec_len"][:] = 10000 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_mask_rollback(self): + """mask_rollback = seq_lens_this_time - output_len for running decode slots.""" + inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42) + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + inputs["seq_lens_encoder"][:] = 0 # All decode slots + inputs["seq_lens_this_time"][:] = [6, 4, 8, 3] + inputs["step_output_len"][:] = [3, 2, 5, 1, 0, 0, 0, 0] + # Avoid EOS/max_dec_len hits + inputs["end_tokens"][:] = [9990, 9991, 9992, 9993] + inputs["max_dec_len"][:] = 10000 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/test_verify_draft_tokens.py b/tests/operators/test_verify_draft_tokens.py new file mode 100644 index 0000000000..4129b4844b --- /dev/null +++ b/tests/operators/test_verify_draft_tokens.py @@ -0,0 +1,766 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for verify_draft_tokens kernel. + +Verification strategies: +- TOPP (0): Verify draft token is in top-p candidate set +- GREEDY (1): Verify draft token matches target model's argmax +- TARGET_MATCH (2): Verify draft token matches target model's sampled token +""" + +import random +import unittest +from typing import Any, Dict + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.gpu import verify_draft_tokens +from fastdeploy.spec_decode import VerifyStrategy + +CUDA_PLACE = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() +CPU_PLACE = paddle.CPUPlace() + + +# ============================================================ +# Helpers: tensor creation / kernel invocation / comparison +# ============================================================ + + +def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Convert numpy input dict to paddle tensors on GPU.""" + paddle_inputs = {} + for k, v in inputs.items(): + if isinstance(v, (int, bool, float, str)): + paddle_inputs[k] = v + elif v is not None: + paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE) + else: + paddle_inputs[k] = None + return paddle_inputs + + +def run_kernel(paddle_inputs: Dict[str, Any], inputs: Dict[str, Any]): + """Call verify_draft_tokens kernel.""" + verify_draft_tokens( + paddle_inputs["step_output_ids"], + paddle_inputs["step_output_len"], + paddle_inputs["step_input_ids"], + paddle_inputs["target_tokens"], + paddle_inputs["candidate_ids"], + paddle_inputs["candidate_scores"], + paddle_inputs["candidate_lens"], + paddle_inputs["topp"], + paddle_inputs["stop_flags"], + paddle_inputs["seq_lens_encoder"], + paddle_inputs["seq_lens_this_time"], + paddle_inputs["end_tokens"], + paddle_inputs["is_block_step"], + paddle_inputs["cu_seqlens_q_output"], + paddle_inputs["reasoning_status"], + paddle_inputs["max_dec_len"], + paddle_inputs["step_idx"], + inputs["max_seq_len"], + inputs["verify_window"], + inputs["verify_strategy"], + inputs["reject_all"], + inputs["accept_all"], + ) + + +def run_ref(inputs: Dict[str, Any]): + """Run reference implementation on deep-copied inputs, return (output_ids, output_len).""" + ref = {k: v.copy() if isinstance(v, np.ndarray) else v for k, v in inputs.items()} + return verify_draft_tokens_ref( + ref["step_output_ids"], + ref["step_output_len"], + ref["step_input_ids"], + ref["target_tokens"], + ref["candidate_ids"], + ref["candidate_scores"], + ref["candidate_lens"], + ref["topp"], + ref["stop_flags"], + ref["seq_lens_encoder"], + ref["seq_lens_this_time"], + ref["end_tokens"], + ref["is_block_step"], + ref["cu_seqlens_q_output"], + ref["reasoning_status"], + ref["max_dec_len"], + ref["step_idx"], + ref["max_seq_len"], + ref["verify_window"], + ref["verify_strategy"], + ref["reject_all"], + ref["accept_all"], + ) + + +def compare_results( + paddle_inputs: Dict[str, Any], + step_output_ids_ref: np.ndarray, + step_output_len_ref: np.ndarray, + inputs: Dict[str, Any], + label: str = "unknown", +): + """Compare GPU kernel output vs reference.""" + gpu_ids = paddle_inputs["step_output_ids"].numpy() + gpu_len = paddle_inputs["step_output_len"].numpy() + + np.testing.assert_array_equal( + gpu_len, + step_output_len_ref, + err_msg=f"step_output_len mismatch ({label})", + ) + + if inputs["verify_strategy"] == 0: # TOPP — Phase 2 is stochastic + real_bsz = inputs["seq_lens_this_time"].shape[0] + for bid in range(real_bsz): + ref_len = int(step_output_len_ref[bid]) + if ref_len > 1: + np.testing.assert_array_equal( + gpu_ids[bid, : ref_len - 1], + step_output_ids_ref[bid, : ref_len - 1], + err_msg=f"step_output_ids (accepted) mismatch at bid={bid} ({label})", + ) + else: + np.testing.assert_array_equal( + gpu_ids, + step_output_ids_ref, + err_msg=f"step_output_ids mismatch ({label})", + ) + + +# ============================================================ +# Reference helpers +# ============================================================ + + +def topp_sampling_kernel(candidate_ids, candidate_scores, curand_value, candidate_len, topp, tid=0): + rand_top_p = curand_value * topp + sum_scores = 0.0 + for i in range(candidate_len): + sum_scores += candidate_scores[i] + if rand_top_p <= sum_scores: + return int(candidate_ids[i]) + return int(candidate_ids[0]) + + +def is_in_end(token, end_tokens, end_length): + return token in end_tokens[:end_length] + + +def is_in(candidate_list, token, length): + return token in candidate_list[:length] + + +class _VerifyContext: + """Python mirror of the CUDA VerifyContext struct for reference testing.""" + + def __init__( + self, + bid, + max_step_tokens, + end_length, + end_tokens, + max_dec_len, + step_input_ids_now, + step_output_ids_flat, + cur_step_idx, + ): + self.bid = bid + self.max_step_tokens = max_step_tokens + self.end_length = end_length + self.end_tokens = end_tokens + self.max_dec_len = max_dec_len + self.step_input_ids_now = step_input_ids_now + self.step_output_ids_flat = step_output_ids_flat + self.cur_step_idx = cur_step_idx + self.output_len_now = 1 + self.stopped = False + + def emit_token(self, pos, token): + """Emit a token to output. Returns True if sequence should stop.""" + self.cur_step_idx += 1 + eos = is_in_end(token, self.end_tokens, self.end_length) + max_hit = self.cur_step_idx >= int(self.max_dec_len[self.bid]) + if (eos or max_hit) and not eos: + token = int(self.end_tokens[0]) + self.step_output_ids_flat[self.bid * self.max_step_tokens + pos] = token + self.output_len_now += 1 + if eos or max_hit: + self.stopped = True + return True + return False + + def emit_final_token(self, pos, token): + """Emit the Phase 2 final token (no output_len_now increment).""" + self.cur_step_idx += 1 + eos = is_in_end(token, self.end_tokens, self.end_length) + max_hit = self.cur_step_idx >= int(self.max_dec_len[self.bid]) + if (eos or max_hit) and not eos: + token = int(self.end_tokens[0]) + self.step_output_ids_flat[self.bid * self.max_step_tokens + pos] = token + + +def verify_draft_tokens_ref( + step_output_ids, + step_output_len, + step_input_ids, + target_tokens, + candidate_ids, + candidate_scores, + candidate_lens, + topp, + stop_flags, + seq_lens_encoder, + seq_lens_this_time, + end_tokens, + is_block_step, + cu_seqlens_q_output, + reasoning_status, + max_dec_len, + step_idx, + max_seq_len, + verify_window, + verify_strategy, + reject_all, + accept_all, +): + """Reference implementation of verify_draft_tokens in Python.""" + real_bsz = seq_lens_this_time.shape[0] + max_step_tokens = step_input_ids.shape[1] + end_length = end_tokens.shape[0] + max_candidate_len = candidate_ids.shape[1] if candidate_ids is not None else 1 + + dev_curand_states = [random.Random(0).random() for _ in range(max_step_tokens)] + + step_output_ids_flat = step_output_ids.reshape(-1) + step_input_ids_flat = step_input_ids.reshape(-1) + candidate_ids_flat = candidate_ids.reshape(-1) if candidate_ids is not None else None + candidate_scores_flat = candidate_scores.reshape(-1) if candidate_scores is not None else None + + for bid in range(real_bsz): + start_token_id = cu_seqlens_q_output[bid] + + if is_block_step[bid] or stop_flags[bid]: + step_output_len[bid] = 0 + continue + + step_input_ids_now = step_input_ids_flat[bid * max_step_tokens :] + target_tokens_now = target_tokens[start_token_id:] if target_tokens is not None else None + candidate_ids_now = ( + candidate_ids_flat[start_token_id * max_candidate_len :] if candidate_ids_flat is not None else None + ) + candidate_lens_now = candidate_lens[start_token_id:] if candidate_lens is not None else None + candidate_scores_now = ( + candidate_scores_flat[start_token_id * max_candidate_len :] if candidate_scores_flat is not None else None + ) + + ctx = _VerifyContext( + bid, + max_step_tokens, + end_length, + end_tokens, + max_dec_len, + step_input_ids_now, + step_output_ids_flat, + int(step_idx[bid]), + ) + + # Phase 1: Verify + i = 0 + while i < seq_lens_this_time[bid] - 1: + if reject_all or seq_lens_encoder[bid] != 0 or reasoning_status[bid] == 1: + break + if accept_all: + if ctx.emit_token(i, step_input_ids_now[i + 1]): + break + i += 1 + continue + + accepted = False + if verify_strategy == 0: # TOPP + actual_cand_len = min(candidate_lens_now[i], max_candidate_len) + accepted = is_in( + candidate_ids_now[i * max_candidate_len : (i + 1) * max_candidate_len], + step_input_ids_now[i + 1], + actual_cand_len, + ) + if not accepted: + # verify_window fallback + ii = i + if ( + max_candidate_len >= 2 + and candidate_ids_now[ii * max_candidate_len + 1] == step_input_ids_now[ii + 1] + ): + j, ii = 0, ii + 1 + while j < verify_window and ii < seq_lens_this_time[bid] - 1: + if candidate_ids_now[ii * max_candidate_len] != step_input_ids_now[ii + 1]: + break + j += 1 + ii += 1 + if j >= verify_window: + for k in range(i, ii): + if ctx.emit_token(k, step_input_ids_now[k + 1]): + i = k + break + if ctx.stopped: + break + i = ii + continue + break + elif verify_strategy in (1, 2): # GREEDY / TARGET_MATCH + accepted = target_tokens_now[i] == step_input_ids_now[i + 1] + + if accepted: + if ctx.emit_token(i, step_input_ids_now[i + 1]): + break + else: + break + i += 1 + + # Phase 2: Sample for rejected/last position + if not ctx.stopped: + if verify_strategy == 0: + if candidate_lens_now is not None and len(candidate_lens_now) > i: + actual_cand_len = min(candidate_lens_now[i], max_candidate_len) + accept_token = topp_sampling_kernel( + candidate_ids_now[i * max_candidate_len : (i + 1) * max_candidate_len], + candidate_scores_now[i * max_candidate_len : (i + 1) * max_candidate_len], + dev_curand_states[i], + actual_cand_len, + topp[bid], + ) + else: + accept_token = int(step_input_ids_now[0]) + elif verify_strategy in (1, 2): + accept_token = ( + int(target_tokens_now[i]) + if target_tokens_now is not None and len(target_tokens_now) > i + else int(step_input_ids_now[0]) + ) + else: + accept_token = ( + int(candidate_ids_now[i * max_candidate_len]) + if candidate_ids_now is not None + else int(step_input_ids_now[0]) + ) + ctx.emit_final_token(i, accept_token) + + step_output_len[bid] = ctx.output_len_now + + return step_output_ids, step_output_len + + +# ============================================================ +# Input generation +# ============================================================ + + +def gen_verify_draft_tokens_inputs( + real_bsz: int = 32, + max_draft_tokens: int = 16, + max_seq_len: int = 256, + max_candidate_len: int = 8, + verify_window: int = 2, + end_length: int = 4, + verify_strategy: int = 1, + reject_all: bool = False, + accept_all: bool = False, + match_ratio: float = 0.0, + seed: int = 2025, +) -> Dict[str, Any]: + """Generate test inputs for verify_draft_tokens kernel. + + Args: + match_ratio: Fraction of draft token positions where target/candidates + are forced to match step_input_ids, so the acceptance path is exercised. + 0.0 = fully random (mostly rejects), 1.0 = all positions match. + """ + rng = np.random.default_rng(seed) + + seq_lens_encoder = np.zeros(real_bsz, dtype=np.int32) + seq_lens_this_time = rng.integers(1, max_draft_tokens + 1, size=real_bsz, dtype=np.int32) + step_input_ids = rng.integers(0, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64) + + sum_seq = int(np.sum(seq_lens_this_time)) + + if verify_strategy in (1, 2): # GREEDY / TARGET_MATCH + target_tokens = rng.integers(0, 1000, size=(sum_seq,), dtype=np.int64) + candidate_ids = None + candidate_scores = None + candidate_lens = None + else: # TOPP + target_tokens = None + candidate_ids = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64) + candidate_scores = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32) + candidate_scores = candidate_scores / candidate_scores.sum(axis=1, keepdims=True) + candidate_lens = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32) + + end_tokens = rng.integers(1, 1000, size=end_length, dtype=np.int64) + is_block_step = rng.integers(0, 2, size=real_bsz, dtype=bool) + + cu_seqlens_q_output = np.zeros(real_bsz + 1, dtype=np.int32) + for i in range(real_bsz): + cu_seqlens_q_output[i + 1] = cu_seqlens_q_output[i] + seq_lens_this_time[i] + cu_seqlens_q_output = cu_seqlens_q_output[:real_bsz].astype(np.int32) + + topp = rng.uniform(0.8, 1.0, size=real_bsz).astype(np.float32) + reasoning_status = np.zeros(real_bsz, dtype=np.int32) + step_output_ids = np.zeros((real_bsz, max_draft_tokens), dtype=np.int64) + step_output_len = np.zeros(real_bsz, dtype=np.int32) + stop_flags = np.zeros(real_bsz, dtype=bool) + + # Force match_ratio fraction of positions so acceptance path is tested + if match_ratio > 0.0: + offset = 0 + for bid in range(real_bsz): + slt = int(seq_lens_this_time[bid]) + n_match = max(1, int((slt - 1) * match_ratio)) # slt-1 verify positions + for pos in range(min(n_match, slt - 1)): + draft_token = int(step_input_ids[bid, pos + 1]) + # Ensure draft_token is not an end_token (would cause early stop) + while draft_token in end_tokens[:end_length]: + draft_token = (draft_token + 1) % 1000 + step_input_ids[bid, pos + 1] = draft_token + if verify_strategy in (1, 2) and target_tokens is not None: + target_tokens[offset + pos] = draft_token + elif verify_strategy == 0 and candidate_ids is not None: + candidate_ids[offset + pos, 0] = draft_token + candidate_lens[offset + pos] = max(candidate_lens[offset + pos], 1) + offset += slt + + return { + "step_output_ids": step_output_ids, + "step_output_len": step_output_len, + "step_input_ids": step_input_ids, + "target_tokens": target_tokens, + "candidate_ids": candidate_ids, + "candidate_scores": candidate_scores, + "candidate_lens": candidate_lens, + "topp": topp, + "stop_flags": stop_flags, + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_this_time": seq_lens_this_time, + "end_tokens": end_tokens, + "is_block_step": is_block_step, + "cu_seqlens_q_output": cu_seqlens_q_output, + "reasoning_status": reasoning_status, + "max_dec_len": rng.integers(50, 200, size=real_bsz, dtype=np.int64), + "step_idx": rng.integers(0, 30, size=real_bsz, dtype=np.int64), + "max_seq_len": max_seq_len, + "verify_window": verify_window, + "verify_strategy": verify_strategy, + "reject_all": reject_all, + "accept_all": accept_all, + } + + +# ============================================================ +# Test configs +# ============================================================ + +TEST_CONFIGS = [ + # --- strategy coverage (random, mostly rejects) --- + { + "name": "greedy_small_batch", + "real_bsz": 1, + "max_draft_tokens": 9, + "max_seq_len": 11, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 5, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + }, + { + "name": "greedy_medium_batch", + "real_bsz": 33, + "max_draft_tokens": 5, + "max_seq_len": 10111, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 6, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + }, + { + "name": "topp_small_batch", + "real_bsz": 6, + "max_draft_tokens": 4, + "max_seq_len": 10001, + "max_candidate_len": 6, + "verify_window": 2, + "end_length": 7, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + }, + { + "name": "target_match_medium", + "real_bsz": 7, + "max_draft_tokens": 3, + "max_seq_len": 777, + "max_candidate_len": 7, + "verify_window": 2, + "end_length": 5, + "verify_strategy": VerifyStrategy.TARGET_MATCH.value, + "seed": 42, + }, + { + "name": "greedy_large_batch", + "real_bsz": 55, + "max_draft_tokens": 5, + "max_seq_len": 31, + "max_candidate_len": 9, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + }, + # --- partial acceptance (match_ratio forces draft tokens to match target/candidates) --- + { + "name": "greedy_half_accept", + "real_bsz": 8, + "max_draft_tokens": 8, + "max_seq_len": 256, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + "match_ratio": 0.5, + }, + { + "name": "greedy_full_accept", + "real_bsz": 8, + "max_draft_tokens": 8, + "max_seq_len": 256, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + "match_ratio": 1.0, + }, + { + "name": "topp_half_accept", + "real_bsz": 8, + "max_draft_tokens": 8, + "max_seq_len": 256, + "max_candidate_len": 6, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + "match_ratio": 0.5, + }, + { + "name": "topp_full_accept", + "real_bsz": 8, + "max_draft_tokens": 8, + "max_seq_len": 256, + "max_candidate_len": 6, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + "match_ratio": 1.0, + }, + { + "name": "target_match_accept", + "real_bsz": 8, + "max_draft_tokens": 6, + "max_seq_len": 256, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TARGET_MATCH.value, + "seed": 42, + "match_ratio": 0.7, + }, + # --- reject_all / accept_all (kernel-level flags) --- + { + "name": "reject_all", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + "reject_all": True, + }, + { + "name": "accept_all", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + "accept_all": True, + }, + # --- edge cases --- + { + "name": "empty_batch", + "real_bsz": 1, + "max_draft_tokens": 1, + "max_seq_len": 10, + "max_candidate_len": 2, + "verify_window": 1, + "end_length": 4, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + }, +] + + +# ============================================================ +# Test suite +# ============================================================ + + +class TestVerifyDraftTokens(unittest.TestCase): + + def setUp(self): + if not paddle.is_compiled_with_cuda(): + self.skipTest("Requires CUDA") + + # ------ shared run + check helper ------ + + def _run_and_compare(self, inputs: Dict[str, Any], label: str = ""): + """Convert→run kernel→run ref→compare.""" + paddle_inputs = to_paddle_inputs(inputs) + run_kernel(paddle_inputs, inputs) + ids_ref, len_ref = run_ref(inputs) + compare_results(paddle_inputs, ids_ref, len_ref, inputs, label) + return paddle_inputs + + # ------ test cases ------ + + def test_verify_configs(self): + """Test all configs in TEST_CONFIGS (strategies, reject/accept, edge cases).""" + for cfg in TEST_CONFIGS: + with self.subTest(name=cfg["name"]): + test_cfg = {k: v for k, v in cfg.items() if k != "name"} + inputs = gen_verify_draft_tokens_inputs(**test_cfg) + self._run_and_compare(inputs, label=cfg["name"]) + + def test_eos_handling(self): + """Test EOS token in draft triggers early stop.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42 + ) + inputs["step_input_ids"][0, 2] = inputs["end_tokens"][0] + self._run_and_compare(inputs, label="eos_handling") + + def test_max_dec_len_truncation(self): + """Test max_dec_len causes token replacement with end_tokens[0].""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0 + ) + # Set step_idx close to max_dec_len so it triggers during verification + inputs["step_idx"][:] = [48, 10, 10, 10] + inputs["max_dec_len"][:] = [50, 200, 200, 200] + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + # Ensure no accidental EOS in draft tokens + for bid in range(4): + for j in range(5): + while inputs["step_input_ids"][bid, j] in inputs["end_tokens"]: + inputs["step_input_ids"][bid, j] = (inputs["step_input_ids"][bid, j] + 1) % 1000 + self._run_and_compare(inputs, label="max_dec_len_truncation") + + def test_verify_strategy_enum(self): + self.assertEqual(VerifyStrategy.TOPP.value, 0) + self.assertEqual(VerifyStrategy.GREEDY.value, 1) + self.assertEqual(VerifyStrategy.TARGET_MATCH.value, 2) + + def test_verify_strategy_from_string(self): + self.assertEqual(VerifyStrategy.from_string("topp"), VerifyStrategy.TOPP) + self.assertEqual(VerifyStrategy.from_string("TOPP"), VerifyStrategy.TOPP) + self.assertEqual(VerifyStrategy.from_string("greedy"), VerifyStrategy.GREEDY) + self.assertEqual(VerifyStrategy.from_string("target_match"), VerifyStrategy.TARGET_MATCH) + with self.assertRaises(ValueError): + VerifyStrategy.from_string("invalid") + + def test_topp_verify_window_fallback(self): + """Test TOPP verify_window fallback: top-2 match + consecutive top-1 matches.""" + real_bsz, max_draft_tokens, max_candidate_len, verify_window = 1, 8, 4, 2 + + inputs = gen_verify_draft_tokens_inputs( + real_bsz=real_bsz, + max_draft_tokens=max_draft_tokens, + verify_strategy=VerifyStrategy.TOPP.value, + max_candidate_len=max_candidate_len, + verify_window=verify_window, + seed=42, + ) + + # Rebuild arrays for full seq_lens_this_time + new_slt = max_draft_tokens + 1 + inputs["seq_lens_this_time"] = np.array([new_slt], dtype=np.int32) + inputs["cu_seqlens_q_output"] = np.array([0], dtype=np.int32) + + rng = np.random.default_rng(42) + sum_seq = new_slt + inputs["candidate_ids"] = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64) + inputs["candidate_scores"] = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32) + inputs["candidate_scores"] /= inputs["candidate_scores"].sum(axis=1, keepdims=True) + inputs["candidate_lens"] = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32) + + # Draft tokens + draft_tokens = [100, 200, 300, 400, 500, 600, 700] + for i, token in enumerate(draft_tokens): + inputs["step_input_ids"][0, i + 1] = token + + # Position 0: draft NOT in candidates, but top-2 matches draft + inputs["candidate_ids"][0] = [999, 100, 998, 997] + # Positions 1,2: top-1 matches next draft tokens + inputs["candidate_ids"][1] = [200, 888, 777, 666] + inputs["candidate_ids"][2] = [300, 555, 444, 333] + inputs["candidate_lens"][:3] = 4 + inputs["is_block_step"] = np.zeros(real_bsz, dtype=bool) + + self._run_and_compare(inputs, label="verify_window_fallback") + + def test_topp_verify_window_no_fallback(self): + """Test TOPP when verify_window fallback does NOT trigger.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=1, + max_draft_tokens=5, + verify_strategy=VerifyStrategy.TOPP.value, + max_candidate_len=4, + verify_window=2, + seed=42, + ) + + inputs["step_input_ids"][0, 1:] = [999, 998, 997, 996] + inputs["candidate_ids"][:] = 0 + inputs["candidate_ids"][0] = [1, 2, 3, 4] + inputs["candidate_lens"][0] = 4 + inputs["seq_lens_this_time"][0] = 5 + + self._run_and_compare(inputs, label="verify_window_no_fallback") + + +if __name__ == "__main__": + unittest.main()