mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
add m_grouped_bf16_gemm_nn_contiguous, del paddle.batch_gemm (#7527)
This commit is contained in:
@@ -53,6 +53,12 @@ from fastdeploy.model_executor.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def m_grouped_bf16_gemm_nn_contiguous(x, y, expert_idx_per_token):
|
||||||
|
out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype)
|
||||||
|
paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||||
"""
|
"""
|
||||||
Use Cutlass Group Gemm to compute Fused MoE.
|
Use Cutlass Group Gemm to compute Fused MoE.
|
||||||
@@ -156,7 +162,8 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
|
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
|
||||||
# --- moe_permute / moe_unpermute path ---
|
# --- moe_permute / moe_unpermute path ---
|
||||||
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
|
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
|
||||||
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute(
|
(permute_input, permute_indices_per_token, dst_weights, _scale_out, m_indices) = (
|
||||||
|
paddle.nn.functional.moe_permute(
|
||||||
hidden_states=recv_x,
|
hidden_states=recv_x,
|
||||||
scale=None,
|
scale=None,
|
||||||
expert_routemap_topk=recv_topk_idx_i32,
|
expert_routemap_topk=recv_topk_idx_i32,
|
||||||
@@ -165,17 +172,31 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
tokens_per_expert=[],
|
tokens_per_expert=[],
|
||||||
padding_alignment=128,
|
padding_alignment=128,
|
||||||
override_buffer_size=token_all_num,
|
override_buffer_size=token_all_num,
|
||||||
|
return_expert_indices=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if paddlefleet_ops is not None:
|
||||||
|
out = m_grouped_bf16_gemm_nn_contiguous(
|
||||||
|
permute_input, getattr(layer, self.added_weight_attrs[0]), m_indices
|
||||||
|
)
|
||||||
|
else:
|
||||||
out = paddle.incubate.nn.functional.batched_gemm(
|
out = paddle.incubate.nn.functional.batched_gemm(
|
||||||
permute_input,
|
permute_input,
|
||||||
getattr(layer, self.added_weight_attrs[0]),
|
getattr(layer, self.added_weight_attrs[0]),
|
||||||
recv_num_tokens_per_expert_list,
|
recv_num_tokens_per_expert_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE:
|
if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE:
|
||||||
out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights)
|
out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights)
|
||||||
else:
|
else:
|
||||||
out = paddle.incubate.nn.functional.swiglu(out)
|
out = paddle.incubate.nn.functional.swiglu(out)
|
||||||
|
|
||||||
|
if paddlefleet_ops is not None:
|
||||||
|
ffn_out = m_grouped_bf16_gemm_nn_contiguous(
|
||||||
|
out, getattr(layer, self.added_weight_attrs[1]), m_indices
|
||||||
|
)
|
||||||
|
else:
|
||||||
ffn_out = paddle.incubate.nn.functional.batched_gemm(
|
ffn_out = paddle.incubate.nn.functional.batched_gemm(
|
||||||
out,
|
out,
|
||||||
getattr(layer, self.added_weight_attrs[1]),
|
getattr(layer, self.added_weight_attrs[1]),
|
||||||
|
|||||||
Reference in New Issue
Block a user