[XPU] Add TP broadcast after sampling in XPU model runner to ensure consistent results across ranks. (#7096)

This commit is contained in:
Jiajun Ji
2026-04-08 19:26:53 +08:00
committed by GitHub
parent 3749457476
commit 9b970de029
+27
View File
@@ -1585,6 +1585,12 @@ class XPUModelRunner(ModelRunnerBase):
sampler_output = None
if not self.speculative_decoding:
sampler_output = self.sampler(logits, self.sampling_metadata)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
sampler_output.sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
else:
sampler_output = self.sampler(
logits,
@@ -1592,6 +1598,27 @@ class XPUModelRunner(ModelRunnerBase):
self.model_config.max_model_len,
self.share_inputs,
)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
self.share_inputs["accept_tokens"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
paddle.distributed.broadcast(
self.share_inputs["accept_num"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
paddle.distributed.broadcast(
self.share_inputs["step_idx"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
paddle.distributed.broadcast(
self.share_inputs["stop_flags"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
prompt_logprobs_list = None
if not self.speculative_decoding: