From 2178f2829bafe7109183f4af95b303cee106ed08 Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Thu, 26 Feb 2026 11:42:05 +0800 Subject: [PATCH] [Speculative Decoding] Support suffix decoding (#6403) * support suffix decoding --- .../gpu_ops/update_attn_mask_offsets.cu | 12 +- docs/features/speculative_decoding.md | 36 ++- docs/zh/features/speculative_decoding.md | 36 ++- fastdeploy/config.py | 47 +++- .../cudagraph_piecewise_backend.py | 12 +- fastdeploy/output/token_processor.py | 2 + fastdeploy/spec_decode/__init__.py | 14 ++ fastdeploy/spec_decode/base.py | 24 ++ fastdeploy/spec_decode/suffix.py | 230 ++++++++++++++++++ fastdeploy/worker/gpu_model_runner.py | 41 +++- fastdeploy/worker/input_batch.py | 4 +- requirements.txt | 1 + .../test_cuda_graph_recapture.py | 3 + .../test_cuda_graph_spec_decode.py | 3 + .../test_graph_opt_backend.py | 3 + .../test_static_graph_cuda_graph_split.py | 3 + tests/operators/test_update_attn_mask.py | 5 +- tests/spec_decode/test_suffix_proposer.py | 141 +++++++++++ 18 files changed, 587 insertions(+), 30 deletions(-) create mode 100644 fastdeploy/spec_decode/suffix.py create mode 100644 tests/spec_decode/test_suffix_proposer.py diff --git a/custom_ops/gpu_ops/update_attn_mask_offsets.cu b/custom_ops/gpu_ops/update_attn_mask_offsets.cu index 7d9611c7bb..f928f7d1e4 100644 --- a/custom_ops/gpu_ops/update_attn_mask_offsets.cu +++ b/custom_ops/gpu_ops/update_attn_mask_offsets.cu @@ -66,13 +66,11 @@ __global__ void update_attn_mask_offsets_kernel( attn_mask_offsets_decoder[bid] += seq_len_this_time; // Speculative decoding in text_generation - if (seq_len_this_time > 1) { - for (int i = 0; i < decode_states_len; i++) { - if (i < seq_len_this_time) { - decode_states_now[i] = 0; - } else { - decode_states_now[i] = -1; - } + for (int i = 0; i < decode_states_len; i++) { + if (i < seq_len_this_time && decode_states_now[i] != 1) { + decode_states_now[i] = 0; + } else { + decode_states_now[i] = -1; } } } diff --git a/docs/features/speculative_decoding.md b/docs/features/speculative_decoding.md index ae9546c4c3..6a14d57dbb 100644 --- a/docs/features/speculative_decoding.md +++ b/docs/features/speculative_decoding.md @@ -12,6 +12,8 @@ This project implements an efficient **Speculative Decoding** inference framewor - **Ngram** +- **Suffix Decoding** + - **MTP (Multi-Token Prediction)** - ✅ Supported: TP Sharding - ✅ Supported: Shared Prefix @@ -52,7 +54,7 @@ This project implements an efficient **Speculative Decoding** inference framewor ## 🔧 Configuration Parameters -- `method`: The speculative decoding strategy, currently supports `["mtp", "ngram"]`. +- `method`: The speculative decoding strategy, currently supports `["mtp", "ngram", "suffix"]`. - `num_speculative_tokens`: Number of speculative tokens to generate; max is 5, currently MTP supports only 1. - `model`: Path to the MTP draft model when using the `"mtp"` method. - `quantization`: Quantization method of the MTP model (e.g., WINT4). @@ -162,3 +164,35 @@ python -m fastdeploy.entrypoints.openai.api_server \ --speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}' ``` + +## 🌲 Using Suffix Decoding + +Suffix Decoding is a model-free speculative decoding method that accelerates repetitive inference tasks (e.g., agent workflows, coding) using efficient CPU-based suffix trees for rapid draft token prediction, eliminating GPU overhead. + +Run on 4 × H100 GPUs with WINT4 quantization: + +> Config file: benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml + +``` +python -m fastdeploy.entrypoints.openai.api_server \ + --model ${path_to_main_model} \ + --tensor-parallel-size 4 \ + --config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \ + --speculative-config '{"method": "mtp", "num_speculative_tokens": 4, "suffix_decoding_max_tree_depth": 64, "suffix_decoding_max_cached_requests": 10000, "suffix_decoding_max_spec_factor": 1.0, "suffix_decoding_min_token_prob": 0.1}' +``` + +Parameter Descriptions + +``` +# The maximum length of token sequences cached in suffix trees. +self.suffix_decoding_max_tree_depth: int = 64 + +# The limits of requests that can be stored in the cache. +self.suffix_decoding_max_cached_requests: int = -1 + +# The factor of matched length, calculated as num_draft_tokens = suffix_max_spec_factor * matched_length +self.suffix_decoding_max_spec_factor: float = 1.0 + +# The probability threshold for speculated tokens. +self.suffix_decoding_min_token_prob: float = 0.1 +``` diff --git a/docs/zh/features/speculative_decoding.md b/docs/zh/features/speculative_decoding.md index 7e81b55fa4..c1401b0beb 100644 --- a/docs/zh/features/speculative_decoding.md +++ b/docs/zh/features/speculative_decoding.md @@ -8,6 +8,8 @@ - **Ngram** +- **后缀解码** + - **MTP (Multi-Token Prediction)** - ✅ 已支持:TP 切分 - ✅ 已支持:共享前缀 @@ -36,7 +38,7 @@ - **高效 DraftModel/MTP 框架**:开发多个融合 Cuda Kernel,统一完成模型类方法的前后处理,相比传统的循环、切片方法,性能高效且易维护 ## 🔧 参数说明 -- `method`: 解码策略,可选值为 `"mtp"` 或 `"ngram"` +- `method`: 解码策略,可选值为 `"mtp"` 、 `"ngram"` 或 `"suffix"` - `num_speculative_tokens`: 每轮预测的 Token 数,最大支持 5(当前 MTP 仅支持 1) - `model`: 若选择 MTP,则需指定 MTP 模型路径 - `quantization`: 模型量化方式,推荐使用 `wint8` @@ -133,3 +135,35 @@ python -m fastdeploy.entrypoints.openai.api_server \ --config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \ --speculative-config '{"method": "ngram", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}' ``` + +## 🌲 使用后缀解码 (Suffix Decoding) + +后缀解码是一种无模型推理框架,通过在 CPU 上使用高效后缀树进行快速草稿 Token 预测,加速重复性推理任务(如代理工作流程、编码等),消除 GPU 开销。 + +使用 4×H100;量化方式选择 WINT4: + +> 配置文件:benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml + +``` +python -m fastdeploy.entrypoints.openai.api_server \ + --model ${path_to_main_model} \ + --tensor-parallel-size 4 \ + --config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \ + --speculative-config '{"method": "mtp", "num_speculative_tokens": 4, "suffix_decoding_max_tree_depth": 64, "suffix_decoding_max_cached_requests": 10000, "suffix_decoding_max_spec_factor": 1.0, "suffix_decoding_min_token_prob": 0.1}' +``` + +参数描述 + +``` +# 后缀树中缓存的token序列的最大长度 +self.suffix_decoding_max_tree_depth: int = 64 + +# 缓存中可存储的请求数量上限 +self.suffix_decoding_max_cached_requests: int = -1 + +# 匹配长度的系数,计算方式为 num_draft_tokens = suffix_max_spec_factor * matched_length +self.suffix_decoding_max_spec_factor: float = 1.0 + +# 推测token的概率阈值 +self.suffix_decoding_min_token_prob: float = 0.1 +``` diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 853e84ae4f..332066eb1d 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -721,7 +721,7 @@ class SpeculativeConfig: self, args, ): - self.method_list = ["ngram_match", "mtp"] + self.method_list = ["ngram_match", "mtp", "suffix"] self.mtp_strategy_list = ["default", "with_ngram"] # speculative method, choose in [None, "ngram_match", "mtp", "hybrid_mtp_ngram"] @@ -739,6 +739,15 @@ class SpeculativeConfig: # ngram match self.max_ngram_size: int = 5 self.min_ngram_size: int = 2 + # Suffix Decoding + # The maximum length of token sequences cached in suffix trees. + self.suffix_decoding_max_tree_depth: int = 64 + # The limits of requests that can be stored in the cache. + self.suffix_decoding_max_cached_requests: int = -1 + # The factor of matched length, calculated as num_draft_tokens = suffix_max_spec_factor * matched_length + self.suffix_decoding_max_spec_factor: float = 1.0 + # The probability threshold for speculated tokens. + self.suffix_decoding_min_token_prob: float = 0.1 # model for mtp/eagle/draft_model self.model: Optional[str] = None # quantization of model @@ -939,6 +948,8 @@ class GraphOptimizationConfig: self.max_capture_size: int = None """ Record maps mapped from real shape to captured size to reduce runtime overhead """ self.real_shape_to_captured_size: dict[int, int] = None + """ Record maps mapped from real batch size to captured size""" + self.real_bsz_to_captured_size: dict[int, int] = {} """ Whether to use shared memory pool for multi capture_size """ self.use_unique_memory_pool: bool = True """ Whether to use cudagraph for draft model.""" @@ -1694,15 +1705,20 @@ class FDConfig: # Initialize cuda graph capture list max_capture_shape = self.scheduler_config.max_num_seqs - if self.speculative_config is not None and self.speculative_config.method == "mtp": + if self.speculative_config is not None and self.speculative_config.method in ["mtp", "suffix"]: max_capture_shape = self.scheduler_config.max_num_seqs * ( self.speculative_config.num_speculative_tokens + 1 ) assert max_capture_shape % 2 == 0, "CUDAGraph only supports capturing even token nums in MTP scenarios." + self.graph_opt_config.real_bsz_to_captured_size = { + k: 0 for k in range(1, self.scheduler_config.max_num_seqs + 1) + } if self.graph_opt_config.cudagraph_only_prefill: max_capture_shape = 512 else: - max_capture_shape = min(512, max_capture_shape) + max_capture_shape = ( + max_capture_shape if self.speculative_config is not None else min(512, max_capture_shape) + ) max_capture_shape_prefill = graph_opt_config.max_capture_shape_prefill @@ -1717,6 +1733,31 @@ class FDConfig: max_capture_shape_prefill=max_capture_shape_prefill, dec_token_per_query_per_step=dec_token_per_query_per_step, ) + if self.speculative_config is not None and self.speculative_config.method in ["mtp", "suffix"]: + real_bsz_to_captured_size = {} + for capture_size in self.graph_opt_config.cudagraph_capture_sizes: + dummy_batch_size = int(capture_size / (self.speculative_config.num_speculative_tokens + 1)) + real_bsz_to_captured_size[dummy_batch_size] = capture_size + + def expand_bsz_map(real_bsz_to_captured_size): + """ + Expand a sparse batch size mapping into a dense one. + + Args: + real_bsz_to_captured_size (dict): Sparse batch size to capture size mapping. + Returns: + dict: Dense batch size to capture size mapping. + """ + sorted_items = sorted(real_bsz_to_captured_size.items()) + result = {} + prev_bsz = 0 + for curr_bsz, cap in sorted_items: + for bsz in range(prev_bsz + 1, curr_bsz + 1): + result[bsz] = cap + prev_bsz = curr_bsz + return result + + self.graph_opt_config.real_bsz_to_captured_size = expand_bsz_map(real_bsz_to_captured_size) self.graph_opt_config.init_with_cudagrpah_size( max_capture_size=max_capture_shape, max_capture_shape_prefill=max_capture_shape_prefill, diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index e7e9c1b5af..952fc03845 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -119,6 +119,10 @@ class CudaGraphPiecewiseBackend: if self.fd_config.graph_opt_config.graph_opt_level > 0: self.cuda_graph_manager = Dy2StCudaGraphManager() + self.speculative_decoding = fd_config.speculative_config.method is not None + self.max_num_seqs = fd_config.scheduler_config.max_num_seqs + self.real_bsz_to_captured_size = fd_config.graph_opt_config.real_bsz_to_captured_size + def run_static_model(self, entry: ConcreteSizeEntry, **kwargs): if not entry.captured: @@ -153,7 +157,10 @@ class CudaGraphPiecewiseBackend: # Get real shape (total num tokens) ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding real_shape = ids_remove_padding.shape[0] - + if self.speculative_decoding and all(self.real_bsz_to_captured_size.values()): + seq_lens_this_time: paddle.Tensor = kwargs["forward_meta"].seq_lens_this_time + num_running_requests = seq_lens_this_time.squeeze(axis=-1).nonzero(as_tuple=False)[-1].item() + 1 + real_shape = self.real_bsz_to_captured_size[num_running_requests] exist_prefill = kwargs["forward_meta"].exist_prefill # Static split graph mode: use Static + CUDAGraph for prefill/mixed phase static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st @@ -188,6 +195,9 @@ class CudaGraphPiecewiseBackend: # Capture a new cuda graph if entry.cuda_graph is None: + assert ( + real_shape == padding_real_shape + ), f"real_shape:{real_shape} is not equal to padding_real_shape:{padding_real_shape} when capture new graph." # Warmup the model for n in range(entry.num_finished_warmup, self.warm_up_size): entry.num_finished_warmup += 1 diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 8f2be06b1c..382b43d683 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -787,6 +787,8 @@ class TokenProcessor: + i * MAX_DRAFT_TOKENS + accept_num[i] ].tolist() + if accept_num[i] == 0: + continue else: token_id = int(tokens[i, 0]) token_ids = [token_id] diff --git a/fastdeploy/spec_decode/__init__.py b/fastdeploy/spec_decode/__init__.py index 086b5003a0..0b675b826a 100644 --- a/fastdeploy/spec_decode/__init__.py +++ b/fastdeploy/spec_decode/__init__.py @@ -23,3 +23,17 @@ from .mtp import MTPProposer if not current_platform.is_xpu(): from .ngram import NgramProposer __all__ = ["Proposer", "MTPProposer", "NgramProposer"] + +# Suffix proposer requires arctic_inference +try: + from .suffix import SuffixProposer + + _suffix_proposer_available = True +except ImportError: + _suffix_proposer_available = False + SuffixProposer = None + +if _suffix_proposer_available: + __all__ = ["Proposer", "MTPProposer", "NgramProposer", "SuffixProposer"] +else: + __all__ = ["Proposer", "MTPProposer", "NgramProposer"] diff --git a/fastdeploy/spec_decode/base.py b/fastdeploy/spec_decode/base.py index b7611f6864..6499b35827 100644 --- a/fastdeploy/spec_decode/base.py +++ b/fastdeploy/spec_decode/base.py @@ -96,3 +96,27 @@ class Proposer(ABC): bool: True if chunk prefill is enabled; False otherwise. """ return False + + def prepare_dummy_speculative_drafts( + self, + share_inputs, + batch_size: int, + ) -> None: + """ + Construct a set of dummy draft tokens for CUDAGraph capture scenarios, + used only to stabilize shape/step count, with no requirement for semantic correctness. + + Args: + share_inputs: share_inputs dict maintained by GPUModelRunner + batch_size: current batch_size for dummy_run + expected_decode_len: expected number of decode steps (must match what's passed to _dummy_run) + """ + + max_fake_drafts = self.max_draft_token_num + + stop = share_inputs["stop_flags"][0].item() + if not stop: + share_inputs["draft_tokens"][:batch_size, :max_fake_drafts] = 5 + share_inputs["seq_lens_this_time"][:batch_size] = max_fake_drafts + 1 + else: + share_inputs["seq_lens_this_time"][:batch_size] = 0 diff --git a/fastdeploy/spec_decode/suffix.py b/fastdeploy/spec_decode/suffix.py new file mode 100644 index 0000000000..3f2f4586d8 --- /dev/null +++ b/fastdeploy/spec_decode/suffix.py @@ -0,0 +1,230 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import numpy as np + +from fastdeploy.config import FDConfig +from fastdeploy.utils import spec_logger + +from .base import Proposer + +try: + from arctic_inference.suffix_decoding import SuffixDecodingCache +except ImportError: + SuffixDecodingCache = None + + +class SuffixProposer(Proposer): + """ + Proposer for Suffix Decoding method. + + Uses SuffixDecodingCache to generate draft tokens based on suffix tree matching. + """ + + def __init__(self, fd_config: FDConfig): + super().__init__(fd_config) + + if SuffixDecodingCache is None: + raise ImportError( + "arctic_inference.suffix_decoding is not available. Please install arctic-inference package." + ) + + # Initialize SuffixDecodingCache + self.suffix_cache = SuffixDecodingCache( + max_tree_depth=self.speculative_config.suffix_decoding_max_tree_depth, + max_cached_requests=self.speculative_config.suffix_decoding_max_cached_requests, + ) + + self.max_tree_depth = self.speculative_config.suffix_decoding_max_tree_depth + self.max_spec_factor = self.speculative_config.suffix_decoding_max_spec_factor + self.min_token_prob = self.speculative_config.suffix_decoding_min_token_prob + + # Track active requests: req_id -> idx mapping + self.req_id_to_idx = {} + self.idx_to_req_id = {} + self.context_tokens = np.full( + (self.max_num_seqs, self.max_model_len), + -1, + dtype=np.int32, + ) + self.ban_tokens = set([101031, 101032, 101033]) + + def _update_request_mapping(self, idx: int, req_id: str): + """ + Update the mapping between request ID and batch index. + + Args: + req_id: Request identifier + idx: Batch index + """ + # Clean up old mapping if exists + if idx in self.idx_to_req_id: + old_req_id = self.idx_to_req_id[idx] + if old_req_id in self.req_id_to_idx: + del self.req_id_to_idx[old_req_id] + + # Set new mapping + self.req_id_to_idx[req_id] = idx + self.idx_to_req_id[idx] = req_id + + def start_request(self, idx: int, req_id: str, prompt_token_ids: list[int]): + """ + Start a new request in the suffix cache. + + Args: + req_id: Request identifier + prompt_token_ids: List of prompt token IDs + """ + if req_id in self.suffix_cache.active_requests: + # Request already active, skip + return + + prompt_array = np.array(prompt_token_ids, dtype=np.int32) + if not prompt_array.flags["CONTIGUOUS"]: + prompt_array = np.ascontiguousarray(prompt_array) + + self.context_tokens[idx, :] = -1 + self.context_tokens[idx, : len(prompt_token_ids)] = prompt_array + self._update_request_mapping(idx, req_id) + if req_id not in self.suffix_cache.active_requests: + if req_id in self.suffix_cache.cached_requests: + # Reset the suffix cache for current req_id + self.suffix_cache.evict_cached_response(req_id) + spec_logger.debug(f"[SuffixDecoding] Reset suffix cache for request {req_id}.") + self.suffix_cache.start_request(req_id, prompt_array) + spec_logger.debug(f"[SuffixDecoding] Start request {req_id}.") + + def stop_request(self, req_id: str): + """ + Stop a request in the suffix cache. + + Args: + req_id: Request identifier + """ + if req_id in self.suffix_cache.active_requests: + self.suffix_cache.stop_request(req_id) + + # Clean up mappings + if req_id in self.req_id_to_idx: + idx = self.req_id_to_idx[req_id] + + del self.req_id_to_idx[req_id] + if idx in self.idx_to_req_id: + del self.idx_to_req_id[idx] + + spec_logger.debug(f"[SuffixDecoding] Stop request {req_id}.") + + def add_active_response(self, req_id: str, token_ids: list[int]): + """ + Add newly sampled tokens to the suffix cache for a request. + + Args: + req_id: Request identifier + token_ids: List of newly sampled token IDs + """ + if req_id not in self.suffix_cache.active_requests: + return + + token_array = np.array(token_ids, dtype=np.int32) + if not token_array.flags["CONTIGUOUS"]: + token_array = np.ascontiguousarray(token_array) + + self.suffix_cache.add_active_response(req_id, token_array) + + def _run_impl(self, share_inputs): + + stop_flags_cpu = share_inputs["stop_flags"].cpu().numpy().flatten() + is_block_step_cpu = share_inputs["is_block_step"].cpu().numpy().flatten() + accept_tokens_cpu = share_inputs["accept_tokens"].cpu() + accept_num_cpu = share_inputs["accept_num"].cpu().numpy().flatten() + seq_lens_encoder = share_inputs["seq_lens_encoder"].cpu().numpy().flatten() + seq_lens_decoder = share_inputs["seq_lens_decoder"].cpu().numpy().flatten() + + draft_tokens_cpu = share_inputs["draft_tokens"].cpu() + seq_lens_this_time_cpu = share_inputs["seq_lens_this_time"].cpu() + + total_lens = seq_lens_encoder + seq_lens_decoder + batch_size = seq_lens_this_time_cpu.shape[0] + + for bid in range(batch_size): + req_id = self.idx_to_req_id.get(bid) + # 1. Stop condition has the highest priority + if stop_flags_cpu[bid]: + seq_lens_this_time_cpu[bid] = 0 + draft_tokens_cpu[bid, :] = -1 + if not is_block_step_cpu[bid]: + if req_id is not None and req_id in self.suffix_cache.active_requests: + self.stop_request(req_id) + continue + else: + seq_lens_this_time_cpu[bid] = 1 + draft_tokens_cpu[bid, 1:] = -1 + # 2. Skip some cases + num_tokens = total_lens[bid] + max_spec_tokens = min( + self.max_draft_token_num, + self.max_model_len - num_tokens - 1, + ) + if max_spec_tokens <= 1: + continue + if req_id is None: + continue + + # 3. Add accept tokens to context + acc_num = int(accept_num_cpu[bid]) + assert ( + acc_num > 0 + ), f"Request {req_id} (bid {bid}) must have at least one accepted token, but got {acc_num}." + if acc_num > 0: + token_ids = accept_tokens_cpu[bid, :acc_num] + ctx_start = seq_lens_decoder[bid] - acc_num + self.context_tokens[bid, ctx_start : ctx_start + acc_num] = token_ids + self.add_active_response(req_id, token_ids) + + # 4. Get context + start = max(0, num_tokens - self.max_tree_depth) + ctx = self.context_tokens[bid, start:num_tokens] + ctx = ctx[ctx >= 0] + if ctx.size == 0: + continue + if not ctx.flags["CONTIGUOUS"]: + ctx = np.ascontiguousarray(ctx, dtype=np.int32) + else: + ctx = ctx.astype(np.int32, copy=False) + + # 5. Speculate + draft = self.suffix_cache.speculate( + req_id, + ctx, + max_spec_tokens=max_spec_tokens, + max_spec_factor=self.max_spec_factor, + min_token_prob=self.min_token_prob, + ) + token_ids = draft.token_ids + + counter = 0 + for token in token_ids: + if token in self.ban_tokens: + break + else: + counter += 1 + if counter > 0: + draft_tokens_cpu[bid, 1 : 1 + counter] = np.array(token_ids[:counter]) + draft_tokens_cpu[bid, 1 + counter :] = -1 + seq_lens_this_time_cpu[bid] = 1 + counter + + share_inputs["draft_tokens"][:] = draft_tokens_cpu.cuda() + share_inputs["seq_lens_this_time"][:] = seq_lens_this_time_cpu.cuda() diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 0191eaec0e..a635f77840 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -91,7 +91,7 @@ from fastdeploy.model_executor.pre_and_post_process import ( ) if not (current_platform.is_dcu() or current_platform.is_iluvatar()): - from fastdeploy.spec_decode import MTPProposer, NgramProposer + from fastdeploy.spec_decode import MTPProposer, NgramProposer, SuffixProposer import zmq @@ -390,6 +390,8 @@ class GPUModelRunner(ModelRunnerBase): self.device_id, self.share_inputs, ) + elif self.speculative_method == "suffix": + self.proposer = SuffixProposer(self.fd_config) else: self.proposer = None @@ -810,6 +812,13 @@ class GPUModelRunner(ModelRunnerBase): self.forward_batch_reqs_list[idx] = request has_prefill_task = True + if self.speculative_decoding and self.speculative_method == "suffix" and self.proposer is not None: + if isinstance(request.prompt_token_ids, np.ndarray): + prompt_token_ids = request.prompt_token_ids.tolist() + else: + prompt_token_ids = request.prompt_token_ids + self.proposer.start_request(idx, request.request_id, prompt_token_ids) + # Routing Replay if self.fd_config.routing_replay_config.enable_routing_replay: if prefill_start_index == 0: @@ -1063,6 +1072,15 @@ class GPUModelRunner(ModelRunnerBase): else: return default_value + # Start suffix decoding request if using suffix proposer + if self.speculative_decoding and self.speculative_method == "suffix" and self.proposer is not None: + if isinstance(request.prompt_token_ids, np.ndarray): + prompt_token_ids = request.prompt_token_ids.tolist() + else: + prompt_token_ids = request.prompt_token_ids + self.proposer.start_request(request.request_id, prompt_token_ids) + self.proposer.update_request_mapping(request.request_id, idx) + assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7) @@ -1768,6 +1786,7 @@ class GPUModelRunner(ModelRunnerBase): self, hidden_states: paddle.Tensor, model_output: paddle.Tensor, + batch_size: int, accept_all_drafts=False, reject_all_drafts=False, ) -> paddle.Tensor: @@ -1876,8 +1895,7 @@ class GPUModelRunner(ModelRunnerBase): is_dummy_run=True, ) else: - self.proposer.run(share_inputs=self.share_inputs) - + self.proposer.prepare_dummy_speculative_drafts(share_inputs=self.share_inputs, batch_size=batch_size) return sampler_output def _dummy_run( @@ -1954,7 +1972,7 @@ class GPUModelRunner(ModelRunnerBase): (self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None), (self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None), ) - self._dummy_sampler_run(hidden_states, model_output, accept_all_drafts, reject_all_drafts) + self._dummy_sampler_run(hidden_states, model_output, batch_size, accept_all_drafts, reject_all_drafts) # 7. Updata 'infer_seed' and step_cuda() self.share_inputs["infer_seed"].add_(self.infer_seed_increment) @@ -2064,22 +2082,19 @@ class GPUModelRunner(ModelRunnerBase): logger.info( f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" ) - elif self.speculative_decoding and self.speculative_method == "mtp": + elif self.speculative_decoding: # Capture Target Model without bsz 1 for capture_size in sorted(capture_sizes, reverse=True): + expected_decode_len = self.speculative_config.num_speculative_tokens * 2 + 1 self._dummy_run( - num_tokens=( - self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1) - if self.scheduler_config.splitwise_role == "decode" - else self.fd_config.get_max_chunk_tokens() - ), + num_tokens=self.fd_config.get_max_chunk_tokens(), batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)), in_capturing=True, - expected_decode_len=self.speculative_config.num_speculative_tokens * 2 + 1, + expected_decode_len=expected_decode_len, accept_all_drafts=True, ) logger.info( - f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}" + f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{expected_decode_len}" ) else: for batch_size in sorted(capture_sizes, reverse=True): @@ -2536,6 +2551,8 @@ class GPUModelRunner(ModelRunnerBase): self.proposer.run( full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph ) + elif self.speculative_method == "suffix": + self.proposer.run(share_inputs=self.share_inputs) else: self.proposer.run(share_inputs=self.share_inputs) diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 92d03f2071..eb600e28e6 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -267,7 +267,7 @@ class InputBatch: self.accept_num = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32") self.draft_tokens = paddle.full( shape=[max_num_seqs, max_draft_token_num + 1], - fill_value=-1, + fill_value=0, dtype="int64", ) @@ -293,7 +293,7 @@ class InputBatch: # For V1_KVCACHE_SCHEDULER self.step_draft_tokens = paddle.full( shape=[max_num_seqs, max_draft_token_num + 1], - fill_value=0, + fill_value=-1, dtype="int64", ) self.step_seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32") diff --git a/requirements.txt b/requirements.txt index 1edb903127..d3af47625c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,3 +48,4 @@ p2pstore py-cpuinfo flashinfer-python-paddle flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl +arctic_inference @ https://paddle-qa.bj.bcebos.com/ernie/arctic_inference-0.1.3-cp310-cp310-linux_x86_64.whl diff --git a/tests/graph_optimization/test_cuda_graph_recapture.py b/tests/graph_optimization/test_cuda_graph_recapture.py index 42966db900..a7640c5700 100644 --- a/tests/graph_optimization/test_cuda_graph_recapture.py +++ b/tests/graph_optimization/test_cuda_graph_recapture.py @@ -25,6 +25,7 @@ from fastdeploy.config import ( GraphOptimizationConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, ) from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( @@ -110,6 +111,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase): cache_config = CacheConfig(args={}) scheduler_config.max_num_seqs = 1 parallel_config = ParallelConfig(args={}) + speculative_config = SpeculativeConfig(args={}) model_config = Mock() model_config.max_model_len = 5120 model_config.architectures = ["test_model"] @@ -120,6 +122,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase): cache_config=cache_config, model_config=model_config, parallel_config=parallel_config, + speculative_config=speculative_config, ) # Run Test Case1 diff --git a/tests/graph_optimization/test_cuda_graph_spec_decode.py b/tests/graph_optimization/test_cuda_graph_spec_decode.py index e903590fe8..6ebbac4f77 100644 --- a/tests/graph_optimization/test_cuda_graph_spec_decode.py +++ b/tests/graph_optimization/test_cuda_graph_spec_decode.py @@ -25,6 +25,7 @@ from fastdeploy.config import ( GraphOptimizationConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, ) from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( @@ -103,6 +104,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase): scheduler_config.max_num_seqs = 1 cache_config = CacheConfig({}) parallel_config = ParallelConfig(args={}) + speculative_config = SpeculativeConfig(args={}) model_config = Mock() model_config.max_model_len = 512 model_config.architectures = ["test_model"] @@ -116,6 +118,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase): cache_config=cache_config, parallel_config=parallel_config, model_config=model_config, + speculative_config=speculative_config, test_mode=True, ) diff --git a/tests/graph_optimization/test_graph_opt_backend.py b/tests/graph_optimization/test_graph_opt_backend.py index d85d94dc5f..f4b19c24bd 100644 --- a/tests/graph_optimization/test_graph_opt_backend.py +++ b/tests/graph_optimization/test_graph_opt_backend.py @@ -27,6 +27,7 @@ from fastdeploy.config import ( GraphOptimizationConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, ) from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( @@ -144,6 +145,7 @@ class TestGraphOptBackend(unittest.TestCase): # Setup cache config cache_config = CacheConfig({}) parallel_config = ParallelConfig(args={}) + speculative_config = SpeculativeConfig(args={}) model_config = Mock() model_config.max_model_len = 512 model_config.architectures = ["test_model"] @@ -156,6 +158,7 @@ class TestGraphOptBackend(unittest.TestCase): cache_config=cache_config, parallel_config=parallel_config, model_config=model_config, + speculative_config=speculative_config, test_mode=True, ) diff --git a/tests/graph_optimization/test_static_graph_cuda_graph_split.py b/tests/graph_optimization/test_static_graph_cuda_graph_split.py index 55c11d8da6..e7c4d5b32a 100644 --- a/tests/graph_optimization/test_static_graph_cuda_graph_split.py +++ b/tests/graph_optimization/test_static_graph_cuda_graph_split.py @@ -31,6 +31,7 @@ from fastdeploy.config import ( GraphOptimizationConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, ) from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( @@ -97,6 +98,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase): graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs) cache_config = CacheConfig({}) parallel_config = ParallelConfig(args={}) + speculative_config = SpeculativeConfig(args={}) model_config = Mock() model_config.max_model_len = 512 model_config.architectures = ["test_model"] @@ -107,6 +109,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase): cache_config=cache_config, parallel_config=parallel_config, model_config=model_config, + speculative_config=speculative_config, test_mode=True, ) diff --git a/tests/operators/test_update_attn_mask.py b/tests/operators/test_update_attn_mask.py index 01b4a56d74..60c17f3aa2 100644 --- a/tests/operators/test_update_attn_mask.py +++ b/tests/operators/test_update_attn_mask.py @@ -112,9 +112,8 @@ def py_update_attn_mask_offsets_op( attn_mask_offsets_decoder[bid] = int(attn_mask_offsets_decoder[bid]) + seq_len_this # speculative decoding: if seq_len_this > 1 then set decode_states_now[i] accordingly - if seq_len_this > 1: - for i in range(decode_states_len): - decode_now[i] = 0 if i < seq_len_this else -1 + for i in range(decode_states_len): + decode_now[i] = 0 if i < seq_len_this and decode_now[i] != 1 else -1 # done decoder branch continue diff --git a/tests/spec_decode/test_suffix_proposer.py b/tests/spec_decode/test_suffix_proposer.py new file mode 100644 index 0000000000..4faf772346 --- /dev/null +++ b/tests/spec_decode/test_suffix_proposer.py @@ -0,0 +1,141 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +import numpy as np +import paddle +from utils import FakeModelConfig, get_default_test_fd_config + +from fastdeploy.config import SpeculativeConfig +from fastdeploy.spec_decode.suffix import SuffixProposer + + +class TestSuffixProposer(unittest.TestCase): + def setUp(self): + self.fd_config = get_default_test_fd_config() + self.fd_config.model_config = FakeModelConfig() + self.fd_config.model_config.max_model_len = 2048 + self.fd_config.speculative_config = SpeculativeConfig({}) + self.fd_config.speculative_config.method = "suffix" + self.fd_config.speculative_config.num_speculative_tokens = 4 + self.fd_config.speculative_config.suffix_decoding_max_tree_depth = 64 + self.fd_config.speculative_config.suffix_decoding_max_cached_requests = 4 + self.fd_config.speculative_config.suffix_decoding_max_spec_factor = 1.0 + self.fd_config.speculative_config.suffix_decoding_min_token_prob = 0.1 + self.fd_config.scheduler_config.max_num_seqs = 4 + + bsz = self.fd_config.scheduler_config.max_num_seqs + max_draft_tokens = self.fd_config.speculative_config.num_speculative_tokens + self.share_inputs = { + "stop_flags": paddle.full([bsz, 1], fill_value=False, dtype="bool"), + "is_block_step": paddle.full([bsz], fill_value=False, dtype="bool"), + "accept_tokens": paddle.zeros([bsz, max_draft_tokens], dtype="int64"), + "accept_num": paddle.zeros([bsz], dtype="int32"), + "seq_lens_this_time": paddle.zeros([bsz, 1], dtype="int32"), + "seq_lens_encoder": paddle.zeros([bsz, 1], dtype="int32"), + "seq_lens_decoder": paddle.zeros([bsz, 1], dtype="int32"), + "draft_tokens": paddle.zeros([bsz, max_draft_tokens], dtype="int64"), + } + + def test_start_and_stop_request(self): + proposer = SuffixProposer(self.fd_config) + + idx = 0 + req_id = "req-001" + prompt_token_ids = [1, 2, 3, 4] + proposer.start_request(idx, req_id, prompt_token_ids) + + refs_context_tokens = np.full( + (self.fd_config.scheduler_config.max_num_seqs, self.fd_config.model_config.max_model_len), + -1, + dtype=np.int32, + ) + refs_req_id_to_idx = {} + refs_idx_to_req_id = {} + refs_context_tokens[idx, : len(prompt_token_ids)] = prompt_token_ids + refs_req_id_to_idx[req_id] = idx + refs_idx_to_req_id[idx] = req_id + + self.assertIsNotNone(proposer.suffix_cache) + np.testing.assert_array_equal(proposer.context_tokens, refs_context_tokens) + np.testing.assert_array_equal(proposer.req_id_to_idx, refs_req_id_to_idx) + np.testing.assert_array_equal(proposer.idx_to_req_id, refs_idx_to_req_id) + + idx = 1 + req_id = "req-002" + prompt_token_ids = [5, 6, 7, 8] + proposer.start_request(idx, req_id, prompt_token_ids) + + refs_context_tokens[idx, : len(prompt_token_ids)] = prompt_token_ids + refs_req_id_to_idx[req_id] = idx + refs_idx_to_req_id[idx] = req_id + + np.testing.assert_array_equal(proposer.context_tokens, refs_context_tokens) + np.testing.assert_array_equal(proposer.req_id_to_idx, refs_req_id_to_idx) + np.testing.assert_array_equal(proposer.idx_to_req_id, refs_idx_to_req_id) + + proposer.stop_request("req-001") + + refs_req_id_to_idx.pop("req-001") + refs_idx_to_req_id.pop(0) + + np.testing.assert_array_equal(proposer.context_tokens, refs_context_tokens) + np.testing.assert_array_equal(proposer.req_id_to_idx, refs_req_id_to_idx) + np.testing.assert_array_equal(proposer.idx_to_req_id, refs_idx_to_req_id) + + def test_propose(self): + + self.share_inputs["accept_tokens"][:, :2] = 42 + self.share_inputs["accept_num"][:] = 2 + self.share_inputs["seq_lens_this_time"][:, :] = 2 + self.share_inputs["seq_lens_encoder"][:, :] = 0 + self.share_inputs["seq_lens_decoder"][:, :] = 100 + self.share_inputs["draft_tokens"][:, :2] = 42 + self.share_inputs["draft_tokens"][:, 2:] = 53 + print(self.share_inputs) + + proposer = SuffixProposer(self.fd_config) + ids = [0, 1, 2, 3] + req_ids = ["req-001", "req-002", "req-003", "req-004"] + prompt_token_ids_list = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + for idx, req_id, prompt_token_ids in zip(ids, req_ids, prompt_token_ids_list): + proposer.start_request(idx, req_id, prompt_token_ids) + + proposer.run(self.share_inputs) + + refs_draft_tokens = np.array( + [ + [42, 42, -1, -1], + [42, 42, -1, -1], + [42, 42, -1, -1], + [42, 42, -1, -1], + ], + dtype=np.int64, + ) + refs_seq_lens_this_time = np.array([[2], [2], [2], [2]], dtype=np.int32) + + np.testing.assert_array_equal(self.share_inputs["draft_tokens"].numpy(), refs_draft_tokens) + np.testing.assert_array_equal(self.share_inputs["seq_lens_this_time"].numpy(), refs_seq_lens_this_time) + + +if __name__ == "__main__": + unittest.main()