mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] 添加 MoE 层 latent mode 支持 (#7382)
This commit is contained in:
@@ -218,6 +218,8 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
gate: nn.Layer,
|
gate: nn.Layer,
|
||||||
topk_ids_hookfunc: Callable = None,
|
topk_ids_hookfunc: Callable = None,
|
||||||
shared_experts: nn.Layer = None,
|
shared_experts: nn.Layer = None,
|
||||||
|
fc1_latent_proj: nn.Layer = None,
|
||||||
|
fc2_latent_proj: nn.Layer = None,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Paddle Cutlass compute Fused MoE.
|
Paddle Cutlass compute Fused MoE.
|
||||||
@@ -237,7 +239,7 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
return self.apply_tp(layer, x, gate, topk_ids_hookfunc)
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedFusedMoEMethod(MoEMethodBase):
|
class UnquantizedFusedMoEMethod(MoEMethodBase):
|
||||||
|
|||||||
@@ -292,6 +292,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
gate: nn.Layer,
|
gate: nn.Layer,
|
||||||
topk_ids_hookfunc: Callable = None,
|
topk_ids_hookfunc: Callable = None,
|
||||||
shared_experts: nn.Layer = None,
|
shared_experts: nn.Layer = None,
|
||||||
|
fc1_latent_proj: nn.Layer = None,
|
||||||
|
fc2_latent_proj: nn.Layer = None,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Triton compute Fused MoE.
|
Triton compute Fused MoE.
|
||||||
@@ -681,6 +683,8 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
|
|||||||
gate: nn.Layer,
|
gate: nn.Layer,
|
||||||
topk_ids_hookfunc: Callable = None,
|
topk_ids_hookfunc: Callable = None,
|
||||||
shared_experts: nn.Layer = None,
|
shared_experts: nn.Layer = None,
|
||||||
|
fc1_latent_proj: nn.Layer = None,
|
||||||
|
fc2_latent_proj: nn.Layer = None,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Triton compute Fused MoE.
|
Triton compute Fused MoE.
|
||||||
@@ -980,6 +984,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
gate: nn.Layer,
|
gate: nn.Layer,
|
||||||
topk_ids_hookfunc: Callable = None,
|
topk_ids_hookfunc: Callable = None,
|
||||||
shared_experts: nn.Layer = None,
|
shared_experts: nn.Layer = None,
|
||||||
|
fc1_latent_proj: nn.Layer = None,
|
||||||
|
fc2_latent_proj: nn.Layer = None,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Triton compute Fused MoE.
|
Triton compute Fused MoE.
|
||||||
@@ -1174,6 +1180,9 @@ def python_op_fused_moe_kernel_paddle_infer_meta(
|
|||||||
config: dict,
|
config: dict,
|
||||||
quant_config,
|
quant_config,
|
||||||
topk_ids_hookfunc,
|
topk_ids_hookfunc,
|
||||||
|
layer,
|
||||||
|
fc1_latent_proj,
|
||||||
|
fc2_latent_proj,
|
||||||
):
|
):
|
||||||
token_num = x.shape[0]
|
token_num = x.shape[0]
|
||||||
return paddle.static.MetaTensor(shape=[token_num, hidden_size], dtype=x.dtype)
|
return paddle.static.MetaTensor(shape=[token_num, hidden_size], dtype=x.dtype)
|
||||||
@@ -1211,19 +1220,34 @@ def python_op_fused_moe_kernel_paddle(
|
|||||||
config: dict,
|
config: dict,
|
||||||
quant_config,
|
quant_config,
|
||||||
topk_ids_hookfunc,
|
topk_ids_hookfunc,
|
||||||
|
layer,
|
||||||
|
fc1_latent_proj,
|
||||||
|
fc2_latent_proj,
|
||||||
):
|
):
|
||||||
|
|
||||||
token_num = x.shape[0]
|
token_num = x.shape[0]
|
||||||
if x.shape[0] == 0:
|
if x.shape[0] == 0:
|
||||||
return paddle.zeros([token_num, hidden_size], dtype=x.dtype)
|
return paddle.zeros([token_num, hidden_size], dtype=x.dtype)
|
||||||
|
|
||||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
if layer.topk_method == "noaux_tc":
|
||||||
gate_out,
|
gate_out, topk_weights, topk_ids = get_moe_scores(
|
||||||
gate_correction_bias,
|
gate_out,
|
||||||
top_k,
|
layer.n_group,
|
||||||
True, # apply_norm_weight
|
layer.topk_group,
|
||||||
False,
|
layer.top_k,
|
||||||
)
|
layer.routed_scaling_factor,
|
||||||
|
layer.gate_correction_bias,
|
||||||
|
getattr(layer, "renormalize", True),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
|
gate_out,
|
||||||
|
gate_correction_bias,
|
||||||
|
top_k,
|
||||||
|
True, # apply_norm_weight
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
if topk_ids_hookfunc is not None:
|
if topk_ids_hookfunc is not None:
|
||||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||||
|
|
||||||
@@ -1244,6 +1268,9 @@ def python_op_fused_moe_kernel_paddle(
|
|||||||
|
|
||||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||||
|
|
||||||
|
if fc1_latent_proj is not None:
|
||||||
|
x = fc1_latent_proj(x)
|
||||||
|
|
||||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False)
|
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False)
|
||||||
else:
|
else:
|
||||||
@@ -1357,6 +1384,9 @@ def python_op_fused_moe_kernel_paddle(
|
|||||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||||
out = intermediate_cache3.sum(axis=1)
|
out = intermediate_cache3.sum(axis=1)
|
||||||
|
|
||||||
|
if fc2_latent_proj is not None:
|
||||||
|
out = fc2_latent_proj(out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -1808,6 +1838,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
gate: nn.Layer,
|
gate: nn.Layer,
|
||||||
topk_ids_hookfunc: Callable = None,
|
topk_ids_hookfunc: Callable = None,
|
||||||
shared_experts: nn.Layer = None,
|
shared_experts: nn.Layer = None,
|
||||||
|
fc1_latent_proj: nn.Layer = None,
|
||||||
|
fc2_latent_proj: nn.Layer = None,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Triton compute Fused MoE.
|
Triton compute Fused MoE.
|
||||||
@@ -1855,4 +1887,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
config,
|
config,
|
||||||
self.quant_config,
|
self.quant_config,
|
||||||
topk_ids_hookfunc,
|
topk_ids_hookfunc,
|
||||||
|
layer,
|
||||||
|
fc1_latent_proj,
|
||||||
|
fc2_latent_proj,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -709,7 +709,13 @@ class FusedMoE(nn.Layer):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None, shared_experts: nn.Layer = None
|
self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
gate: nn.Layer,
|
||||||
|
forward_meta: ForwardMeta = None,
|
||||||
|
shared_experts: nn.Layer = None,
|
||||||
|
fc1_latent_proj: nn.Layer = None,
|
||||||
|
fc2_latent_proj: nn.Layer = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Defines the forward computation of the moe layer.
|
Defines the forward computation of the moe layer.
|
||||||
@@ -762,7 +768,13 @@ class FusedMoE(nn.Layer):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out = self.forward_normal(
|
out = self.forward_normal(
|
||||||
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
x,
|
||||||
|
gate,
|
||||||
|
forward_meta,
|
||||||
|
topk_ids_hookfunc,
|
||||||
|
shared_experts,
|
||||||
|
fc1_latent_proj,
|
||||||
|
fc2_latent_proj,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and self.tp_size > 1:
|
||||||
@@ -829,6 +841,8 @@ class FusedMoE(nn.Layer):
|
|||||||
forward_meta: ForwardMeta,
|
forward_meta: ForwardMeta,
|
||||||
topk_ids_hookfunc: Callable = None,
|
topk_ids_hookfunc: Callable = None,
|
||||||
shared_experts: nn.Layer = None,
|
shared_experts: nn.Layer = None,
|
||||||
|
fc1_latent_proj: nn.Layer = None,
|
||||||
|
fc2_latent_proj: nn.Layer = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Normal mode of forward.
|
Normal mode of forward.
|
||||||
@@ -842,7 +856,13 @@ class FusedMoE(nn.Layer):
|
|||||||
"""
|
"""
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
out = self.quant_method.apply(
|
out = self.quant_method.apply(
|
||||||
self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
|
self,
|
||||||
|
x,
|
||||||
|
gate,
|
||||||
|
topk_ids_hookfunc,
|
||||||
|
shared_experts,
|
||||||
|
fc1_latent_proj,
|
||||||
|
fc2_latent_proj,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||||
|
|||||||
@@ -509,6 +509,9 @@ class TestFusedMoeTritonBackend:
|
|||||||
config,
|
config,
|
||||||
quant_config,
|
quant_config,
|
||||||
hook,
|
hook,
|
||||||
|
layer,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "topk" in captured
|
assert "topk" in captured
|
||||||
@@ -530,6 +533,9 @@ class TestFusedMoeTritonBackend:
|
|||||||
config,
|
config,
|
||||||
quant_config,
|
quant_config,
|
||||||
None,
|
None,
|
||||||
|
layer,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert meta.shape == [2, layer.hidden_size]
|
assert meta.shape == [2, layer.hidden_size]
|
||||||
|
|||||||
@@ -506,10 +506,41 @@ class FuseMoEWrapper(paddle.nn.Layer):
|
|||||||
skip_quant=True,
|
skip_quant=True,
|
||||||
weight_dtype="float32",
|
weight_dtype="float32",
|
||||||
)
|
)
|
||||||
|
self.gating.weight.set_value(paddle.rand(self.gating.weight.shape, dtype=paddle.float32))
|
||||||
|
|
||||||
|
self.fc1_latent_proj = None
|
||||||
|
self.fc2_latent_proj = None
|
||||||
|
|
||||||
|
if self.fd_config.model_config.use_latent_moe:
|
||||||
|
self.fc1_latent_proj = ReplicatedLinear(
|
||||||
|
fd_config=self.fd_config,
|
||||||
|
input_size=self.fd_config.model_config.hidden_size,
|
||||||
|
output_size=self.fd_config.model_config.moe_latent_size,
|
||||||
|
with_bias=True,
|
||||||
|
)
|
||||||
|
self.fc1_latent_proj.weight.set_value(
|
||||||
|
paddle.zeros(self.fc1_latent_proj.weight.shape).cast(paddle.float8_e4m3fn)
|
||||||
|
)
|
||||||
|
self.fc1_latent_proj.bias.set_value(paddle.zeros(self.fc1_latent_proj.bias.shape))
|
||||||
|
|
||||||
|
self.fc2_latent_proj = ReplicatedLinear(
|
||||||
|
fd_config=self.fd_config,
|
||||||
|
input_size=self.fd_config.model_config.moe_latent_size,
|
||||||
|
output_size=self.fd_config.model_config.hidden_size,
|
||||||
|
with_bias=True,
|
||||||
|
)
|
||||||
|
self.fc2_latent_proj.weight.set_value(
|
||||||
|
paddle.zeros(self.fc2_latent_proj.weight.shape).cast(paddle.float8_e4m3fn)
|
||||||
|
)
|
||||||
|
self.fc2_latent_proj.bias.set_value(paddle.zeros(self.fc2_latent_proj.bias.shape))
|
||||||
|
|
||||||
self.fused_moe = FusedMoE(
|
self.fused_moe = FusedMoE(
|
||||||
fd_config=self.fd_config,
|
fd_config=self.fd_config,
|
||||||
hidden_size=self.fd_config.model_config.hidden_size,
|
hidden_size=(
|
||||||
|
self.fd_config.model_config.moe_latent_size
|
||||||
|
if self.fd_config.model_config.use_latent_moe
|
||||||
|
else self.fd_config.model_config.hidden_size
|
||||||
|
),
|
||||||
moe_intermediate_size=self.fd_config.model_config.moe_intermediate_size,
|
moe_intermediate_size=self.fd_config.model_config.moe_intermediate_size,
|
||||||
num_experts=self.fd_config.model_config.moe_num_experts,
|
num_experts=self.fd_config.model_config.moe_num_experts,
|
||||||
top_k=self.fd_config.model_config.moe_k,
|
top_k=self.fd_config.model_config.moe_k,
|
||||||
@@ -517,8 +548,8 @@ class FuseMoEWrapper(paddle.nn.Layer):
|
|||||||
layer_idx=666,
|
layer_idx=666,
|
||||||
weight_key_map=weight_key_map,
|
weight_key_map=weight_key_map,
|
||||||
topk_method="noaux_tc",
|
topk_method="noaux_tc",
|
||||||
topk_group=4,
|
topk_group=0,
|
||||||
n_group=8,
|
n_group=0,
|
||||||
gate_correction_bias=paddle.zeros([self.fd_config.model_config.moe_num_experts], paddle.float32),
|
gate_correction_bias=paddle.zeros([self.fd_config.model_config.moe_num_experts], paddle.float32),
|
||||||
# gate_correction_bias = gate_correction_bias_real_data
|
# gate_correction_bias = gate_correction_bias_real_data
|
||||||
)
|
)
|
||||||
@@ -558,11 +589,20 @@ class FuseMoEWrapper(paddle.nn.Layer):
|
|||||||
class TestFusedMoE(unittest.TestCase):
|
class TestFusedMoE(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.architectures = ["Ernie4_5_MoeForCausalLM"]
|
self.architectures = ["Ernie4_5_MoeForCausalLM"]
|
||||||
self.hidden_size = 4096
|
self.hidden_size = 1536
|
||||||
self.moe_intermediate_size = 2048
|
self.moe_intermediate_size = 1024
|
||||||
self.moe_num_experts = 64
|
|
||||||
self.moe_k = 8
|
self.moe_num_experts = 256
|
||||||
self.num_layers = 2
|
self.moe_k = 16
|
||||||
|
self.use_latent_moe = True
|
||||||
|
self.moe_latent_size = 768
|
||||||
|
|
||||||
|
# self.moe_num_experts = 128
|
||||||
|
# self.moe_k = 8
|
||||||
|
# self.use_latent_moe = False
|
||||||
|
# self.moe_latent_size = 768
|
||||||
|
|
||||||
|
self.num_layers = 50
|
||||||
self.num_attention_heads = -1
|
self.num_attention_heads = -1
|
||||||
self.model_config = self.build_model_config()
|
self.model_config = self.build_model_config()
|
||||||
|
|
||||||
@@ -584,6 +624,8 @@ class TestFusedMoE(unittest.TestCase):
|
|||||||
"moe_k": self.moe_k,
|
"moe_k": self.moe_k,
|
||||||
"num_attention_heads": self.num_attention_heads,
|
"num_attention_heads": self.num_attention_heads,
|
||||||
"dtype": "bfloat16",
|
"dtype": "bfloat16",
|
||||||
|
"use_latent_moe": self.use_latent_moe,
|
||||||
|
"moe_latent_size": self.moe_latent_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
tmp_dir = f"./tmpwedfewfef{paddle.distributed.get_rank()}"
|
tmp_dir = f"./tmpwedfewfef{paddle.distributed.get_rank()}"
|
||||||
@@ -635,9 +677,10 @@ class TestFusedMoE(unittest.TestCase):
|
|||||||
out = cache_hidden_states + cache_hidden_states
|
out = cache_hidden_states + cache_hidden_states
|
||||||
else:
|
else:
|
||||||
gating = fused_moe[j % real_weight_layers].gating
|
gating = fused_moe[j % real_weight_layers].gating
|
||||||
gating.weight.set_value(paddle.rand(gating.weight.shape, dtype=paddle.float32))
|
fc1_latent_proj = fused_moe[j % real_weight_layers].fc1_latent_proj
|
||||||
|
fc2_latent_proj = fused_moe[j % real_weight_layers].fc2_latent_proj
|
||||||
out = fused_moe[j % real_weight_layers].fused_moe(
|
out = fused_moe[j % real_weight_layers].fused_moe(
|
||||||
cache_hidden_states[idx], gating, forward_meta=MockForwardMeta()
|
cache_hidden_states[idx], gating, MockForwardMeta(), None, fc1_latent_proj, fc2_latent_proj
|
||||||
)
|
)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user