[OP]Unify MoE op with moe_permute path for bf16 GLM (#7164) (#7279)

This commit is contained in:
fxyfxy777
2026-04-09 21:37:42 +08:00
committed by GitHub
parent 921a0ae60b
commit dea9d35171
5 changed files with 444 additions and 69 deletions
+3 -1
View File
@@ -537,7 +537,9 @@ std::vector<paddle::Tensor> TextImageGatherScatter(
const bool is_scatter);
std::vector<paddle::Tensor> count_tokens_per_expert_func(
const paddle::Tensor& topk_ids, int64_t num_experts);
const paddle::Tensor& topk_ids,
int64_t num_experts,
bool compute_padded_cumsum = false);
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
+52 -18
View File
@@ -15,10 +15,11 @@
#include "helper.h"
#include "paddle/extension.h"
template <typename scalar_t>
template <typename scalar_t, bool kComputeCumsum>
__global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids,
int32_t *__restrict__ res,
int32_t *__restrict__ res_padded,
int32_t *__restrict__ res_padded_cumsum,
size_t numel,
int num_experts) {
extern __shared__ int32_t tokens_per_ep[];
@@ -35,48 +36,81 @@ __global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids,
__syncthreads();
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
res[i] = tokens_per_ep[i];
res_padded[i] = (res[i] + 127) / 128 * 128;
if constexpr (kComputeCumsum) {
if (threadIdx.x == 0) {
int32_t running_sum = 0;
for (int i = 0; i < num_experts; i++) {
int32_t count = tokens_per_ep[i];
int32_t padded = (count + 127) / 128 * 128;
res[i] = count;
res_padded[i] = padded;
running_sum += padded;
res_padded_cumsum[i] = running_sum;
}
}
} else {
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
res[i] = tokens_per_ep[i];
res_padded[i] = (tokens_per_ep[i] + 127) / 128 * 128;
}
}
}
std::vector<paddle::Tensor> count_tokens_per_expert_func(
const paddle::Tensor &topk_ids, int64_t num_experts) {
const paddle::Tensor &topk_ids,
int64_t num_experts,
bool compute_padded_cumsum) {
int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1];
int64_t num_rows = compute_padded_cumsum ? 3 : 2;
auto token_nums_per_expert = paddle::empty(
{2, num_experts}, paddle::DataType::INT32, topk_ids.place());
{num_rows, num_experts}, paddle::DataType::INT32, topk_ids.place());
auto stream = topk_ids.stream();
using scalar_t = int64_t;
// CUDA_CHECK(cudaGetLastError());
cuda_kernel<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
topk_ids.data<scalar_t>(),
token_nums_per_expert.data<int32_t>(),
token_nums_per_expert.data<int32_t>() + num_experts,
topk_ids_numel,
num_experts);
if (compute_padded_cumsum) {
cuda_kernel<scalar_t, true>
<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
topk_ids.data<scalar_t>(),
token_nums_per_expert.data<int32_t>(),
token_nums_per_expert.data<int32_t>() + num_experts,
token_nums_per_expert.data<int32_t>() + 2 * num_experts,
topk_ids_numel,
num_experts);
} else {
cuda_kernel<scalar_t, false>
<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
topk_ids.data<scalar_t>(),
token_nums_per_expert.data<int32_t>(),
token_nums_per_expert.data<int32_t>() + num_experts,
nullptr,
topk_ids_numel,
num_experts);
}
// CUDA_CHECK(cudaGetLastError());
return {token_nums_per_expert};
}
std::vector<paddle::DataType> count_tokens_per_expert_func_infer_dtype(
const paddle::DataType &topk_ids_dtype, int64_t num_experts) {
const paddle::DataType &topk_ids_dtype,
int64_t num_experts,
bool compute_padded_cumsum) {
return {paddle::DataType::INT32};
}
std::vector<std::vector<int64_t>> count_tokens_per_expert_func_infer_shape(
const std::vector<int64_t> &topk_ids_shape, int64_t num_experts) {
return {{2, num_experts}};
const std::vector<int64_t> &topk_ids_shape,
int64_t num_experts,
bool compute_padded_cumsum) {
int64_t num_rows = compute_padded_cumsum ? 3 : 2;
return {{num_rows, num_experts}};
}
PD_BUILD_STATIC_OP(count_tokens_per_expert_func)
.Inputs({"topk_ids"})
.Outputs({"token_nums_per_expert"})
.Attrs({"num_experts:int64_t"})
.Attrs({"num_experts:int64_t", "compute_padded_cumsum:bool"})
.SetKernelFn(PD_KERNEL(count_tokens_per_expert_func))
.SetInferShapeFn(PD_INFER_SHAPE(count_tokens_per_expert_func_infer_shape))
.SetInferDtypeFn(PD_INFER_DTYPE(count_tokens_per_expert_func_infer_dtype));
@@ -28,7 +28,11 @@ from ..utils import get_tensor, group_wise_int4_weight_quantize, pack, rotate_mo
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
from fastdeploy.model_executor.ops.gpu import (
count_tokens_per_expert_func,
moe_expert_dispatch,
moe_expert_reduce,
)
try:
from fastdeploy.model_executor.ops.gpu import (
@@ -126,6 +130,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
# 2. EP Dispatch
dispatch_kwargs = {"expert_alignment": 128} if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE else {}
(
recv_x,
recv_topk_idx,
@@ -133,7 +138,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
recv_num_tokens_per_expert_list,
handle,
event,
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights, **dispatch_kwargs)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
@@ -146,54 +151,91 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
# 3. Compute ffn
if token_all_num > 0:
logger.debug(f"token_all_num {token_all_num}")
(
permute_input,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
expert_idx_per_token,
dequant_scale,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
recv_x,
recv_topk_idx,
recv_topk_weights,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
recv_num_tokens_per_expert_list,
token_all_num,
self.moe_quant_type,
)
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 and w4afp8 need expert_idx_per_token
# Other need not this tensor, so we make it None.
expert_idx_per_token = None
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
# --- moe_permute / moe_unpermute path ---
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute(
hidden_states=recv_x,
scale=None,
expert_routemap_topk=recv_topk_idx_i32,
expert_prob_topk=recv_topk_weights,
num_experts=layer.num_local_experts,
tokens_per_expert=[],
padding_alignment=128,
override_buffer_size=token_all_num,
)
token_nums_per_expert_cumsum = count_tokens_per_expert_func(
recv_topk_idx, layer.num_local_experts, True
)[2].cast(paddle.int64)
ffn_out = self.compute_ffn(
layer,
permute_input,
token_nums_per_expert_cumsum,
None,
False,
-1,
None,
None,
)
tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,
zipped_expertwise_rowmap=permute_indices_per_token,
expert_routemap_topk=recv_topk_idx_i32,
token_prob_unzipped=dst_weights,
total_zipped_tokens=recv_x.shape[0],
num_experts=layer.num_local_experts,
using_weighted_combine=True,
)
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
# --- original ep_moe_expert_dispatch / combine path ---
(
permute_input,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
expert_idx_per_token,
dequant_scale,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
recv_x,
recv_topk_idx,
recv_topk_weights,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
recv_num_tokens_per_expert_list,
token_all_num,
self.moe_quant_type,
)
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
expert_idx_per_token = None
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
if hasattr(layer, "up_gate_proj_in_scale"):
dequant_scale = None
if hasattr(layer, "up_gate_proj_in_scale"):
dequant_scale = None
ffn_out = self.compute_ffn(
layer,
permute_input,
recv_num_tokens_per_expert_list_cumsum,
expert_idx_per_token,
False,
-1,
dequant_scale,
)
ffn_out = self.compute_ffn(
layer,
permute_input,
recv_num_tokens_per_expert_list_cumsum,
expert_idx_per_token,
False,
-1,
dequant_scale,
)
# prmt back per rank
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None, # down_proj_bias,
False, # norm_topk_prob
1.0,
)
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None, # down_proj_bias,
False, # norm_topk_prob
1.0,
)
else:
tmp_ffn_out = recv_x
@@ -276,6 +318,69 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
"""
gate_out = gate(x)
gate_out = gate_out.cast("float32")
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
if layer.topk_method == "noaux_tc":
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
layer.top_k,
True, # apply_norm_weight
False,
)
topk_idx_i32 = topk_idx.astype(paddle.int32)
override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1)
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = ( # zipped_expertwise_rowmap
paddle.nn.functional.moe_permute(
hidden_states=x,
scale=None,
expert_routemap_topk=topk_idx_i32,
expert_prob_topk=topk_weights,
num_experts=layer.num_experts,
tokens_per_expert=[],
padding_alignment=128,
override_buffer_size=override_buffer_size,
)
)
# Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert.
token_nums_per_expert_cumsum = count_tokens_per_expert_func(topk_idx, layer.num_experts, True)[2].cast(
paddle.int64
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
ffn_out = self.compute_ffn(
layer,
permute_input,
token_nums_per_expert_cumsum,
None, # expert_idx_per_token not needed for w16a16 without bias
False,
-1,
None, # dequant_scale
None, # max_tokens_per_expert
)
fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,
zipped_expertwise_rowmap=permute_indices_per_token,
expert_routemap_topk=topk_idx_i32,
token_prob_unzipped=dst_weights,
total_zipped_tokens=x.shape[0],
num_experts=layer.num_experts,
using_weighted_combine=True,
)
return fused_moe_out
if layer.topk_method == "noaux_tc":
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out,
@@ -287,6 +392,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
getattr(layer, "renormalize", True),
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
)
(
permute_input,
token_nums_per_expert,
@@ -341,7 +447,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
expert_idx_per_token = None
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
ffn_out = self.compute_ffn(
layer,
permute_input,
@@ -363,7 +468,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
routed_scaling_factor=1.0,
)
return fused_moe_out
@@ -521,7 +521,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
else:
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts, False)
(
permute_input,
permute_scale,
@@ -805,7 +805,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
else:
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts, False)
(
permute_input,
permute_scale,
@@ -35,6 +35,7 @@ iluvatar_stub.paged_attention = lambda *args, **kwargs: None
iluvatar_stub.prefill_fused_paged_attention = lambda *args, **kwargs: None
sys.modules["fastdeploy.model_executor.ops.iluvatar"] = iluvatar_stub
import fastdeploy # noqa: E402
from fastdeploy.model_executor.layers import utils as layer_utils
from fastdeploy.model_executor.layers.moe import fused_moe_cutlass_backend as backend
@@ -709,3 +710,237 @@ class TestFusedMoeCutlassBackend:
int4_method.create_weights(
int4_layer, num_experts=2, hidden_size=4, moe_intermediate_size=2, model_format="paddle"
)
# ---------------------------------------------------------------------------
# Real-op tests for FD_USE_PHI_MOE_PERMUTE=True (w16a16, moe_permute path)
# ---------------------------------------------------------------------------
from fastdeploy.platforms import current_platform # noqa: E402
_CUDA_AVAILABLE = current_platform.is_cuda()
requires_cuda = pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA required")
class RealMoELayer(paddle.nn.Layer):
"""Minimal bf16 MoE layer with real weights for moe_permute path testing."""
def __init__(self, num_experts=4, hidden_size=64, moe_intermediate_size=32, top_k=2):
super().__init__()
self.fd_config = DummyFDConfig()
self.num_experts = num_experts
self.num_local_experts = num_experts
self.hidden_size = hidden_size
self.moe_intermediate_size = moe_intermediate_size
self.top_k = top_k
self.topk_method = "noaux_tc"
self.n_group = 1
self.topk_group = 1
self.routed_scaling_factor = 1.0
self.with_bias = False
self.ep_size = 1
self.ep_rank = 0
self.layer_idx = 0
self.weight_dtype = "bfloat16"
self.is_quantized = False
self.activation = "swiglu"
self.moe_quant_config = types.SimpleNamespace(moe_dynamic_quant=False, hadamard_block_size=128)
self.gate_correction_bias = self.create_parameter(
shape=[1, num_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
paddle.seed(0)
self.up_gate_proj_weight = self.create_parameter(
shape=[num_experts, 2 * moe_intermediate_size, hidden_size],
dtype="bfloat16",
)
self.down_proj_weight = self.create_parameter(
shape=[num_experts, hidden_size, moe_intermediate_size],
dtype="bfloat16",
)
self.up_gate_proj_weight.set_value(
paddle.randn([num_experts, 2 * moe_intermediate_size, hidden_size]).cast("bfloat16") * 0.01
)
self.down_proj_weight.set_value(
paddle.randn([num_experts, hidden_size, moe_intermediate_size]).cast("bfloat16") * 0.01
)
class SimpleLinearGate(paddle.nn.Layer):
def __init__(self, hidden_size, num_experts):
super().__init__()
self.weight = self.create_parameter(shape=[hidden_size, num_experts], dtype="float32")
def forward(self, x):
return paddle.matmul(x.cast("float32"), self.weight)
class TestMoePermuteTrueRealOps:
"""Real-op tests for FD_USE_PHI_MOE_PERMUTE=True on the w16a16 path."""
def _build(self, num_experts=4, hidden_size=64, moe_intermediate_size=32, top_k=2):
layer = RealMoELayer(
num_experts=num_experts,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
top_k=top_k,
)
gate = SimpleLinearGate(hidden_size, num_experts)
method = backend.CutlassMoEMethod(None)
method.moe_quant_type = "w16a16"
return layer, gate, method
@requires_cuda
def test_apply_tp_moe_permute_real_ops(self, monkeypatch):
"""FD_USE_PHI_MOE_PERMUTE=True + w16a16: real moe_permute/moe_unpermute/
count_tokens_per_expert_func/moe_expert_ffn all called end-to-end."""
monkeypatch.setattr(backend.fastdeploy.envs, "FD_USE_PHI_MOE_PERMUTE", True)
num_tokens, hidden_size = 8, 64
layer, gate, method = self._build(hidden_size=hidden_size)
paddle.seed(42)
x = paddle.randn([num_tokens, hidden_size], dtype="bfloat16")
# Spy: confirm moe_permute is called, moe_expert_dispatch is NOT
permute_called = {"v": False}
dispatch_called = {"v": False}
original_permute = paddle.nn.functional.moe_permute
def spy_permute(*args, **kwargs):
permute_called["v"] = True
return original_permute(*args, **kwargs)
monkeypatch.setattr(paddle.nn.functional, "moe_permute", spy_permute)
monkeypatch.setattr(
backend,
"moe_expert_dispatch",
lambda *a, **kw: (_ for _ in ()).throw(AssertionError("moe_expert_dispatch must not be called")),
)
out = method.apply_tp(layer, x, gate)
assert permute_called["v"], "moe_permute was not called"
assert not dispatch_called["v"], "moe_expert_dispatch must not be called"
assert list(out.shape) == [num_tokens, hidden_size], f"wrong output shape: {out.shape}"
assert not paddle.isnan(out).any(), "output contains NaN"
assert not paddle.isinf(out).any(), "output contains Inf"
@requires_cuda
def test_apply_ep_prefill_moe_permute_real_ops(self, monkeypatch):
"""FD_USE_PHI_MOE_PERMUTE=True + w16a16: EP prefill uses real moe_permute /
moe_unpermute / count_tokens_per_expert_func / moe_expert_ffn end-to-end.
The EP dispatch/combine are stubbed (no real NCCL needed).
Use num_tokens=128 and num_experts=4 so each expert gets exactly 64 tokens
(128 * top_k=2 / 4 experts = 64), satisfying moe_expert_ffn alignment."""
monkeypatch.setattr(backend.fastdeploy.envs, "FD_USE_PHI_MOE_PERMUTE", True)
# 128 tokens, top_k=2, 4 experts → 64 tokens/expert (128-aligned after padding)
num_tokens, hidden_size = 128, 64
layer, gate, method = self._build(num_experts=4, hidden_size=hidden_size, top_k=2)
paddle.seed(42)
x = paddle.randn([num_tokens, hidden_size], dtype="bfloat16")
# Stub only the EP communication runner (dispatch/combine).
# All on-device compute (moe_permute, moe_expert_ffn, moe_unpermute) runs for real.
class StubEPRunner:
ep_engine = types.SimpleNamespace(async_finish=False)
def moe_select(self, _layer, gate_out):
n = gate_out.shape[0]
# Route token i to experts (i % E) and ((i+1) % E) so all experts
# get tokens and recv_num_tokens_per_expert_list is accurate.
E = _layer.num_local_experts
idx0 = paddle.arange(n, dtype="int64") % E
idx1 = (paddle.arange(n, dtype="int64") + 1) % E
topk_ids = paddle.stack([idx0, idx1], axis=1)
topk_weights = paddle.ones([n, _layer.top_k], dtype="float32") / _layer.top_k
return topk_ids, topk_weights
def dispatch(self, x, topk_idx, topk_weights, **kwargs):
# Pass tensors through unchanged — single-rank, no real communication.
# Compute accurate recv_num_tokens_per_expert_list from topk_idx.
E = layer.num_local_experts
counts = [int((topk_idx == e).sum().item()) for e in range(E)]
return (
x,
topk_idx,
topk_weights,
counts,
object(),
types.SimpleNamespace(current_stream_wait=lambda: None),
)
def combine(self, ffn_out, handle, recv_topk_weights):
return ffn_out, types.SimpleNamespace(current_stream_wait=lambda: None)
method.ep_prefill_runner = StubEPRunner()
# Spy: confirm moe_permute is called inside ep_prefill
permute_called = {"v": False}
original_permute = paddle.nn.functional.moe_permute
def spy_permute(*args, **kwargs):
permute_called["v"] = True
return original_permute(*args, **kwargs)
monkeypatch.setattr(paddle.nn.functional, "moe_permute", spy_permute)
out = method.apply_ep_prefill(layer, x, gate)
assert permute_called["v"], "moe_permute was not called in ep_prefill path"
assert len(out.shape) == 2, f"wrong output ndim: {out.shape}"
assert out.shape[1] == hidden_size, f"wrong hidden_size: {out.shape}"
assert not paddle.isnan(out).any(), "output contains NaN"
assert not paddle.isinf(out).any(), "output contains Inf"
@requires_cuda
def test_apply_tp_moe_permute_non_noaux_tc(self, monkeypatch):
"""FD_USE_PHI_MOE_PERMUTE=True + w16a16 + topk_method != 'noaux_tc':
the else-branch calls moe_topk_select instead of get_moe_scores,
then proceeds through moe_permute / moe_expert_ffn / moe_unpermute."""
monkeypatch.setattr(backend.fastdeploy.envs, "FD_USE_PHI_MOE_PERMUTE", True)
num_tokens, hidden_size = 8, 64
layer, gate, method = self._build(hidden_size=hidden_size)
# Switch to non-noaux_tc to exercise the else-branch (moe_topk_select)
layer.topk_method = "greedy"
paddle.seed(7)
x = paddle.randn([num_tokens, hidden_size], dtype="bfloat16")
# Spy on which routing function is invoked
get_moe_scores_called = {"v": False}
moe_topk_select_called = {"v": False}
permute_called = {"v": False}
original_get_moe_scores = backend.get_moe_scores
original_moe_topk_select = fastdeploy.model_executor.ops.gpu.moe_topk_select
original_permute = paddle.nn.functional.moe_permute
def spy_get_moe_scores(*args, **kwargs):
get_moe_scores_called["v"] = True
return original_get_moe_scores(*args, **kwargs)
def spy_moe_topk_select(*args, **kwargs):
moe_topk_select_called["v"] = True
return original_moe_topk_select(*args, **kwargs)
def spy_permute(*args, **kwargs):
permute_called["v"] = True
return original_permute(*args, **kwargs)
monkeypatch.setattr(backend, "get_moe_scores", spy_get_moe_scores)
monkeypatch.setattr(fastdeploy.model_executor.ops.gpu, "moe_topk_select", spy_moe_topk_select)
monkeypatch.setattr(paddle.nn.functional, "moe_permute", spy_permute)
out = method.apply_tp(layer, x, gate)
assert not get_moe_scores_called["v"], "get_moe_scores must NOT be called for non-noaux_tc"
assert moe_topk_select_called["v"], "moe_topk_select must be called for non-noaux_tc"
assert permute_called["v"], "moe_permute must be called"
assert list(out.shape) == [num_tokens, hidden_size], f"wrong shape: {out.shape}"
assert not paddle.isnan(out).any(), "output contains NaN"
assert not paddle.isinf(out).any(), "output contains Inf"