[BugFix]Fix attention mask bug in D-Node of PD-split mode (#5245)

This commit is contained in:
freeliuzc
2025-11-26 17:56:28 +08:00
committed by GitHub
parent 61fc368066
commit ba915e03e1
3 changed files with 12 additions and 8 deletions
+6
View File
@@ -515,6 +515,12 @@ class MTPProposer(Proposer):
self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = (
inputs["attention_mask_offset"][prefill_end_index - 1] + 1
)
if (
self.fd_config.scheduler_config.splitwise_role == "decode"
): # In PD, we continue to decode after P generates first token
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
# P-D split need rollback one step
self.model_inputs["mask_rollback"][idx : idx + 1] = 1
# has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task