[RL]moe bf16 ep support paddle batch_gemm (#7337)

* moe bf16 ep support paddle batch_gemm
This commit is contained in:
chen
2026-04-11 21:51:12 +08:00
committed by GitHub
parent ba01d7a823
commit 4982aa000e
2 changed files with 25 additions and 17 deletions
+11 -5
View File
@@ -40,6 +40,10 @@ from fastdeploy.model_executor.layers import utils as layer_utils
from fastdeploy.model_executor.layers.moe import fused_moe_cutlass_backend as backend
def align(x, y):
return (x + y - 1) // y * y
class DummyQuantConfig:
def __init__(self, algo="weight_only_int8", is_quantized=False, is_checkpoint_bf16=False):
self.algo = algo
@@ -752,18 +756,18 @@ class RealMoELayer(paddle.nn.Layer):
)
paddle.seed(0)
self.up_gate_proj_weight = self.create_parameter(
shape=[num_experts, 2 * moe_intermediate_size, hidden_size],
shape=[num_experts, hidden_size, 2 * moe_intermediate_size],
dtype="bfloat16",
)
self.down_proj_weight = self.create_parameter(
shape=[num_experts, hidden_size, moe_intermediate_size],
shape=[num_experts, moe_intermediate_size, hidden_size],
dtype="bfloat16",
)
self.up_gate_proj_weight.set_value(
paddle.randn([num_experts, 2 * moe_intermediate_size, hidden_size]).cast("bfloat16") * 0.01
paddle.randn([num_experts, hidden_size, 2 * moe_intermediate_size]).cast("bfloat16") * 0.01
)
self.down_proj_weight.set_value(
paddle.randn([num_experts, hidden_size, moe_intermediate_size]).cast("bfloat16") * 0.01
paddle.randn([num_experts, moe_intermediate_size, hidden_size]).cast("bfloat16") * 0.01
)
@@ -863,7 +867,9 @@ class TestMoePermuteTrueRealOps:
# Pass tensors through unchanged — single-rank, no real communication.
# Compute accurate recv_num_tokens_per_expert_list from topk_idx.
E = layer.num_local_experts
counts = [int((topk_idx == e).sum().item()) for e in range(E)]
counts = [
align(int((topk_idx == e).sum().item()), kwargs.get("expert_alignment", 1)) for e in range(E)
]
return (
x,
topk_idx,