mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] use phi permute/unpermute & rm swiglu (#6361)
* tp文字输出正常 * B eb5 mini文字输出正常 * eb5mini ep B卡 文字输出正常 * default use phi moe op * stash * tp H卡正常 * ep ok * rm debug * rm debug tool * rm del ffn_out * rm swiglu * add envs to swiglu * merge dev * fix ci baseline Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix ci baseline 2 --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from paddleformers.transformers import PretrainedModel
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||
support_graph_optimization,
|
||||
@@ -161,6 +162,7 @@ class Glm4Moe(nn.Layer):
|
||||
|
||||
self.experts = FusedMoE(
|
||||
fd_config,
|
||||
reduce_results=False,
|
||||
renormalize=self.norm_topk_prob,
|
||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||
num_experts=fd_config.model_config.n_routed_experts,
|
||||
@@ -181,14 +183,21 @@ class Glm4Moe(nn.Layer):
|
||||
intermediate_size=shared_experts_intermediate_size,
|
||||
layer_id=layer_id,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
reduce_results=False,
|
||||
)
|
||||
|
||||
def forward(self, x, forward_meta: ForwardMeta = None):
|
||||
# Both experts and shared_experts return partial sums (no all-reduce).
|
||||
# Combine them first, then do a single all-reduce — eliminating one
|
||||
# collective communication compared to the naive sequential approach.
|
||||
# NOTE: only valid for pure-TP mode (use_ep=False). In EP or EP+TP modes
|
||||
# FusedMoE uses all-to-all internally and already produces a full result,
|
||||
# so the extra all-reduce must be skipped to avoid double-reduction.
|
||||
out = self.experts(x, self.gate, forward_meta)
|
||||
if self.n_shared_experts > 0:
|
||||
shared_experts_out = self.shared_experts(x)
|
||||
out = out + shared_experts_out
|
||||
|
||||
out = out + self.shared_experts(x)
|
||||
if self.use_tp and not self.use_ep:
|
||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
return out
|
||||
|
||||
|
||||
@@ -535,7 +544,9 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
""" """
|
||||
paddle.cuda.nvtx.range_push("GLM4_MOE_BF")
|
||||
hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
|
||||
paddle.cuda.nvtx.range_pop()
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user