[Feature] support compute shared experts before combine for better overlap (#6697)

* [Feature] support compute shared experts before combine for better overlap

* fix test

* fix xpu

* fix
This commit is contained in:
Longzhi Wang
2026-03-17 15:18:51 +08:00
committed by GitHub
parent 12eb001d0c
commit daaf498213
15 changed files with 104 additions and 27 deletions
+50 -17
View File
@@ -650,7 +650,9 @@ class FusedMoE(nn.Layer):
else:
self.quant_method.process_loaded_weights(self, state_dict)
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None):
def forward_split_allgather(
self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None
):
"""
Forward split allgather function.
"""
@@ -665,14 +667,21 @@ class FusedMoE(nn.Layer):
if end_offset > token_num:
end_offset = token_num
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
if current_platform.is_cuda():
out = self.quant_method.apply(
self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
else:
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
multi_outs = paddle.zeros([token_num_per_rank * self.attn_tp_size, x.shape[1]], dtype=x.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = multi_outs[:token_num, :]
return out
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None):
def forward(
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None, shared_experts: nn.Layer = None
):
"""
Defines the forward computation of the moe layer.
@@ -701,7 +710,9 @@ class FusedMoE(nn.Layer):
)
if current_platform.is_intel_hpu():
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.forward_normal(
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
if self.reduce_results and (self.ep_size > 1 or self.tp_size > 1):
tensor_model_parallel_all_reduce_custom(out)
return out
@@ -713,23 +724,29 @@ class FusedMoE(nn.Layer):
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
and token_num >= self.attn_tp_size
):
out = self.forward_split_allgather(x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.forward_split_allgather(
x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
out = self.forward_chunked_moe(
x,
gate,
forward_meta,
topk_ids_hookfunc=topk_ids_hookfunc,
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
else:
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.forward_normal(
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
if self.reduce_results and self.tp_size > 1:
out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out
def forward_chunked_moe(
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
self,
x: paddle.Tensor,
gate: nn.Layer,
forward_meta: ForwardMeta,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
):
"""
Split input to multi chunk to reduce the memory usage of moe.
@@ -755,23 +772,34 @@ class FusedMoE(nn.Layer):
for i in range(forward_meta.max_moe_num_chunk):
if i < forward_meta.moe_num_chunk:
out_split_list[i] = self.quant_method.apply(
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
else:
# just need to use real data to infer max_moe_num_chunk times.
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
self.quant_method.apply(
self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
out = paddle.concat(out_split_list, axis=0)
else:
# when only one chunk, just need to use real data to infer once.
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.quant_method.apply(
self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
for i in range(forward_meta.max_moe_num_chunk - 1):
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
self.quant_method.apply(
self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
return out
def forward_normal(
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
self,
x: paddle.Tensor,
gate: nn.Layer,
forward_meta: ForwardMeta,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
):
"""
Normal mode of forward.
@@ -783,5 +811,10 @@ class FusedMoE(nn.Layer):
Tensor: Output tensor.s
"""
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
if current_platform.is_cuda():
out = self.quant_method.apply(
self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
)
else:
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
return out