mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] 添加 MoE 层 latent mode 支持 (#7382)
This commit is contained in:
@@ -218,6 +218,8 @@ class MoEMethodBase(QuantMethodBase):
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
fc1_latent_proj: nn.Layer = None,
|
||||
fc2_latent_proj: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -237,7 +239,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
)
|
||||
else:
|
||||
return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
return self.apply_tp(layer, x, gate, topk_ids_hookfunc)
|
||||
|
||||
|
||||
class UnquantizedFusedMoEMethod(MoEMethodBase):
|
||||
|
||||
@@ -292,6 +292,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
fc1_latent_proj: nn.Layer = None,
|
||||
fc2_latent_proj: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -681,6 +683,8 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
fc1_latent_proj: nn.Layer = None,
|
||||
fc2_latent_proj: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -980,6 +984,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
fc1_latent_proj: nn.Layer = None,
|
||||
fc2_latent_proj: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -1174,6 +1180,9 @@ def python_op_fused_moe_kernel_paddle_infer_meta(
|
||||
config: dict,
|
||||
quant_config,
|
||||
topk_ids_hookfunc,
|
||||
layer,
|
||||
fc1_latent_proj,
|
||||
fc2_latent_proj,
|
||||
):
|
||||
token_num = x.shape[0]
|
||||
return paddle.static.MetaTensor(shape=[token_num, hidden_size], dtype=x.dtype)
|
||||
@@ -1211,19 +1220,34 @@ def python_op_fused_moe_kernel_paddle(
|
||||
config: dict,
|
||||
quant_config,
|
||||
topk_ids_hookfunc,
|
||||
layer,
|
||||
fc1_latent_proj,
|
||||
fc2_latent_proj,
|
||||
):
|
||||
|
||||
token_num = x.shape[0]
|
||||
if x.shape[0] == 0:
|
||||
return paddle.zeros([token_num, hidden_size], dtype=x.dtype)
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
gate_correction_bias,
|
||||
top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
if layer.topk_method == "noaux_tc":
|
||||
gate_out, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
else:
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
gate_correction_bias,
|
||||
top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
@@ -1244,6 +1268,9 @@ def python_op_fused_moe_kernel_paddle(
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
|
||||
if fc1_latent_proj is not None:
|
||||
x = fc1_latent_proj(x)
|
||||
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False)
|
||||
else:
|
||||
@@ -1357,6 +1384,9 @@ def python_op_fused_moe_kernel_paddle(
|
||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||
out = intermediate_cache3.sum(axis=1)
|
||||
|
||||
if fc2_latent_proj is not None:
|
||||
out = fc2_latent_proj(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -1808,6 +1838,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
fc1_latent_proj: nn.Layer = None,
|
||||
fc2_latent_proj: nn.Layer = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -1855,4 +1887,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
config,
|
||||
self.quant_config,
|
||||
topk_ids_hookfunc,
|
||||
layer,
|
||||
fc1_latent_proj,
|
||||
fc2_latent_proj,
|
||||
)
|
||||
|
||||
@@ -709,7 +709,13 @@ class FusedMoE(nn.Layer):
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None, shared_experts: nn.Layer = None
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
forward_meta: ForwardMeta = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
fc1_latent_proj: nn.Layer = None,
|
||||
fc2_latent_proj: nn.Layer = None,
|
||||
):
|
||||
"""
|
||||
Defines the forward computation of the moe layer.
|
||||
@@ -762,7 +768,13 @@ class FusedMoE(nn.Layer):
|
||||
)
|
||||
else:
|
||||
out = self.forward_normal(
|
||||
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
x,
|
||||
gate,
|
||||
forward_meta,
|
||||
topk_ids_hookfunc,
|
||||
shared_experts,
|
||||
fc1_latent_proj,
|
||||
fc2_latent_proj,
|
||||
)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
@@ -829,6 +841,8 @@ class FusedMoE(nn.Layer):
|
||||
forward_meta: ForwardMeta,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
shared_experts: nn.Layer = None,
|
||||
fc1_latent_proj: nn.Layer = None,
|
||||
fc2_latent_proj: nn.Layer = None,
|
||||
):
|
||||
"""
|
||||
Normal mode of forward.
|
||||
@@ -842,7 +856,13 @@ class FusedMoE(nn.Layer):
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
out = self.quant_method.apply(
|
||||
self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||
self,
|
||||
x,
|
||||
gate,
|
||||
topk_ids_hookfunc,
|
||||
shared_experts,
|
||||
fc1_latent_proj,
|
||||
fc2_latent_proj,
|
||||
)
|
||||
else:
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
Reference in New Issue
Block a user