[Speculative Decoding] Unify Spec and non-spec branch (#6685)

* optimize spec-inference architecture

* delete debug log

* optimize spec_method usage  && fix unit_test

* add claude unit-test skill

* fix some ugly bug

* enhance robustness and bounds check

* unify method & spec_method to method to avoid bug

* activate CI

* fix unit test

* Unify logprobs computation for naive and speculative decoding, fix CUDA kernel

* fix logprob bug && optimize verify kernel

* fix exist_decode() judge
This commit is contained in:
freeliuzc
2026-03-11 14:58:44 +08:00
committed by GitHub
parent b6190de557
commit cf7934a4b2
41 changed files with 3428 additions and 392 deletions
+2 -20
View File
@@ -14,26 +14,8 @@
"""
speculative decoding module
"""
from fastdeploy.platforms import current_platform
from .base import Proposer
from .mtp import MTPProposer
from .types import SpecMethod, VerifyStrategy
# XPU is not support ngram proposer now
if not current_platform.is_xpu():
from .ngram import NgramProposer
__all__ = ["Proposer", "MTPProposer", "NgramProposer"]
# Suffix proposer requires arctic_inference
try:
from .suffix import SuffixProposer
_suffix_proposer_available = True
except ImportError:
_suffix_proposer_available = False
SuffixProposer = None
if _suffix_proposer_available:
__all__ = ["Proposer", "MTPProposer", "NgramProposer", "SuffixProposer"]
else:
__all__ = ["Proposer", "MTPProposer", "NgramProposer"]
__all__ = ["Proposer", "SpecMethod", "VerifyStrategy"]
+5 -3
View File
@@ -16,14 +16,16 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any
from typing import TYPE_CHECKING, Any
import paddle.distributed as dist
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.utils import spec_logger
if TYPE_CHECKING:
from fastdeploy.config import FDConfig
class Proposer(ABC):
"""
@@ -33,7 +35,7 @@ class Proposer(ABC):
the speculative decoding framework
"""
def __init__(self, fd_config: FDConfig):
def __init__(self, fd_config: "FDConfig"):
"""
Init Speculative proposer
"""
+6 -4
View File
@@ -16,14 +16,13 @@
import os
import time
from typing import List
from typing import TYPE_CHECKING, List
import numpy as np
import paddle
from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request, RequestType
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.forward_meta import ForwardMeta
@@ -81,6 +80,9 @@ from fastdeploy.worker.input_batch import (
from .base import Proposer
if TYPE_CHECKING:
from fastdeploy.config import FDConfig
class MTPProposer(Proposer):
"""
@@ -89,7 +91,7 @@ class MTPProposer(Proposer):
def __init__(
self,
fd_config: FDConfig,
fd_config: "FDConfig",
main_model: ModelForCasualLM,
local_rank: int,
device_id: int, # physical device id
@@ -724,7 +726,7 @@ class MTPProposer(Proposer):
self.target_model_inputs["is_block_step"],
self.target_model_inputs["draft_tokens"],
self.num_model_steps,
self.speculative_method in ["eagle", "mtp"],
True,
self.role == "prefill",
use_v1_cache_scheduler,
)
+6 -2
View File
@@ -14,13 +14,17 @@
# limitations under the License.
"""
from typing import TYPE_CHECKING
import paddle
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.ops.gpu import ngram_match
from .base import Proposer
if TYPE_CHECKING:
from fastdeploy.config import FDConfig
class NgramProposer(Proposer):
"""
@@ -29,7 +33,7 @@ class NgramProposer(Proposer):
Matching corresponding tokens in input and output as draft tokens.
"""
def __init__(self, fd_config: FDConfig):
def __init__(self, fd_config: "FDConfig"):
super().__init__(fd_config)
self.max_ngram_size = self.speculative_config.max_ngram_size
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
+6 -2
View File
@@ -14,13 +14,17 @@
# limitations under the License.
"""
from typing import TYPE_CHECKING
import numpy as np
from fastdeploy.config import FDConfig
from fastdeploy.utils import spec_logger
from .base import Proposer
if TYPE_CHECKING:
from fastdeploy.config import FDConfig
try:
from arctic_inference.suffix_decoding import SuffixDecodingCache
except ImportError:
@@ -34,7 +38,7 @@ class SuffixProposer(Proposer):
Uses SuffixDecodingCache to generate draft tokens based on suffix tree matching.
"""
def __init__(self, fd_config: FDConfig):
def __init__(self, fd_config: "FDConfig"):
super().__init__(fd_config)
if SuffixDecodingCache is None:
+151
View File
@@ -0,0 +1,151 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from enum import Enum
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from fastdeploy.spec_decode.base import Proposer
class VerifyStrategy(int, Enum):
"""Draft token verification strategy enum.
Used in verify_draft_tokens kernel to control how draft tokens are verified
and how bonus/correction tokens are sampled.
Values match the kernel's internal constants:
0 = TOPP: draft in top-p candidate set, stochastic sampling for bonus
1 = GREEDY: draft == argmax, deterministic argmax for bonus
2 = TARGET_MATCH: draft == target sampled token, use target sample
"""
TOPP = 0
GREEDY = 1
TARGET_MATCH = 2
@classmethod
def from_string(cls, value: str) -> "VerifyStrategy":
"""Create VerifyStrategy from string with validation (case-insensitive).
Args:
value: Strategy name (e.g., "topp", "GREEDY", "Target_Match")
Returns:
VerifyStrategy enum value
Raises:
ValueError: If the strategy name is not recognized
TypeError: If value is not a string
"""
if not isinstance(value, str):
raise TypeError(
f"Expected string input for VerifyStrategy.from_string(), "
f"but got {type(value).__name__}: {value}. "
f"If you have an int value, use VerifyStrategy(value) directly."
)
try:
return cls[value.upper()]
except KeyError:
valid_names = [s.name for s in cls]
raise ValueError(
f"Invalid verify strategy '{value}'. " f"Must be one of: {valid_names} (case-insensitive)"
)
class SpecMethod(str, Enum):
"""Speculative decoding method enum.
Value is the config string passed via --speculative-config '{"method": "mtp"}'.
"""
NAIVE = "naive"
MTP = "mtp"
NGRAM = "ngram"
SUFFIX = "suffix"
def create_proposer(self, fd_config, **kwargs) -> Optional["Proposer"]:
"""Factory method: create the appropriate Proposer for this method.
Args:
fd_config: FDConfig instance.
**kwargs: Method-specific args forwarded to the Proposer constructor.
MTP requires: main_model, local_rank, device_id, share_inputs.
Returns:
Proposer instance, or None for NAIVE.
"""
if self == SpecMethod.NAIVE:
return None
elif self == SpecMethod.MTP:
from fastdeploy.spec_decode.mtp import MTPProposer
return MTPProposer(
fd_config,
kwargs["main_model"],
kwargs["local_rank"],
kwargs["device_id"],
kwargs["share_inputs"],
)
elif self == SpecMethod.NGRAM:
from fastdeploy.spec_decode.ngram import NgramProposer
return NgramProposer(fd_config)
elif self == SpecMethod.SUFFIX:
from fastdeploy.spec_decode.suffix import SuffixProposer
return SuffixProposer(fd_config)
@property
def needs_proposer(self) -> bool:
"""Whether this method requires a proposer model."""
return self != SpecMethod.NAIVE
@property
def needs_kv_cache(self) -> bool:
"""Whether the proposer needs its own KV cache layer."""
return self == SpecMethod.MTP
@classmethod
def from_string(cls, value: str) -> "SpecMethod":
"""Create SpecMethod from string with validation (case-insensitive).
Args:
value: Method name (e.g., "mtp", "NGRAM", "Naive")
Returns:
SpecMethod enum value
Raises:
ValueError: If the method name is not recognized
TypeError: If value is not a string
"""
if not isinstance(value, str):
raise TypeError(
f"Expected string input for SpecMethod.from_string(), "
f"but got {type(value).__name__}: {value}. "
f"If you have an enum value, use SpecMethod(value) directly."
)
# Backward-compatible aliases
ALIASES = {"ngram_match": "ngram"}
normalized = ALIASES.get(value.lower(), value.lower())
try:
return cls(normalized)
except ValueError:
valid_names = [m.value for m in cls]
raise ValueError(
f"Invalid speculative method '{value}'. " f"Must be one of: {valid_names} (case-insensitive)"
)