[Speculative Decoding] Support suffix decoding (#6403)

* support suffix decoding
This commit is contained in:
GoldPancake
2026-02-26 11:42:05 +08:00
committed by GitHub
parent 6d3fede240
commit 2178f2829b
18 changed files with 587 additions and 30 deletions
@@ -66,13 +66,11 @@ __global__ void update_attn_mask_offsets_kernel(
attn_mask_offsets_decoder[bid] += seq_len_this_time; attn_mask_offsets_decoder[bid] += seq_len_this_time;
// Speculative decoding in text_generation // Speculative decoding in text_generation
if (seq_len_this_time > 1) { for (int i = 0; i < decode_states_len; i++) {
for (int i = 0; i < decode_states_len; i++) { if (i < seq_len_this_time && decode_states_now[i] != 1) {
if (i < seq_len_this_time) { decode_states_now[i] = 0;
decode_states_now[i] = 0; } else {
} else { decode_states_now[i] = -1;
decode_states_now[i] = -1;
}
} }
} }
} }
+35 -1
View File
@@ -12,6 +12,8 @@ This project implements an efficient **Speculative Decoding** inference framewor
- **Ngram** - **Ngram**
- **Suffix Decoding**
- **MTP (Multi-Token Prediction)** - **MTP (Multi-Token Prediction)**
- ✅ Supported: TP Sharding - ✅ Supported: TP Sharding
- ✅ Supported: Shared Prefix - ✅ Supported: Shared Prefix
@@ -52,7 +54,7 @@ This project implements an efficient **Speculative Decoding** inference framewor
## 🔧 Configuration Parameters ## 🔧 Configuration Parameters
- `method`: The speculative decoding strategy, currently supports `["mtp", "ngram"]`. - `method`: The speculative decoding strategy, currently supports `["mtp", "ngram", "suffix"]`.
- `num_speculative_tokens`: Number of speculative tokens to generate; max is 5, currently MTP supports only 1. - `num_speculative_tokens`: Number of speculative tokens to generate; max is 5, currently MTP supports only 1.
- `model`: Path to the MTP draft model when using the `"mtp"` method. - `model`: Path to the MTP draft model when using the `"mtp"` method.
- `quantization`: Quantization method of the MTP model (e.g., WINT4). - `quantization`: Quantization method of the MTP model (e.g., WINT4).
@@ -162,3 +164,35 @@ python -m fastdeploy.entrypoints.openai.api_server \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}' --speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}'
``` ```
## 🌲 Using Suffix Decoding
Suffix Decoding is a model-free speculative decoding method that accelerates repetitive inference tasks (e.g., agent workflows, coding) using efficient CPU-based suffix trees for rapid draft token prediction, eliminating GPU overhead.
Run on 4 × H100 GPUs with WINT4 quantization:
> Config file: benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml
```
python -m fastdeploy.entrypoints.openai.api_server \
--model ${path_to_main_model} \
--tensor-parallel-size 4 \
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 4, "suffix_decoding_max_tree_depth": 64, "suffix_decoding_max_cached_requests": 10000, "suffix_decoding_max_spec_factor": 1.0, "suffix_decoding_min_token_prob": 0.1}'
```
Parameter Descriptions
```
# The maximum length of token sequences cached in suffix trees.
self.suffix_decoding_max_tree_depth: int = 64
# The limits of requests that can be stored in the cache.
self.suffix_decoding_max_cached_requests: int = -1
# The factor of matched length, calculated as num_draft_tokens = suffix_max_spec_factor * matched_length
self.suffix_decoding_max_spec_factor: float = 1.0
# The probability threshold for speculated tokens.
self.suffix_decoding_min_token_prob: float = 0.1
```
+35 -1
View File
@@ -8,6 +8,8 @@
- **Ngram** - **Ngram**
- **后缀解码**
- **MTP (Multi-Token Prediction)** - **MTP (Multi-Token Prediction)**
- ✅ 已支持:TP 切分 - ✅ 已支持:TP 切分
- ✅ 已支持:共享前缀 - ✅ 已支持:共享前缀
@@ -36,7 +38,7 @@
- **高效 DraftModel/MTP 框架**:开发多个融合 Cuda Kernel,统一完成模型类方法的前后处理,相比传统的循环、切片方法,性能高效且易维护 - **高效 DraftModel/MTP 框架**:开发多个融合 Cuda Kernel,统一完成模型类方法的前后处理,相比传统的循环、切片方法,性能高效且易维护
## 🔧 参数说明 ## 🔧 参数说明
- `method`: 解码策略,可选值为 `"mtp"` `"ngram"` - `method`: 解码策略,可选值为 `"mtp"` `"ngram"``"suffix"`
- `num_speculative_tokens`: 每轮预测的 Token 数,最大支持 5(当前 MTP 仅支持 1) - `num_speculative_tokens`: 每轮预测的 Token 数,最大支持 5(当前 MTP 仅支持 1)
- `model`: 若选择 MTP,则需指定 MTP 模型路径 - `model`: 若选择 MTP,则需指定 MTP 模型路径
- `quantization`: 模型量化方式,推荐使用 `wint8` - `quantization`: 模型量化方式,推荐使用 `wint8`
@@ -133,3 +135,35 @@ python -m fastdeploy.entrypoints.openai.api_server \
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \ --config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
--speculative-config '{"method": "ngram", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}' --speculative-config '{"method": "ngram", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}'
``` ```
## 🌲 使用后缀解码 (Suffix Decoding)
后缀解码是一种无模型推理框架,通过在 CPU 上使用高效后缀树进行快速草稿 Token 预测,加速重复性推理任务(如代理工作流程、编码等),消除 GPU 开销。
使用 4×H100;量化方式选择 WINT4:
> 配置文件:benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml
```
python -m fastdeploy.entrypoints.openai.api_server \
--model ${path_to_main_model} \
--tensor-parallel-size 4 \
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 4, "suffix_decoding_max_tree_depth": 64, "suffix_decoding_max_cached_requests": 10000, "suffix_decoding_max_spec_factor": 1.0, "suffix_decoding_min_token_prob": 0.1}'
```
参数描述
```
# 后缀树中缓存的token序列的最大长度
self.suffix_decoding_max_tree_depth: int = 64
# 缓存中可存储的请求数量上限
self.suffix_decoding_max_cached_requests: int = -1
# 匹配长度的系数,计算方式为 num_draft_tokens = suffix_max_spec_factor * matched_length
self.suffix_decoding_max_spec_factor: float = 1.0
# 推测token的概率阈值
self.suffix_decoding_min_token_prob: float = 0.1
```
+44 -3
View File
@@ -721,7 +721,7 @@ class SpeculativeConfig:
self, self,
args, args,
): ):
self.method_list = ["ngram_match", "mtp"] self.method_list = ["ngram_match", "mtp", "suffix"]
self.mtp_strategy_list = ["default", "with_ngram"] self.mtp_strategy_list = ["default", "with_ngram"]
# speculative method, choose in [None, "ngram_match", "mtp", "hybrid_mtp_ngram"] # speculative method, choose in [None, "ngram_match", "mtp", "hybrid_mtp_ngram"]
@@ -739,6 +739,15 @@ class SpeculativeConfig:
# ngram match # ngram match
self.max_ngram_size: int = 5 self.max_ngram_size: int = 5
self.min_ngram_size: int = 2 self.min_ngram_size: int = 2
# Suffix Decoding
# The maximum length of token sequences cached in suffix trees.
self.suffix_decoding_max_tree_depth: int = 64
# The limits of requests that can be stored in the cache.
self.suffix_decoding_max_cached_requests: int = -1
# The factor of matched length, calculated as num_draft_tokens = suffix_max_spec_factor * matched_length
self.suffix_decoding_max_spec_factor: float = 1.0
# The probability threshold for speculated tokens.
self.suffix_decoding_min_token_prob: float = 0.1
# model for mtp/eagle/draft_model # model for mtp/eagle/draft_model
self.model: Optional[str] = None self.model: Optional[str] = None
# quantization of model # quantization of model
@@ -939,6 +948,8 @@ class GraphOptimizationConfig:
self.max_capture_size: int = None self.max_capture_size: int = None
""" Record maps mapped from real shape to captured size to reduce runtime overhead """ """ Record maps mapped from real shape to captured size to reduce runtime overhead """
self.real_shape_to_captured_size: dict[int, int] = None self.real_shape_to_captured_size: dict[int, int] = None
""" Record maps mapped from real batch size to captured size"""
self.real_bsz_to_captured_size: dict[int, int] = {}
""" Whether to use shared memory pool for multi capture_size """ """ Whether to use shared memory pool for multi capture_size """
self.use_unique_memory_pool: bool = True self.use_unique_memory_pool: bool = True
""" Whether to use cudagraph for draft model.""" """ Whether to use cudagraph for draft model."""
@@ -1694,15 +1705,20 @@ class FDConfig:
# Initialize cuda graph capture list # Initialize cuda graph capture list
max_capture_shape = self.scheduler_config.max_num_seqs max_capture_shape = self.scheduler_config.max_num_seqs
if self.speculative_config is not None and self.speculative_config.method == "mtp": if self.speculative_config is not None and self.speculative_config.method in ["mtp", "suffix"]:
max_capture_shape = self.scheduler_config.max_num_seqs * ( max_capture_shape = self.scheduler_config.max_num_seqs * (
self.speculative_config.num_speculative_tokens + 1 self.speculative_config.num_speculative_tokens + 1
) )
assert max_capture_shape % 2 == 0, "CUDAGraph only supports capturing even token nums in MTP scenarios." assert max_capture_shape % 2 == 0, "CUDAGraph only supports capturing even token nums in MTP scenarios."
self.graph_opt_config.real_bsz_to_captured_size = {
k: 0 for k in range(1, self.scheduler_config.max_num_seqs + 1)
}
if self.graph_opt_config.cudagraph_only_prefill: if self.graph_opt_config.cudagraph_only_prefill:
max_capture_shape = 512 max_capture_shape = 512
else: else:
max_capture_shape = min(512, max_capture_shape) max_capture_shape = (
max_capture_shape if self.speculative_config is not None else min(512, max_capture_shape)
)
max_capture_shape_prefill = graph_opt_config.max_capture_shape_prefill max_capture_shape_prefill = graph_opt_config.max_capture_shape_prefill
@@ -1717,6 +1733,31 @@ class FDConfig:
max_capture_shape_prefill=max_capture_shape_prefill, max_capture_shape_prefill=max_capture_shape_prefill,
dec_token_per_query_per_step=dec_token_per_query_per_step, dec_token_per_query_per_step=dec_token_per_query_per_step,
) )
if self.speculative_config is not None and self.speculative_config.method in ["mtp", "suffix"]:
real_bsz_to_captured_size = {}
for capture_size in self.graph_opt_config.cudagraph_capture_sizes:
dummy_batch_size = int(capture_size / (self.speculative_config.num_speculative_tokens + 1))
real_bsz_to_captured_size[dummy_batch_size] = capture_size
def expand_bsz_map(real_bsz_to_captured_size):
"""
Expand a sparse batch size mapping into a dense one.
Args:
real_bsz_to_captured_size (dict): Sparse batch size to capture size mapping.
Returns:
dict: Dense batch size to capture size mapping.
"""
sorted_items = sorted(real_bsz_to_captured_size.items())
result = {}
prev_bsz = 0
for curr_bsz, cap in sorted_items:
for bsz in range(prev_bsz + 1, curr_bsz + 1):
result[bsz] = cap
prev_bsz = curr_bsz
return result
self.graph_opt_config.real_bsz_to_captured_size = expand_bsz_map(real_bsz_to_captured_size)
self.graph_opt_config.init_with_cudagrpah_size( self.graph_opt_config.init_with_cudagrpah_size(
max_capture_size=max_capture_shape, max_capture_size=max_capture_shape,
max_capture_shape_prefill=max_capture_shape_prefill, max_capture_shape_prefill=max_capture_shape_prefill,
@@ -119,6 +119,10 @@ class CudaGraphPiecewiseBackend:
if self.fd_config.graph_opt_config.graph_opt_level > 0: if self.fd_config.graph_opt_config.graph_opt_level > 0:
self.cuda_graph_manager = Dy2StCudaGraphManager() self.cuda_graph_manager = Dy2StCudaGraphManager()
self.speculative_decoding = fd_config.speculative_config.method is not None
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
self.real_bsz_to_captured_size = fd_config.graph_opt_config.real_bsz_to_captured_size
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs): def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
if not entry.captured: if not entry.captured:
@@ -153,7 +157,10 @@ class CudaGraphPiecewiseBackend:
# Get real shape (total num tokens) # Get real shape (total num tokens)
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0] real_shape = ids_remove_padding.shape[0]
if self.speculative_decoding and all(self.real_bsz_to_captured_size.values()):
seq_lens_this_time: paddle.Tensor = kwargs["forward_meta"].seq_lens_this_time
num_running_requests = seq_lens_this_time.squeeze(axis=-1).nonzero(as_tuple=False)[-1].item() + 1
real_shape = self.real_bsz_to_captured_size[num_running_requests]
exist_prefill = kwargs["forward_meta"].exist_prefill exist_prefill = kwargs["forward_meta"].exist_prefill
# Static split graph mode: use Static + CUDAGraph for prefill/mixed phase # Static split graph mode: use Static + CUDAGraph for prefill/mixed phase
static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st
@@ -188,6 +195,9 @@ class CudaGraphPiecewiseBackend:
# Capture a new cuda graph # Capture a new cuda graph
if entry.cuda_graph is None: if entry.cuda_graph is None:
assert (
real_shape == padding_real_shape
), f"real_shape:{real_shape} is not equal to padding_real_shape:{padding_real_shape} when capture new graph."
# Warmup the model # Warmup the model
for n in range(entry.num_finished_warmup, self.warm_up_size): for n in range(entry.num_finished_warmup, self.warm_up_size):
entry.num_finished_warmup += 1 entry.num_finished_warmup += 1
+2
View File
@@ -787,6 +787,8 @@ class TokenProcessor:
+ i * MAX_DRAFT_TOKENS + i * MAX_DRAFT_TOKENS
+ accept_num[i] + accept_num[i]
].tolist() ].tolist()
if accept_num[i] == 0:
continue
else: else:
token_id = int(tokens[i, 0]) token_id = int(tokens[i, 0])
token_ids = [token_id] token_ids = [token_id]
+14
View File
@@ -23,3 +23,17 @@ from .mtp import MTPProposer
if not current_platform.is_xpu(): if not current_platform.is_xpu():
from .ngram import NgramProposer from .ngram import NgramProposer
__all__ = ["Proposer", "MTPProposer", "NgramProposer"] __all__ = ["Proposer", "MTPProposer", "NgramProposer"]
# Suffix proposer requires arctic_inference
try:
from .suffix import SuffixProposer
_suffix_proposer_available = True
except ImportError:
_suffix_proposer_available = False
SuffixProposer = None
if _suffix_proposer_available:
__all__ = ["Proposer", "MTPProposer", "NgramProposer", "SuffixProposer"]
else:
__all__ = ["Proposer", "MTPProposer", "NgramProposer"]
+24
View File
@@ -96,3 +96,27 @@ class Proposer(ABC):
bool: True if chunk prefill is enabled; False otherwise. bool: True if chunk prefill is enabled; False otherwise.
""" """
return False return False
def prepare_dummy_speculative_drafts(
self,
share_inputs,
batch_size: int,
) -> None:
"""
Construct a set of dummy draft tokens for CUDAGraph capture scenarios,
used only to stabilize shape/step count, with no requirement for semantic correctness.
Args:
share_inputs: share_inputs dict maintained by GPUModelRunner
batch_size: current batch_size for dummy_run
expected_decode_len: expected number of decode steps (must match what's passed to _dummy_run)
"""
max_fake_drafts = self.max_draft_token_num
stop = share_inputs["stop_flags"][0].item()
if not stop:
share_inputs["draft_tokens"][:batch_size, :max_fake_drafts] = 5
share_inputs["seq_lens_this_time"][:batch_size] = max_fake_drafts + 1
else:
share_inputs["seq_lens_this_time"][:batch_size] = 0
+230
View File
@@ -0,0 +1,230 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import numpy as np
from fastdeploy.config import FDConfig
from fastdeploy.utils import spec_logger
from .base import Proposer
try:
from arctic_inference.suffix_decoding import SuffixDecodingCache
except ImportError:
SuffixDecodingCache = None
class SuffixProposer(Proposer):
"""
Proposer for Suffix Decoding method.
Uses SuffixDecodingCache to generate draft tokens based on suffix tree matching.
"""
def __init__(self, fd_config: FDConfig):
super().__init__(fd_config)
if SuffixDecodingCache is None:
raise ImportError(
"arctic_inference.suffix_decoding is not available. Please install arctic-inference package."
)
# Initialize SuffixDecodingCache
self.suffix_cache = SuffixDecodingCache(
max_tree_depth=self.speculative_config.suffix_decoding_max_tree_depth,
max_cached_requests=self.speculative_config.suffix_decoding_max_cached_requests,
)
self.max_tree_depth = self.speculative_config.suffix_decoding_max_tree_depth
self.max_spec_factor = self.speculative_config.suffix_decoding_max_spec_factor
self.min_token_prob = self.speculative_config.suffix_decoding_min_token_prob
# Track active requests: req_id -> idx mapping
self.req_id_to_idx = {}
self.idx_to_req_id = {}
self.context_tokens = np.full(
(self.max_num_seqs, self.max_model_len),
-1,
dtype=np.int32,
)
self.ban_tokens = set([101031, 101032, 101033])
def _update_request_mapping(self, idx: int, req_id: str):
"""
Update the mapping between request ID and batch index.
Args:
req_id: Request identifier
idx: Batch index
"""
# Clean up old mapping if exists
if idx in self.idx_to_req_id:
old_req_id = self.idx_to_req_id[idx]
if old_req_id in self.req_id_to_idx:
del self.req_id_to_idx[old_req_id]
# Set new mapping
self.req_id_to_idx[req_id] = idx
self.idx_to_req_id[idx] = req_id
def start_request(self, idx: int, req_id: str, prompt_token_ids: list[int]):
"""
Start a new request in the suffix cache.
Args:
req_id: Request identifier
prompt_token_ids: List of prompt token IDs
"""
if req_id in self.suffix_cache.active_requests:
# Request already active, skip
return
prompt_array = np.array(prompt_token_ids, dtype=np.int32)
if not prompt_array.flags["CONTIGUOUS"]:
prompt_array = np.ascontiguousarray(prompt_array)
self.context_tokens[idx, :] = -1
self.context_tokens[idx, : len(prompt_token_ids)] = prompt_array
self._update_request_mapping(idx, req_id)
if req_id not in self.suffix_cache.active_requests:
if req_id in self.suffix_cache.cached_requests:
# Reset the suffix cache for current req_id
self.suffix_cache.evict_cached_response(req_id)
spec_logger.debug(f"[SuffixDecoding] Reset suffix cache for request {req_id}.")
self.suffix_cache.start_request(req_id, prompt_array)
spec_logger.debug(f"[SuffixDecoding] Start request {req_id}.")
def stop_request(self, req_id: str):
"""
Stop a request in the suffix cache.
Args:
req_id: Request identifier
"""
if req_id in self.suffix_cache.active_requests:
self.suffix_cache.stop_request(req_id)
# Clean up mappings
if req_id in self.req_id_to_idx:
idx = self.req_id_to_idx[req_id]
del self.req_id_to_idx[req_id]
if idx in self.idx_to_req_id:
del self.idx_to_req_id[idx]
spec_logger.debug(f"[SuffixDecoding] Stop request {req_id}.")
def add_active_response(self, req_id: str, token_ids: list[int]):
"""
Add newly sampled tokens to the suffix cache for a request.
Args:
req_id: Request identifier
token_ids: List of newly sampled token IDs
"""
if req_id not in self.suffix_cache.active_requests:
return
token_array = np.array(token_ids, dtype=np.int32)
if not token_array.flags["CONTIGUOUS"]:
token_array = np.ascontiguousarray(token_array)
self.suffix_cache.add_active_response(req_id, token_array)
def _run_impl(self, share_inputs):
stop_flags_cpu = share_inputs["stop_flags"].cpu().numpy().flatten()
is_block_step_cpu = share_inputs["is_block_step"].cpu().numpy().flatten()
accept_tokens_cpu = share_inputs["accept_tokens"].cpu()
accept_num_cpu = share_inputs["accept_num"].cpu().numpy().flatten()
seq_lens_encoder = share_inputs["seq_lens_encoder"].cpu().numpy().flatten()
seq_lens_decoder = share_inputs["seq_lens_decoder"].cpu().numpy().flatten()
draft_tokens_cpu = share_inputs["draft_tokens"].cpu()
seq_lens_this_time_cpu = share_inputs["seq_lens_this_time"].cpu()
total_lens = seq_lens_encoder + seq_lens_decoder
batch_size = seq_lens_this_time_cpu.shape[0]
for bid in range(batch_size):
req_id = self.idx_to_req_id.get(bid)
# 1. Stop condition has the highest priority
if stop_flags_cpu[bid]:
seq_lens_this_time_cpu[bid] = 0
draft_tokens_cpu[bid, :] = -1
if not is_block_step_cpu[bid]:
if req_id is not None and req_id in self.suffix_cache.active_requests:
self.stop_request(req_id)
continue
else:
seq_lens_this_time_cpu[bid] = 1
draft_tokens_cpu[bid, 1:] = -1
# 2. Skip some cases
num_tokens = total_lens[bid]
max_spec_tokens = min(
self.max_draft_token_num,
self.max_model_len - num_tokens - 1,
)
if max_spec_tokens <= 1:
continue
if req_id is None:
continue
# 3. Add accept tokens to context
acc_num = int(accept_num_cpu[bid])
assert (
acc_num > 0
), f"Request {req_id} (bid {bid}) must have at least one accepted token, but got {acc_num}."
if acc_num > 0:
token_ids = accept_tokens_cpu[bid, :acc_num]
ctx_start = seq_lens_decoder[bid] - acc_num
self.context_tokens[bid, ctx_start : ctx_start + acc_num] = token_ids
self.add_active_response(req_id, token_ids)
# 4. Get context
start = max(0, num_tokens - self.max_tree_depth)
ctx = self.context_tokens[bid, start:num_tokens]
ctx = ctx[ctx >= 0]
if ctx.size == 0:
continue
if not ctx.flags["CONTIGUOUS"]:
ctx = np.ascontiguousarray(ctx, dtype=np.int32)
else:
ctx = ctx.astype(np.int32, copy=False)
# 5. Speculate
draft = self.suffix_cache.speculate(
req_id,
ctx,
max_spec_tokens=max_spec_tokens,
max_spec_factor=self.max_spec_factor,
min_token_prob=self.min_token_prob,
)
token_ids = draft.token_ids
counter = 0
for token in token_ids:
if token in self.ban_tokens:
break
else:
counter += 1
if counter > 0:
draft_tokens_cpu[bid, 1 : 1 + counter] = np.array(token_ids[:counter])
draft_tokens_cpu[bid, 1 + counter :] = -1
seq_lens_this_time_cpu[bid] = 1 + counter
share_inputs["draft_tokens"][:] = draft_tokens_cpu.cuda()
share_inputs["seq_lens_this_time"][:] = seq_lens_this_time_cpu.cuda()
+29 -12
View File
@@ -91,7 +91,7 @@ from fastdeploy.model_executor.pre_and_post_process import (
) )
if not (current_platform.is_dcu() or current_platform.is_iluvatar()): if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
from fastdeploy.spec_decode import MTPProposer, NgramProposer from fastdeploy.spec_decode import MTPProposer, NgramProposer, SuffixProposer
import zmq import zmq
@@ -390,6 +390,8 @@ class GPUModelRunner(ModelRunnerBase):
self.device_id, self.device_id,
self.share_inputs, self.share_inputs,
) )
elif self.speculative_method == "suffix":
self.proposer = SuffixProposer(self.fd_config)
else: else:
self.proposer = None self.proposer = None
@@ -810,6 +812,13 @@ class GPUModelRunner(ModelRunnerBase):
self.forward_batch_reqs_list[idx] = request self.forward_batch_reqs_list[idx] = request
has_prefill_task = True has_prefill_task = True
if self.speculative_decoding and self.speculative_method == "suffix" and self.proposer is not None:
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
prompt_token_ids = request.prompt_token_ids
self.proposer.start_request(idx, request.request_id, prompt_token_ids)
# Routing Replay # Routing Replay
if self.fd_config.routing_replay_config.enable_routing_replay: if self.fd_config.routing_replay_config.enable_routing_replay:
if prefill_start_index == 0: if prefill_start_index == 0:
@@ -1063,6 +1072,15 @@ class GPUModelRunner(ModelRunnerBase):
else: else:
return default_value return default_value
# Start suffix decoding request if using suffix proposer
if self.speculative_decoding and self.speculative_method == "suffix" and self.proposer is not None:
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
prompt_token_ids = request.prompt_token_ids
self.proposer.start_request(request.request_id, prompt_token_ids)
self.proposer.update_request_mapping(request.request_id, idx)
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7) self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
@@ -1768,6 +1786,7 @@ class GPUModelRunner(ModelRunnerBase):
self, self,
hidden_states: paddle.Tensor, hidden_states: paddle.Tensor,
model_output: paddle.Tensor, model_output: paddle.Tensor,
batch_size: int,
accept_all_drafts=False, accept_all_drafts=False,
reject_all_drafts=False, reject_all_drafts=False,
) -> paddle.Tensor: ) -> paddle.Tensor:
@@ -1876,8 +1895,7 @@ class GPUModelRunner(ModelRunnerBase):
is_dummy_run=True, is_dummy_run=True,
) )
else: else:
self.proposer.run(share_inputs=self.share_inputs) self.proposer.prepare_dummy_speculative_drafts(share_inputs=self.share_inputs, batch_size=batch_size)
return sampler_output return sampler_output
def _dummy_run( def _dummy_run(
@@ -1954,7 +1972,7 @@ class GPUModelRunner(ModelRunnerBase):
(self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None), (self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None),
(self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None), (self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None),
) )
self._dummy_sampler_run(hidden_states, model_output, accept_all_drafts, reject_all_drafts) self._dummy_sampler_run(hidden_states, model_output, batch_size, accept_all_drafts, reject_all_drafts)
# 7. Updata 'infer_seed' and step_cuda() # 7. Updata 'infer_seed' and step_cuda()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment) self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
@@ -2064,22 +2082,19 @@ class GPUModelRunner(ModelRunnerBase):
logger.info( logger.info(
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
) )
elif self.speculative_decoding and self.speculative_method == "mtp": elif self.speculative_decoding:
# Capture Target Model without bsz 1 # Capture Target Model without bsz 1
for capture_size in sorted(capture_sizes, reverse=True): for capture_size in sorted(capture_sizes, reverse=True):
expected_decode_len = self.speculative_config.num_speculative_tokens * 2 + 1
self._dummy_run( self._dummy_run(
num_tokens=( num_tokens=self.fd_config.get_max_chunk_tokens(),
self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1)
if self.scheduler_config.splitwise_role == "decode"
else self.fd_config.get_max_chunk_tokens()
),
batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)), batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)),
in_capturing=True, in_capturing=True,
expected_decode_len=self.speculative_config.num_speculative_tokens * 2 + 1, expected_decode_len=expected_decode_len,
accept_all_drafts=True, accept_all_drafts=True,
) )
logger.info( logger.info(
f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}" f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{expected_decode_len}"
) )
else: else:
for batch_size in sorted(capture_sizes, reverse=True): for batch_size in sorted(capture_sizes, reverse=True):
@@ -2536,6 +2551,8 @@ class GPUModelRunner(ModelRunnerBase):
self.proposer.run( self.proposer.run(
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
) )
elif self.speculative_method == "suffix":
self.proposer.run(share_inputs=self.share_inputs)
else: else:
self.proposer.run(share_inputs=self.share_inputs) self.proposer.run(share_inputs=self.share_inputs)
+2 -2
View File
@@ -267,7 +267,7 @@ class InputBatch:
self.accept_num = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32") self.accept_num = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32")
self.draft_tokens = paddle.full( self.draft_tokens = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1], shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=-1, fill_value=0,
dtype="int64", dtype="int64",
) )
@@ -293,7 +293,7 @@ class InputBatch:
# For V1_KVCACHE_SCHEDULER # For V1_KVCACHE_SCHEDULER
self.step_draft_tokens = paddle.full( self.step_draft_tokens = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1], shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=0, fill_value=-1,
dtype="int64", dtype="int64",
) )
self.step_seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.step_seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32")
+1
View File
@@ -48,3 +48,4 @@ p2pstore
py-cpuinfo py-cpuinfo
flashinfer-python-paddle flashinfer-python-paddle
flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl
arctic_inference @ https://paddle-qa.bj.bcebos.com/ernie/arctic_inference-0.1.3-cp310-cp310-linux_x86_64.whl
@@ -25,6 +25,7 @@ from fastdeploy.config import (
GraphOptimizationConfig, GraphOptimizationConfig,
ParallelConfig, ParallelConfig,
SchedulerConfig, SchedulerConfig,
SpeculativeConfig,
) )
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import ( from fastdeploy.model_executor.graph_optimization.decorator import (
@@ -110,6 +111,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
cache_config = CacheConfig(args={}) cache_config = CacheConfig(args={})
scheduler_config.max_num_seqs = 1 scheduler_config.max_num_seqs = 1
parallel_config = ParallelConfig(args={}) parallel_config = ParallelConfig(args={})
speculative_config = SpeculativeConfig(args={})
model_config = Mock() model_config = Mock()
model_config.max_model_len = 5120 model_config.max_model_len = 5120
model_config.architectures = ["test_model"] model_config.architectures = ["test_model"]
@@ -120,6 +122,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
cache_config=cache_config, cache_config=cache_config,
model_config=model_config, model_config=model_config,
parallel_config=parallel_config, parallel_config=parallel_config,
speculative_config=speculative_config,
) )
# Run Test Case1 # Run Test Case1
@@ -25,6 +25,7 @@ from fastdeploy.config import (
GraphOptimizationConfig, GraphOptimizationConfig,
ParallelConfig, ParallelConfig,
SchedulerConfig, SchedulerConfig,
SpeculativeConfig,
) )
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import ( from fastdeploy.model_executor.graph_optimization.decorator import (
@@ -103,6 +104,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
scheduler_config.max_num_seqs = 1 scheduler_config.max_num_seqs = 1
cache_config = CacheConfig({}) cache_config = CacheConfig({})
parallel_config = ParallelConfig(args={}) parallel_config = ParallelConfig(args={})
speculative_config = SpeculativeConfig(args={})
model_config = Mock() model_config = Mock()
model_config.max_model_len = 512 model_config.max_model_len = 512
model_config.architectures = ["test_model"] model_config.architectures = ["test_model"]
@@ -116,6 +118,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
cache_config=cache_config, cache_config=cache_config,
parallel_config=parallel_config, parallel_config=parallel_config,
model_config=model_config, model_config=model_config,
speculative_config=speculative_config,
test_mode=True, test_mode=True,
) )
@@ -27,6 +27,7 @@ from fastdeploy.config import (
GraphOptimizationConfig, GraphOptimizationConfig,
ParallelConfig, ParallelConfig,
SchedulerConfig, SchedulerConfig,
SpeculativeConfig,
) )
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import ( from fastdeploy.model_executor.graph_optimization.decorator import (
@@ -144,6 +145,7 @@ class TestGraphOptBackend(unittest.TestCase):
# Setup cache config # Setup cache config
cache_config = CacheConfig({}) cache_config = CacheConfig({})
parallel_config = ParallelConfig(args={}) parallel_config = ParallelConfig(args={})
speculative_config = SpeculativeConfig(args={})
model_config = Mock() model_config = Mock()
model_config.max_model_len = 512 model_config.max_model_len = 512
model_config.architectures = ["test_model"] model_config.architectures = ["test_model"]
@@ -156,6 +158,7 @@ class TestGraphOptBackend(unittest.TestCase):
cache_config=cache_config, cache_config=cache_config,
parallel_config=parallel_config, parallel_config=parallel_config,
model_config=model_config, model_config=model_config,
speculative_config=speculative_config,
test_mode=True, test_mode=True,
) )
@@ -31,6 +31,7 @@ from fastdeploy.config import (
GraphOptimizationConfig, GraphOptimizationConfig,
ParallelConfig, ParallelConfig,
SchedulerConfig, SchedulerConfig,
SpeculativeConfig,
) )
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import ( from fastdeploy.model_executor.graph_optimization.decorator import (
@@ -97,6 +98,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs) graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs)
cache_config = CacheConfig({}) cache_config = CacheConfig({})
parallel_config = ParallelConfig(args={}) parallel_config = ParallelConfig(args={})
speculative_config = SpeculativeConfig(args={})
model_config = Mock() model_config = Mock()
model_config.max_model_len = 512 model_config.max_model_len = 512
model_config.architectures = ["test_model"] model_config.architectures = ["test_model"]
@@ -107,6 +109,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
cache_config=cache_config, cache_config=cache_config,
parallel_config=parallel_config, parallel_config=parallel_config,
model_config=model_config, model_config=model_config,
speculative_config=speculative_config,
test_mode=True, test_mode=True,
) )
+2 -3
View File
@@ -112,9 +112,8 @@ def py_update_attn_mask_offsets_op(
attn_mask_offsets_decoder[bid] = int(attn_mask_offsets_decoder[bid]) + seq_len_this attn_mask_offsets_decoder[bid] = int(attn_mask_offsets_decoder[bid]) + seq_len_this
# speculative decoding: if seq_len_this > 1 then set decode_states_now[i] accordingly # speculative decoding: if seq_len_this > 1 then set decode_states_now[i] accordingly
if seq_len_this > 1: for i in range(decode_states_len):
for i in range(decode_states_len): decode_now[i] = 0 if i < seq_len_this and decode_now[i] != 1 else -1
decode_now[i] = 0 if i < seq_len_this else -1
# done decoder branch # done decoder branch
continue continue
+141
View File
@@ -0,0 +1,141 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
import numpy as np
import paddle
from utils import FakeModelConfig, get_default_test_fd_config
from fastdeploy.config import SpeculativeConfig
from fastdeploy.spec_decode.suffix import SuffixProposer
class TestSuffixProposer(unittest.TestCase):
def setUp(self):
self.fd_config = get_default_test_fd_config()
self.fd_config.model_config = FakeModelConfig()
self.fd_config.model_config.max_model_len = 2048
self.fd_config.speculative_config = SpeculativeConfig({})
self.fd_config.speculative_config.method = "suffix"
self.fd_config.speculative_config.num_speculative_tokens = 4
self.fd_config.speculative_config.suffix_decoding_max_tree_depth = 64
self.fd_config.speculative_config.suffix_decoding_max_cached_requests = 4
self.fd_config.speculative_config.suffix_decoding_max_spec_factor = 1.0
self.fd_config.speculative_config.suffix_decoding_min_token_prob = 0.1
self.fd_config.scheduler_config.max_num_seqs = 4
bsz = self.fd_config.scheduler_config.max_num_seqs
max_draft_tokens = self.fd_config.speculative_config.num_speculative_tokens
self.share_inputs = {
"stop_flags": paddle.full([bsz, 1], fill_value=False, dtype="bool"),
"is_block_step": paddle.full([bsz], fill_value=False, dtype="bool"),
"accept_tokens": paddle.zeros([bsz, max_draft_tokens], dtype="int64"),
"accept_num": paddle.zeros([bsz], dtype="int32"),
"seq_lens_this_time": paddle.zeros([bsz, 1], dtype="int32"),
"seq_lens_encoder": paddle.zeros([bsz, 1], dtype="int32"),
"seq_lens_decoder": paddle.zeros([bsz, 1], dtype="int32"),
"draft_tokens": paddle.zeros([bsz, max_draft_tokens], dtype="int64"),
}
def test_start_and_stop_request(self):
proposer = SuffixProposer(self.fd_config)
idx = 0
req_id = "req-001"
prompt_token_ids = [1, 2, 3, 4]
proposer.start_request(idx, req_id, prompt_token_ids)
refs_context_tokens = np.full(
(self.fd_config.scheduler_config.max_num_seqs, self.fd_config.model_config.max_model_len),
-1,
dtype=np.int32,
)
refs_req_id_to_idx = {}
refs_idx_to_req_id = {}
refs_context_tokens[idx, : len(prompt_token_ids)] = prompt_token_ids
refs_req_id_to_idx[req_id] = idx
refs_idx_to_req_id[idx] = req_id
self.assertIsNotNone(proposer.suffix_cache)
np.testing.assert_array_equal(proposer.context_tokens, refs_context_tokens)
np.testing.assert_array_equal(proposer.req_id_to_idx, refs_req_id_to_idx)
np.testing.assert_array_equal(proposer.idx_to_req_id, refs_idx_to_req_id)
idx = 1
req_id = "req-002"
prompt_token_ids = [5, 6, 7, 8]
proposer.start_request(idx, req_id, prompt_token_ids)
refs_context_tokens[idx, : len(prompt_token_ids)] = prompt_token_ids
refs_req_id_to_idx[req_id] = idx
refs_idx_to_req_id[idx] = req_id
np.testing.assert_array_equal(proposer.context_tokens, refs_context_tokens)
np.testing.assert_array_equal(proposer.req_id_to_idx, refs_req_id_to_idx)
np.testing.assert_array_equal(proposer.idx_to_req_id, refs_idx_to_req_id)
proposer.stop_request("req-001")
refs_req_id_to_idx.pop("req-001")
refs_idx_to_req_id.pop(0)
np.testing.assert_array_equal(proposer.context_tokens, refs_context_tokens)
np.testing.assert_array_equal(proposer.req_id_to_idx, refs_req_id_to_idx)
np.testing.assert_array_equal(proposer.idx_to_req_id, refs_idx_to_req_id)
def test_propose(self):
self.share_inputs["accept_tokens"][:, :2] = 42
self.share_inputs["accept_num"][:] = 2
self.share_inputs["seq_lens_this_time"][:, :] = 2
self.share_inputs["seq_lens_encoder"][:, :] = 0
self.share_inputs["seq_lens_decoder"][:, :] = 100
self.share_inputs["draft_tokens"][:, :2] = 42
self.share_inputs["draft_tokens"][:, 2:] = 53
print(self.share_inputs)
proposer = SuffixProposer(self.fd_config)
ids = [0, 1, 2, 3]
req_ids = ["req-001", "req-002", "req-003", "req-004"]
prompt_token_ids_list = [
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
]
for idx, req_id, prompt_token_ids in zip(ids, req_ids, prompt_token_ids_list):
proposer.start_request(idx, req_id, prompt_token_ids)
proposer.run(self.share_inputs)
refs_draft_tokens = np.array(
[
[42, 42, -1, -1],
[42, 42, -1, -1],
[42, 42, -1, -1],
[42, 42, -1, -1],
],
dtype=np.int64,
)
refs_seq_lens_this_time = np.array([[2], [2], [2], [2]], dtype=np.int32)
np.testing.assert_array_equal(self.share_inputs["draft_tokens"].numpy(), refs_draft_tokens)
np.testing.assert_array_equal(self.share_inputs["seq_lens_this_time"].numpy(), refs_seq_lens_this_time)
if __name__ == "__main__":
unittest.main()