mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[Feature] support compute shared experts before combine for better overlap (#6697)
* [Feature] support compute shared experts before combine for better overlap * fix test * fix xpu * fix
This commit is contained in:
@@ -650,7 +650,9 @@ class FusedMoE(nn.Layer):
|
||||
else:
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
|
||||
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None):
|
||||
def forward_split_allgather(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None
|
||||
):
|
||||
"""
|
||||
Forward split allgather function.
|
||||
"""
|
||||
@@ -665,14 +667,21 @@ class FusedMoE(nn.Layer):
|
||||
if end_offset > token_num:
|
||||
end_offset = token_num
|
||||
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
|
||||
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
if current_platform.is_cuda():
|
||||
out = self.quant_method.apply(
|
||||
self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
else:
|
||||
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
multi_outs = paddle.zeros([token_num_per_rank * self.attn_tp_size, x.shape[1]], dtype=x.dtype)
|
||||
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
|
||||
out = multi_outs[:token_num, :]
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None):
|
||||
def forward(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None, shared_experts: nn.Layer = None
|
||||
):
|
||||
"""
|
||||
Defines the forward computation of the moe layer.
|
||||
|
||||
@@ -701,7 +710,9 @@ class FusedMoE(nn.Layer):
|
||||
)
|
||||
|
||||
if current_platform.is_intel_hpu():
|
||||
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
out = self.forward_normal(
|
||||
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
if self.reduce_results and (self.ep_size > 1 or self.tp_size > 1):
|
||||
tensor_model_parallel_all_reduce_custom(out)
|
||||
return out
|
||||
@@ -713,23 +724,29 @@ class FusedMoE(nn.Layer):
|
||||
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
|
||||
and token_num >= self.attn_tp_size
|
||||
):
|
||||
out = self.forward_split_allgather(x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
out = self.forward_split_allgather(
|
||||
x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
|
||||
out = self.forward_chunked_moe(
|
||||
x,
|
||||
gate,
|
||||
forward_meta,
|
||||
topk_ids_hookfunc=topk_ids_hookfunc,
|
||||
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
else:
|
||||
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
out = self.forward_normal(
|
||||
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
return out
|
||||
|
||||
def forward_chunked_moe(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
):
|
||||
"""
|
||||
Split input to multi chunk to reduce the memory usage of moe.
|
||||
@@ -755,23 +772,34 @@ class FusedMoE(nn.Layer):
|
||||
for i in range(forward_meta.max_moe_num_chunk):
|
||||
if i < forward_meta.moe_num_chunk:
|
||||
out_split_list[i] = self.quant_method.apply(
|
||||
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc
|
||||
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
else:
|
||||
# just need to use real data to infer max_moe_num_chunk times.
|
||||
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
self.quant_method.apply(
|
||||
self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
|
||||
out = paddle.concat(out_split_list, axis=0)
|
||||
else:
|
||||
# when only one chunk, just need to use real data to infer once.
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
out = self.quant_method.apply(
|
||||
self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
for i in range(forward_meta.max_moe_num_chunk - 1):
|
||||
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
self.quant_method.apply(
|
||||
self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def forward_normal(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
):
|
||||
"""
|
||||
Normal mode of forward.
|
||||
@@ -783,5 +811,10 @@ class FusedMoE(nn.Layer):
|
||||
Tensor: Output tensor.s
|
||||
|
||||
"""
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
if current_platform.is_cuda():
|
||||
out = self.quant_method.apply(
|
||||
self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
else:
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user