mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Add TP broadcast after sampling in XPU model runner to ensure consistent results across ranks. (#7096)
This commit is contained in:
@@ -1585,6 +1585,12 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
sampler_output = None
|
||||
if not self.speculative_decoding:
|
||||
sampler_output = self.sampler(logits, self.sampling_metadata)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(
|
||||
sampler_output.sampled_token_ids,
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
else:
|
||||
sampler_output = self.sampler(
|
||||
logits,
|
||||
@@ -1592,6 +1598,27 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.model_config.max_model_len,
|
||||
self.share_inputs,
|
||||
)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["accept_tokens"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["accept_num"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["step_idx"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["stop_flags"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
prompt_logprobs_list = None
|
||||
if not self.speculative_decoding:
|
||||
|
||||
Reference in New Issue
Block a user