mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] support compute shared experts before combine for better overlap (#6697)
* [Feature] support compute shared experts before combine for better overlap * fix test * fix xpu * fix
This commit is contained in:
@@ -217,6 +217,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -226,11 +227,15 @@ class MoEMethodBase(QuantMethodBase):
|
||||
if layer.fd_config.model_config.moe_phase.phase == "prefill":
|
||||
if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer:
|
||||
self.ep_prefill_runner.clean_low_latency_buffer()
|
||||
return self.apply_ep_prefill(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
return self.apply_ep_prefill(
|
||||
layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
else:
|
||||
if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer:
|
||||
self.ep_decoder_runner.clean_low_latency_buffer()
|
||||
return self.apply_ep_decode(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
return self.apply_ep_decode(
|
||||
layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
else:
|
||||
return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -229,6 +230,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
|
||||
@@ -320,6 +320,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -337,11 +338,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
# 2. Dynamic compute blockwise quantization scales
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
x_fp8, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
x, self.quant_config.weight_block_size[0], self.quant_config.deepgemm_scale_ue8m0
|
||||
)
|
||||
else:
|
||||
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x_fp8, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
|
||||
@@ -366,7 +367,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
handle,
|
||||
event,
|
||||
) = self.ep_prefill_runner.dispatch(
|
||||
x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128, previous_event=event
|
||||
x_fp8, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128, previous_event=event
|
||||
)
|
||||
|
||||
if self.ep_prefill_runner.num_worst_tokens > 0:
|
||||
@@ -599,6 +600,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
)
|
||||
else:
|
||||
tmp_ffn_out = paddle.empty([0, hidden_size], paddle.bfloat16)
|
||||
|
||||
if shared_experts is not None:
|
||||
s_x = shared_experts(x)
|
||||
|
||||
# 5. EP combine
|
||||
event = deep_ep.Buffer.capture()
|
||||
if self.ep_prefill_runner.num_worst_tokens <= 0:
|
||||
@@ -614,6 +619,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
event.current_stream_wait()
|
||||
|
||||
global_values[thread_name]["combine_out"] = tmp_ffn_out
|
||||
if shared_experts is not None:
|
||||
tmp_ffn_out += s_x
|
||||
|
||||
return tmp_ffn_out
|
||||
|
||||
@@ -623,6 +630,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -690,8 +698,16 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
token_nums_per_expert,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
if shared_experts is not None:
|
||||
s_x = shared_experts(x)
|
||||
|
||||
# 4. EP combine
|
||||
return self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle)
|
||||
out = self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle)
|
||||
|
||||
if shared_experts is not None:
|
||||
out += s_x
|
||||
return out
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
|
||||
@@ -242,6 +242,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Marlin compute Fused MoE.
|
||||
|
||||
@@ -289,6 +289,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -677,6 +678,7 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -971,6 +973,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -1756,6 +1759,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
|
||||
@@ -299,6 +299,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
||||
@@ -371,6 +372,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
||||
|
||||
@@ -650,7 +650,9 @@ class FusedMoE(nn.Layer):
|
||||
else:
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
|
||||
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None):
|
||||
def forward_split_allgather(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None
|
||||
):
|
||||
"""
|
||||
Forward split allgather function.
|
||||
"""
|
||||
@@ -665,14 +667,21 @@ class FusedMoE(nn.Layer):
|
||||
if end_offset > token_num:
|
||||
end_offset = token_num
|
||||
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
|
||||
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
if current_platform.is_cuda():
|
||||
out = self.quant_method.apply(
|
||||
self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
else:
|
||||
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
multi_outs = paddle.zeros([token_num_per_rank * self.attn_tp_size, x.shape[1]], dtype=x.dtype)
|
||||
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
|
||||
out = multi_outs[:token_num, :]
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None):
|
||||
def forward(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None, shared_experts: nn.Layer = None
|
||||
):
|
||||
"""
|
||||
Defines the forward computation of the moe layer.
|
||||
|
||||
@@ -701,7 +710,9 @@ class FusedMoE(nn.Layer):
|
||||
)
|
||||
|
||||
if current_platform.is_intel_hpu():
|
||||
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
out = self.forward_normal(
|
||||
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
if self.reduce_results and (self.ep_size > 1 or self.tp_size > 1):
|
||||
tensor_model_parallel_all_reduce_custom(out)
|
||||
return out
|
||||
@@ -713,23 +724,29 @@ class FusedMoE(nn.Layer):
|
||||
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
|
||||
and token_num >= self.attn_tp_size
|
||||
):
|
||||
out = self.forward_split_allgather(x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
out = self.forward_split_allgather(
|
||||
x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
|
||||
out = self.forward_chunked_moe(
|
||||
x,
|
||||
gate,
|
||||
forward_meta,
|
||||
topk_ids_hookfunc=topk_ids_hookfunc,
|
||||
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
else:
|
||||
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
out = self.forward_normal(
|
||||
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
return out
|
||||
|
||||
def forward_chunked_moe(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
):
|
||||
"""
|
||||
Split input to multi chunk to reduce the memory usage of moe.
|
||||
@@ -755,23 +772,34 @@ class FusedMoE(nn.Layer):
|
||||
for i in range(forward_meta.max_moe_num_chunk):
|
||||
if i < forward_meta.moe_num_chunk:
|
||||
out_split_list[i] = self.quant_method.apply(
|
||||
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc
|
||||
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
else:
|
||||
# just need to use real data to infer max_moe_num_chunk times.
|
||||
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
self.quant_method.apply(
|
||||
self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
|
||||
out = paddle.concat(out_split_list, axis=0)
|
||||
else:
|
||||
# when only one chunk, just need to use real data to infer once.
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
out = self.quant_method.apply(
|
||||
self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
for i in range(forward_meta.max_moe_num_chunk - 1):
|
||||
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
self.quant_method.apply(
|
||||
self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def forward_normal(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
):
|
||||
"""
|
||||
Normal mode of forward.
|
||||
@@ -783,5 +811,10 @@ class FusedMoE(nn.Layer):
|
||||
Tensor: Output tensor.s
|
||||
|
||||
"""
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
if current_platform.is_cuda():
|
||||
out = self.quant_method.apply(
|
||||
self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
else:
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user