add m_grouped_bf16_gemm_nn_contiguous, del paddle.batch_gemm (#7527)

This commit is contained in:
chen
2026-04-21 20:19:46 +08:00
committed by GitHub
parent 8883757bad
commit c618a39562
@@ -53,6 +53,12 @@ from fastdeploy.model_executor.utils import (
) )
def m_grouped_bf16_gemm_nn_contiguous(x, y, expert_idx_per_token):
out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype)
paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token)
return out
class CutlassMoEMethod(UnquantizedFusedMoEMethod): class CutlassMoEMethod(UnquantizedFusedMoEMethod):
""" """
Use Cutlass Group Gemm to compute Fused MoE. Use Cutlass Group Gemm to compute Fused MoE.
@@ -156,7 +162,8 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
# --- moe_permute / moe_unpermute path --- # --- moe_permute / moe_unpermute path ---
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32) recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute( (permute_input, permute_indices_per_token, dst_weights, _scale_out, m_indices) = (
paddle.nn.functional.moe_permute(
hidden_states=recv_x, hidden_states=recv_x,
scale=None, scale=None,
expert_routemap_topk=recv_topk_idx_i32, expert_routemap_topk=recv_topk_idx_i32,
@@ -165,17 +172,31 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
tokens_per_expert=[], tokens_per_expert=[],
padding_alignment=128, padding_alignment=128,
override_buffer_size=token_all_num, override_buffer_size=token_all_num,
return_expert_indices=True,
)
) )
if paddlefleet_ops is not None:
out = m_grouped_bf16_gemm_nn_contiguous(
permute_input, getattr(layer, self.added_weight_attrs[0]), m_indices
)
else:
out = paddle.incubate.nn.functional.batched_gemm( out = paddle.incubate.nn.functional.batched_gemm(
permute_input, permute_input,
getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_weight_attrs[0]),
recv_num_tokens_per_expert_list, recv_num_tokens_per_expert_list,
) )
if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE: if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE:
out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights) out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights)
else: else:
out = paddle.incubate.nn.functional.swiglu(out) out = paddle.incubate.nn.functional.swiglu(out)
if paddlefleet_ops is not None:
ffn_out = m_grouped_bf16_gemm_nn_contiguous(
out, getattr(layer, self.added_weight_attrs[1]), m_indices
)
else:
ffn_out = paddle.incubate.nn.functional.batched_gemm( ffn_out = paddle.incubate.nn.functional.batched_gemm(
out, out,
getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_weight_attrs[1]),