[BugFix] Fix ep compatibility issues & Optimize permute operator (#6821)

* fix ep compatibility issues & optimize permute operator

* fix ut

* fix ut
This commit is contained in:
RichardWooSJTU
2026-03-17 10:32:11 +08:00
committed by GitHub
parent a6351dea0b
commit 4ed483d20b
4 changed files with 26 additions and 9 deletions
@@ -43,10 +43,11 @@ __global__ void PrefillPermuteToMaskedGemmKernel(
smem_topk_ids[tidx] = topk_ids[token_idx * TOP_K + tidx];
}
__syncthreads();
bool should_break = true;
for (int slot = 0; slot < TOP_K; slot++) {
int64_t expert_idx = smem_topk_ids[slot];
if (expert_idx != -1) {
should_break = false;
if (tidx == 0) {
smem_offset = atomicAdd(&token_nums_per_expert[expert_idx], 1);
permuted_indice_map[token_idx * TOP_K + slot] = static_cast<int32_t>(
@@ -89,6 +90,9 @@ __global__ void PrefillPermuteToMaskedGemmKernel(
__syncthreads();
}
}
if (should_break) {
break;
}
}
}
@@ -117,10 +121,6 @@ std::vector<paddle::Tensor> PrefillPermuteToMaskedGemmDispatch(
auto permute_x = GetEmptyTensor(
{num_local_experts, max_token_num, hidden}, x.dtype(), place);
// Allocate permute_scale with transposed physical layout [E, S, M]
// but logical shape [E, M, S] with strides [S*M, 1, M]
// This matches the CuteDSL version's paddle.empty([E, S, M]).transpose((0, 2,
// 1))
auto permute_scale =
GetEmptyTensor({num_local_experts, max_token_num, hidden_scale},
{static_cast<int64_t>(hidden_scale) * max_token_num,