mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization] merge_allreduce (#7039)
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from paddleformers.utils.log import logger
|
|||||||
|
|
||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||||
support_graph_optimization,
|
support_graph_optimization,
|
||||||
@@ -158,8 +159,16 @@ class Glm4Moe(nn.Layer):
|
|||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# In pure-TP mode (tp>1, ep=1) both branches return partial sums, so we
|
||||||
|
# defer the all-reduce to after combining them — saving one collective.
|
||||||
|
# In all other modes (EP, EP+attn-TP, no parallelism) each branch handles
|
||||||
|
# its own reduction internally (reduce_results default=True), so we must
|
||||||
|
# NOT add an extra all-reduce here.
|
||||||
|
self.merge_ffn_tp = self.use_tp and not self.use_ep
|
||||||
|
|
||||||
self.experts = FusedMoE(
|
self.experts = FusedMoE(
|
||||||
fd_config,
|
fd_config,
|
||||||
|
reduce_results=not self.merge_ffn_tp,
|
||||||
renormalize=self.norm_topk_prob,
|
renormalize=self.norm_topk_prob,
|
||||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||||
num_experts=fd_config.model_config.n_routed_experts,
|
num_experts=fd_config.model_config.n_routed_experts,
|
||||||
@@ -181,14 +190,16 @@ class Glm4Moe(nn.Layer):
|
|||||||
intermediate_size=shared_experts_intermediate_size,
|
intermediate_size=shared_experts_intermediate_size,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
prefix=f"{prefix}.shared_experts",
|
prefix=f"{prefix}.shared_experts",
|
||||||
|
reduce_results=not self.merge_ffn_tp,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, forward_meta: ForwardMeta = None):
|
def forward(self, x, forward_meta: ForwardMeta = None):
|
||||||
out = self.experts(x, self.gate, forward_meta)
|
out = self.experts(x, self.gate, forward_meta)
|
||||||
if self.n_shared_experts > 0:
|
if self.n_shared_experts > 0:
|
||||||
shared_experts_out = self.shared_experts(x)
|
out = out + self.shared_experts(x)
|
||||||
out = out + shared_experts_out
|
if self.merge_ffn_tp:
|
||||||
|
# Both branches produced partial sums; combine first, then single all-reduce.
|
||||||
|
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user