[Speculative Decoding] Support mtp expert-parallel and support different modality deploy (#7018)

* support mtp ep and support different modality

* fix default arg
This commit is contained in:
freeliuzc
2026-03-26 13:52:16 +08:00
committed by GitHub
parent 61ebac49ef
commit 4fd877ed43
10 changed files with 112 additions and 19 deletions
+24 -14
View File
@@ -17,7 +17,13 @@
import paddle
from paddleformers.utils.log import logger
from fastdeploy.config import CacheConfig, FDConfig, ModelConfig, SpeculativeConfig
from fastdeploy.config import (
CacheConfig,
DeployModality,
FDConfig,
ModelConfig,
SpeculativeConfig,
)
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.logits_processor import build_logits_processors
from fastdeploy.platforms import current_platform
@@ -829,20 +835,23 @@ class ProposerInputBatch(InputBatch):
)
# attn_mask
if self.enable_mm:
self.attn_mask_offsets = paddle.full(
shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len],
fill_value=-1,
dtype="int32",
)
self.attn_mask_offsets_full = paddle.full(
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32"
)
self.attn_mask_offsets_decoder = paddle.full([self.scheduler_config.max_num_seqs, 1], -1, dtype="int32")
self.decode_states = paddle.full(
[self.scheduler_config.max_num_seqs, self.speculative_config.num_speculative_tokens + 1],
-1,
dtype="int32",
)
if self.fd_config.deploy_modality != DeployModality.TEXT:
self.attn_mask_offsets = paddle.full(
shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len],
fill_value=-1,
dtype="int32",
)
self.attn_mask_offsets_full = paddle.full(
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32"
)
self.attn_mask_offsets_decoder = paddle.full(
[self.scheduler_config.max_num_seqs, 1], -1, dtype="int32"
)
def swap_states(self, i1, i2) -> None:
def swap_data(tensor, idx1, idx2):
@@ -864,7 +873,7 @@ class ProposerInputBatch(InputBatch):
swap_data(self.input_ids_len, i1, i2)
swap_data(self.mask_rollback, i1, i2)
swap_data(self.recompute_token_num, i1, i2)
if self.enable_mm:
if self.enable_mm and self.fd_config.deploy_modality != DeployModality.TEXT:
swap_data(self.attn_mask_offsets_full, i1, i2)
swap_data(self.attn_mask_offsets_decoder, i1, i2)
@@ -998,10 +1007,11 @@ class ProposerInputBatch(InputBatch):
# Reset multimodal tensors if enabled
if self.enable_mm:
fill_paddle_tensor(self, "attn_mask_offsets", -1)
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1)
fill_paddle_tensor(self, "decode_states", -1)
if self.fd_config.deploy_modality != DeployModality.TEXT:
fill_paddle_tensor(self, "attn_mask_offsets", -1)
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1)
logger.info("model_inputs reset completed")
except Exception as e: