mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Unify Spec and non-spec branch (#6685)
* optimize spec-inference architecture * delete debug log * optimize spec_method usage && fix unit_test * add claude unit-test skill * fix some ugly bug * enhance robustness and bounds check * unify method & spec_method to method to avoid bug * activate CI * fix unit test * Unify logprobs computation for naive and speculative decoding, fix CUDA kernel * fix logprob bug && optimize verify kernel * fix exist_decode() judge
This commit is contained in:
@@ -45,6 +45,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.spec_decode import SpecMethod
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -143,10 +144,10 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
if fd_config.speculative_config.model_type != "main":
|
||||
self.rope_3d = False
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.speculative_method: str = fd_config.speculative_config.method
|
||||
self.speculative_method = fd_config.speculative_config.method
|
||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
|
||||
@@ -68,6 +68,8 @@ else:
|
||||
merge_prefill_decode_output = None
|
||||
|
||||
|
||||
from fastdeploy.spec_decode import SpecMethod
|
||||
|
||||
FLASH_ATTN_VERSION = None
|
||||
|
||||
|
||||
@@ -255,7 +257,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.use_speculate = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
|
||||
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.spec_decode import SpecMethod
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
|
||||
@@ -106,7 +107,7 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
self.use_speculate = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
|
||||
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionMetadata,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
from fastdeploy.spec_decode import SpecMethod
|
||||
|
||||
|
||||
@triton.jit()
|
||||
@@ -257,11 +258,11 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.speculative_method: str = fd_config.speculative_config.method
|
||||
self.speculative_method = fd_config.speculative_config.method
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
|
||||
|
||||
self.num_heads: int = num_heads
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
|
||||
Reference in New Issue
Block a user