add m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_custom_python_op (#5847)

This commit is contained in:
Ryan
2026-01-07 16:17:55 +08:00
committed by GitHub
parent eabd01cd21
commit 3e74bacc5e
@@ -24,12 +24,94 @@ from paddleformers.utils.log import logger
import fastdeploy
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
from fastdeploy.utils import register_custom_python_op
from fastdeploy.worker.tbo import let_another_thread_run
from .fused_moe_backend_base import MoEMethodBase
from .fused_moe_triton_backend import BlockWiseFP8MoEMethod
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_custom_python_op_infermeta(
permute_input: "paddle.static.MetaTensor",
permute_scale: "paddle.static.MetaTensor",
layer_added_weight_attrs_0: "paddle.static.MetaTensor",
layer_added_scale_attrs_0: "paddle.static.MetaTensor",
m_indices: "paddle.static.MetaTensor",
layer_added_weight_attrs_1: "paddle.static.MetaTensor",
layer_added_scale_attrs_1: "paddle.static.MetaTensor",
quant_config_weight_block_size_0: int,
):
return paddle.static.MetaTensor(
shape=[permute_input.shape[0], layer_added_weight_attrs_1.shape[1]], dtype=paddle.bfloat16
)
@register_custom_python_op(
name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_custom",
infer_meta=m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_custom_python_op_infermeta,
input_names=[
"permute_input",
"permute_scale",
"layer_added_weight_attrs_0",
"layer_added_scale_attrs_0",
"m_indices",
"layer_added_weight_attrs_1",
"layer_added_scale_attrs_1",
],
output_names=["ffn_new_out"],
inplace_map={},
)
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_custom_python_op(
permute_input: paddle.Tensor,
permute_scale: paddle.Tensor,
layer_added_weight_attrs_0: paddle.Tensor, # getattr(layer, self.added_weight_attrs[0])
layer_added_scale_attrs_0: paddle.Tensor, # getattr(layer, self.added_scale_attrs[0])
m_indices: paddle.Tensor,
layer_added_weight_attrs_1: paddle.Tensor, # getattr(layer, self.added_weight_attrs[1])
layer_added_scale_attrs_1: paddle.Tensor, # getattr(layer, self.added_scale_attrs[1])
quant_config_weight_block_size_0: int, # self.quant_config.weight_block_size[0]
):
# up_gate_proj
ffn_out = paddle.empty(
(permute_input.shape[0], layer_added_weight_attrs_0.shape[1]),
dtype=paddle.bfloat16,
)
permute_scale = permute_scale.transpose([1, 0]).contiguous()
permute_scale = permute_scale.transpose([1, 0])
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(permute_input, permute_scale),
(layer_added_weight_attrs_0, layer_added_scale_attrs_0),
ffn_out,
m_indices,
)
# swiglu
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out)
# down_proj
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, quant_config_weight_block_size_0
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous()
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
ffn_out = paddle.empty(
(permute_input.shape[0], layer_added_weight_attrs_1.shape[1]),
dtype=paddle.bfloat16,
)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(ffn_in_x, ffn_in_x_scale_tensor),
(layer_added_weight_attrs_1, layer_added_scale_attrs_1),
ffn_out,
m_indices,
)
return ffn_out
class DeepGemmFusedMoeMethod(MoEMethodBase):
"""
DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend.
@@ -368,9 +450,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
gate_out = gate(x.cast("float32"))
if layer.topk_method == "noaux_tc":
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
_, topk_weights, topk_ids = get_moe_scores(
_, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
@@ -416,41 +496,17 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
-1,
)
permute_scale = permute_scale.transpose([1, 0]).contiguous()
permute_scale = permute_scale.transpose([1, 0])
# up_gate_proj
ffn_out = paddle.empty(
(permute_input.shape[0], getattr(layer, self.added_weight_attrs[0]).shape[1]),
dtype=paddle.bfloat16,
)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(permute_input, permute_scale),
(getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_scale_attrs[0])),
ffn_out,
ffn_out = m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_custom_python_op(
permute_input,
permute_scale,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_scale_attrs[0]),
m_indices,
)
# swiglu
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out)
# down_proj
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, self.quant_config.weight_block_size[0]
getattr(layer, self.added_weight_attrs[1]),
getattr(layer, self.added_scale_attrs[1]),
self.quant_config.weight_block_size[0],
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous()
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
ffn_out = paddle.empty(
(ffn_out.shape[0], getattr(layer, self.added_weight_attrs[1]).shape[1]),
dtype=paddle.bfloat16,
)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(ffn_in_x, ffn_in_x_scale_tensor),
(getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_scale_attrs[1])),
ffn_out,
m_indices,
)
# prmt back per rank
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,