add m_grouped_bf16_gemm_nn_contiguous, del paddle.batch_gemm (#7527)

This commit is contained in:
chen
2026-04-21 20:19:46 +08:00
committed by GitHub
parent 8883757bad
commit c618a39562
@@ -53,6 +53,12 @@ from fastdeploy.model_executor.utils import (
) )
def m_grouped_bf16_gemm_nn_contiguous(x, y, expert_idx_per_token):
out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype)
paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token)
return out
class CutlassMoEMethod(UnquantizedFusedMoEMethod): class CutlassMoEMethod(UnquantizedFusedMoEMethod):
""" """
Use Cutlass Group Gemm to compute Fused MoE. Use Cutlass Group Gemm to compute Fused MoE.
@@ -156,31 +162,46 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
# --- moe_permute / moe_unpermute path --- # --- moe_permute / moe_unpermute path ---
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32) recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute( (permute_input, permute_indices_per_token, dst_weights, _scale_out, m_indices) = (
hidden_states=recv_x, paddle.nn.functional.moe_permute(
scale=None, hidden_states=recv_x,
expert_routemap_topk=recv_topk_idx_i32, scale=None,
expert_prob_topk=recv_topk_weights, expert_routemap_topk=recv_topk_idx_i32,
num_experts=layer.num_local_experts, expert_prob_topk=recv_topk_weights,
tokens_per_expert=[], num_experts=layer.num_local_experts,
padding_alignment=128, tokens_per_expert=[],
override_buffer_size=token_all_num, padding_alignment=128,
override_buffer_size=token_all_num,
return_expert_indices=True,
)
) )
out = paddle.incubate.nn.functional.batched_gemm( if paddlefleet_ops is not None:
permute_input, out = m_grouped_bf16_gemm_nn_contiguous(
getattr(layer, self.added_weight_attrs[0]), permute_input, getattr(layer, self.added_weight_attrs[0]), m_indices
recv_num_tokens_per_expert_list, )
) else:
out = paddle.incubate.nn.functional.batched_gemm(
permute_input,
getattr(layer, self.added_weight_attrs[0]),
recv_num_tokens_per_expert_list,
)
if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE: if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE:
out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights) out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights)
else: else:
out = paddle.incubate.nn.functional.swiglu(out) out = paddle.incubate.nn.functional.swiglu(out)
ffn_out = paddle.incubate.nn.functional.batched_gemm(
out, if paddlefleet_ops is not None:
getattr(layer, self.added_weight_attrs[1]), ffn_out = m_grouped_bf16_gemm_nn_contiguous(
recv_num_tokens_per_expert_list, out, getattr(layer, self.added_weight_attrs[1]), m_indices
) )
else:
ffn_out = paddle.incubate.nn.functional.batched_gemm(
out,
getattr(layer, self.added_weight_attrs[1]),
recv_num_tokens_per_expert_list,
)
tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute( tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out, hidden_states_unzipped=ffn_out,