[RL]moe bf16 ep support paddle batch_gemm (#7337)

* moe bf16 ep support paddle batch_gemm
This commit is contained in:
chen
2026-04-11 21:51:12 +08:00
committed by GitHub
parent ba01d7a823
commit 4982aa000e
2 changed files with 25 additions and 17 deletions
@@ -43,6 +43,7 @@ if current_platform.is_cuda():
logger.warning("import w4afp8_gemm_scale_permute Failed!")
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.layers.quantization.fp8_utils import paddlefleet_ops
from fastdeploy.model_executor.utils import (
TensorTracker,
free_tensor,
@@ -166,18 +167,19 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
override_buffer_size=token_all_num,
)
token_nums_per_expert_cumsum = count_tokens_per_expert_func(
recv_topk_idx, layer.num_local_experts, True
)[2].cast(paddle.int64)
ffn_out = self.compute_ffn(
layer,
out = paddle.incubate.nn.functional.batched_gemm(
permute_input,
token_nums_per_expert_cumsum,
None,
False,
-1,
None,
None,
getattr(layer, self.added_weight_attrs[0]),
recv_num_tokens_per_expert_list,
)
if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE:
out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights)
else:
out = paddle.incubate.nn.functional.swiglu(out)
ffn_out = paddle.incubate.nn.functional.batched_gemm(
out,
getattr(layer, self.added_weight_attrs[1]),
recv_num_tokens_per_expert_list,
)
tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
@@ -187,7 +189,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
token_prob_unzipped=dst_weights,
total_zipped_tokens=recv_x.shape[0],
num_experts=layer.num_local_experts,
using_weighted_combine=True,
using_weighted_combine=not fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE,
)
else:
# --- original ep_moe_expert_dispatch / combine path ---