[Speculative Decoding] Unify Spec and non-spec branch (#6685)

* optimize spec-inference architecture

* delete debug log

* optimize spec_method usage  && fix unit_test

* add claude unit-test skill

* fix some ugly bug

* enhance robustness and bounds check

* unify method & spec_method to method to avoid bug

* activate CI

* fix unit test

* Unify logprobs computation for naive and speculative decoding, fix CUDA kernel

* fix logprob bug && optimize verify kernel

* fix exist_decode() judge
This commit is contained in:
freeliuzc
2026-03-11 14:58:44 +08:00
committed by GitHub
parent b6190de557
commit cf7934a4b2
41 changed files with 3428 additions and 392 deletions
@@ -45,6 +45,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.platforms import current_platform
from fastdeploy.spec_decode import SpecMethod
@dataclass
@@ -143,10 +144,10 @@ class AppendAttentionBackend(AttentionBackend):
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method
self.speculative_method = fd_config.speculative_config.method
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
@@ -68,6 +68,8 @@ else:
merge_prefill_decode_output = None
from fastdeploy.spec_decode import SpecMethod
FLASH_ATTN_VERSION = None
@@ -255,7 +257,7 @@ class FlashAttentionBackend(AttentionBackend):
self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
@@ -44,6 +44,7 @@ if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.platforms import current_platform
from fastdeploy.spec_decode import SpecMethod
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
@@ -106,7 +107,7 @@ class FlashMaskAttentionBackend(AttentionBackend):
self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
@@ -60,6 +60,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.spec_decode import SpecMethod
@triton.jit()
@@ -257,11 +258,11 @@ class MLAAttentionBackend(AttentionBackend):
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method
self.speculative_method = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
self.num_heads: int = num_heads
self.head_dim: int = fd_config.model_config.head_dim