[Speculative Decoding] Support suffix decoding (#6403)

* support suffix decoding
This commit is contained in:
GoldPancake
2026-02-26 11:42:05 +08:00
committed by GitHub
parent 6d3fede240
commit 2178f2829b
18 changed files with 587 additions and 30 deletions
@@ -66,13 +66,11 @@ __global__ void update_attn_mask_offsets_kernel(
attn_mask_offsets_decoder[bid] += seq_len_this_time;
// Speculative decoding in text_generation
if (seq_len_this_time > 1) {
for (int i = 0; i < decode_states_len; i++) {
if (i < seq_len_this_time) {
decode_states_now[i] = 0;
} else {
decode_states_now[i] = -1;
}
for (int i = 0; i < decode_states_len; i++) {
if (i < seq_len_this_time && decode_states_now[i] != 1) {
decode_states_now[i] = 0;
} else {
decode_states_now[i] = -1;
}
}
}
+35 -1
View File
@@ -12,6 +12,8 @@ This project implements an efficient **Speculative Decoding** inference framewor
- **Ngram**
- **Suffix Decoding**
- **MTP (Multi-Token Prediction)**
- ✅ Supported: TP Sharding
- ✅ Supported: Shared Prefix
@@ -52,7 +54,7 @@ This project implements an efficient **Speculative Decoding** inference framewor
## 🔧 Configuration Parameters
- `method`: The speculative decoding strategy, currently supports `["mtp", "ngram"]`.
- `method`: The speculative decoding strategy, currently supports `["mtp", "ngram", "suffix"]`.
- `num_speculative_tokens`: Number of speculative tokens to generate; max is 5, currently MTP supports only 1.
- `model`: Path to the MTP draft model when using the `"mtp"` method.
- `quantization`: Quantization method of the MTP model (e.g., WINT4).
@@ -162,3 +164,35 @@ python -m fastdeploy.entrypoints.openai.api_server \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}'
```
## 🌲 Using Suffix Decoding
Suffix Decoding is a model-free speculative decoding method that accelerates repetitive inference tasks (e.g., agent workflows, coding) using efficient CPU-based suffix trees for rapid draft token prediction, eliminating GPU overhead.
Run on 4 × H100 GPUs with WINT4 quantization:
> Config file: benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml
```
python -m fastdeploy.entrypoints.openai.api_server \
--model ${path_to_main_model} \
--tensor-parallel-size 4 \
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 4, "suffix_decoding_max_tree_depth": 64, "suffix_decoding_max_cached_requests": 10000, "suffix_decoding_max_spec_factor": 1.0, "suffix_decoding_min_token_prob": 0.1}'
```
Parameter Descriptions
```
# The maximum length of token sequences cached in suffix trees.
self.suffix_decoding_max_tree_depth: int = 64
# The limits of requests that can be stored in the cache.
self.suffix_decoding_max_cached_requests: int = -1
# The factor of matched length, calculated as num_draft_tokens = suffix_max_spec_factor * matched_length
self.suffix_decoding_max_spec_factor: float = 1.0
# The probability threshold for speculated tokens.
self.suffix_decoding_min_token_prob: float = 0.1
```
+35 -1
View File
@@ -8,6 +8,8 @@
- **Ngram**
- **后缀解码**
- **MTP (Multi-Token Prediction)**
- ✅ 已支持:TP 切分
- ✅ 已支持:共享前缀
@@ -36,7 +38,7 @@
- **高效 DraftModel/MTP 框架**:开发多个融合 Cuda Kernel,统一完成模型类方法的前后处理,相比传统的循环、切片方法,性能高效且易维护
## 🔧 参数说明
- `method`: 解码策略,可选值为 `"mtp"` `"ngram"`
- `method`: 解码策略,可选值为 `"mtp"` `"ngram"``"suffix"`
- `num_speculative_tokens`: 每轮预测的 Token 数,最大支持 5(当前 MTP 仅支持 1)
- `model`: 若选择 MTP,则需指定 MTP 模型路径
- `quantization`: 模型量化方式,推荐使用 `wint8`
@@ -133,3 +135,35 @@ python -m fastdeploy.entrypoints.openai.api_server \
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
--speculative-config '{"method": "ngram", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}'
```
## 🌲 使用后缀解码 (Suffix Decoding)
后缀解码是一种无模型推理框架,通过在 CPU 上使用高效后缀树进行快速草稿 Token 预测,加速重复性推理任务(如代理工作流程、编码等),消除 GPU 开销。
使用 4×H100;量化方式选择 WINT4:
> 配置文件:benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml
```
python -m fastdeploy.entrypoints.openai.api_server \
--model ${path_to_main_model} \
--tensor-parallel-size 4 \
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 4, "suffix_decoding_max_tree_depth": 64, "suffix_decoding_max_cached_requests": 10000, "suffix_decoding_max_spec_factor": 1.0, "suffix_decoding_min_token_prob": 0.1}'
```
参数描述
```
# 后缀树中缓存的token序列的最大长度
self.suffix_decoding_max_tree_depth: int = 64
# 缓存中可存储的请求数量上限
self.suffix_decoding_max_cached_requests: int = -1
# 匹配长度的系数,计算方式为 num_draft_tokens = suffix_max_spec_factor * matched_length
self.suffix_decoding_max_spec_factor: float = 1.0
# 推测token的概率阈值
self.suffix_decoding_min_token_prob: float = 0.1
```
+44 -3
View File
@@ -721,7 +721,7 @@ class SpeculativeConfig:
self,
args,
):
self.method_list = ["ngram_match", "mtp"]
self.method_list = ["ngram_match", "mtp", "suffix"]
self.mtp_strategy_list = ["default", "with_ngram"]
# speculative method, choose in [None, "ngram_match", "mtp", "hybrid_mtp_ngram"]
@@ -739,6 +739,15 @@ class SpeculativeConfig:
# ngram match
self.max_ngram_size: int = 5
self.min_ngram_size: int = 2
# Suffix Decoding
# The maximum length of token sequences cached in suffix trees.
self.suffix_decoding_max_tree_depth: int = 64
# The limits of requests that can be stored in the cache.
self.suffix_decoding_max_cached_requests: int = -1
# The factor of matched length, calculated as num_draft_tokens = suffix_max_spec_factor * matched_length
self.suffix_decoding_max_spec_factor: float = 1.0
# The probability threshold for speculated tokens.
self.suffix_decoding_min_token_prob: float = 0.1
# model for mtp/eagle/draft_model
self.model: Optional[str] = None
# quantization of model
@@ -939,6 +948,8 @@ class GraphOptimizationConfig:
self.max_capture_size: int = None
""" Record maps mapped from real shape to captured size to reduce runtime overhead """
self.real_shape_to_captured_size: dict[int, int] = None
""" Record maps mapped from real batch size to captured size"""
self.real_bsz_to_captured_size: dict[int, int] = {}
""" Whether to use shared memory pool for multi capture_size """
self.use_unique_memory_pool: bool = True
""" Whether to use cudagraph for draft model."""
@@ -1694,15 +1705,20 @@ class FDConfig:
# Initialize cuda graph capture list
max_capture_shape = self.scheduler_config.max_num_seqs
if self.speculative_config is not None and self.speculative_config.method == "mtp":
if self.speculative_config is not None and self.speculative_config.method in ["mtp", "suffix"]:
max_capture_shape = self.scheduler_config.max_num_seqs * (
self.speculative_config.num_speculative_tokens + 1
)
assert max_capture_shape % 2 == 0, "CUDAGraph only supports capturing even token nums in MTP scenarios."
self.graph_opt_config.real_bsz_to_captured_size = {
k: 0 for k in range(1, self.scheduler_config.max_num_seqs + 1)
}
if self.graph_opt_config.cudagraph_only_prefill:
max_capture_shape = 512
else:
max_capture_shape = min(512, max_capture_shape)
max_capture_shape = (
max_capture_shape if self.speculative_config is not None else min(512, max_capture_shape)
)
max_capture_shape_prefill = graph_opt_config.max_capture_shape_prefill
@@ -1717,6 +1733,31 @@ class FDConfig:
max_capture_shape_prefill=max_capture_shape_prefill,
dec_token_per_query_per_step=dec_token_per_query_per_step,
)
if self.speculative_config is not None and self.speculative_config.method in ["mtp", "suffix"]:
real_bsz_to_captured_size = {}
for capture_size in self.graph_opt_config.cudagraph_capture_sizes:
dummy_batch_size = int(capture_size / (self.speculative_config.num_speculative_tokens + 1))
real_bsz_to_captured_size[dummy_batch_size] = capture_size
def expand_bsz_map(real_bsz_to_captured_size):
"""
Expand a sparse batch size mapping into a dense one.
Args:
real_bsz_to_captured_size (dict): Sparse batch size to capture size mapping.
Returns:
dict: Dense batch size to capture size mapping.
"""
sorted_items = sorted(real_bsz_to_captured_size.items())
result = {}
prev_bsz = 0
for curr_bsz, cap in sorted_items:
for bsz in range(prev_bsz + 1, curr_bsz + 1):
result[bsz] = cap
prev_bsz = curr_bsz
return result
self.graph_opt_config.real_bsz_to_captured_size = expand_bsz_map(real_bsz_to_captured_size)
self.graph_opt_config.init_with_cudagrpah_size(
max_capture_size=max_capture_shape,
max_capture_shape_prefill=max_capture_shape_prefill,
@@ -119,6 +119,10 @@ class CudaGraphPiecewiseBackend:
if self.fd_config.graph_opt_config.graph_opt_level > 0:
self.cuda_graph_manager = Dy2StCudaGraphManager()
self.speculative_decoding = fd_config.speculative_config.method is not None
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
self.real_bsz_to_captured_size = fd_config.graph_opt_config.real_bsz_to_captured_size
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
if not entry.captured:
@@ -153,7 +157,10 @@ class CudaGraphPiecewiseBackend:
# Get real shape (total num tokens)
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0]
if self.speculative_decoding and all(self.real_bsz_to_captured_size.values()):
seq_lens_this_time: paddle.Tensor = kwargs["forward_meta"].seq_lens_this_time
num_running_requests = seq_lens_this_time.squeeze(axis=-1).nonzero(as_tuple=False)[-1].item() + 1
real_shape = self.real_bsz_to_captured_size[num_running_requests]
exist_prefill = kwargs["forward_meta"].exist_prefill
# Static split graph mode: use Static + CUDAGraph for prefill/mixed phase
static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st
@@ -188,6 +195,9 @@ class CudaGraphPiecewiseBackend:
# Capture a new cuda graph
if entry.cuda_graph is None:
assert (
real_shape == padding_real_shape
), f"real_shape:{real_shape} is not equal to padding_real_shape:{padding_real_shape} when capture new graph."
# Warmup the model
for n in range(entry.num_finished_warmup, self.warm_up_size):
entry.num_finished_warmup += 1
+2
View File
@@ -787,6 +787,8 @@ class TokenProcessor:
+ i * MAX_DRAFT_TOKENS
+ accept_num[i]
].tolist()
if accept_num[i] == 0:
continue
else:
token_id = int(tokens[i, 0])
token_ids = [token_id]
+14
View File
@@ -23,3 +23,17 @@ from .mtp import MTPProposer
if not current_platform.is_xpu():
from .ngram import NgramProposer
__all__ = ["Proposer", "MTPProposer", "NgramProposer"]
# Suffix proposer requires arctic_inference
try:
from .suffix import SuffixProposer
_suffix_proposer_available = True
except ImportError:
_suffix_proposer_available = False
SuffixProposer = None
if _suffix_proposer_available:
__all__ = ["Proposer", "MTPProposer", "NgramProposer", "SuffixProposer"]
else:
__all__ = ["Proposer", "MTPProposer", "NgramProposer"]
+24
View File
@@ -96,3 +96,27 @@ class Proposer(ABC):
bool: True if chunk prefill is enabled; False otherwise.
"""
return False
def prepare_dummy_speculative_drafts(
self,
share_inputs,
batch_size: int,
) -> None:
"""
Construct a set of dummy draft tokens for CUDAGraph capture scenarios,
used only to stabilize shape/step count, with no requirement for semantic correctness.
Args:
share_inputs: share_inputs dict maintained by GPUModelRunner
batch_size: current batch_size for dummy_run
expected_decode_len: expected number of decode steps (must match what's passed to _dummy_run)
"""
max_fake_drafts = self.max_draft_token_num
stop = share_inputs["stop_flags"][0].item()
if not stop:
share_inputs["draft_tokens"][:batch_size, :max_fake_drafts] = 5
share_inputs["seq_lens_this_time"][:batch_size] = max_fake_drafts + 1
else:
share_inputs["seq_lens_this_time"][:batch_size] = 0
+230
View File
@@ -0,0 +1,230 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import numpy as np
from fastdeploy.config import FDConfig
from fastdeploy.utils import spec_logger
from .base import Proposer
try:
from arctic_inference.suffix_decoding import SuffixDecodingCache
except ImportError:
SuffixDecodingCache = None
class SuffixProposer(Proposer):
"""
Proposer for Suffix Decoding method.
Uses SuffixDecodingCache to generate draft tokens based on suffix tree matching.
"""
def __init__(self, fd_config: FDConfig):
super().__init__(fd_config)
if SuffixDecodingCache is None:
raise ImportError(
"arctic_inference.suffix_decoding is not available. Please install arctic-inference package."
)
# Initialize SuffixDecodingCache
self.suffix_cache = SuffixDecodingCache(
max_tree_depth=self.speculative_config.suffix_decoding_max_tree_depth,
max_cached_requests=self.speculative_config.suffix_decoding_max_cached_requests,
)
self.max_tree_depth = self.speculative_config.suffix_decoding_max_tree_depth
self.max_spec_factor = self.speculative_config.suffix_decoding_max_spec_factor
self.min_token_prob = self.speculative_config.suffix_decoding_min_token_prob
# Track active requests: req_id -> idx mapping
self.req_id_to_idx = {}
self.idx_to_req_id = {}
self.context_tokens = np.full(
(self.max_num_seqs, self.max_model_len),
-1,
dtype=np.int32,
)
self.ban_tokens = set([101031, 101032, 101033])
def _update_request_mapping(self, idx: int, req_id: str):
"""
Update the mapping between request ID and batch index.
Args:
req_id: Request identifier
idx: Batch index
"""
# Clean up old mapping if exists
if idx in self.idx_to_req_id:
old_req_id = self.idx_to_req_id[idx]
if old_req_id in self.req_id_to_idx:
del self.req_id_to_idx[old_req_id]
# Set new mapping
self.req_id_to_idx[req_id] = idx
self.idx_to_req_id[idx] = req_id
def start_request(self, idx: int, req_id: str, prompt_token_ids: list[int]):
"""
Start a new request in the suffix cache.
Args:
req_id: Request identifier
prompt_token_ids: List of prompt token IDs
"""
if req_id in self.suffix_cache.active_requests:
# Request already active, skip
return
prompt_array = np.array(prompt_token_ids, dtype=np.int32)
if not prompt_array.flags["CONTIGUOUS"]:
prompt_array = np.ascontiguousarray(prompt_array)
self.context_tokens[idx, :] = -1
self.context_tokens[idx, : len(prompt_token_ids)] = prompt_array
self._update_request_mapping(idx, req_id)
if req_id not in self.suffix_cache.active_requests:
if req_id in self.suffix_cache.cached_requests:
# Reset the suffix cache for current req_id
self.suffix_cache.evict_cached_response(req_id)
spec_logger.debug(f"[SuffixDecoding] Reset suffix cache for request {req_id}.")
self.suffix_cache.start_request(req_id, prompt_array)
spec_logger.debug(f"[SuffixDecoding] Start request {req_id}.")
def stop_request(self, req_id: str):
"""
Stop a request in the suffix cache.
Args:
req_id: Request identifier
"""
if req_id in self.suffix_cache.active_requests:
self.suffix_cache.stop_request(req_id)
# Clean up mappings
if req_id in self.req_id_to_idx:
idx = self.req_id_to_idx[req_id]
del self.req_id_to_idx[req_id]
if idx in self.idx_to_req_id:
del self.idx_to_req_id[idx]
spec_logger.debug(f"[SuffixDecoding] Stop request {req_id}.")
def add_active_response(self, req_id: str, token_ids: list[int]):
"""
Add newly sampled tokens to the suffix cache for a request.
Args:
req_id: Request identifier
token_ids: List of newly sampled token IDs
"""
if req_id not in self.suffix_cache.active_requests:
return
token_array = np.array(token_ids, dtype=np.int32)
if not token_array.flags["CONTIGUOUS"]:
token_array = np.ascontiguousarray(token_array)
self.suffix_cache.add_active_response(req_id, token_array)
def _run_impl(self, share_inputs):
stop_flags_cpu = share_inputs["stop_flags"].cpu().numpy().flatten()
is_block_step_cpu = share_inputs["is_block_step"].cpu().numpy().flatten()
accept_tokens_cpu = share_inputs["accept_tokens"].cpu()
accept_num_cpu = share_inputs["accept_num"].cpu().numpy().flatten()
seq_lens_encoder = share_inputs["seq_lens_encoder"].cpu().numpy().flatten()
seq_lens_decoder = share_inputs["seq_lens_decoder"].cpu().numpy().flatten()
draft_tokens_cpu = share_inputs["draft_tokens"].cpu()
seq_lens_this_time_cpu = share_inputs["seq_lens_this_time"].cpu()
total_lens = seq_lens_encoder + seq_lens_decoder
batch_size = seq_lens_this_time_cpu.shape[0]
for bid in range(batch_size):
req_id = self.idx_to_req_id.get(bid)
# 1. Stop condition has the highest priority
if stop_flags_cpu[bid]:
seq_lens_this_time_cpu[bid] = 0
draft_tokens_cpu[bid, :] = -1
if not is_block_step_cpu[bid]:
if req_id is not None and req_id in self.suffix_cache.active_requests:
self.stop_request(req_id)
continue
else:
seq_lens_this_time_cpu[bid] = 1
draft_tokens_cpu[bid, 1:] = -1
# 2. Skip some cases
num_tokens = total_lens[bid]
max_spec_tokens = min(
self.max_draft_token_num,
self.max_model_len - num_tokens - 1,
)
if max_spec_tokens <= 1:
continue
if req_id is None:
continue
# 3. Add accept tokens to context
acc_num = int(accept_num_cpu[bid])
assert (
acc_num > 0
), f"Request {req_id} (bid {bid}) must have at least one accepted token, but got {acc_num}."
if acc_num > 0:
token_ids = accept_tokens_cpu[bid, :acc_num]
ctx_start = seq_lens_decoder[bid] - acc_num
self.context_tokens[bid, ctx_start : ctx_start + acc_num] = token_ids
self.add_active_response(req_id, token_ids)
# 4. Get context
start = max(0, num_tokens - self.max_tree_depth)
ctx = self.context_tokens[bid, start:num_tokens]
ctx = ctx[ctx >= 0]
if ctx.size == 0:
continue
if not ctx.flags["CONTIGUOUS"]:
ctx = np.ascontiguousarray(ctx, dtype=np.int32)
else:
ctx = ctx.astype(np.int32, copy=False)
# 5. Speculate
draft = self.suffix_cache.speculate(
req_id,
ctx,
max_spec_tokens=max_spec_tokens,
max_spec_factor=self.max_spec_factor,
min_token_prob=self.min_token_prob,
)
token_ids = draft.token_ids
counter = 0
for token in token_ids:
if token in self.ban_tokens:
break
else:
counter += 1
if counter > 0:
draft_tokens_cpu[bid, 1 : 1 + counter] = np.array(token_ids[:counter])
draft_tokens_cpu[bid, 1 + counter :] = -1
seq_lens_this_time_cpu[bid] = 1 + counter
share_inputs["draft_tokens"][:] = draft_tokens_cpu.cuda()
share_inputs["seq_lens_this_time"][:] = seq_lens_this_time_cpu.cuda()
+29 -12
View File
@@ -91,7 +91,7 @@ from fastdeploy.model_executor.pre_and_post_process import (
)
if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
from fastdeploy.spec_decode import MTPProposer, NgramProposer
from fastdeploy.spec_decode import MTPProposer, NgramProposer, SuffixProposer
import zmq
@@ -390,6 +390,8 @@ class GPUModelRunner(ModelRunnerBase):
self.device_id,
self.share_inputs,
)
elif self.speculative_method == "suffix":
self.proposer = SuffixProposer(self.fd_config)
else:
self.proposer = None
@@ -810,6 +812,13 @@ class GPUModelRunner(ModelRunnerBase):
self.forward_batch_reqs_list[idx] = request
has_prefill_task = True
if self.speculative_decoding and self.speculative_method == "suffix" and self.proposer is not None:
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
prompt_token_ids = request.prompt_token_ids
self.proposer.start_request(idx, request.request_id, prompt_token_ids)
# Routing Replay
if self.fd_config.routing_replay_config.enable_routing_replay:
if prefill_start_index == 0:
@@ -1063,6 +1072,15 @@ class GPUModelRunner(ModelRunnerBase):
else:
return default_value
# Start suffix decoding request if using suffix proposer
if self.speculative_decoding and self.speculative_method == "suffix" and self.proposer is not None:
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
prompt_token_ids = request.prompt_token_ids
self.proposer.start_request(request.request_id, prompt_token_ids)
self.proposer.update_request_mapping(request.request_id, idx)
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
@@ -1768,6 +1786,7 @@ class GPUModelRunner(ModelRunnerBase):
self,
hidden_states: paddle.Tensor,
model_output: paddle.Tensor,
batch_size: int,
accept_all_drafts=False,
reject_all_drafts=False,
) -> paddle.Tensor:
@@ -1876,8 +1895,7 @@ class GPUModelRunner(ModelRunnerBase):
is_dummy_run=True,
)
else:
self.proposer.run(share_inputs=self.share_inputs)
self.proposer.prepare_dummy_speculative_drafts(share_inputs=self.share_inputs, batch_size=batch_size)
return sampler_output
def _dummy_run(
@@ -1954,7 +1972,7 @@ class GPUModelRunner(ModelRunnerBase):
(self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None),
(self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None),
)
self._dummy_sampler_run(hidden_states, model_output, accept_all_drafts, reject_all_drafts)
self._dummy_sampler_run(hidden_states, model_output, batch_size, accept_all_drafts, reject_all_drafts)
# 7. Updata 'infer_seed' and step_cuda()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
@@ -2064,22 +2082,19 @@ class GPUModelRunner(ModelRunnerBase):
logger.info(
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
)
elif self.speculative_decoding and self.speculative_method == "mtp":
elif self.speculative_decoding:
# Capture Target Model without bsz 1
for capture_size in sorted(capture_sizes, reverse=True):
expected_decode_len = self.speculative_config.num_speculative_tokens * 2 + 1
self._dummy_run(
num_tokens=(
self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1)
if self.scheduler_config.splitwise_role == "decode"
else self.fd_config.get_max_chunk_tokens()
),
num_tokens=self.fd_config.get_max_chunk_tokens(),
batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)),
in_capturing=True,
expected_decode_len=self.speculative_config.num_speculative_tokens * 2 + 1,
expected_decode_len=expected_decode_len,
accept_all_drafts=True,
)
logger.info(
f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}"
f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{expected_decode_len}"
)
else:
for batch_size in sorted(capture_sizes, reverse=True):
@@ -2536,6 +2551,8 @@ class GPUModelRunner(ModelRunnerBase):
self.proposer.run(
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
)
elif self.speculative_method == "suffix":
self.proposer.run(share_inputs=self.share_inputs)
else:
self.proposer.run(share_inputs=self.share_inputs)
+2 -2
View File
@@ -267,7 +267,7 @@ class InputBatch:
self.accept_num = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32")
self.draft_tokens = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=-1,
fill_value=0,
dtype="int64",
)
@@ -293,7 +293,7 @@ class InputBatch:
# For V1_KVCACHE_SCHEDULER
self.step_draft_tokens = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=0,
fill_value=-1,
dtype="int64",
)
self.step_seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32")
+1
View File
@@ -48,3 +48,4 @@ p2pstore
py-cpuinfo
flashinfer-python-paddle
flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl
arctic_inference @ https://paddle-qa.bj.bcebos.com/ernie/arctic_inference-0.1.3-cp310-cp310-linux_x86_64.whl
@@ -25,6 +25,7 @@ from fastdeploy.config import (
GraphOptimizationConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
)
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
@@ -110,6 +111,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
cache_config = CacheConfig(args={})
scheduler_config.max_num_seqs = 1
parallel_config = ParallelConfig(args={})
speculative_config = SpeculativeConfig(args={})
model_config = Mock()
model_config.max_model_len = 5120
model_config.architectures = ["test_model"]
@@ -120,6 +122,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
cache_config=cache_config,
model_config=model_config,
parallel_config=parallel_config,
speculative_config=speculative_config,
)
# Run Test Case1
@@ -25,6 +25,7 @@ from fastdeploy.config import (
GraphOptimizationConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
)
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
@@ -103,6 +104,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
scheduler_config.max_num_seqs = 1
cache_config = CacheConfig({})
parallel_config = ParallelConfig(args={})
speculative_config = SpeculativeConfig(args={})
model_config = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
@@ -116,6 +118,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
cache_config=cache_config,
parallel_config=parallel_config,
model_config=model_config,
speculative_config=speculative_config,
test_mode=True,
)
@@ -27,6 +27,7 @@ from fastdeploy.config import (
GraphOptimizationConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
)
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
@@ -144,6 +145,7 @@ class TestGraphOptBackend(unittest.TestCase):
# Setup cache config
cache_config = CacheConfig({})
parallel_config = ParallelConfig(args={})
speculative_config = SpeculativeConfig(args={})
model_config = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
@@ -156,6 +158,7 @@ class TestGraphOptBackend(unittest.TestCase):
cache_config=cache_config,
parallel_config=parallel_config,
model_config=model_config,
speculative_config=speculative_config,
test_mode=True,
)
@@ -31,6 +31,7 @@ from fastdeploy.config import (
GraphOptimizationConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
)
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
@@ -97,6 +98,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs)
cache_config = CacheConfig({})
parallel_config = ParallelConfig(args={})
speculative_config = SpeculativeConfig(args={})
model_config = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
@@ -107,6 +109,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
cache_config=cache_config,
parallel_config=parallel_config,
model_config=model_config,
speculative_config=speculative_config,
test_mode=True,
)
+2 -3
View File
@@ -112,9 +112,8 @@ def py_update_attn_mask_offsets_op(
attn_mask_offsets_decoder[bid] = int(attn_mask_offsets_decoder[bid]) + seq_len_this
# speculative decoding: if seq_len_this > 1 then set decode_states_now[i] accordingly
if seq_len_this > 1:
for i in range(decode_states_len):
decode_now[i] = 0 if i < seq_len_this else -1
for i in range(decode_states_len):
decode_now[i] = 0 if i < seq_len_this and decode_now[i] != 1 else -1
# done decoder branch
continue
+141
View File
@@ -0,0 +1,141 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
import numpy as np
import paddle
from utils import FakeModelConfig, get_default_test_fd_config
from fastdeploy.config import SpeculativeConfig
from fastdeploy.spec_decode.suffix import SuffixProposer
class TestSuffixProposer(unittest.TestCase):
def setUp(self):
self.fd_config = get_default_test_fd_config()
self.fd_config.model_config = FakeModelConfig()
self.fd_config.model_config.max_model_len = 2048
self.fd_config.speculative_config = SpeculativeConfig({})
self.fd_config.speculative_config.method = "suffix"
self.fd_config.speculative_config.num_speculative_tokens = 4
self.fd_config.speculative_config.suffix_decoding_max_tree_depth = 64
self.fd_config.speculative_config.suffix_decoding_max_cached_requests = 4
self.fd_config.speculative_config.suffix_decoding_max_spec_factor = 1.0
self.fd_config.speculative_config.suffix_decoding_min_token_prob = 0.1
self.fd_config.scheduler_config.max_num_seqs = 4
bsz = self.fd_config.scheduler_config.max_num_seqs
max_draft_tokens = self.fd_config.speculative_config.num_speculative_tokens
self.share_inputs = {
"stop_flags": paddle.full([bsz, 1], fill_value=False, dtype="bool"),
"is_block_step": paddle.full([bsz], fill_value=False, dtype="bool"),
"accept_tokens": paddle.zeros([bsz, max_draft_tokens], dtype="int64"),
"accept_num": paddle.zeros([bsz], dtype="int32"),
"seq_lens_this_time": paddle.zeros([bsz, 1], dtype="int32"),
"seq_lens_encoder": paddle.zeros([bsz, 1], dtype="int32"),
"seq_lens_decoder": paddle.zeros([bsz, 1], dtype="int32"),
"draft_tokens": paddle.zeros([bsz, max_draft_tokens], dtype="int64"),
}
def test_start_and_stop_request(self):
proposer = SuffixProposer(self.fd_config)
idx = 0
req_id = "req-001"
prompt_token_ids = [1, 2, 3, 4]
proposer.start_request(idx, req_id, prompt_token_ids)
refs_context_tokens = np.full(
(self.fd_config.scheduler_config.max_num_seqs, self.fd_config.model_config.max_model_len),
-1,
dtype=np.int32,
)
refs_req_id_to_idx = {}
refs_idx_to_req_id = {}
refs_context_tokens[idx, : len(prompt_token_ids)] = prompt_token_ids
refs_req_id_to_idx[req_id] = idx
refs_idx_to_req_id[idx] = req_id
self.assertIsNotNone(proposer.suffix_cache)
np.testing.assert_array_equal(proposer.context_tokens, refs_context_tokens)
np.testing.assert_array_equal(proposer.req_id_to_idx, refs_req_id_to_idx)
np.testing.assert_array_equal(proposer.idx_to_req_id, refs_idx_to_req_id)
idx = 1
req_id = "req-002"
prompt_token_ids = [5, 6, 7, 8]
proposer.start_request(idx, req_id, prompt_token_ids)
refs_context_tokens[idx, : len(prompt_token_ids)] = prompt_token_ids
refs_req_id_to_idx[req_id] = idx
refs_idx_to_req_id[idx] = req_id
np.testing.assert_array_equal(proposer.context_tokens, refs_context_tokens)
np.testing.assert_array_equal(proposer.req_id_to_idx, refs_req_id_to_idx)
np.testing.assert_array_equal(proposer.idx_to_req_id, refs_idx_to_req_id)
proposer.stop_request("req-001")
refs_req_id_to_idx.pop("req-001")
refs_idx_to_req_id.pop(0)
np.testing.assert_array_equal(proposer.context_tokens, refs_context_tokens)
np.testing.assert_array_equal(proposer.req_id_to_idx, refs_req_id_to_idx)
np.testing.assert_array_equal(proposer.idx_to_req_id, refs_idx_to_req_id)
def test_propose(self):
self.share_inputs["accept_tokens"][:, :2] = 42
self.share_inputs["accept_num"][:] = 2
self.share_inputs["seq_lens_this_time"][:, :] = 2
self.share_inputs["seq_lens_encoder"][:, :] = 0
self.share_inputs["seq_lens_decoder"][:, :] = 100
self.share_inputs["draft_tokens"][:, :2] = 42
self.share_inputs["draft_tokens"][:, 2:] = 53
print(self.share_inputs)
proposer = SuffixProposer(self.fd_config)
ids = [0, 1, 2, 3]
req_ids = ["req-001", "req-002", "req-003", "req-004"]
prompt_token_ids_list = [
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
]
for idx, req_id, prompt_token_ids in zip(ids, req_ids, prompt_token_ids_list):
proposer.start_request(idx, req_id, prompt_token_ids)
proposer.run(self.share_inputs)
refs_draft_tokens = np.array(
[
[42, 42, -1, -1],
[42, 42, -1, -1],
[42, 42, -1, -1],
[42, 42, -1, -1],
],
dtype=np.int64,
)
refs_seq_lens_this_time = np.array([[2], [2], [2], [2]], dtype=np.int32)
np.testing.assert_array_equal(self.share_inputs["draft_tokens"].numpy(), refs_draft_tokens)
np.testing.assert_array_equal(self.share_inputs["seq_lens_this_time"].numpy(), refs_seq_lens_this_time)
if __name__ == "__main__":
unittest.main()