add m_grouped_bf16_gemm_nn_contiguous, del paddle.batch_gemm (#7527)

This commit is contained in:
chen
2026-04-21 20:19:46 +08:00
committed by GitHub
parent 8883757bad
commit c618a39562
@@ -53,6 +53,12 @@ from fastdeploy.model_executor.utils import (
)
def m_grouped_bf16_gemm_nn_contiguous(x, y, expert_idx_per_token):
out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype)
paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token)
return out
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
"""
Use Cutlass Group Gemm to compute Fused MoE.
@@ -156,31 +162,46 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
# --- moe_permute / moe_unpermute path ---
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute(
hidden_states=recv_x,
scale=None,
expert_routemap_topk=recv_topk_idx_i32,
expert_prob_topk=recv_topk_weights,
num_experts=layer.num_local_experts,
tokens_per_expert=[],
padding_alignment=128,
override_buffer_size=token_all_num,
(permute_input, permute_indices_per_token, dst_weights, _scale_out, m_indices) = (
paddle.nn.functional.moe_permute(
hidden_states=recv_x,
scale=None,
expert_routemap_topk=recv_topk_idx_i32,
expert_prob_topk=recv_topk_weights,
num_experts=layer.num_local_experts,
tokens_per_expert=[],
padding_alignment=128,
override_buffer_size=token_all_num,
return_expert_indices=True,
)
)
out = paddle.incubate.nn.functional.batched_gemm(
permute_input,
getattr(layer, self.added_weight_attrs[0]),
recv_num_tokens_per_expert_list,
)
if paddlefleet_ops is not None:
out = m_grouped_bf16_gemm_nn_contiguous(
permute_input, getattr(layer, self.added_weight_attrs[0]), m_indices
)
else:
out = paddle.incubate.nn.functional.batched_gemm(
permute_input,
getattr(layer, self.added_weight_attrs[0]),
recv_num_tokens_per_expert_list,
)
if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE:
out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights)
else:
out = paddle.incubate.nn.functional.swiglu(out)
ffn_out = paddle.incubate.nn.functional.batched_gemm(
out,
getattr(layer, self.added_weight_attrs[1]),
recv_num_tokens_per_expert_list,
)
if paddlefleet_ops is not None:
ffn_out = m_grouped_bf16_gemm_nn_contiguous(
out, getattr(layer, self.added_weight_attrs[1]), m_indices
)
else:
ffn_out = paddle.incubate.nn.functional.batched_gemm(
out,
getattr(layer, self.added_weight_attrs[1]),
recv_num_tokens_per_expert_list,
)
tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,