mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Unify Spec and non-spec branch (#6685)
* optimize spec-inference architecture * delete debug log * optimize spec_method usage && fix unit_test * add claude unit-test skill * fix some ugly bug * enhance robustness and bounds check * unify method & spec_method to method to avoid bug * activate CI * fix unit test * Unify logprobs computation for naive and speculative decoding, fix CUDA kernel * fix logprob bug && optimize verify kernel * fix exist_decode() judge
This commit is contained in:
@@ -14,26 +14,8 @@
|
||||
"""
|
||||
speculative decoding module
|
||||
"""
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from .base import Proposer
|
||||
from .mtp import MTPProposer
|
||||
from .types import SpecMethod, VerifyStrategy
|
||||
|
||||
# XPU is not support ngram proposer now
|
||||
if not current_platform.is_xpu():
|
||||
from .ngram import NgramProposer
|
||||
__all__ = ["Proposer", "MTPProposer", "NgramProposer"]
|
||||
|
||||
# Suffix proposer requires arctic_inference
|
||||
try:
|
||||
from .suffix import SuffixProposer
|
||||
|
||||
_suffix_proposer_available = True
|
||||
except ImportError:
|
||||
_suffix_proposer_available = False
|
||||
SuffixProposer = None
|
||||
|
||||
if _suffix_proposer_available:
|
||||
__all__ = ["Proposer", "MTPProposer", "NgramProposer", "SuffixProposer"]
|
||||
else:
|
||||
__all__ = ["Proposer", "MTPProposer", "NgramProposer"]
|
||||
__all__ = ["Proposer", "SpecMethod", "VerifyStrategy"]
|
||||
|
||||
@@ -16,14 +16,16 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import paddle.distributed as dist
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.utils import spec_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
|
||||
class Proposer(ABC):
|
||||
"""
|
||||
@@ -33,7 +35,7 @@ class Proposer(ABC):
|
||||
the speculative decoding framework
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
def __init__(self, fd_config: "FDConfig"):
|
||||
"""
|
||||
Init Speculative proposer
|
||||
"""
|
||||
|
||||
@@ -16,14 +16,13 @@
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request, RequestType
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
@@ -81,6 +80,9 @@ from fastdeploy.worker.input_batch import (
|
||||
|
||||
from .base import Proposer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
|
||||
class MTPProposer(Proposer):
|
||||
"""
|
||||
@@ -89,7 +91,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
fd_config: "FDConfig",
|
||||
main_model: ModelForCasualLM,
|
||||
local_rank: int,
|
||||
device_id: int, # physical device id
|
||||
@@ -724,7 +726,7 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["is_block_step"],
|
||||
self.target_model_inputs["draft_tokens"],
|
||||
self.num_model_steps,
|
||||
self.speculative_method in ["eagle", "mtp"],
|
||||
True,
|
||||
self.role == "prefill",
|
||||
use_v1_cache_scheduler,
|
||||
)
|
||||
|
||||
@@ -14,13 +14,17 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.ops.gpu import ngram_match
|
||||
|
||||
from .base import Proposer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
|
||||
class NgramProposer(Proposer):
|
||||
"""
|
||||
@@ -29,7 +33,7 @@ class NgramProposer(Proposer):
|
||||
Matching corresponding tokens in input and output as draft tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
def __init__(self, fd_config: "FDConfig"):
|
||||
super().__init__(fd_config)
|
||||
self.max_ngram_size = self.speculative_config.max_ngram_size
|
||||
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
|
||||
|
||||
@@ -14,13 +14,17 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.utils import spec_logger
|
||||
|
||||
from .base import Proposer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
try:
|
||||
from arctic_inference.suffix_decoding import SuffixDecodingCache
|
||||
except ImportError:
|
||||
@@ -34,7 +38,7 @@ class SuffixProposer(Proposer):
|
||||
Uses SuffixDecodingCache to generate draft tokens based on suffix tree matching.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
def __init__(self, fd_config: "FDConfig"):
|
||||
super().__init__(fd_config)
|
||||
|
||||
if SuffixDecodingCache is None:
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.spec_decode.base import Proposer
|
||||
|
||||
|
||||
class VerifyStrategy(int, Enum):
|
||||
"""Draft token verification strategy enum.
|
||||
|
||||
Used in verify_draft_tokens kernel to control how draft tokens are verified
|
||||
and how bonus/correction tokens are sampled.
|
||||
|
||||
Values match the kernel's internal constants:
|
||||
0 = TOPP: draft in top-p candidate set, stochastic sampling for bonus
|
||||
1 = GREEDY: draft == argmax, deterministic argmax for bonus
|
||||
2 = TARGET_MATCH: draft == target sampled token, use target sample
|
||||
"""
|
||||
|
||||
TOPP = 0
|
||||
GREEDY = 1
|
||||
TARGET_MATCH = 2
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "VerifyStrategy":
|
||||
"""Create VerifyStrategy from string with validation (case-insensitive).
|
||||
|
||||
Args:
|
||||
value: Strategy name (e.g., "topp", "GREEDY", "Target_Match")
|
||||
|
||||
Returns:
|
||||
VerifyStrategy enum value
|
||||
|
||||
Raises:
|
||||
ValueError: If the strategy name is not recognized
|
||||
TypeError: If value is not a string
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f"Expected string input for VerifyStrategy.from_string(), "
|
||||
f"but got {type(value).__name__}: {value}. "
|
||||
f"If you have an int value, use VerifyStrategy(value) directly."
|
||||
)
|
||||
try:
|
||||
return cls[value.upper()]
|
||||
except KeyError:
|
||||
valid_names = [s.name for s in cls]
|
||||
raise ValueError(
|
||||
f"Invalid verify strategy '{value}'. " f"Must be one of: {valid_names} (case-insensitive)"
|
||||
)
|
||||
|
||||
|
||||
class SpecMethod(str, Enum):
|
||||
"""Speculative decoding method enum.
|
||||
|
||||
Value is the config string passed via --speculative-config '{"method": "mtp"}'.
|
||||
"""
|
||||
|
||||
NAIVE = "naive"
|
||||
MTP = "mtp"
|
||||
NGRAM = "ngram"
|
||||
SUFFIX = "suffix"
|
||||
|
||||
def create_proposer(self, fd_config, **kwargs) -> Optional["Proposer"]:
|
||||
"""Factory method: create the appropriate Proposer for this method.
|
||||
|
||||
Args:
|
||||
fd_config: FDConfig instance.
|
||||
**kwargs: Method-specific args forwarded to the Proposer constructor.
|
||||
MTP requires: main_model, local_rank, device_id, share_inputs.
|
||||
|
||||
Returns:
|
||||
Proposer instance, or None for NAIVE.
|
||||
"""
|
||||
if self == SpecMethod.NAIVE:
|
||||
return None
|
||||
elif self == SpecMethod.MTP:
|
||||
from fastdeploy.spec_decode.mtp import MTPProposer
|
||||
|
||||
return MTPProposer(
|
||||
fd_config,
|
||||
kwargs["main_model"],
|
||||
kwargs["local_rank"],
|
||||
kwargs["device_id"],
|
||||
kwargs["share_inputs"],
|
||||
)
|
||||
elif self == SpecMethod.NGRAM:
|
||||
from fastdeploy.spec_decode.ngram import NgramProposer
|
||||
|
||||
return NgramProposer(fd_config)
|
||||
elif self == SpecMethod.SUFFIX:
|
||||
from fastdeploy.spec_decode.suffix import SuffixProposer
|
||||
|
||||
return SuffixProposer(fd_config)
|
||||
|
||||
@property
|
||||
def needs_proposer(self) -> bool:
|
||||
"""Whether this method requires a proposer model."""
|
||||
return self != SpecMethod.NAIVE
|
||||
|
||||
@property
|
||||
def needs_kv_cache(self) -> bool:
|
||||
"""Whether the proposer needs its own KV cache layer."""
|
||||
return self == SpecMethod.MTP
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "SpecMethod":
|
||||
"""Create SpecMethod from string with validation (case-insensitive).
|
||||
|
||||
Args:
|
||||
value: Method name (e.g., "mtp", "NGRAM", "Naive")
|
||||
|
||||
Returns:
|
||||
SpecMethod enum value
|
||||
|
||||
Raises:
|
||||
ValueError: If the method name is not recognized
|
||||
TypeError: If value is not a string
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f"Expected string input for SpecMethod.from_string(), "
|
||||
f"but got {type(value).__name__}: {value}. "
|
||||
f"If you have an enum value, use SpecMethod(value) directly."
|
||||
)
|
||||
# Backward-compatible aliases
|
||||
ALIASES = {"ngram_match": "ngram"}
|
||||
normalized = ALIASES.get(value.lower(), value.lower())
|
||||
try:
|
||||
return cls(normalized)
|
||||
except ValueError:
|
||||
valid_names = [m.value for m in cls]
|
||||
raise ValueError(
|
||||
f"Invalid speculative method '{value}'. " f"Must be one of: {valid_names} (case-insensitive)"
|
||||
)
|
||||
Reference in New Issue
Block a user