[Optimization] merge_allreduce (#7039)

This commit is contained in:
fxyfxy777
2026-04-02 19:52:13 +08:00
committed by GitHub
parent f142b486c9
commit 9f3b3ce7f5
3 changed files with 17 additions and 6 deletions
+14 -3
View File
@@ -26,6 +26,7 @@ from paddleformers.transformers import PretrainedModel
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import ( from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization, support_graph_optimization,
@@ -160,8 +161,16 @@ class Glm4Moe(nn.Layer):
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
) )
# In pure-TP mode (tp>1, ep=1) both branches return partial sums, so we
# defer the all-reduce to after combining them — saving one collective.
# In all other modes (EP, EP+attn-TP, no parallelism) each branch handles
# its own reduction internally (reduce_results default=True), so we must
# NOT add an extra all-reduce here.
self.merge_ffn_tp = self.use_tp and not self.use_ep
self.experts = FusedMoE( self.experts = FusedMoE(
fd_config, fd_config,
reduce_results=not self.merge_ffn_tp,
renormalize=self.norm_topk_prob, renormalize=self.norm_topk_prob,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size, moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.n_routed_experts, num_experts=fd_config.model_config.n_routed_experts,
@@ -182,14 +191,16 @@ class Glm4Moe(nn.Layer):
intermediate_size=shared_experts_intermediate_size, intermediate_size=shared_experts_intermediate_size,
layer_id=layer_id, layer_id=layer_id,
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
reduce_results=not self.merge_ffn_tp,
) )
def forward(self, x, forward_meta: ForwardMeta = None): def forward(self, x, forward_meta: ForwardMeta = None):
out = self.experts(x, self.gate, forward_meta) out = self.experts(x, self.gate, forward_meta)
if self.n_shared_experts > 0: if self.n_shared_experts > 0:
shared_experts_out = self.shared_experts(x) out = out + self.shared_experts(x)
out = out + shared_experts_out if self.merge_ffn_tp:
# Both branches produced partial sums; combine first, then single all-reduce.
out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out return out
@@ -185,7 +185,7 @@ def test_lm_head_fp32(api_url, headers, consistent_payload):
# 校验返回内容与概率信息 # 校验返回内容与概率信息
assert ( assert (
resp_json["choices"][0]["message"]["content"] resp_json["choices"][0]["message"]["content"]
== "\n<think>这个问题是关于牛顿的三大运动定律。牛顿的三大运动定律是经典" == "\n<think>我需要回答牛顿的三大运动定律是什么。牛顿的三大运动定律是经典"
), f"The response content is not as expected {resp_json['choices'][0]['message']['content']}." ), f"The response content is not as expected {resp_json['choices'][0]['message']['content']}."
@@ -157,10 +157,10 @@ def check_routing_replay_chat_completion(openai_client, moe_layer_num: int, mode
model_path = os.getenv("MODEL_PATH") model_path = os.getenv("MODEL_PATH")
if model_path: if model_path:
baseline_path = os.path.join( baseline_path = os.path.join(
model_path, f"R3_BaseLine_dev_uint8_0312/routing_replay_output_baseline_{model_name}" model_path, f"R3_BaseLine_dev_uint8_0402/routing_replay_output_baseline_{model_name}"
) )
else: else:
baseline_path = f"./R3_BaseLine_dev_uint8_0312/routing_replay_output_baseline_{model_name}" baseline_path = f"./R3_BaseLine_dev_uint8_0402/routing_replay_output_baseline_{model_name}"
stream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_stream") stream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_stream")
nonstream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_nonstream") nonstream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_nonstream")