[Feature] support compute shared experts before combine for better overlap (#6697)

* [Feature] support compute shared experts before combine for better overlap

* fix test

* fix xpu

* fix
This commit is contained in:
Longzhi Wang
2026-03-17 15:18:51 +08:00
committed by GitHub
parent 12eb001d0c
commit daaf498213
15 changed files with 104 additions and 27 deletions
@@ -217,6 +217,7 @@ class MoEMethodBase(QuantMethodBase):
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
@@ -226,11 +227,15 @@ class MoEMethodBase(QuantMethodBase):
if layer.fd_config.model_config.moe_phase.phase == "prefill":
if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer:
self.ep_prefill_runner.clean_low_latency_buffer()
return self.apply_ep_prefill(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
return self.apply_ep_prefill(
layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
else:
if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer:
self.ep_decoder_runner.clean_low_latency_buffer()
return self.apply_ep_decode(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
return self.apply_ep_decode(
layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
else:
return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
@@ -137,6 +137,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -229,6 +230,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -320,6 +320,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -337,11 +338,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
# 2. Dynamic compute blockwise quantization scales
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
x_fp8, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
x, self.quant_config.weight_block_size[0], self.quant_config.deepgemm_scale_ue8m0
)
else:
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
x_fp8, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
x,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
@@ -366,7 +367,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
handle,
event,
) = self.ep_prefill_runner.dispatch(
x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128, previous_event=event
x_fp8, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128, previous_event=event
)
if self.ep_prefill_runner.num_worst_tokens > 0:
@@ -599,6 +600,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
)
else:
tmp_ffn_out = paddle.empty([0, hidden_size], paddle.bfloat16)
if shared_experts is not None:
s_x = shared_experts(x)
# 5. EP combine
event = deep_ep.Buffer.capture()
if self.ep_prefill_runner.num_worst_tokens <= 0:
@@ -614,6 +619,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
event.current_stream_wait()
global_values[thread_name]["combine_out"] = tmp_ffn_out
if shared_experts is not None:
tmp_ffn_out += s_x
return tmp_ffn_out
@@ -623,6 +630,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -690,8 +698,16 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
token_nums_per_expert,
expected_m,
)
if shared_experts is not None:
s_x = shared_experts(x)
# 4. EP combine
return self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle)
out = self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle)
if shared_experts is not None:
out += s_x
return out
def apply_tp(
self,
@@ -242,6 +242,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
"""
Marlin compute Fused MoE.
@@ -289,6 +289,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -677,6 +678,7 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -971,6 +973,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -1756,6 +1759,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -299,6 +299,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
"""
Use Wint2 Triton Fusedmoe compute Fused MoE.
@@ -371,6 +372,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
"""
Use Wint2 Triton Fusedmoe compute Fused MoE.
+50 -17
View File
@@ -650,7 +650,9 @@ class FusedMoE(nn.Layer):
else:
self.quant_method.process_loaded_weights(self, state_dict)
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None):
def forward_split_allgather(
self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None
):
"""
Forward split allgather function.
"""
@@ -665,14 +667,21 @@ class FusedMoE(nn.Layer):
if end_offset > token_num:
end_offset = token_num
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
if current_platform.is_cuda():
out = self.quant_method.apply(
self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
else:
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
multi_outs = paddle.zeros([token_num_per_rank * self.attn_tp_size, x.shape[1]], dtype=x.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = multi_outs[:token_num, :]
return out
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None):
def forward(
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None, shared_experts: nn.Layer = None
):
"""
Defines the forward computation of the moe layer.
@@ -701,7 +710,9 @@ class FusedMoE(nn.Layer):
)
if current_platform.is_intel_hpu():
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.forward_normal(
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
if self.reduce_results and (self.ep_size > 1 or self.tp_size > 1):
tensor_model_parallel_all_reduce_custom(out)
return out
@@ -713,23 +724,29 @@ class FusedMoE(nn.Layer):
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
and token_num >= self.attn_tp_size
):
out = self.forward_split_allgather(x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.forward_split_allgather(
x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
out = self.forward_chunked_moe(
x,
gate,
forward_meta,
topk_ids_hookfunc=topk_ids_hookfunc,
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
else:
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.forward_normal(
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
if self.reduce_results and self.tp_size > 1:
out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out
def forward_chunked_moe(
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
self,
x: paddle.Tensor,
gate: nn.Layer,
forward_meta: ForwardMeta,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
):
"""
Split input to multi chunk to reduce the memory usage of moe.
@@ -755,23 +772,34 @@ class FusedMoE(nn.Layer):
for i in range(forward_meta.max_moe_num_chunk):
if i < forward_meta.moe_num_chunk:
out_split_list[i] = self.quant_method.apply(
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
else:
# just need to use real data to infer max_moe_num_chunk times.
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
self.quant_method.apply(
self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
out = paddle.concat(out_split_list, axis=0)
else:
# when only one chunk, just need to use real data to infer once.
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.quant_method.apply(
self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
for i in range(forward_meta.max_moe_num_chunk - 1):
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
self.quant_method.apply(
self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
return out
def forward_normal(
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
self,
x: paddle.Tensor,
gate: nn.Layer,
forward_meta: ForwardMeta,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
):
"""
Normal mode of forward.
@@ -783,5 +811,10 @@ class FusedMoE(nn.Layer):
Tensor: Output tensor.s
"""
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
if current_platform.is_cuda():
out = self.quant_method.apply(
self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
else:
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
return out