fix eb5 mtp(mix) (#6800)

This commit is contained in:
cmcamdy
2026-03-13 17:36:57 +08:00
committed by GitHub
parent 8c1a2827d3
commit 7591e0d6bc
5 changed files with 55 additions and 5 deletions
+21 -2
View File
@@ -47,6 +47,7 @@ if current_platform.is_xpu():
mtp_step_paddle,
set_data_ipc,
share_external_data,
update_attn_mask_offsets,
)
from fastdeploy.model_executor.xpu_pre_and_post_process import (
xpu_pre_process,
@@ -513,7 +514,6 @@ class MTPProposer(Proposer):
# NOTE(liuzichang):
# extra 1 : P-D split need rollback one step
self.model_inputs["mask_rollback"][idx : idx + 1] = 1
# has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task
encoder_block_num = len(request.block_tables)
@@ -684,6 +684,8 @@ class MTPProposer(Proposer):
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
self.forward_meta.kv_signal_sender = self.target_model_inputs["kv_signal_sender"]
self.forward_meta.is_draft = True
# Initialzie attention meta data
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)
@@ -998,7 +1000,6 @@ class MTPProposer(Proposer):
step_use_cudagraph: bool
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
"""
# TODO(chenhuan09)check multi step
for substep in range(self.num_model_steps):
if self.model_inputs["not_need_stop"]:
self.model_inputs["substep"] = substep
@@ -1013,6 +1014,24 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
)
if self.enable_mm:
attn_mask_offsets = update_attn_mask_offsets(
self.model_inputs["ids_remove_padding"],
getattr(
self.model_inputs, "seq_lens_this_time", self.model_inputs["seq_lens_this_time_buffer"]
),
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["cu_seqlens_q"],
self.model_inputs["attn_mask_offsets_full"],
self.model_inputs["attn_mask_offsets_decoder"],
self.model_inputs["is_block_step"],
self.model_inputs["decode_states"],
self.model_inputs["mask_rollback"],
)
self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False)
self._initialize_forward_meta_xpu()
# Get sampling metadata
self.sampling_metadata = SamplingMetadata(