diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index c758a153ab..58755f52e2 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -53,6 +53,12 @@ from fastdeploy.model_executor.utils import ( ) +def m_grouped_bf16_gemm_nn_contiguous(x, y, expert_idx_per_token): + out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype) + paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token) + return out + + class CutlassMoEMethod(UnquantizedFusedMoEMethod): """ Use Cutlass Group Gemm to compute Fused MoE. @@ -156,31 +162,46 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": # --- moe_permute / moe_unpermute path --- recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32) - (permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute( - hidden_states=recv_x, - scale=None, - expert_routemap_topk=recv_topk_idx_i32, - expert_prob_topk=recv_topk_weights, - num_experts=layer.num_local_experts, - tokens_per_expert=[], - padding_alignment=128, - override_buffer_size=token_all_num, + (permute_input, permute_indices_per_token, dst_weights, _scale_out, m_indices) = ( + paddle.nn.functional.moe_permute( + hidden_states=recv_x, + scale=None, + expert_routemap_topk=recv_topk_idx_i32, + expert_prob_topk=recv_topk_weights, + num_experts=layer.num_local_experts, + tokens_per_expert=[], + padding_alignment=128, + override_buffer_size=token_all_num, + return_expert_indices=True, + ) ) - out = paddle.incubate.nn.functional.batched_gemm( - permute_input, - getattr(layer, self.added_weight_attrs[0]), - recv_num_tokens_per_expert_list, - ) + if paddlefleet_ops is not None: + out = m_grouped_bf16_gemm_nn_contiguous( + permute_input, getattr(layer, self.added_weight_attrs[0]), m_indices + ) + else: + out = paddle.incubate.nn.functional.batched_gemm( + permute_input, + getattr(layer, self.added_weight_attrs[0]), + recv_num_tokens_per_expert_list, + ) + if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE: out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights) else: out = paddle.incubate.nn.functional.swiglu(out) - ffn_out = paddle.incubate.nn.functional.batched_gemm( - out, - getattr(layer, self.added_weight_attrs[1]), - recv_num_tokens_per_expert_list, - ) + + if paddlefleet_ops is not None: + ffn_out = m_grouped_bf16_gemm_nn_contiguous( + out, getattr(layer, self.added_weight_attrs[1]), m_indices + ) + else: + ffn_out = paddle.incubate.nn.functional.batched_gemm( + out, + getattr(layer, self.added_weight_attrs[1]), + recv_num_tokens_per_expert_list, + ) tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute( hidden_states_unzipped=ffn_out,