[Intel HPU] Enable dist sampler on intel hpu platform (#4445)

This commit is contained in:
Jianyu Li
2025-10-16 19:02:27 +08:00
committed by GitHub
parent 4251ac5e95
commit 3bbe99eae7
2 changed files with 4 additions and 3 deletions
+3 -2
View File
@@ -24,6 +24,7 @@ import paddle.nn as nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce_custom
from fastdeploy.engine.request import Request
# from fastdeploy.spec_decode import MTPProposer, NgramProposer
@@ -944,7 +945,7 @@ class HPUModelRunner(ModelRunnerBase):
if self.parallel_config.tensor_parallel_size > 1:
dtype = sampled_token_ids.dtype
sampled_token_ids = sampled_token_ids.to("float32")
paddle.distributed.broadcast(sampled_token_ids, 0)
tensor_model_parallel_all_reduce_custom(sampled_token_ids)
sampled_token_ids = sampled_token_ids.to(dtype)
# 6. post process
@@ -1272,7 +1273,7 @@ class HPUModelRunner(ModelRunnerBase):
if self.parallel_config.tensor_parallel_size > 1:
dtype = sampled_token_ids.dtype
sampled_token_ids = sampled_token_ids.to("float32")
paddle.distributed.broadcast(sampled_token_ids, 0)
tensor_model_parallel_all_reduce_custom(sampled_token_ids)
sampled_token_ids = sampled_token_ids.to(dtype)
if self.is_hpu_perf_breakdown_sync_mode:
sampled_token_ids.cpu()