mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
[RL]moe bf16 ep support paddle batch_gemm (#7337)
* moe bf16 ep support paddle batch_gemm
This commit is contained in:
@@ -43,6 +43,7 @@ if current_platform.is_cuda():
|
||||
logger.warning("import w4afp8_gemm_scale_permute Failed!")
|
||||
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.model_executor.layers.quantization.fp8_utils import paddlefleet_ops
|
||||
from fastdeploy.model_executor.utils import (
|
||||
TensorTracker,
|
||||
free_tensor,
|
||||
@@ -166,18 +167,19 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
override_buffer_size=token_all_num,
|
||||
)
|
||||
|
||||
token_nums_per_expert_cumsum = count_tokens_per_expert_func(
|
||||
recv_topk_idx, layer.num_local_experts, True
|
||||
)[2].cast(paddle.int64)
|
||||
ffn_out = self.compute_ffn(
|
||||
layer,
|
||||
out = paddle.incubate.nn.functional.batched_gemm(
|
||||
permute_input,
|
||||
token_nums_per_expert_cumsum,
|
||||
None,
|
||||
False,
|
||||
-1,
|
||||
None,
|
||||
None,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
recv_num_tokens_per_expert_list,
|
||||
)
|
||||
if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE:
|
||||
out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights)
|
||||
else:
|
||||
out = paddle.incubate.nn.functional.swiglu(out)
|
||||
ffn_out = paddle.incubate.nn.functional.batched_gemm(
|
||||
out,
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
recv_num_tokens_per_expert_list,
|
||||
)
|
||||
|
||||
tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
|
||||
@@ -187,7 +189,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
token_prob_unzipped=dst_weights,
|
||||
total_zipped_tokens=recv_x.shape[0],
|
||||
num_experts=layer.num_local_experts,
|
||||
using_weighted_combine=True,
|
||||
using_weighted_combine=not fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE,
|
||||
)
|
||||
else:
|
||||
# --- original ep_moe_expert_dispatch / combine path ---
|
||||
|
||||
@@ -40,6 +40,10 @@ from fastdeploy.model_executor.layers import utils as layer_utils
|
||||
from fastdeploy.model_executor.layers.moe import fused_moe_cutlass_backend as backend
|
||||
|
||||
|
||||
def align(x, y):
|
||||
return (x + y - 1) // y * y
|
||||
|
||||
|
||||
class DummyQuantConfig:
|
||||
def __init__(self, algo="weight_only_int8", is_quantized=False, is_checkpoint_bf16=False):
|
||||
self.algo = algo
|
||||
@@ -752,18 +756,18 @@ class RealMoELayer(paddle.nn.Layer):
|
||||
)
|
||||
paddle.seed(0)
|
||||
self.up_gate_proj_weight = self.create_parameter(
|
||||
shape=[num_experts, 2 * moe_intermediate_size, hidden_size],
|
||||
shape=[num_experts, hidden_size, 2 * moe_intermediate_size],
|
||||
dtype="bfloat16",
|
||||
)
|
||||
self.down_proj_weight = self.create_parameter(
|
||||
shape=[num_experts, hidden_size, moe_intermediate_size],
|
||||
shape=[num_experts, moe_intermediate_size, hidden_size],
|
||||
dtype="bfloat16",
|
||||
)
|
||||
self.up_gate_proj_weight.set_value(
|
||||
paddle.randn([num_experts, 2 * moe_intermediate_size, hidden_size]).cast("bfloat16") * 0.01
|
||||
paddle.randn([num_experts, hidden_size, 2 * moe_intermediate_size]).cast("bfloat16") * 0.01
|
||||
)
|
||||
self.down_proj_weight.set_value(
|
||||
paddle.randn([num_experts, hidden_size, moe_intermediate_size]).cast("bfloat16") * 0.01
|
||||
paddle.randn([num_experts, moe_intermediate_size, hidden_size]).cast("bfloat16") * 0.01
|
||||
)
|
||||
|
||||
|
||||
@@ -863,7 +867,9 @@ class TestMoePermuteTrueRealOps:
|
||||
# Pass tensors through unchanged — single-rank, no real communication.
|
||||
# Compute accurate recv_num_tokens_per_expert_list from topk_idx.
|
||||
E = layer.num_local_experts
|
||||
counts = [int((topk_idx == e).sum().item()) for e in range(E)]
|
||||
counts = [
|
||||
align(int((topk_idx == e).sum().item()), kwargs.get("expert_alignment", 1)) for e in range(E)
|
||||
]
|
||||
return (
|
||||
x,
|
||||
topk_idx,
|
||||
|
||||
Reference in New Issue
Block a user