[Speculative Decoding] Support mtp expert-parallel and support different modality deploy (#7018)

* support mtp ep and support different modality

* fix default arg
This commit is contained in:
freeliuzc
2026-03-26 13:52:16 +08:00
committed by GitHub
parent 61ebac49ef
commit 4fd877ed43
10 changed files with 112 additions and 19 deletions
+7 -3
View File
@@ -102,6 +102,8 @@ class MTPProposer(Proposer):
self.num_main_model_layers = self.model_config.num_hidden_layers
self.local_rank = local_rank
self.device_id = device_id
self.use_attn_mask_offset = self.enable_mm and self.fd_config.deploy_modality != "text"
self._update_mtp_config(main_model)
self._load_model()
self.target_model_inputs = target_model_inputs
@@ -162,6 +164,8 @@ class MTPProposer(Proposer):
self.model_config.quantization = self.speculative_config.quantization
self.model_config.start_layer_index = self.num_main_model_layers
self.speculative_config.model_type = "mtp"
if not self.use_attn_mask_offset:
self.model_config.causal = True
def _load_model(self):
"""
@@ -503,7 +507,7 @@ class MTPProposer(Proposer):
self.model_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
if self.enable_mm:
if self.use_attn_mask_offset:
inputs = request.multimodal_inputs
self.model_inputs["attn_mask_offsets_full"][idx][0 : prefill_end_index - prefill_start_index] = (
paddle.to_tensor(
@@ -662,7 +666,7 @@ class MTPProposer(Proposer):
kv_batch_ids=self.model_inputs["kv_batch_ids"],
kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"],
kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"],
attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.enable_mm else None,
attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.use_attn_mask_offset else None,
)
# Initialzie attention meta data
@@ -888,7 +892,7 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_decoder"],
)
if self.enable_mm:
if self.use_attn_mask_offset:
attn_mask_offsets = update_attn_mask_offsets(
ids_remove_padding,
getattr(