mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Intel HPU] Enable dist sampler on intel hpu platform (#4445)
This commit is contained in:
@@ -24,6 +24,7 @@ import paddle.nn as nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce_custom
|
||||
from fastdeploy.engine.request import Request
|
||||
|
||||
# from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
||||
@@ -944,7 +945,7 @@ class HPUModelRunner(ModelRunnerBase):
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
dtype = sampled_token_ids.dtype
|
||||
sampled_token_ids = sampled_token_ids.to("float32")
|
||||
paddle.distributed.broadcast(sampled_token_ids, 0)
|
||||
tensor_model_parallel_all_reduce_custom(sampled_token_ids)
|
||||
sampled_token_ids = sampled_token_ids.to(dtype)
|
||||
|
||||
# 6. post process
|
||||
@@ -1272,7 +1273,7 @@ class HPUModelRunner(ModelRunnerBase):
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
dtype = sampled_token_ids.dtype
|
||||
sampled_token_ids = sampled_token_ids.to("float32")
|
||||
paddle.distributed.broadcast(sampled_token_ids, 0)
|
||||
tensor_model_parallel_all_reduce_custom(sampled_token_ids)
|
||||
sampled_token_ids = sampled_token_ids.to(dtype)
|
||||
if self.is_hpu_perf_breakdown_sync_mode:
|
||||
sampled_token_ids.cpu()
|
||||
|
||||
Reference in New Issue
Block a user