[Bugfix]fix model weight signal tensor num (#5900)

This commit is contained in:
gaoziyuan
2026-01-06 14:36:59 +08:00
committed by GitHub
parent 9445fbe054
commit e99ec4c9d5
+2 -1
View File
@@ -282,7 +282,8 @@ class PaddleDisWorkerProc:
def _broadcast_model_weights_signal(self, src: int, group) -> int:
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
return model_weights_signal_tensor.item()
value = model_weights_signal_tensor.numpy()[0]
return int(value)
def _tp_barrier_wait(self):
if current_platform.is_xpu():