mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Support suffix decoding (#6403)
* support suffix decoding
This commit is contained in:
@@ -66,13 +66,11 @@ __global__ void update_attn_mask_offsets_kernel(
|
|||||||
attn_mask_offsets_decoder[bid] += seq_len_this_time;
|
attn_mask_offsets_decoder[bid] += seq_len_this_time;
|
||||||
|
|
||||||
// Speculative decoding in text_generation
|
// Speculative decoding in text_generation
|
||||||
if (seq_len_this_time > 1) {
|
for (int i = 0; i < decode_states_len; i++) {
|
||||||
for (int i = 0; i < decode_states_len; i++) {
|
if (i < seq_len_this_time && decode_states_now[i] != 1) {
|
||||||
if (i < seq_len_this_time) {
|
decode_states_now[i] = 0;
|
||||||
decode_states_now[i] = 0;
|
} else {
|
||||||
} else {
|
decode_states_now[i] = -1;
|
||||||
decode_states_now[i] = -1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ This project implements an efficient **Speculative Decoding** inference framewor
|
|||||||
|
|
||||||
- **Ngram**
|
- **Ngram**
|
||||||
|
|
||||||
|
- **Suffix Decoding**
|
||||||
|
|
||||||
- **MTP (Multi-Token Prediction)**
|
- **MTP (Multi-Token Prediction)**
|
||||||
- ✅ Supported: TP Sharding
|
- ✅ Supported: TP Sharding
|
||||||
- ✅ Supported: Shared Prefix
|
- ✅ Supported: Shared Prefix
|
||||||
@@ -52,7 +54,7 @@ This project implements an efficient **Speculative Decoding** inference framewor
|
|||||||
|
|
||||||
## 🔧 Configuration Parameters
|
## 🔧 Configuration Parameters
|
||||||
|
|
||||||
- `method`: The speculative decoding strategy, currently supports `["mtp", "ngram"]`.
|
- `method`: The speculative decoding strategy, currently supports `["mtp", "ngram", "suffix"]`.
|
||||||
- `num_speculative_tokens`: Number of speculative tokens to generate; max is 5, currently MTP supports only 1.
|
- `num_speculative_tokens`: Number of speculative tokens to generate; max is 5, currently MTP supports only 1.
|
||||||
- `model`: Path to the MTP draft model when using the `"mtp"` method.
|
- `model`: Path to the MTP draft model when using the `"mtp"` method.
|
||||||
- `quantization`: Quantization method of the MTP model (e.g., WINT4).
|
- `quantization`: Quantization method of the MTP model (e.g., WINT4).
|
||||||
@@ -162,3 +164,35 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
|||||||
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}'
|
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}'
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 🌲 Using Suffix Decoding
|
||||||
|
|
||||||
|
Suffix Decoding is a model-free speculative decoding method that accelerates repetitive inference tasks (e.g., agent workflows, coding) using efficient CPU-based suffix trees for rapid draft token prediction, eliminating GPU overhead.
|
||||||
|
|
||||||
|
Run on 4 × H100 GPUs with WINT4 quantization:
|
||||||
|
|
||||||
|
> Config file: benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m fastdeploy.entrypoints.openai.api_server \
|
||||||
|
--model ${path_to_main_model} \
|
||||||
|
--tensor-parallel-size 4 \
|
||||||
|
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
|
||||||
|
--speculative-config '{"method": "mtp", "num_speculative_tokens": 4, "suffix_decoding_max_tree_depth": 64, "suffix_decoding_max_cached_requests": 10000, "suffix_decoding_max_spec_factor": 1.0, "suffix_decoding_min_token_prob": 0.1}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Parameter Descriptions
|
||||||
|
|
||||||
|
```
|
||||||
|
# The maximum length of token sequences cached in suffix trees.
|
||||||
|
self.suffix_decoding_max_tree_depth: int = 64
|
||||||
|
|
||||||
|
# The limits of requests that can be stored in the cache.
|
||||||
|
self.suffix_decoding_max_cached_requests: int = -1
|
||||||
|
|
||||||
|
# The factor of matched length, calculated as num_draft_tokens = suffix_max_spec_factor * matched_length
|
||||||
|
self.suffix_decoding_max_spec_factor: float = 1.0
|
||||||
|
|
||||||
|
# The probability threshold for speculated tokens.
|
||||||
|
self.suffix_decoding_min_token_prob: float = 0.1
|
||||||
|
```
|
||||||
|
|||||||
@@ -8,6 +8,8 @@
|
|||||||
|
|
||||||
- **Ngram**
|
- **Ngram**
|
||||||
|
|
||||||
|
- **后缀解码**
|
||||||
|
|
||||||
- **MTP (Multi-Token Prediction)**
|
- **MTP (Multi-Token Prediction)**
|
||||||
- ✅ 已支持:TP 切分
|
- ✅ 已支持:TP 切分
|
||||||
- ✅ 已支持:共享前缀
|
- ✅ 已支持:共享前缀
|
||||||
@@ -36,7 +38,7 @@
|
|||||||
- **高效 DraftModel/MTP 框架**:开发多个融合 Cuda Kernel,统一完成模型类方法的前后处理,相比传统的循环、切片方法,性能高效且易维护
|
- **高效 DraftModel/MTP 框架**:开发多个融合 Cuda Kernel,统一完成模型类方法的前后处理,相比传统的循环、切片方法,性能高效且易维护
|
||||||
|
|
||||||
## 🔧 参数说明
|
## 🔧 参数说明
|
||||||
- `method`: 解码策略,可选值为 `"mtp"` 或 `"ngram"`
|
- `method`: 解码策略,可选值为 `"mtp"` 、 `"ngram"` 或 `"suffix"`
|
||||||
- `num_speculative_tokens`: 每轮预测的 Token 数,最大支持 5(当前 MTP 仅支持 1)
|
- `num_speculative_tokens`: 每轮预测的 Token 数,最大支持 5(当前 MTP 仅支持 1)
|
||||||
- `model`: 若选择 MTP,则需指定 MTP 模型路径
|
- `model`: 若选择 MTP,则需指定 MTP 模型路径
|
||||||
- `quantization`: 模型量化方式,推荐使用 `wint8`
|
- `quantization`: 模型量化方式,推荐使用 `wint8`
|
||||||
@@ -133,3 +135,35 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
|||||||
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
|
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
|
||||||
--speculative-config '{"method": "ngram", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}'
|
--speculative-config '{"method": "ngram", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 🌲 使用后缀解码 (Suffix Decoding)
|
||||||
|
|
||||||
|
后缀解码是一种无模型推理框架,通过在 CPU 上使用高效后缀树进行快速草稿 Token 预测,加速重复性推理任务(如代理工作流程、编码等),消除 GPU 开销。
|
||||||
|
|
||||||
|
使用 4×H100;量化方式选择 WINT4:
|
||||||
|
|
||||||
|
> 配置文件:benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m fastdeploy.entrypoints.openai.api_server \
|
||||||
|
--model ${path_to_main_model} \
|
||||||
|
--tensor-parallel-size 4 \
|
||||||
|
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
|
||||||
|
--speculative-config '{"method": "mtp", "num_speculative_tokens": 4, "suffix_decoding_max_tree_depth": 64, "suffix_decoding_max_cached_requests": 10000, "suffix_decoding_max_spec_factor": 1.0, "suffix_decoding_min_token_prob": 0.1}'
|
||||||
|
```
|
||||||
|
|
||||||
|
参数描述
|
||||||
|
|
||||||
|
```
|
||||||
|
# 后缀树中缓存的token序列的最大长度
|
||||||
|
self.suffix_decoding_max_tree_depth: int = 64
|
||||||
|
|
||||||
|
# 缓存中可存储的请求数量上限
|
||||||
|
self.suffix_decoding_max_cached_requests: int = -1
|
||||||
|
|
||||||
|
# 匹配长度的系数,计算方式为 num_draft_tokens = suffix_max_spec_factor * matched_length
|
||||||
|
self.suffix_decoding_max_spec_factor: float = 1.0
|
||||||
|
|
||||||
|
# 推测token的概率阈值
|
||||||
|
self.suffix_decoding_min_token_prob: float = 0.1
|
||||||
|
```
|
||||||
|
|||||||
+44
-3
@@ -721,7 +721,7 @@ class SpeculativeConfig:
|
|||||||
self,
|
self,
|
||||||
args,
|
args,
|
||||||
):
|
):
|
||||||
self.method_list = ["ngram_match", "mtp"]
|
self.method_list = ["ngram_match", "mtp", "suffix"]
|
||||||
self.mtp_strategy_list = ["default", "with_ngram"]
|
self.mtp_strategy_list = ["default", "with_ngram"]
|
||||||
|
|
||||||
# speculative method, choose in [None, "ngram_match", "mtp", "hybrid_mtp_ngram"]
|
# speculative method, choose in [None, "ngram_match", "mtp", "hybrid_mtp_ngram"]
|
||||||
@@ -739,6 +739,15 @@ class SpeculativeConfig:
|
|||||||
# ngram match
|
# ngram match
|
||||||
self.max_ngram_size: int = 5
|
self.max_ngram_size: int = 5
|
||||||
self.min_ngram_size: int = 2
|
self.min_ngram_size: int = 2
|
||||||
|
# Suffix Decoding
|
||||||
|
# The maximum length of token sequences cached in suffix trees.
|
||||||
|
self.suffix_decoding_max_tree_depth: int = 64
|
||||||
|
# The limits of requests that can be stored in the cache.
|
||||||
|
self.suffix_decoding_max_cached_requests: int = -1
|
||||||
|
# The factor of matched length, calculated as num_draft_tokens = suffix_max_spec_factor * matched_length
|
||||||
|
self.suffix_decoding_max_spec_factor: float = 1.0
|
||||||
|
# The probability threshold for speculated tokens.
|
||||||
|
self.suffix_decoding_min_token_prob: float = 0.1
|
||||||
# model for mtp/eagle/draft_model
|
# model for mtp/eagle/draft_model
|
||||||
self.model: Optional[str] = None
|
self.model: Optional[str] = None
|
||||||
# quantization of model
|
# quantization of model
|
||||||
@@ -939,6 +948,8 @@ class GraphOptimizationConfig:
|
|||||||
self.max_capture_size: int = None
|
self.max_capture_size: int = None
|
||||||
""" Record maps mapped from real shape to captured size to reduce runtime overhead """
|
""" Record maps mapped from real shape to captured size to reduce runtime overhead """
|
||||||
self.real_shape_to_captured_size: dict[int, int] = None
|
self.real_shape_to_captured_size: dict[int, int] = None
|
||||||
|
""" Record maps mapped from real batch size to captured size"""
|
||||||
|
self.real_bsz_to_captured_size: dict[int, int] = {}
|
||||||
""" Whether to use shared memory pool for multi capture_size """
|
""" Whether to use shared memory pool for multi capture_size """
|
||||||
self.use_unique_memory_pool: bool = True
|
self.use_unique_memory_pool: bool = True
|
||||||
""" Whether to use cudagraph for draft model."""
|
""" Whether to use cudagraph for draft model."""
|
||||||
@@ -1694,15 +1705,20 @@ class FDConfig:
|
|||||||
|
|
||||||
# Initialize cuda graph capture list
|
# Initialize cuda graph capture list
|
||||||
max_capture_shape = self.scheduler_config.max_num_seqs
|
max_capture_shape = self.scheduler_config.max_num_seqs
|
||||||
if self.speculative_config is not None and self.speculative_config.method == "mtp":
|
if self.speculative_config is not None and self.speculative_config.method in ["mtp", "suffix"]:
|
||||||
max_capture_shape = self.scheduler_config.max_num_seqs * (
|
max_capture_shape = self.scheduler_config.max_num_seqs * (
|
||||||
self.speculative_config.num_speculative_tokens + 1
|
self.speculative_config.num_speculative_tokens + 1
|
||||||
)
|
)
|
||||||
assert max_capture_shape % 2 == 0, "CUDAGraph only supports capturing even token nums in MTP scenarios."
|
assert max_capture_shape % 2 == 0, "CUDAGraph only supports capturing even token nums in MTP scenarios."
|
||||||
|
self.graph_opt_config.real_bsz_to_captured_size = {
|
||||||
|
k: 0 for k in range(1, self.scheduler_config.max_num_seqs + 1)
|
||||||
|
}
|
||||||
if self.graph_opt_config.cudagraph_only_prefill:
|
if self.graph_opt_config.cudagraph_only_prefill:
|
||||||
max_capture_shape = 512
|
max_capture_shape = 512
|
||||||
else:
|
else:
|
||||||
max_capture_shape = min(512, max_capture_shape)
|
max_capture_shape = (
|
||||||
|
max_capture_shape if self.speculative_config is not None else min(512, max_capture_shape)
|
||||||
|
)
|
||||||
|
|
||||||
max_capture_shape_prefill = graph_opt_config.max_capture_shape_prefill
|
max_capture_shape_prefill = graph_opt_config.max_capture_shape_prefill
|
||||||
|
|
||||||
@@ -1717,6 +1733,31 @@ class FDConfig:
|
|||||||
max_capture_shape_prefill=max_capture_shape_prefill,
|
max_capture_shape_prefill=max_capture_shape_prefill,
|
||||||
dec_token_per_query_per_step=dec_token_per_query_per_step,
|
dec_token_per_query_per_step=dec_token_per_query_per_step,
|
||||||
)
|
)
|
||||||
|
if self.speculative_config is not None and self.speculative_config.method in ["mtp", "suffix"]:
|
||||||
|
real_bsz_to_captured_size = {}
|
||||||
|
for capture_size in self.graph_opt_config.cudagraph_capture_sizes:
|
||||||
|
dummy_batch_size = int(capture_size / (self.speculative_config.num_speculative_tokens + 1))
|
||||||
|
real_bsz_to_captured_size[dummy_batch_size] = capture_size
|
||||||
|
|
||||||
|
def expand_bsz_map(real_bsz_to_captured_size):
|
||||||
|
"""
|
||||||
|
Expand a sparse batch size mapping into a dense one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
real_bsz_to_captured_size (dict): Sparse batch size to capture size mapping.
|
||||||
|
Returns:
|
||||||
|
dict: Dense batch size to capture size mapping.
|
||||||
|
"""
|
||||||
|
sorted_items = sorted(real_bsz_to_captured_size.items())
|
||||||
|
result = {}
|
||||||
|
prev_bsz = 0
|
||||||
|
for curr_bsz, cap in sorted_items:
|
||||||
|
for bsz in range(prev_bsz + 1, curr_bsz + 1):
|
||||||
|
result[bsz] = cap
|
||||||
|
prev_bsz = curr_bsz
|
||||||
|
return result
|
||||||
|
|
||||||
|
self.graph_opt_config.real_bsz_to_captured_size = expand_bsz_map(real_bsz_to_captured_size)
|
||||||
self.graph_opt_config.init_with_cudagrpah_size(
|
self.graph_opt_config.init_with_cudagrpah_size(
|
||||||
max_capture_size=max_capture_shape,
|
max_capture_size=max_capture_shape,
|
||||||
max_capture_shape_prefill=max_capture_shape_prefill,
|
max_capture_shape_prefill=max_capture_shape_prefill,
|
||||||
|
|||||||
@@ -119,6 +119,10 @@ class CudaGraphPiecewiseBackend:
|
|||||||
if self.fd_config.graph_opt_config.graph_opt_level > 0:
|
if self.fd_config.graph_opt_config.graph_opt_level > 0:
|
||||||
self.cuda_graph_manager = Dy2StCudaGraphManager()
|
self.cuda_graph_manager = Dy2StCudaGraphManager()
|
||||||
|
|
||||||
|
self.speculative_decoding = fd_config.speculative_config.method is not None
|
||||||
|
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
||||||
|
self.real_bsz_to_captured_size = fd_config.graph_opt_config.real_bsz_to_captured_size
|
||||||
|
|
||||||
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
|
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
|
||||||
|
|
||||||
if not entry.captured:
|
if not entry.captured:
|
||||||
@@ -153,7 +157,10 @@ class CudaGraphPiecewiseBackend:
|
|||||||
# Get real shape (total num tokens)
|
# Get real shape (total num tokens)
|
||||||
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
|
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
|
||||||
real_shape = ids_remove_padding.shape[0]
|
real_shape = ids_remove_padding.shape[0]
|
||||||
|
if self.speculative_decoding and all(self.real_bsz_to_captured_size.values()):
|
||||||
|
seq_lens_this_time: paddle.Tensor = kwargs["forward_meta"].seq_lens_this_time
|
||||||
|
num_running_requests = seq_lens_this_time.squeeze(axis=-1).nonzero(as_tuple=False)[-1].item() + 1
|
||||||
|
real_shape = self.real_bsz_to_captured_size[num_running_requests]
|
||||||
exist_prefill = kwargs["forward_meta"].exist_prefill
|
exist_prefill = kwargs["forward_meta"].exist_prefill
|
||||||
# Static split graph mode: use Static + CUDAGraph for prefill/mixed phase
|
# Static split graph mode: use Static + CUDAGraph for prefill/mixed phase
|
||||||
static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st
|
static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st
|
||||||
@@ -188,6 +195,9 @@ class CudaGraphPiecewiseBackend:
|
|||||||
|
|
||||||
# Capture a new cuda graph
|
# Capture a new cuda graph
|
||||||
if entry.cuda_graph is None:
|
if entry.cuda_graph is None:
|
||||||
|
assert (
|
||||||
|
real_shape == padding_real_shape
|
||||||
|
), f"real_shape:{real_shape} is not equal to padding_real_shape:{padding_real_shape} when capture new graph."
|
||||||
# Warmup the model
|
# Warmup the model
|
||||||
for n in range(entry.num_finished_warmup, self.warm_up_size):
|
for n in range(entry.num_finished_warmup, self.warm_up_size):
|
||||||
entry.num_finished_warmup += 1
|
entry.num_finished_warmup += 1
|
||||||
|
|||||||
@@ -787,6 +787,8 @@ class TokenProcessor:
|
|||||||
+ i * MAX_DRAFT_TOKENS
|
+ i * MAX_DRAFT_TOKENS
|
||||||
+ accept_num[i]
|
+ accept_num[i]
|
||||||
].tolist()
|
].tolist()
|
||||||
|
if accept_num[i] == 0:
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
token_id = int(tokens[i, 0])
|
token_id = int(tokens[i, 0])
|
||||||
token_ids = [token_id]
|
token_ids = [token_id]
|
||||||
|
|||||||
@@ -23,3 +23,17 @@ from .mtp import MTPProposer
|
|||||||
if not current_platform.is_xpu():
|
if not current_platform.is_xpu():
|
||||||
from .ngram import NgramProposer
|
from .ngram import NgramProposer
|
||||||
__all__ = ["Proposer", "MTPProposer", "NgramProposer"]
|
__all__ = ["Proposer", "MTPProposer", "NgramProposer"]
|
||||||
|
|
||||||
|
# Suffix proposer requires arctic_inference
|
||||||
|
try:
|
||||||
|
from .suffix import SuffixProposer
|
||||||
|
|
||||||
|
_suffix_proposer_available = True
|
||||||
|
except ImportError:
|
||||||
|
_suffix_proposer_available = False
|
||||||
|
SuffixProposer = None
|
||||||
|
|
||||||
|
if _suffix_proposer_available:
|
||||||
|
__all__ = ["Proposer", "MTPProposer", "NgramProposer", "SuffixProposer"]
|
||||||
|
else:
|
||||||
|
__all__ = ["Proposer", "MTPProposer", "NgramProposer"]
|
||||||
|
|||||||
@@ -96,3 +96,27 @@ class Proposer(ABC):
|
|||||||
bool: True if chunk prefill is enabled; False otherwise.
|
bool: True if chunk prefill is enabled; False otherwise.
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def prepare_dummy_speculative_drafts(
|
||||||
|
self,
|
||||||
|
share_inputs,
|
||||||
|
batch_size: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Construct a set of dummy draft tokens for CUDAGraph capture scenarios,
|
||||||
|
used only to stabilize shape/step count, with no requirement for semantic correctness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
share_inputs: share_inputs dict maintained by GPUModelRunner
|
||||||
|
batch_size: current batch_size for dummy_run
|
||||||
|
expected_decode_len: expected number of decode steps (must match what's passed to _dummy_run)
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_fake_drafts = self.max_draft_token_num
|
||||||
|
|
||||||
|
stop = share_inputs["stop_flags"][0].item()
|
||||||
|
if not stop:
|
||||||
|
share_inputs["draft_tokens"][:batch_size, :max_fake_drafts] = 5
|
||||||
|
share_inputs["seq_lens_this_time"][:batch_size] = max_fake_drafts + 1
|
||||||
|
else:
|
||||||
|
share_inputs["seq_lens_this_time"][:batch_size] = 0
|
||||||
|
|||||||
@@ -0,0 +1,230 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.utils import spec_logger
|
||||||
|
|
||||||
|
from .base import Proposer
|
||||||
|
|
||||||
|
try:
|
||||||
|
from arctic_inference.suffix_decoding import SuffixDecodingCache
|
||||||
|
except ImportError:
|
||||||
|
SuffixDecodingCache = None
|
||||||
|
|
||||||
|
|
||||||
|
class SuffixProposer(Proposer):
|
||||||
|
"""
|
||||||
|
Proposer for Suffix Decoding method.
|
||||||
|
|
||||||
|
Uses SuffixDecodingCache to generate draft tokens based on suffix tree matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fd_config: FDConfig):
|
||||||
|
super().__init__(fd_config)
|
||||||
|
|
||||||
|
if SuffixDecodingCache is None:
|
||||||
|
raise ImportError(
|
||||||
|
"arctic_inference.suffix_decoding is not available. Please install arctic-inference package."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize SuffixDecodingCache
|
||||||
|
self.suffix_cache = SuffixDecodingCache(
|
||||||
|
max_tree_depth=self.speculative_config.suffix_decoding_max_tree_depth,
|
||||||
|
max_cached_requests=self.speculative_config.suffix_decoding_max_cached_requests,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.max_tree_depth = self.speculative_config.suffix_decoding_max_tree_depth
|
||||||
|
self.max_spec_factor = self.speculative_config.suffix_decoding_max_spec_factor
|
||||||
|
self.min_token_prob = self.speculative_config.suffix_decoding_min_token_prob
|
||||||
|
|
||||||
|
# Track active requests: req_id -> idx mapping
|
||||||
|
self.req_id_to_idx = {}
|
||||||
|
self.idx_to_req_id = {}
|
||||||
|
self.context_tokens = np.full(
|
||||||
|
(self.max_num_seqs, self.max_model_len),
|
||||||
|
-1,
|
||||||
|
dtype=np.int32,
|
||||||
|
)
|
||||||
|
self.ban_tokens = set([101031, 101032, 101033])
|
||||||
|
|
||||||
|
def _update_request_mapping(self, idx: int, req_id: str):
|
||||||
|
"""
|
||||||
|
Update the mapping between request ID and batch index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req_id: Request identifier
|
||||||
|
idx: Batch index
|
||||||
|
"""
|
||||||
|
# Clean up old mapping if exists
|
||||||
|
if idx in self.idx_to_req_id:
|
||||||
|
old_req_id = self.idx_to_req_id[idx]
|
||||||
|
if old_req_id in self.req_id_to_idx:
|
||||||
|
del self.req_id_to_idx[old_req_id]
|
||||||
|
|
||||||
|
# Set new mapping
|
||||||
|
self.req_id_to_idx[req_id] = idx
|
||||||
|
self.idx_to_req_id[idx] = req_id
|
||||||
|
|
||||||
|
def start_request(self, idx: int, req_id: str, prompt_token_ids: list[int]):
|
||||||
|
"""
|
||||||
|
Start a new request in the suffix cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req_id: Request identifier
|
||||||
|
prompt_token_ids: List of prompt token IDs
|
||||||
|
"""
|
||||||
|
if req_id in self.suffix_cache.active_requests:
|
||||||
|
# Request already active, skip
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt_array = np.array(prompt_token_ids, dtype=np.int32)
|
||||||
|
if not prompt_array.flags["CONTIGUOUS"]:
|
||||||
|
prompt_array = np.ascontiguousarray(prompt_array)
|
||||||
|
|
||||||
|
self.context_tokens[idx, :] = -1
|
||||||
|
self.context_tokens[idx, : len(prompt_token_ids)] = prompt_array
|
||||||
|
self._update_request_mapping(idx, req_id)
|
||||||
|
if req_id not in self.suffix_cache.active_requests:
|
||||||
|
if req_id in self.suffix_cache.cached_requests:
|
||||||
|
# Reset the suffix cache for current req_id
|
||||||
|
self.suffix_cache.evict_cached_response(req_id)
|
||||||
|
spec_logger.debug(f"[SuffixDecoding] Reset suffix cache for request {req_id}.")
|
||||||
|
self.suffix_cache.start_request(req_id, prompt_array)
|
||||||
|
spec_logger.debug(f"[SuffixDecoding] Start request {req_id}.")
|
||||||
|
|
||||||
|
def stop_request(self, req_id: str):
|
||||||
|
"""
|
||||||
|
Stop a request in the suffix cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req_id: Request identifier
|
||||||
|
"""
|
||||||
|
if req_id in self.suffix_cache.active_requests:
|
||||||
|
self.suffix_cache.stop_request(req_id)
|
||||||
|
|
||||||
|
# Clean up mappings
|
||||||
|
if req_id in self.req_id_to_idx:
|
||||||
|
idx = self.req_id_to_idx[req_id]
|
||||||
|
|
||||||
|
del self.req_id_to_idx[req_id]
|
||||||
|
if idx in self.idx_to_req_id:
|
||||||
|
del self.idx_to_req_id[idx]
|
||||||
|
|
||||||
|
spec_logger.debug(f"[SuffixDecoding] Stop request {req_id}.")
|
||||||
|
|
||||||
|
def add_active_response(self, req_id: str, token_ids: list[int]):
|
||||||
|
"""
|
||||||
|
Add newly sampled tokens to the suffix cache for a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req_id: Request identifier
|
||||||
|
token_ids: List of newly sampled token IDs
|
||||||
|
"""
|
||||||
|
if req_id not in self.suffix_cache.active_requests:
|
||||||
|
return
|
||||||
|
|
||||||
|
token_array = np.array(token_ids, dtype=np.int32)
|
||||||
|
if not token_array.flags["CONTIGUOUS"]:
|
||||||
|
token_array = np.ascontiguousarray(token_array)
|
||||||
|
|
||||||
|
self.suffix_cache.add_active_response(req_id, token_array)
|
||||||
|
|
||||||
|
def _run_impl(self, share_inputs):
|
||||||
|
|
||||||
|
stop_flags_cpu = share_inputs["stop_flags"].cpu().numpy().flatten()
|
||||||
|
is_block_step_cpu = share_inputs["is_block_step"].cpu().numpy().flatten()
|
||||||
|
accept_tokens_cpu = share_inputs["accept_tokens"].cpu()
|
||||||
|
accept_num_cpu = share_inputs["accept_num"].cpu().numpy().flatten()
|
||||||
|
seq_lens_encoder = share_inputs["seq_lens_encoder"].cpu().numpy().flatten()
|
||||||
|
seq_lens_decoder = share_inputs["seq_lens_decoder"].cpu().numpy().flatten()
|
||||||
|
|
||||||
|
draft_tokens_cpu = share_inputs["draft_tokens"].cpu()
|
||||||
|
seq_lens_this_time_cpu = share_inputs["seq_lens_this_time"].cpu()
|
||||||
|
|
||||||
|
total_lens = seq_lens_encoder + seq_lens_decoder
|
||||||
|
batch_size = seq_lens_this_time_cpu.shape[0]
|
||||||
|
|
||||||
|
for bid in range(batch_size):
|
||||||
|
req_id = self.idx_to_req_id.get(bid)
|
||||||
|
# 1. Stop condition has the highest priority
|
||||||
|
if stop_flags_cpu[bid]:
|
||||||
|
seq_lens_this_time_cpu[bid] = 0
|
||||||
|
draft_tokens_cpu[bid, :] = -1
|
||||||
|
if not is_block_step_cpu[bid]:
|
||||||
|
if req_id is not None and req_id in self.suffix_cache.active_requests:
|
||||||
|
self.stop_request(req_id)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
seq_lens_this_time_cpu[bid] = 1
|
||||||
|
draft_tokens_cpu[bid, 1:] = -1
|
||||||
|
# 2. Skip some cases
|
||||||
|
num_tokens = total_lens[bid]
|
||||||
|
max_spec_tokens = min(
|
||||||
|
self.max_draft_token_num,
|
||||||
|
self.max_model_len - num_tokens - 1,
|
||||||
|
)
|
||||||
|
if max_spec_tokens <= 1:
|
||||||
|
continue
|
||||||
|
if req_id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 3. Add accept tokens to context
|
||||||
|
acc_num = int(accept_num_cpu[bid])
|
||||||
|
assert (
|
||||||
|
acc_num > 0
|
||||||
|
), f"Request {req_id} (bid {bid}) must have at least one accepted token, but got {acc_num}."
|
||||||
|
if acc_num > 0:
|
||||||
|
token_ids = accept_tokens_cpu[bid, :acc_num]
|
||||||
|
ctx_start = seq_lens_decoder[bid] - acc_num
|
||||||
|
self.context_tokens[bid, ctx_start : ctx_start + acc_num] = token_ids
|
||||||
|
self.add_active_response(req_id, token_ids)
|
||||||
|
|
||||||
|
# 4. Get context
|
||||||
|
start = max(0, num_tokens - self.max_tree_depth)
|
||||||
|
ctx = self.context_tokens[bid, start:num_tokens]
|
||||||
|
ctx = ctx[ctx >= 0]
|
||||||
|
if ctx.size == 0:
|
||||||
|
continue
|
||||||
|
if not ctx.flags["CONTIGUOUS"]:
|
||||||
|
ctx = np.ascontiguousarray(ctx, dtype=np.int32)
|
||||||
|
else:
|
||||||
|
ctx = ctx.astype(np.int32, copy=False)
|
||||||
|
|
||||||
|
# 5. Speculate
|
||||||
|
draft = self.suffix_cache.speculate(
|
||||||
|
req_id,
|
||||||
|
ctx,
|
||||||
|
max_spec_tokens=max_spec_tokens,
|
||||||
|
max_spec_factor=self.max_spec_factor,
|
||||||
|
min_token_prob=self.min_token_prob,
|
||||||
|
)
|
||||||
|
token_ids = draft.token_ids
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
for token in token_ids:
|
||||||
|
if token in self.ban_tokens:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
counter += 1
|
||||||
|
if counter > 0:
|
||||||
|
draft_tokens_cpu[bid, 1 : 1 + counter] = np.array(token_ids[:counter])
|
||||||
|
draft_tokens_cpu[bid, 1 + counter :] = -1
|
||||||
|
seq_lens_this_time_cpu[bid] = 1 + counter
|
||||||
|
|
||||||
|
share_inputs["draft_tokens"][:] = draft_tokens_cpu.cuda()
|
||||||
|
share_inputs["seq_lens_this_time"][:] = seq_lens_this_time_cpu.cuda()
|
||||||
@@ -91,7 +91,7 @@ from fastdeploy.model_executor.pre_and_post_process import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
|
if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
|
||||||
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
from fastdeploy.spec_decode import MTPProposer, NgramProposer, SuffixProposer
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
@@ -390,6 +390,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.device_id,
|
self.device_id,
|
||||||
self.share_inputs,
|
self.share_inputs,
|
||||||
)
|
)
|
||||||
|
elif self.speculative_method == "suffix":
|
||||||
|
self.proposer = SuffixProposer(self.fd_config)
|
||||||
else:
|
else:
|
||||||
self.proposer = None
|
self.proposer = None
|
||||||
|
|
||||||
@@ -810,6 +812,13 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.forward_batch_reqs_list[idx] = request
|
self.forward_batch_reqs_list[idx] = request
|
||||||
has_prefill_task = True
|
has_prefill_task = True
|
||||||
|
|
||||||
|
if self.speculative_decoding and self.speculative_method == "suffix" and self.proposer is not None:
|
||||||
|
if isinstance(request.prompt_token_ids, np.ndarray):
|
||||||
|
prompt_token_ids = request.prompt_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
prompt_token_ids = request.prompt_token_ids
|
||||||
|
self.proposer.start_request(idx, request.request_id, prompt_token_ids)
|
||||||
|
|
||||||
# Routing Replay
|
# Routing Replay
|
||||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||||
if prefill_start_index == 0:
|
if prefill_start_index == 0:
|
||||||
@@ -1063,6 +1072,15 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
else:
|
else:
|
||||||
return default_value
|
return default_value
|
||||||
|
|
||||||
|
# Start suffix decoding request if using suffix proposer
|
||||||
|
if self.speculative_decoding and self.speculative_method == "suffix" and self.proposer is not None:
|
||||||
|
if isinstance(request.prompt_token_ids, np.ndarray):
|
||||||
|
prompt_token_ids = request.prompt_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
prompt_token_ids = request.prompt_token_ids
|
||||||
|
self.proposer.start_request(request.request_id, prompt_token_ids)
|
||||||
|
self.proposer.update_request_mapping(request.request_id, idx)
|
||||||
|
|
||||||
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
|
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
|
||||||
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||||
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
|
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
|
||||||
@@ -1768,6 +1786,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self,
|
self,
|
||||||
hidden_states: paddle.Tensor,
|
hidden_states: paddle.Tensor,
|
||||||
model_output: paddle.Tensor,
|
model_output: paddle.Tensor,
|
||||||
|
batch_size: int,
|
||||||
accept_all_drafts=False,
|
accept_all_drafts=False,
|
||||||
reject_all_drafts=False,
|
reject_all_drafts=False,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
@@ -1876,8 +1895,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
is_dummy_run=True,
|
is_dummy_run=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.proposer.run(share_inputs=self.share_inputs)
|
self.proposer.prepare_dummy_speculative_drafts(share_inputs=self.share_inputs, batch_size=batch_size)
|
||||||
|
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
@@ -1954,7 +1972,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
(self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None),
|
(self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None),
|
||||||
(self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None),
|
(self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None),
|
||||||
)
|
)
|
||||||
self._dummy_sampler_run(hidden_states, model_output, accept_all_drafts, reject_all_drafts)
|
self._dummy_sampler_run(hidden_states, model_output, batch_size, accept_all_drafts, reject_all_drafts)
|
||||||
|
|
||||||
# 7. Updata 'infer_seed' and step_cuda()
|
# 7. Updata 'infer_seed' and step_cuda()
|
||||||
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
|
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
|
||||||
@@ -2064,22 +2082,19 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
|
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
|
||||||
)
|
)
|
||||||
elif self.speculative_decoding and self.speculative_method == "mtp":
|
elif self.speculative_decoding:
|
||||||
# Capture Target Model without bsz 1
|
# Capture Target Model without bsz 1
|
||||||
for capture_size in sorted(capture_sizes, reverse=True):
|
for capture_size in sorted(capture_sizes, reverse=True):
|
||||||
|
expected_decode_len = self.speculative_config.num_speculative_tokens * 2 + 1
|
||||||
self._dummy_run(
|
self._dummy_run(
|
||||||
num_tokens=(
|
num_tokens=self.fd_config.get_max_chunk_tokens(),
|
||||||
self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1)
|
|
||||||
if self.scheduler_config.splitwise_role == "decode"
|
|
||||||
else self.fd_config.get_max_chunk_tokens()
|
|
||||||
),
|
|
||||||
batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)),
|
batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)),
|
||||||
in_capturing=True,
|
in_capturing=True,
|
||||||
expected_decode_len=self.speculative_config.num_speculative_tokens * 2 + 1,
|
expected_decode_len=expected_decode_len,
|
||||||
accept_all_drafts=True,
|
accept_all_drafts=True,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}"
|
f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{expected_decode_len}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
for batch_size in sorted(capture_sizes, reverse=True):
|
for batch_size in sorted(capture_sizes, reverse=True):
|
||||||
@@ -2536,6 +2551,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.proposer.run(
|
self.proposer.run(
|
||||||
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
|
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
|
||||||
)
|
)
|
||||||
|
elif self.speculative_method == "suffix":
|
||||||
|
self.proposer.run(share_inputs=self.share_inputs)
|
||||||
else:
|
else:
|
||||||
self.proposer.run(share_inputs=self.share_inputs)
|
self.proposer.run(share_inputs=self.share_inputs)
|
||||||
|
|
||||||
|
|||||||
@@ -267,7 +267,7 @@ class InputBatch:
|
|||||||
self.accept_num = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32")
|
self.accept_num = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32")
|
||||||
self.draft_tokens = paddle.full(
|
self.draft_tokens = paddle.full(
|
||||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||||
fill_value=-1,
|
fill_value=0,
|
||||||
dtype="int64",
|
dtype="int64",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -293,7 +293,7 @@ class InputBatch:
|
|||||||
# For V1_KVCACHE_SCHEDULER
|
# For V1_KVCACHE_SCHEDULER
|
||||||
self.step_draft_tokens = paddle.full(
|
self.step_draft_tokens = paddle.full(
|
||||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||||
fill_value=0,
|
fill_value=-1,
|
||||||
dtype="int64",
|
dtype="int64",
|
||||||
)
|
)
|
||||||
self.step_seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.step_seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
|
|||||||
@@ -48,3 +48,4 @@ p2pstore
|
|||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
flashinfer-python-paddle
|
flashinfer-python-paddle
|
||||||
flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl
|
flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl
|
||||||
|
arctic_inference @ https://paddle-qa.bj.bcebos.com/ernie/arctic_inference-0.1.3-cp310-cp310-linux_x86_64.whl
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from fastdeploy.config import (
|
|||||||
GraphOptimizationConfig,
|
GraphOptimizationConfig,
|
||||||
ParallelConfig,
|
ParallelConfig,
|
||||||
SchedulerConfig,
|
SchedulerConfig,
|
||||||
|
SpeculativeConfig,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||||
@@ -110,6 +111,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
|
|||||||
cache_config = CacheConfig(args={})
|
cache_config = CacheConfig(args={})
|
||||||
scheduler_config.max_num_seqs = 1
|
scheduler_config.max_num_seqs = 1
|
||||||
parallel_config = ParallelConfig(args={})
|
parallel_config = ParallelConfig(args={})
|
||||||
|
speculative_config = SpeculativeConfig(args={})
|
||||||
model_config = Mock()
|
model_config = Mock()
|
||||||
model_config.max_model_len = 5120
|
model_config.max_model_len = 5120
|
||||||
model_config.architectures = ["test_model"]
|
model_config.architectures = ["test_model"]
|
||||||
@@ -120,6 +122,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
|
speculative_config=speculative_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run Test Case1
|
# Run Test Case1
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from fastdeploy.config import (
|
|||||||
GraphOptimizationConfig,
|
GraphOptimizationConfig,
|
||||||
ParallelConfig,
|
ParallelConfig,
|
||||||
SchedulerConfig,
|
SchedulerConfig,
|
||||||
|
SpeculativeConfig,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||||
@@ -103,6 +104,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
|
|||||||
scheduler_config.max_num_seqs = 1
|
scheduler_config.max_num_seqs = 1
|
||||||
cache_config = CacheConfig({})
|
cache_config = CacheConfig({})
|
||||||
parallel_config = ParallelConfig(args={})
|
parallel_config = ParallelConfig(args={})
|
||||||
|
speculative_config = SpeculativeConfig(args={})
|
||||||
model_config = Mock()
|
model_config = Mock()
|
||||||
model_config.max_model_len = 512
|
model_config.max_model_len = 512
|
||||||
model_config.architectures = ["test_model"]
|
model_config.architectures = ["test_model"]
|
||||||
@@ -116,6 +118,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
speculative_config=speculative_config,
|
||||||
test_mode=True,
|
test_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from fastdeploy.config import (
|
|||||||
GraphOptimizationConfig,
|
GraphOptimizationConfig,
|
||||||
ParallelConfig,
|
ParallelConfig,
|
||||||
SchedulerConfig,
|
SchedulerConfig,
|
||||||
|
SpeculativeConfig,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||||
@@ -144,6 +145,7 @@ class TestGraphOptBackend(unittest.TestCase):
|
|||||||
# Setup cache config
|
# Setup cache config
|
||||||
cache_config = CacheConfig({})
|
cache_config = CacheConfig({})
|
||||||
parallel_config = ParallelConfig(args={})
|
parallel_config = ParallelConfig(args={})
|
||||||
|
speculative_config = SpeculativeConfig(args={})
|
||||||
model_config = Mock()
|
model_config = Mock()
|
||||||
model_config.max_model_len = 512
|
model_config.max_model_len = 512
|
||||||
model_config.architectures = ["test_model"]
|
model_config.architectures = ["test_model"]
|
||||||
@@ -156,6 +158,7 @@ class TestGraphOptBackend(unittest.TestCase):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
speculative_config=speculative_config,
|
||||||
test_mode=True,
|
test_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from fastdeploy.config import (
|
|||||||
GraphOptimizationConfig,
|
GraphOptimizationConfig,
|
||||||
ParallelConfig,
|
ParallelConfig,
|
||||||
SchedulerConfig,
|
SchedulerConfig,
|
||||||
|
SpeculativeConfig,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||||
@@ -97,6 +98,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
|
|||||||
graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs)
|
graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs)
|
||||||
cache_config = CacheConfig({})
|
cache_config = CacheConfig({})
|
||||||
parallel_config = ParallelConfig(args={})
|
parallel_config = ParallelConfig(args={})
|
||||||
|
speculative_config = SpeculativeConfig(args={})
|
||||||
model_config = Mock()
|
model_config = Mock()
|
||||||
model_config.max_model_len = 512
|
model_config.max_model_len = 512
|
||||||
model_config.architectures = ["test_model"]
|
model_config.architectures = ["test_model"]
|
||||||
@@ -107,6 +109,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
speculative_config=speculative_config,
|
||||||
test_mode=True,
|
test_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -112,9 +112,8 @@ def py_update_attn_mask_offsets_op(
|
|||||||
attn_mask_offsets_decoder[bid] = int(attn_mask_offsets_decoder[bid]) + seq_len_this
|
attn_mask_offsets_decoder[bid] = int(attn_mask_offsets_decoder[bid]) + seq_len_this
|
||||||
|
|
||||||
# speculative decoding: if seq_len_this > 1 then set decode_states_now[i] accordingly
|
# speculative decoding: if seq_len_this > 1 then set decode_states_now[i] accordingly
|
||||||
if seq_len_this > 1:
|
for i in range(decode_states_len):
|
||||||
for i in range(decode_states_len):
|
decode_now[i] = 0 if i < seq_len_this and decode_now[i] != 1 else -1
|
||||||
decode_now[i] = 0 if i < seq_len_this else -1
|
|
||||||
# done decoder branch
|
# done decoder branch
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from utils import FakeModelConfig, get_default_test_fd_config
|
||||||
|
|
||||||
|
from fastdeploy.config import SpeculativeConfig
|
||||||
|
from fastdeploy.spec_decode.suffix import SuffixProposer
|
||||||
|
|
||||||
|
|
||||||
|
class TestSuffixProposer(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.fd_config = get_default_test_fd_config()
|
||||||
|
self.fd_config.model_config = FakeModelConfig()
|
||||||
|
self.fd_config.model_config.max_model_len = 2048
|
||||||
|
self.fd_config.speculative_config = SpeculativeConfig({})
|
||||||
|
self.fd_config.speculative_config.method = "suffix"
|
||||||
|
self.fd_config.speculative_config.num_speculative_tokens = 4
|
||||||
|
self.fd_config.speculative_config.suffix_decoding_max_tree_depth = 64
|
||||||
|
self.fd_config.speculative_config.suffix_decoding_max_cached_requests = 4
|
||||||
|
self.fd_config.speculative_config.suffix_decoding_max_spec_factor = 1.0
|
||||||
|
self.fd_config.speculative_config.suffix_decoding_min_token_prob = 0.1
|
||||||
|
self.fd_config.scheduler_config.max_num_seqs = 4
|
||||||
|
|
||||||
|
bsz = self.fd_config.scheduler_config.max_num_seqs
|
||||||
|
max_draft_tokens = self.fd_config.speculative_config.num_speculative_tokens
|
||||||
|
self.share_inputs = {
|
||||||
|
"stop_flags": paddle.full([bsz, 1], fill_value=False, dtype="bool"),
|
||||||
|
"is_block_step": paddle.full([bsz], fill_value=False, dtype="bool"),
|
||||||
|
"accept_tokens": paddle.zeros([bsz, max_draft_tokens], dtype="int64"),
|
||||||
|
"accept_num": paddle.zeros([bsz], dtype="int32"),
|
||||||
|
"seq_lens_this_time": paddle.zeros([bsz, 1], dtype="int32"),
|
||||||
|
"seq_lens_encoder": paddle.zeros([bsz, 1], dtype="int32"),
|
||||||
|
"seq_lens_decoder": paddle.zeros([bsz, 1], dtype="int32"),
|
||||||
|
"draft_tokens": paddle.zeros([bsz, max_draft_tokens], dtype="int64"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_start_and_stop_request(self):
|
||||||
|
proposer = SuffixProposer(self.fd_config)
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
req_id = "req-001"
|
||||||
|
prompt_token_ids = [1, 2, 3, 4]
|
||||||
|
proposer.start_request(idx, req_id, prompt_token_ids)
|
||||||
|
|
||||||
|
refs_context_tokens = np.full(
|
||||||
|
(self.fd_config.scheduler_config.max_num_seqs, self.fd_config.model_config.max_model_len),
|
||||||
|
-1,
|
||||||
|
dtype=np.int32,
|
||||||
|
)
|
||||||
|
refs_req_id_to_idx = {}
|
||||||
|
refs_idx_to_req_id = {}
|
||||||
|
refs_context_tokens[idx, : len(prompt_token_ids)] = prompt_token_ids
|
||||||
|
refs_req_id_to_idx[req_id] = idx
|
||||||
|
refs_idx_to_req_id[idx] = req_id
|
||||||
|
|
||||||
|
self.assertIsNotNone(proposer.suffix_cache)
|
||||||
|
np.testing.assert_array_equal(proposer.context_tokens, refs_context_tokens)
|
||||||
|
np.testing.assert_array_equal(proposer.req_id_to_idx, refs_req_id_to_idx)
|
||||||
|
np.testing.assert_array_equal(proposer.idx_to_req_id, refs_idx_to_req_id)
|
||||||
|
|
||||||
|
idx = 1
|
||||||
|
req_id = "req-002"
|
||||||
|
prompt_token_ids = [5, 6, 7, 8]
|
||||||
|
proposer.start_request(idx, req_id, prompt_token_ids)
|
||||||
|
|
||||||
|
refs_context_tokens[idx, : len(prompt_token_ids)] = prompt_token_ids
|
||||||
|
refs_req_id_to_idx[req_id] = idx
|
||||||
|
refs_idx_to_req_id[idx] = req_id
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(proposer.context_tokens, refs_context_tokens)
|
||||||
|
np.testing.assert_array_equal(proposer.req_id_to_idx, refs_req_id_to_idx)
|
||||||
|
np.testing.assert_array_equal(proposer.idx_to_req_id, refs_idx_to_req_id)
|
||||||
|
|
||||||
|
proposer.stop_request("req-001")
|
||||||
|
|
||||||
|
refs_req_id_to_idx.pop("req-001")
|
||||||
|
refs_idx_to_req_id.pop(0)
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(proposer.context_tokens, refs_context_tokens)
|
||||||
|
np.testing.assert_array_equal(proposer.req_id_to_idx, refs_req_id_to_idx)
|
||||||
|
np.testing.assert_array_equal(proposer.idx_to_req_id, refs_idx_to_req_id)
|
||||||
|
|
||||||
|
def test_propose(self):
|
||||||
|
|
||||||
|
self.share_inputs["accept_tokens"][:, :2] = 42
|
||||||
|
self.share_inputs["accept_num"][:] = 2
|
||||||
|
self.share_inputs["seq_lens_this_time"][:, :] = 2
|
||||||
|
self.share_inputs["seq_lens_encoder"][:, :] = 0
|
||||||
|
self.share_inputs["seq_lens_decoder"][:, :] = 100
|
||||||
|
self.share_inputs["draft_tokens"][:, :2] = 42
|
||||||
|
self.share_inputs["draft_tokens"][:, 2:] = 53
|
||||||
|
print(self.share_inputs)
|
||||||
|
|
||||||
|
proposer = SuffixProposer(self.fd_config)
|
||||||
|
ids = [0, 1, 2, 3]
|
||||||
|
req_ids = ["req-001", "req-002", "req-003", "req-004"]
|
||||||
|
prompt_token_ids_list = [
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8],
|
||||||
|
[9, 10, 11, 12],
|
||||||
|
[13, 14, 15, 16],
|
||||||
|
]
|
||||||
|
for idx, req_id, prompt_token_ids in zip(ids, req_ids, prompt_token_ids_list):
|
||||||
|
proposer.start_request(idx, req_id, prompt_token_ids)
|
||||||
|
|
||||||
|
proposer.run(self.share_inputs)
|
||||||
|
|
||||||
|
refs_draft_tokens = np.array(
|
||||||
|
[
|
||||||
|
[42, 42, -1, -1],
|
||||||
|
[42, 42, -1, -1],
|
||||||
|
[42, 42, -1, -1],
|
||||||
|
[42, 42, -1, -1],
|
||||||
|
],
|
||||||
|
dtype=np.int64,
|
||||||
|
)
|
||||||
|
refs_seq_lens_this_time = np.array([[2], [2], [2], [2]], dtype=np.int32)
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(self.share_inputs["draft_tokens"].numpy(), refs_draft_tokens)
|
||||||
|
np.testing.assert_array_equal(self.share_inputs["seq_lens_this_time"].numpy(), refs_seq_lens_this_time)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user