diff --git a/custom_ops/gpu_ops/moe/prefill_permute_to_masked_gemm.cu b/custom_ops/gpu_ops/moe/prefill_permute_to_masked_gemm.cu index cc831e881c..ca12a12687 100644 --- a/custom_ops/gpu_ops/moe/prefill_permute_to_masked_gemm.cu +++ b/custom_ops/gpu_ops/moe/prefill_permute_to_masked_gemm.cu @@ -43,10 +43,11 @@ __global__ void PrefillPermuteToMaskedGemmKernel( smem_topk_ids[tidx] = topk_ids[token_idx * TOP_K + tidx]; } __syncthreads(); - + bool should_break = true; for (int slot = 0; slot < TOP_K; slot++) { int64_t expert_idx = smem_topk_ids[slot]; if (expert_idx != -1) { + should_break = false; if (tidx == 0) { smem_offset = atomicAdd(&token_nums_per_expert[expert_idx], 1); permuted_indice_map[token_idx * TOP_K + slot] = static_cast( @@ -89,6 +90,9 @@ __global__ void PrefillPermuteToMaskedGemmKernel( __syncthreads(); } } + if (should_break) { + break; + } } } @@ -117,10 +121,6 @@ std::vector PrefillPermuteToMaskedGemmDispatch( auto permute_x = GetEmptyTensor( {num_local_experts, max_token_num, hidden}, x.dtype(), place); - // Allocate permute_scale with transposed physical layout [E, S, M] - // but logical shape [E, M, S] with strides [S*M, 1, M] - // This matches the CuteDSL version's paddle.empty([E, S, M]).transpose((0, 2, - // 1)) auto permute_scale = GetEmptyTensor({num_local_experts, max_token_num, hidden_scale}, {static_cast(hidden_scale) * max_token_num, diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index a21efb0c85..66bb8bfeef 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -650,9 +650,12 @@ class EPPrefillRunner(EPRunner): "expert_alignment": expert_alignment, "allocate_on_comm_stream": EPPrefillRunner.allocate_on_comm_stream, "previous_event": event, - "num_worst_tokens": self.num_worst_tokens, - "skip_x_record_stream": self.num_worst_tokens > 0, } + + if envs.FD_USE_PFCC_DEEP_EP: + dispatch_args["num_worst_tokens"] = self.num_worst_tokens + dispatch_args["skip_x_record_stream"] = self.num_worst_tokens > 0 + return buffer.dispatch(**dispatch_args) def combine( @@ -674,8 +677,11 @@ class EPPrefillRunner(EPRunner): "topk_weights": recv_topk_weights, "previous_event": event, "allocate_on_comm_stream": EPPrefillRunner.allocate_on_comm_stream, - "skip_x_record_stream": self.num_worst_tokens > 0, } + + if envs.FD_USE_PFCC_DEEP_EP: + combine_args["skip_x_record_stream"] = self.num_worst_tokens > 0 + fused_moe_out, _, event = buffer.combine(**combine_args) return fused_moe_out, event diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index d21d88623f..cd24e9cba9 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -521,7 +521,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): cumsum_idx_gpu, m_indices, ) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8( - recv_x, + recv_x_value, recv_x_scale, recv_topk_idx, recv_topk_weights, diff --git a/tests/operators/test_permute_prefill_masked_gemm.py b/tests/operators/test_permute_prefill_masked_gemm.py index 3da75901a4..89c28dab74 100644 --- a/tests/operators/test_permute_prefill_masked_gemm.py +++ b/tests/operators/test_permute_prefill_masked_gemm.py @@ -113,6 +113,17 @@ class TestPrefillPermuteToMaskedGemm(unittest.TestCase): topk_ids_np[i, :] = experts mask = np.random.rand(num_tokens, topk) < sparsity topk_ids_np[mask] = -1 + + # The kernel breaks early when a block encounters an all-(-1) row, + # so valid rows must come first in token order. + # Sort rows: rows with at least one valid expert first, all-(-1) rows last. + valid_mask = (topk_ids_np >= 0).any(axis=1) + sorted_idx = np.concatenate([np.where(valid_mask)[0], np.where(~valid_mask)[0]]) + topk_ids_np = topk_ids_np[sorted_idx] + x_np = x_np[sorted_idx] + scale_np = scale_np[sorted_idx] + x = paddle.to_tensor(x_np).cast(x_dtype) + scale = paddle.to_tensor(scale_np).cast(scale_dtype).contiguous() topk_ids = paddle.to_tensor(topk_ids_np).cast(paddle.int64) permute_x, permute_scale, permuted_indice_map, token_nums_per_expert = call_prefill_permute_to_masked_gemm(