mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
add m_grouped_bf16_gemm_nn_contiguous, del paddle.batch_gemm (#7527)
This commit is contained in:
@@ -53,6 +53,12 @@ from fastdeploy.model_executor.utils import (
|
||||
)
|
||||
|
||||
|
||||
def m_grouped_bf16_gemm_nn_contiguous(x, y, expert_idx_per_token):
|
||||
out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype)
|
||||
paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token)
|
||||
return out
|
||||
|
||||
|
||||
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
"""
|
||||
Use Cutlass Group Gemm to compute Fused MoE.
|
||||
@@ -156,31 +162,46 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
|
||||
# --- moe_permute / moe_unpermute path ---
|
||||
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
|
||||
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute(
|
||||
hidden_states=recv_x,
|
||||
scale=None,
|
||||
expert_routemap_topk=recv_topk_idx_i32,
|
||||
expert_prob_topk=recv_topk_weights,
|
||||
num_experts=layer.num_local_experts,
|
||||
tokens_per_expert=[],
|
||||
padding_alignment=128,
|
||||
override_buffer_size=token_all_num,
|
||||
(permute_input, permute_indices_per_token, dst_weights, _scale_out, m_indices) = (
|
||||
paddle.nn.functional.moe_permute(
|
||||
hidden_states=recv_x,
|
||||
scale=None,
|
||||
expert_routemap_topk=recv_topk_idx_i32,
|
||||
expert_prob_topk=recv_topk_weights,
|
||||
num_experts=layer.num_local_experts,
|
||||
tokens_per_expert=[],
|
||||
padding_alignment=128,
|
||||
override_buffer_size=token_all_num,
|
||||
return_expert_indices=True,
|
||||
)
|
||||
)
|
||||
|
||||
out = paddle.incubate.nn.functional.batched_gemm(
|
||||
permute_input,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
recv_num_tokens_per_expert_list,
|
||||
)
|
||||
if paddlefleet_ops is not None:
|
||||
out = m_grouped_bf16_gemm_nn_contiguous(
|
||||
permute_input, getattr(layer, self.added_weight_attrs[0]), m_indices
|
||||
)
|
||||
else:
|
||||
out = paddle.incubate.nn.functional.batched_gemm(
|
||||
permute_input,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
recv_num_tokens_per_expert_list,
|
||||
)
|
||||
|
||||
if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE:
|
||||
out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights)
|
||||
else:
|
||||
out = paddle.incubate.nn.functional.swiglu(out)
|
||||
ffn_out = paddle.incubate.nn.functional.batched_gemm(
|
||||
out,
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
recv_num_tokens_per_expert_list,
|
||||
)
|
||||
|
||||
if paddlefleet_ops is not None:
|
||||
ffn_out = m_grouped_bf16_gemm_nn_contiguous(
|
||||
out, getattr(layer, self.added_weight_attrs[1]), m_indices
|
||||
)
|
||||
else:
|
||||
ffn_out = paddle.incubate.nn.functional.batched_gemm(
|
||||
out,
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
recv_num_tokens_per_expert_list,
|
||||
)
|
||||
|
||||
tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
|
||||
hidden_states_unzipped=ffn_out,
|
||||
|
||||
Reference in New Issue
Block a user