[Feature] Support EP prefill with num_worst_tokens (#6574)

* support num worst tokens

* support num worst tokens

* fix build error

* support num worst tokens: fix errors

* support num worst tokens: fix feild

* support num worst tokens: delete requiements

* replace permute and depermute op by pure cuda

* replace permute and depermute op by pure cuda

* fix ci

* fix op

* fix nan

* fix code style

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
RichardWooSJTU
2026-03-11 17:09:07 +08:00
committed by GitHub
parent 0466c7e8a8
commit 9f0778f991
21 changed files with 1775 additions and 166 deletions
+35 -2
View File
@@ -313,9 +313,11 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const int token_nums_this_rank_padded);
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
const int block_size);
const int block_size,
const bool use_ue8m0);
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
const int block_size);
const int block_size,
const bool use_ue8m0);
std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
paddle::Tensor& input,
@@ -1175,6 +1177,19 @@ std::vector<paddle::Tensor> get_attn_mask_q(
const paddle::optional<paddle::Tensor>& attn_mask_kv,
const int kv_token_num);
std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
const paddle::Tensor& x,
const paddle::Tensor& scale,
const paddle::Tensor& topk_ids,
const int num_local_experts,
const int max_token_num);
std::vector<paddle::Tensor> DepermutePrefillCombine(
const paddle::Tensor& x,
const paddle::Tensor& indice_map,
const paddle::Tensor& topk_weights,
const int num_worst_tokens);
void RadixTopkRaggedTransform(
paddle::Tensor& input,
paddle::Tensor& output_indices,
@@ -1367,12 +1382,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&PerTokenQuant,
py::arg("input"),
py::arg("block_size"),
py::arg("use_ue8m0"),
"per token per block quant");
m.def("per_token_quant_padding",
&PerTokenQuantPadding,
py::arg("input"),
py::arg("block_size"),
py::arg("use_ue8m0"),
"per token per block quant and padding transpose scale");
m.def("fused_mask_swiglu_fp8_quant",
@@ -1826,6 +1843,22 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("custom_numpy_to_tensor",
&CustomNumpyToTensor,
"custom_numpy_to_tensor function");
m.def("prefill_permute_to_masked_gemm",
&PrefillPermuteToMaskedGemm,
py::arg("x"),
py::arg("scale"),
py::arg("topk_ids"),
py::arg("num_local_experts"),
py::arg("max_token_num"),
"Prefill permute to masked GEMM for MoE");
m.def("depermute_prefill_combine",
&DepermutePrefillCombine,
py::arg("x"),
py::arg("indice_map"),
py::arg("topk_weights"),
py::arg("num_worst_tokens"),
"Depermute and combine expert outputs for MoE prefill");
m.def("radix_topk_ragged_transform",
&RadixTopkRaggedTransform,
@@ -20,10 +20,18 @@ __host__ __device__ __forceinline__ int ceil_div(int x, int y) {
return (x + y - 1) / y;
}
__host__ __device__ __forceinline__ int64_t ceil_div(int64_t x, int64_t y) {
return (x + y - 1) / y;
}
__host__ __device__ __forceinline__ int align(int x, int y) {
return ceil_div(x, y) * y;
}
__host__ __device__ __forceinline__ int64_t align(int64_t x, int64_t y) {
return ceil_div(x, y) * y;
}
#ifndef BOOL_SWITCH
#define BOOL_SWITCH(cond, name, ...) \
if (cond) { \
@@ -41,10 +49,10 @@ __global__ void fused_swiglu_fp8_quant_kernel(
const index_t* __restrict__ token_nums_per_expert,
phi::dtype::float8_e4m3fn* __restrict__ out_fp8,
ScaleT* __restrict__ out_scale,
int group_num,
int group_size,
int hidden_size,
int hidden_size_scale,
int64_t group_num,
int64_t group_size,
int64_t hidden_size,
int64_t hidden_size_scale,
bool use_finegrained_range) {
constexpr int BLOCK = 128;
@@ -53,7 +61,7 @@ __global__ void fused_swiglu_fp8_quant_kernel(
int warp = tid >> 5;
int num_warps = blockDim.x >> 5;
int block_id = static_cast<int64_t>(blockIdx.x);
int64_t block_id = static_cast<int64_t>(blockIdx.x);
using VecBF16 = AlignedVector<T, 4>;
VecBF16 x1_vec, x2_vec;
@@ -62,13 +70,13 @@ __global__ void fused_swiglu_fp8_quant_kernel(
while (true) {
// ================= token mapping =================
int expert = -1;
int token_in_expert = -1;
int64_t expert = -1;
int64_t token_in_expert = -1;
if (lane == 0) {
int cumsum = 0;
for (int i = 0; i < group_num; ++i) {
int cnt = token_nums_per_expert[i];
int64_t cumsum = 0;
for (int64_t i = 0; i < group_num; ++i) {
int64_t cnt = static_cast<int64_t>(token_nums_per_expert[i]);
if (block_id >= cumsum && block_id < cumsum + cnt) {
expert = i;
token_in_expert = block_id - cumsum;
@@ -84,17 +92,17 @@ __global__ void fused_swiglu_fp8_quant_kernel(
if (expert < 0 || token_in_expert >= group_size) break;
// ================= base pointers =================
int token = expert * group_size + token_in_expert;
int64_t token = expert * group_size + token_in_expert;
const T* in = input + token * hidden_size * 2;
auto* out = out_fp8 + token * hidden_size;
int num_iters = hidden_size / BLOCK;
int64_t num_iters = hidden_size / BLOCK;
// ================= main loop =================
for (int iter = warp; iter < num_iters; iter += num_warps) {
int base = iter * BLOCK + lane * 4;
for (int64_t iter = warp; iter < num_iters; iter += num_warps) {
int64_t base = iter * BLOCK + lane * 4;
// vec load
Load(in + base, &x1_vec);
@@ -141,20 +149,20 @@ __global__ void fused_swiglu_fp8_quant_kernel(
const int exp = (__float_as_int(scale) >> 23) & 0xFF;
// 2. pack information
const int pack_idx = iter >> 2; // iter / 4
const int byte_idx = iter & 3; // iter % 4
const int64_t pack_idx = iter >> 2; // iter / 4
const int64_t byte_idx = iter & 3; // iter % 4
// 3. layout parameters
const int pack_num = ceil_div(hidden_size_scale, 4);
const int token_stride = align(group_size, 4);
const int64_t pack_num = ceil_div(hidden_size_scale, (int64_t)4);
const int64_t token_stride = align(group_size, (int64_t)4);
// 4. base pointer (int32 pack)
auto* scale_pack = reinterpret_cast<int32_t*>(out_scale);
// 5. column-major offset:
// [expert][pack][token]
const int base_idx = expert * pack_num * token_stride +
pack_idx * token_stride + token_in_expert;
const int64_t base_idx = expert * pack_num * token_stride +
pack_idx * token_stride + token_in_expert;
// 6. write one byte into pack
reinterpret_cast<uint8_t*>(&scale_pack[base_idx])[byte_idx] =
static_cast<uint8_t>(exp);
@@ -184,11 +192,11 @@ std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
const int block_size,
const bool use_ue8m0) {
auto dim = input.dims();
const int group_num = token_nums_per_expert.shape()[0];
const int group_size = dim[1];
const int hidden_size = dim[2] / 2;
const int hidden_size_scale = hidden_size / block_size;
const int token_num = group_num * group_size;
const int64_t group_num = token_nums_per_expert.shape()[0];
const int64_t group_size = dim[1];
const int64_t hidden_size = dim[2] / 2;
const int64_t hidden_size_scale = hidden_size / block_size;
const int64_t token_num = group_num * group_size;
auto out_fp8 = GetEmptyTensor({group_num, group_size, hidden_size},
paddle::DataType::FLOAT8_E4M3FN,
@@ -200,21 +208,22 @@ std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
paddle::DataType::FLOAT32,
input.place());
if (use_ue8m0) {
int hidden_size_scale_pack = ceil_div(hidden_size_scale, 4);
out_scale = GetEmptyTensor({group_num, group_size, hidden_size_scale_pack},
{hidden_size_scale_pack * align(group_size, 4),
1,
align(group_size, 4)},
paddle::DataType::INT32,
input.place());
int64_t hidden_size_scale_pack = ceil_div(hidden_size_scale, (int64_t)4);
int64_t group_size_aligned = align(group_size, (int64_t)4);
out_scale = GetEmptyTensor(
{group_num, group_size, hidden_size_scale_pack},
{hidden_size_scale_pack * group_size_aligned, 1, group_size_aligned},
paddle::DataType::INT32,
input.place());
}
int sm_count = 0;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, 0);
constexpr int BLOCKS_PER_SM = 2;
int gridx = std::min(sm_count * BLOCKS_PER_SM, token_num);
int blockx = std::min(1024, hidden_size / 128 * 32);
int gridx =
std::min(static_cast<int64_t>(sm_count * BLOCKS_PER_SM), token_num);
int blockx = std::min(1024L, hidden_size / 128 * 32);
bool use_finegrained_range = false;
if (auto* env = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"))
+7
View File
@@ -243,6 +243,13 @@ class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
typedef __nv_fp8_e4m3 DataType;
typedef paddle::float8_e4m3fn data_t;
};
template <>
class PDTraits<paddle::DataType::INT32> {
public:
typedef int32_t DataType;
typedef int32_t data_t;
};
#endif
template <typename T, int Size>
@@ -0,0 +1,218 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
constexpr int DEPERMUTE_BLOCK_THREADS = 512;
// Depermute and combine expert outputs back to token-major layout.
//
// For each output token, this kernel:
// 1. Loads indice_map and topk_weights from shared memory
// 2. For each valid expert slot (indice >= 0), reads the expert output row,
// scales by topk_weight, and accumulates in float32
// 3. Writes the combined result back in the original dtype
//
// x: [num_experts, max_tokens_per_expert, hidden] - expert outputs
// indice_map: [num_worst_tokens, topk] int32 - flat indices (expert_idx * M +
// offset) topk_weights: [num_worst_tokens, topk] float32 - combination weights
// depermuted_x: [num_worst_tokens, hidden] - output (same dtype as x)
template <typename T, int VecSize, int TOP_K>
__global__ void DepermutePrefillCombineKernel(
T* __restrict__ depermuted_x,
const T* __restrict__ x,
const int32_t* __restrict__ indice_map,
const float* __restrict__ topk_weights,
const int num_worst_tokens,
const int hidden,
const int max_tokens_per_expert) {
__shared__ int32_t smem_indices[TOP_K];
__shared__ float smem_weights[TOP_K];
const int tidx = threadIdx.x;
const int num_vecs = hidden / VecSize;
for (int token_idx = blockIdx.x; token_idx < num_worst_tokens;
token_idx += gridDim.x) {
// Thread 0 loads indice_map, thread 32 loads topk_weights
if (tidx < TOP_K) {
smem_indices[tidx] = indice_map[token_idx * TOP_K + tidx];
}
if (tidx >= TOP_K && tidx < 2 * TOP_K) {
int k = tidx - TOP_K;
smem_weights[k] = topk_weights[token_idx * TOP_K + k];
}
__syncthreads();
// Check if any expert slot is valid
bool need_store = false;
#pragma unroll
for (int k = 0; k < TOP_K; k++) {
if (smem_indices[k] >= 0) {
need_store = true;
break;
}
}
if (need_store) {
// Each thread processes a subset of hidden vectors
for (int v = tidx; v < num_vecs; v += DEPERMUTE_BLOCK_THREADS) {
// Initialize accumulator in float32
float acc[VecSize];
#pragma unroll
for (int i = 0; i < VecSize; i++) {
acc[i] = 0.0f;
}
// Accumulate weighted contributions from each expert
for (int k = 0; k < TOP_K; k++) {
int32_t indice = smem_indices[k];
if (indice >= 0) {
float weight = smem_weights[k];
int64_t expert_idx =
static_cast<int64_t>(indice) / max_tokens_per_expert;
int64_t offset =
static_cast<int64_t>(indice) % max_tokens_per_expert;
const T* src = x + expert_idx * max_tokens_per_expert * hidden +
offset * hidden;
AlignedVector<T, VecSize> vec;
Load<T, VecSize>(src + v * VecSize, &vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
acc[i] += static_cast<float>(vec[i]) * weight;
}
}
}
// Cast back and store
AlignedVector<T, VecSize> out_vec;
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(acc[i]);
}
Store<T, VecSize>(out_vec,
depermuted_x +
static_cast<int64_t>(token_idx) * hidden +
v * VecSize);
}
}
__syncthreads();
}
}
template <paddle::DataType D, int TOP_K>
std::vector<paddle::Tensor> DepermutePrefillCombineDispatch(
const paddle::Tensor& x,
const paddle::Tensor& indice_map,
const paddle::Tensor& topk_weights,
const int num_worst_tokens) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int hidden = x.shape()[2];
const int max_tokens_per_expert = x.shape()[1];
auto place = x.place();
auto stream = x.stream();
auto depermuted_x =
GetEmptyTensor({num_worst_tokens, hidden}, x.dtype(), place);
constexpr int VecSize = 16 / sizeof(DataType_);
int dev;
cudaGetDevice(&dev);
int sm_count;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
int num_blocks = min(sm_count * 2, num_worst_tokens);
DepermutePrefillCombineKernel<DataType_, VecSize, TOP_K>
<<<num_blocks, DEPERMUTE_BLOCK_THREADS, 0, stream>>>(
reinterpret_cast<DataType_*>(depermuted_x.data<data_t>()),
reinterpret_cast<const DataType_*>(x.data<data_t>()),
indice_map.data<int32_t>(),
topk_weights.data<float>(),
num_worst_tokens,
hidden,
max_tokens_per_expert);
return {depermuted_x};
}
std::vector<paddle::Tensor> DepermutePrefillCombine(
const paddle::Tensor& x,
const paddle::Tensor& indice_map,
const paddle::Tensor& topk_weights,
const int num_worst_tokens) {
const int topk = indice_map.shape()[1];
#define DISPATCH_TOPK(DTYPE, TOPK_VAL) \
case TOPK_VAL: \
return DepermutePrefillCombineDispatch<DTYPE, TOPK_VAL>( \
x, indice_map, topk_weights, num_worst_tokens);
switch (x.dtype()) {
case paddle::DataType::FLOAT8_E4M3FN: {
switch (topk) {
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 4)
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
case paddle::DataType::BFLOAT16: {
switch (topk) {
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 4)
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
default:
PD_THROW("Unsupported dtype, must be float8_e4m3fn or bfloat16");
}
#undef DISPATCH_TOPK
}
std::vector<std::vector<int64_t>> DepermutePrefillCombineInferShape(
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& indice_map_shape,
const std::vector<int64_t>& topk_weights_shape,
const int num_worst_tokens) {
int64_t hidden = x_shape[2];
return {{num_worst_tokens, hidden}};
}
std::vector<paddle::DataType> DepermutePrefillCombineInferDtype(
const paddle::DataType& x_dtype,
const paddle::DataType& indice_map_dtype,
const paddle::DataType& topk_weights_dtype,
const int num_worst_tokens) {
return {x_dtype};
}
PD_BUILD_STATIC_OP(depermute_prefill_combine)
.Inputs({"x", "indice_map", "topk_weights"})
.Outputs({"depermuted_x"})
.Attrs({"num_worst_tokens: int"})
.SetKernelFn(PD_KERNEL(DepermutePrefillCombine))
.SetInferShapeFn(PD_INFER_SHAPE(DepermutePrefillCombineInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(DepermutePrefillCombineInferDtype));
@@ -0,0 +1,283 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
constexpr int BLOCK_THREADS = 512;
template <typename T, typename ScaleT, int VecSize, int TOP_K>
__global__ void PrefillPermuteToMaskedGemmKernel(
T* __restrict__ permute_x,
ScaleT* __restrict__ permute_scale,
int32_t* __restrict__ permuted_indice_map,
int32_t* __restrict__ token_nums_per_expert,
const T* __restrict__ x,
const ScaleT* __restrict__ scale,
const int64_t* __restrict__ topk_ids,
const int num_tokens,
const int hidden,
const int hidden_scale,
const int max_tokens_per_expert) {
__shared__ int32_t smem_offset;
__shared__ int64_t smem_topk_ids[TOP_K];
const int tidx = threadIdx.x;
const int x_num_vecs = hidden / VecSize;
constexpr int ScaleVecSize = 16 / sizeof(float); // 4
const int scale_num_vecs = hidden_scale / ScaleVecSize;
for (int token_idx = blockIdx.x; token_idx < num_tokens;
token_idx += gridDim.x) {
if (tidx < TOP_K) {
smem_topk_ids[tidx] = topk_ids[token_idx * TOP_K + tidx];
}
__syncthreads();
for (int slot = 0; slot < TOP_K; slot++) {
int64_t expert_idx = smem_topk_ids[slot];
if (expert_idx != -1) {
if (tidx == 0) {
smem_offset = atomicAdd(&token_nums_per_expert[expert_idx], 1);
permuted_indice_map[token_idx * TOP_K + slot] = static_cast<int32_t>(
expert_idx * max_tokens_per_expert + smem_offset);
}
__syncthreads();
int offset = smem_offset;
// Vectorized copy of x[token_idx, :] -> permute_x[expert_idx, offset,
// :]
const T* src_x = x + static_cast<int64_t>(token_idx) * hidden;
T* dst_x =
permute_x +
static_cast<int64_t>(expert_idx) * max_tokens_per_expert * hidden +
static_cast<int64_t>(offset) * hidden;
AlignedVector<T, VecSize> vec_x;
for (int v = tidx; v < x_num_vecs; v += BLOCK_THREADS) {
Load<T, VecSize>(src_x + v * VecSize, &vec_x);
Store<T, VecSize>(vec_x, dst_x + v * VecSize);
}
// Copy scale[token_idx, :] -> permute_scale with transposed layout
// Physical layout is [E, S, M], accessed as [E, M, S] via strides [S*M,
// 1, M] So permute_scale[expert_idx, offset, s] -> physical addr:
// expert_idx*(S*M) + offset + s*M
const ScaleT* src_scale =
scale + static_cast<int64_t>(token_idx) * hidden_scale;
ScaleT* dst_scale_base = permute_scale +
static_cast<int64_t>(expert_idx) *
hidden_scale * max_tokens_per_expert +
offset;
for (int s = tidx; s < hidden_scale; s += BLOCK_THREADS) {
dst_scale_base[static_cast<int64_t>(s) * max_tokens_per_expert] =
src_scale[s];
}
__syncthreads();
}
}
}
}
template <paddle::DataType D, paddle::DataType ScaleD, int TOP_K>
std::vector<paddle::Tensor> PrefillPermuteToMaskedGemmDispatch(
const paddle::Tensor& x,
const paddle::Tensor& scale,
const paddle::Tensor& topk_ids,
const int num_local_experts,
const int max_token_num) {
typedef PDTraits<D> traits_;
typedef PDTraits<ScaleD> scale_traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
typedef typename scale_traits_::DataType ScaleDataType_;
typedef typename scale_traits_::data_t scale_data_t;
const int num_tokens = x.shape()[0];
const int hidden = x.shape()[1];
const int hidden_scale = scale.shape()[1];
const int topk = topk_ids.shape()[1];
auto place = x.place();
auto stream = x.stream();
auto permute_x = GetEmptyTensor(
{num_local_experts, max_token_num, hidden}, x.dtype(), place);
// Allocate permute_scale with transposed physical layout [E, S, M]
// but logical shape [E, M, S] with strides [S*M, 1, M]
// This matches the CuteDSL version's paddle.empty([E, S, M]).transpose((0, 2,
// 1))
auto permute_scale =
GetEmptyTensor({num_local_experts, max_token_num, hidden_scale},
{static_cast<int64_t>(hidden_scale) * max_token_num,
1,
static_cast<int64_t>(max_token_num)},
ScaleD,
place);
auto permuted_indice_map =
GetEmptyTensor({num_tokens, topk}, paddle::DataType::INT32, place);
auto token_nums_per_expert =
GetEmptyTensor({num_local_experts, 1}, paddle::DataType::INT32, place);
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(token_nums_per_expert.data<int32_t>(),
0,
num_local_experts * sizeof(int32_t),
stream));
// memset 0xFF for int32 produces -1
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(permuted_indice_map.data<int32_t>(),
0xFF,
num_tokens * topk * sizeof(int32_t),
stream));
constexpr int VecSize = 16 / sizeof(DataType_);
int dev;
cudaGetDevice(&dev);
int sm_count;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
int num_blocks = sm_count * 2;
PrefillPermuteToMaskedGemmKernel<DataType_, ScaleDataType_, VecSize, TOP_K>
<<<num_blocks, BLOCK_THREADS, 0, stream>>>(
reinterpret_cast<DataType_*>(permute_x.data<data_t>()),
reinterpret_cast<ScaleDataType_*>(
permute_scale.template data<scale_data_t>()),
permuted_indice_map.data<int32_t>(),
token_nums_per_expert.data<int32_t>(),
reinterpret_cast<const DataType_*>(x.data<data_t>()),
reinterpret_cast<const ScaleDataType_*>(
scale.template data<scale_data_t>()),
topk_ids.data<int64_t>(),
num_tokens,
hidden,
hidden_scale,
max_token_num);
return {permute_x, permute_scale, permuted_indice_map, token_nums_per_expert};
}
std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
const paddle::Tensor& x,
const paddle::Tensor& scale,
const paddle::Tensor& topk_ids,
const int num_local_experts,
const int max_token_num) {
const int topk = topk_ids.shape()[1];
#define DISPATCH_TOPK(DTYPE, SCALE_DTYPE, TOPK_VAL) \
case TOPK_VAL: \
return PrefillPermuteToMaskedGemmDispatch<DTYPE, SCALE_DTYPE, TOPK_VAL>( \
x, scale, topk_ids, num_local_experts, max_token_num);
switch (x.dtype()) {
case paddle::DataType::FLOAT8_E4M3FN: {
switch (scale.dtype()) {
case paddle::DataType::FLOAT32: {
switch (topk) {
DISPATCH_TOPK(
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32, 4)
DISPATCH_TOPK(
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
case paddle::DataType::INT32: {
switch (topk) {
DISPATCH_TOPK(
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32, 4)
DISPATCH_TOPK(
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
}
}
case paddle::DataType::BFLOAT16: {
switch (scale.dtype()) {
case paddle::DataType::FLOAT32: {
switch (topk) {
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 4)
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
case paddle::DataType::INT32: {
switch (topk) {
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::INT32, 4)
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::INT32, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
}
}
default:
PD_THROW("Unsupported dtype, must be float8_e4m3fn or bfloat16");
}
#undef DISPATCH_TOPK
}
std::vector<std::vector<int64_t>> PrefillPermuteToMaskedGemmInferShape(
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& scale_shape,
const std::vector<int64_t>& topk_ids_shape,
const int num_local_experts,
const int max_token_num) {
int64_t num_tokens = x_shape[0];
int64_t hidden = x_shape[1];
int64_t hidden_scale = scale_shape[1];
int64_t topk = topk_ids_shape[1];
return {
{num_local_experts, max_token_num, hidden},
{num_local_experts, max_token_num, hidden_scale},
{num_tokens, topk},
{num_local_experts, 1},
};
}
std::vector<paddle::DataType> PrefillPermuteToMaskedGemmInferDtype(
const paddle::DataType& x_dtype,
const paddle::DataType& scale_dtype,
const paddle::DataType& topk_ids_dtype,
const int num_local_experts,
const int max_token_num) {
return {
x_dtype, scale_dtype, paddle::DataType::INT32, paddle::DataType::INT32};
}
PD_BUILD_STATIC_OP(prefill_permute_to_masked_gemm)
.Inputs({"x", "scale", "topk_ids"})
.Outputs({"permute_x",
"permute_scale",
"permuted_indice_map",
"token_nums_per_expert"})
.Attrs({"num_local_experts: int", "max_token_num: int"})
.SetKernelFn(PD_KERNEL(PrefillPermuteToMaskedGemm))
.SetInferShapeFn(PD_INFER_SHAPE(PrefillPermuteToMaskedGemmInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PrefillPermuteToMaskedGemmInferDtype));
+235 -82
View File
@@ -16,11 +16,19 @@
constexpr float epsilon = 1e-10;
template <typename T>
__host__ __device__ __forceinline__ int ceil_div(int x, int y) {
return (x + y - 1) / y;
}
__host__ __device__ __forceinline__ int align(int x, int y) {
return ceil_div(x, y) * y;
}
template <typename T, typename ScaleT, bool UseUE8M0>
__global__ void quant_per_token_per_block(
const T *input,
phi::dtype::float8_e4m3fn *quanted_res,
float *quanted_scale,
ScaleT *quanted_scale,
const int token_num,
const int hidden_size,
const int hidden_size_scale,
@@ -38,10 +46,11 @@ __global__ void quant_per_token_per_block(
AlignedVector<float, NUM_PER_THREADS> load_vec_float;
AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec;
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
const T *input_now = input + token_idx * hidden_size;
const T *input_now = input + static_cast<int64_t>(token_idx) * hidden_size;
phi::dtype::float8_e4m3fn *quanted_res_now =
quanted_res + token_idx * hidden_size;
float *quanted_scale_now = quanted_scale + token_idx * hidden_size_scale;
quanted_res + static_cast<int64_t>(token_idx) * hidden_size;
float *quanted_scale_now = reinterpret_cast<float *>(quanted_scale) +
token_idx * hidden_size_scale;
// deal a block per warp
for (int iter = warp_id; iter < end_iter; iter += num_warp) {
const int start_offset = iter * 128;
@@ -83,11 +92,22 @@ __global__ void quant_per_token_per_block(
}
float scale_to_store = max_value_thread / MAX_VALUE;
// quant
if constexpr (UseUE8M0) {
scale_to_store =
exp2f(ceilf(log2f(fmaxf(scale_to_store, epsilon) + 5e-7f)));
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] * MAX_VALUE / max_value_thread);
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] / scale_to_store);
}
} else {
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] * MAX_VALUE / max_value_thread);
}
}
// store
if (is_valid_data)
@@ -95,14 +115,26 @@ __global__ void quant_per_token_per_block(
res_vec,
quanted_res_now + start_offset + lane_id * NUM_PER_THREADS);
if (lane_id == 0) {
quanted_scale_now[iter] = scale_to_store;
if constexpr (UseUE8M0) {
int exp = (reinterpret_cast<int &>(scale_to_store) >> 23) & 0xFF;
const int pack_idx = iter >> 2;
const int byte_idx = iter & 3;
const int pack_num = ceil_div(hidden_size_scale, 4);
int32_t *scale_now = quanted_scale;
const int base_idx = token_idx * pack_num + pack_idx;
reinterpret_cast<uint8_t *>(&scale_now[base_idx])[byte_idx] =
static_cast<uint8_t>(exp);
} else {
quanted_scale_now[iter] = scale_to_store;
}
}
}
}
}
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
const int block_size) {
const int block_size,
const bool use_ue8m0) {
auto input_dim = input.dims();
const int token_num = input_dim[0];
const int hidden_size = input_dim[1];
@@ -112,8 +144,7 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
auto quanted_x = GetEmptyTensor(
{token_num, hidden_size}, paddle::DataType::FLOAT8_E4M3FN, input.place());
auto quanted_scale = GetEmptyTensor(
{token_num, hidden_size_scale}, paddle::DataType::FLOAT32, input.place());
const int gridx = min(132 * 8, token_num);
const int blockx = min(1024, hidden_size / 128 * 32);
@@ -123,31 +154,70 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
}
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
if (use_ue8m0) {
auto quanted_scale =
GetEmptyTensor({token_num, ceil_div(hidden_size_scale, 4)},
paddle::DataType::INT32,
input.place());
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block<paddle::bfloat16, int32_t, true>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<int32_t>(),
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block<paddle::float16, int32_t, true>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<int32_t>(),
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
}
return {quanted_x, quanted_scale};
} else {
auto quanted_scale = GetEmptyTensor({token_num, hidden_size_scale},
paddle::DataType::FLOAT32,
input.place());
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block<paddle::bfloat16, float, false>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block<paddle::float16, float, false>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
}
return {quanted_x, quanted_scale};
}
return {quanted_x, quanted_scale};
}
std::vector<std::vector<int64_t>> PerTokenQuantInferShape(
@@ -155,19 +225,26 @@ std::vector<std::vector<int64_t>> PerTokenQuantInferShape(
const int token_num = input_shape[0];
const int hidden_size = input_shape[1];
const int hidden_size_scale = (hidden_size + block_size - 1) / block_size;
if (GetSMVersion() >= 100) {
return {{token_num, hidden_size},
{token_num, ceil_div(hidden_size_scale, 4)}};
}
return {{token_num, hidden_size}, {token_num, hidden_size_scale}};
}
std::vector<paddle::DataType> PerTokenQuantInferDtype(
paddle::DataType input_dtype, const int block_size) {
if (GetSMVersion() >= 100) {
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32};
}
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32};
}
template <typename T>
template <typename T, typename ScaleT, bool UseUE8M0>
__global__ void quant_per_token_per_block_padding(
const T *input,
phi::dtype::float8_e4m3fn *quanted_res,
float *quanted_scale,
ScaleT *quanted_scale,
const int token_num,
const int padded_token_num,
const int hidden_size,
@@ -185,13 +262,11 @@ __global__ void quant_per_token_per_block_padding(
AlignedVector<float, NUM_PER_THREADS> load_vec_float;
AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec;
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
const T *input_now = input + token_idx * hidden_size;
const T *input_now = input + static_cast<int64_t>(token_idx) * hidden_size;
phi::dtype::float8_e4m3fn *quanted_res_now =
quanted_res + token_idx * hidden_size;
quanted_res + static_cast<int64_t>(token_idx) * hidden_size;
// deal a block per warp
for (int iter = warp_id; iter < end_iter; iter += num_warp) {
float *quanted_scale_now =
quanted_scale + iter * padded_token_num + token_idx;
const int start_offset = iter * 128;
Load<T, NUM_PER_THREADS>(
input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec);
@@ -222,26 +297,58 @@ __global__ void quant_per_token_per_block_padding(
}
float scale_to_store = max_value_thread / MAX_VALUE;
// quant
if constexpr (UseUE8M0) {
scale_to_store =
exp2f(ceilf(log2f(fmaxf(scale_to_store, epsilon) + 5e-7f)));
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] * MAX_VALUE / max_value_thread);
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] / scale_to_store);
}
} else {
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] * MAX_VALUE / max_value_thread);
}
}
// store
Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>(
res_vec, quanted_res_now + start_offset + lane_id * NUM_PER_THREADS);
if (lane_id == 0) {
*quanted_scale_now = scale_to_store;
if constexpr (UseUE8M0) {
// exp
int exp = (reinterpret_cast<int &>(scale_to_store) >> 23) & 0xFF;
const int pack_idx = iter >> 2;
const int byte_idx = iter & 3;
// pack
const int pack_num = align(hidden_size_scale, 4) >> 2;
// column-major base index
int32_t *scale_now = quanted_scale;
const int base_idx = token_idx + pack_idx * padded_token_num;
// ---------------- store exp ----------------
reinterpret_cast<uint8_t *>(&scale_now[base_idx])[byte_idx] =
static_cast<uint8_t>(exp);
} else {
float *scale_now =
quanted_scale + iter * padded_token_num + token_idx;
*scale_now = scale_to_store;
}
}
}
}
}
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
const int block_size) {
const int block_size,
const bool use_ue8m0) {
using ScaleDtype = float;
auto input_dim = input.dims();
const int token_num = input_dim[0];
const int hidden_size = input_dim[1];
@@ -256,13 +363,11 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
const int tma_alignment_bytes = 16;
const int tma_alignment_elements = tma_alignment_bytes / sizeof(ScaleDtype);
const int padded_token_num =
((token_num + tma_alignment_elements - 1) / tma_alignment_elements) *
tma_alignment_elements;
auto quanted_scale = GetEmptyTensor({padded_token_num, hidden_size_scale},
{1, padded_token_num},
paddle::DataType::FLOAT32,
input.place());
const int gridx = min(132 * 8, token_num);
const int blockx = min(1024, hidden_size / 128 * 32);
@@ -271,34 +376,76 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
if (env_var) {
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
}
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<ScaleDtype>(),
token_num,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<ScaleDtype>(),
token_num,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
if (use_ue8m0) {
auto quanted_scale =
GetEmptyTensor({padded_token_num, ceil_div(hidden_size_scale, 4)},
{1, padded_token_num},
paddle::DataType::INT32,
input.place());
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block_padding<paddle::bfloat16, int32_t, true>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<int32_t>(),
token_num,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block_padding<paddle::float16, int32_t, true>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<int32_t>(),
token_num,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
}
return {quanted_x, quanted_scale};
} else {
auto quanted_scale = GetEmptyTensor({padded_token_num, hidden_size_scale},
{1, padded_token_num},
paddle::DataType::FLOAT32,
input.place());
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block_padding<paddle::bfloat16, float, false>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block_padding<paddle::float16, float, false>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
}
return {quanted_x, quanted_scale};
}
return {quanted_x, quanted_scale};
}
std::vector<std::vector<int64_t>> PerTokenQuantPaddingInferShape(
@@ -314,19 +461,25 @@ std::vector<std::vector<int64_t>> PerTokenQuantPaddingInferShape(
const int padded_token_num =
((token_num + tma_alignment_elements - 1) / tma_alignment_elements) *
tma_alignment_elements;
if (GetSMVersion() >= 100) {
return {{token_num, hidden_size},
{padded_token_num, ceil_div(hidden_size_scale, 4)}};
}
return {{token_num, hidden_size}, {padded_token_num, hidden_size_scale}};
}
std::vector<paddle::DataType> PerTokenQuantPaddingInferDtype(
paddle::DataType input_dtype) {
if (GetSMVersion() >= 100) {
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32};
}
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32};
}
PD_BUILD_STATIC_OP(per_token_quant)
.Inputs({"input"})
.Outputs({"output", "output_scale"})
.Attrs({"block_size: int"})
.Attrs({"block_size: int", "use_ue8m0: bool"})
.SetKernelFn(PD_KERNEL(PerTokenQuant))
.SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantInferDtype));
@@ -334,7 +487,7 @@ PD_BUILD_STATIC_OP(per_token_quant)
PD_BUILD_STATIC_OP(per_token_quant_padding)
.Inputs({"input"})
.Outputs({"output", "output_scale"})
.Attrs({"block_size: int"})
.Attrs({"block_size: int", "use_ue8m0: bool"})
.SetKernelFn(PD_KERNEL(PerTokenQuantPadding))
.SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantPaddingInferDtype));
+2
View File
@@ -643,6 +643,8 @@ class ParallelConfig:
self.disable_sequence_parallel_moe: bool = False
# shutdown comm group if worker idle
self.shutdown_comm_group_if_worker_idle: bool = None
# ep_prefill_use_worst_num_tokens
self.ep_prefill_use_worst_num_tokens: bool = False
self.pod_ip: str = None
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
+11
View File
@@ -546,6 +546,11 @@ class EngineArgs:
Flag to enable entropy output. Default is False (disabled).
"""
ep_prefill_use_worst_num_tokens: bool = False
"""
Flag to enable prefill_use_worst_num_tokens. Default is False (disabled).
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -1060,6 +1065,12 @@ class EngineArgs:
default=EngineArgs.shutdown_comm_group_if_worker_idle,
help="Shutdown communication group when worker is idle.",
)
parallel_group.add_argument(
"--ep-prefill-use-worst-num-tokens",
action="store_true",
default=EngineArgs.ep_prefill_use_worst_num_tokens,
help="Enable prefill use worst num tokens for EP.",
)
# Load group
load_group = parser.add_argument_group("Load Configuration")
+1
View File
@@ -643,6 +643,7 @@ class LLMEngine:
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
"shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle,
"enable_entropy": self.cfg.model_config.enable_entropy,
"ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
}
for worker_flag, value in worker_store_true_flag.items():
+2
View File
@@ -865,6 +865,7 @@ class RequestMetrics:
llm_engine_recv_req_timestamp: Optional[float] = None
llm_engine_send_req_to_engine_timestamp: Optional[float] = None
llm_engine_recv_latest_token_timestamp: Optional[float] = None
llm_engine_recv_token_timestamp: Optional[float] = None
speculate_metrics: Optional[SpeculateMetrics] = None
@@ -933,6 +934,7 @@ class RequestMetrics:
# for compatibility with old metrics
self.llm_engine_recv_req_timestamp = self.engine_get_req_time
self.llm_engine_send_req_to_engine_timestamp = self.inference_start_time
self.llm_engine_recv_token_timestamp = self.engine_recv_first_token_time
def get(self, key: str, default_value=None):
if hasattr(self, key):
@@ -319,7 +319,7 @@ class FlashAttentionBackend(AttentionBackend):
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
forward_meta.attention_metadata = metadata
self.attention_metadata = metadata
def forward_mixed(
self,
@@ -332,7 +332,7 @@ class FlashAttentionBackend(AttentionBackend):
layer: Attention,
forward_meta: ForwardMeta,
):
metadata = forward_meta.attention_metadata
metadata = self.attention_metadata
if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
@@ -340,6 +340,14 @@ class FlashAttentionBackend(AttentionBackend):
layer.layer_id + self.start_layer_index,
)
if int(os.getenv("USE_TBO", "0")) == 1:
if hasattr(forward_meta, "tbo_microbatch_id"):
# here we only let the last microbatch invoke cache kv transfer
if forward_meta.tbo_microbatch_id == 0:
os.environ["FLAGS_fmt_write_cache_completed_signal"] = "0"
elif forward_meta.tbo_microbatch_id == 1:
os.environ["FLAGS_fmt_write_cache_completed_signal"] = "1"
norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None
@@ -369,15 +377,15 @@ class FlashAttentionBackend(AttentionBackend):
if forward_meta.max_len_tensor_cpu[1].item() > 0:
metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0
forward_meta.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
forward_meta.max_len_tensor_cpu_decoder[1] = 0
(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
metadata.kv_token_num_cpu,
forward_meta.cu_seqlens_k,
forward_meta.pre_cache_batch_ids,
forward_meta.pre_cache_tile_ids_per_batch,
forward_meta.pre_cache_num_blocks_cpu,
forward_meta.kv_token_num_cpu,
) = pre_cache_len_concat(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
@@ -386,12 +394,14 @@ class FlashAttentionBackend(AttentionBackend):
self.block_size,
)
if FLASH_ATTN_VERSION == 4 or forward_meta.attn_mask_offsets is not None:
metadata.attn_mask_q = get_attn_mask_q(
forward_meta.attn_mask_q = get_attn_mask_q(
cu_seqlens_q=forward_meta.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_k,
cu_seqlens_k=forward_meta.cu_seqlens_k,
attn_mask_kv=forward_meta.attn_mask_offsets,
kv_token_num=metadata.kv_token_num_cpu[0].item(),
kv_token_num=forward_meta.kv_token_num_cpu[0].item(),
)
else:
forward_meta.attn_mask_q = None
use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0
@@ -401,7 +411,7 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
forward_meta.cu_seqlens_q,
metadata.cu_seqlens_k,
forward_meta.cu_seqlens_k,
forward_meta.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
@@ -411,9 +421,9 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
forward_meta.pre_cache_batch_ids,
forward_meta.pre_cache_tile_ids_per_batch,
forward_meta.pre_cache_num_blocks_cpu,
q_norm_weight,
k_norm_weight,
getattr(layer, "cache_k_scale", None),
@@ -423,7 +433,7 @@ class FlashAttentionBackend(AttentionBackend):
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
metadata.kv_signal_data_list[layer.layer_id],
metadata.kv_token_num_cpu[0].item(),
forward_meta.kv_token_num_cpu[0].item(),
self.max_seq_len,
getattr(layer, "rms_norm_eps", 1e-6),
layer.use_neox_rotary_style,
@@ -435,11 +445,11 @@ class FlashAttentionBackend(AttentionBackend):
q,
k,
v,
forward_meta.cu_seqlens_q[: metadata.cu_seqlens_k.shape[0]],
metadata.cu_seqlens_k,
forward_meta.cu_seqlens_q[: forward_meta.cu_seqlens_k.shape[0]],
forward_meta.cu_seqlens_k,
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
attn_mask_q=metadata.attn_mask_q,
attn_mask_q=forward_meta.attn_mask_q,
causal=self.causal,
num_heads=self.num_heads,
kv_num_heads=self.kv_num_heads,
@@ -465,7 +475,7 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
metadata.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu,
forward_meta.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu,
forward_meta.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,
@@ -164,7 +164,7 @@ class FlashMaskAttentionBackend(AttentionBackend):
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
forward_meta.attention_metadata = metadata
self.attention_metadata = metadata
def forward_mixed(
self,
@@ -177,7 +177,7 @@ class FlashMaskAttentionBackend(AttentionBackend):
layer: Attention,
forward_meta: ForwardMeta,
):
metadata = forward_meta.attention_metadata
metadata = self.attention_metadata
norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
+8 -5
View File
@@ -58,12 +58,9 @@ def load_deep_ep() -> ModuleType:
return deep_ep
except Exception as e:
logger.error(
"import deep_ep failed! FD_USE_PFCC_DEEP_EP=%s. type=%s, err=%s",
envs.FD_USE_PFCC_DEEP_EP,
type(e).__name__,
e,
f"import deep_ep failed! FD_USE_PFCC_DEEP_EP={envs.FD_USE_PFCC_DEEP_EP}. type={type(e).__name__}, err={e}"
)
logger.error("Traceback:\n%s", traceback.format_exc())
logger.error(f"Traceback:{traceback.format_exc()}")
raise
@@ -586,6 +583,7 @@ class EPPrefillRunner(EPRunner):
moe_phase: MoEPhase = MoEPhase("prefill"),
ep_group=None,
use_internode_ll_two_stage: bool = False,
prefill_num_worst_tokens: int = 0,
):
super().__init__(
top_k,
@@ -600,6 +598,8 @@ class EPPrefillRunner(EPRunner):
ep_group=ep_group,
use_internode_ll_two_stage=use_internode_ll_two_stage,
)
self.num_worst_tokens = prefill_num_worst_tokens
logger.info(f"prefill_num_worst_tokens {prefill_num_worst_tokens}")
def set_allocate_on_comm_stream(allocate_on_comm_stream: bool = False):
if EPPrefillRunner.allocate_on_comm_stream == allocate_on_comm_stream:
@@ -650,6 +650,8 @@ class EPPrefillRunner(EPRunner):
"expert_alignment": expert_alignment,
"allocate_on_comm_stream": EPPrefillRunner.allocate_on_comm_stream,
"previous_event": event,
"num_worst_tokens": self.num_worst_tokens,
"skip_x_record_stream": self.num_worst_tokens > 0,
}
return buffer.dispatch(**dispatch_args)
@@ -672,6 +674,7 @@ class EPPrefillRunner(EPRunner):
"topk_weights": recv_topk_weights,
"previous_event": event,
"allocate_on_comm_stream": EPPrefillRunner.allocate_on_comm_stream,
"skip_x_record_stream": self.num_worst_tokens > 0,
}
fused_moe_out, _, event = buffer.combine(**combine_args)
return fused_moe_out, event
@@ -14,6 +14,7 @@
# limitations under the License.
"""
import os
from abc import abstractmethod
from typing import Callable
@@ -92,6 +93,18 @@ class MoEMethodBase(QuantMethodBase):
splitwise_role = config.scheduler_config.splitwise_role
load_strategy = config.load_config.load_strategy
if config.parallel_config.ep_prefill_use_worst_num_tokens:
token_split_factor = 2 if int(os.getenv("USE_TBO", "0")) == 1 else 1
prefill_num_worst_tokens = (
config.scheduler_config.max_num_batched_tokens
// config.parallel_config.tensor_parallel_size
* layer.ep_size
* layer.top_k
// token_split_factor
)
else:
prefill_num_worst_tokens = 0
# For "mixed" splitwise role: conditionally initialize both or none
if splitwise_role == "mixed":
if load_strategy == "meta":
@@ -102,6 +115,7 @@ class MoEMethodBase(QuantMethodBase):
self.ep_prefill_runner = self.EPPrefillRunner(
**common_args,
use_internode_ll_two_stage=layer.fd_config.parallel_config.use_internode_ll_two_stage,
prefill_num_worst_tokens=prefill_num_worst_tokens,
)
self.ep_decoder_runner = self.EPDecoderRunner(
**common_args,
@@ -119,6 +133,7 @@ class MoEMethodBase(QuantMethodBase):
self.ep_prefill_runner = self.EPPrefillRunner(
**common_args,
use_internode_ll_two_stage=layer.fd_config.parallel_config.use_internode_ll_two_stage,
prefill_num_worst_tokens=prefill_num_worst_tokens,
)
else:
self.ep_decoder_runner = self.EPDecoderRunner(
@@ -14,6 +14,8 @@
# limitations under the License.
"""
import os
import threading
from typing import Callable
import paddle
@@ -24,7 +26,11 @@ import fastdeploy
from fastdeploy.model_executor.layers.moe.ep import deep_ep
from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func
from fastdeploy.model_executor.ops.gpu import (
count_tokens_per_expert_func,
depermute_prefill_combine,
prefill_permute_to_masked_gemm,
)
from fastdeploy.platforms import current_platform
from fastdeploy.utils import register_custom_python_op
from fastdeploy.worker.tbo import let_another_thread_run
@@ -43,6 +49,60 @@ else:
m_grouped_fp8_gemm_nt_contiguous = None
m_grouped_fp8_gemm_nt_masked = None
global_values = {}
def call_prefill_permute_to_masked_gemm(
x: paddle.Tensor,
scale: paddle.Tensor,
topk_ids: paddle.Tensor,
num_local_experts: int,
max_token_num: int,
):
"""
Permute input tokens and scales from token-major to expert-major layout
for MoE masked GEMM operations.
Args:
x: Input hidden states [num_tokens, hidden].
scale: Input scales [num_tokens, hidden_scale].
topk_ids: Expert routing indices [num_tokens, topk] (int64 or int32).
num_local_experts: Number of local experts on this device.
max_token_num: Maximum tokens per expert buffer.
Returns:
tuple: (permute_x, permute_scale, permuted_indice_map, token_nums_per_expert)
"""
if topk_ids.dtype != paddle.int64:
topk_ids = topk_ids.cast(paddle.int64)
results = prefill_permute_to_masked_gemm(x, scale, topk_ids, num_local_experts, max_token_num)
return results[0], results[1], results[2], results[3]
def call_depermute_prefill_combine(
x: paddle.Tensor,
indice_map: paddle.Tensor,
topk_weights: paddle.Tensor,
num_worst_tokens: int,
):
"""
Depermute and combine expert outputs back to token-major layout.
Args:
x: Expert outputs [num_local_experts, max_tokens_per_expert, hidden].
indice_map: Flat index tensor [num_worst_tokens, topk] (int32).
topk_weights: Combination weights [num_worst_tokens, topk] (float32).
num_worst_tokens: Number of output tokens to produce.
Returns:
depermuted_x: Combined output [num_worst_tokens, hidden].
"""
results = depermute_prefill_combine(x, indice_map, topk_weights, num_worst_tokens)
return results
def m_grouped_fp8_gemm_nt_contiguous_custom_python_op_infermeta(
permute_input: "paddle.static.MetaTensor",
@@ -108,7 +168,7 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
# down_proj
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, quant_config_weight_block_size_0
ffn_out, quant_config_weight_block_size_0, not disable_ue8m0_cast
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous()
@@ -277,7 +337,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
# 2. Dynamic compute blockwise quantization scales
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
x, self.quant_config.weight_block_size[0]
x, self.quant_config.weight_block_size[0], self.quant_config.deepgemm_scale_ue8m0
)
else:
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
@@ -293,7 +353,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
)
event = deep_ep.Buffer.capture()
let_another_thread_run()
if self.ep_prefill_runner.num_worst_tokens <= 0:
let_another_thread_run()
# 3. EP Dispatch
(
recv_x,
@@ -306,9 +368,34 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128, previous_event=event
)
if self.ep_prefill_runner.num_worst_tokens > 0:
let_another_thread_run()
thread_name = threading.current_thread().name
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
global global_values
if thread_name not in global_values:
global_values[thread_name] = {}
(recv_x_value, recv_x_scale) = recv_x
(recv_x_value, recv_x_scale) = recv_x
global_values[thread_name]["x"] = x
global_values[thread_name]["topk_idx"] = topk_idx
global_values[thread_name]["topk_weights"] = topk_weights
global_values[thread_name]["x_scale_tensor"] = x_scale_tensor
global_values[thread_name]["recv_x_value"] = recv_x_value
global_values[thread_name]["recv_x_scale"] = recv_x_scale
global_values[thread_name]["recv_topk_idx"] = recv_topk_idx
global_values[thread_name]["recv_topk_weights"] = recv_topk_weights
global_values[thread_name]["handle"] = handle
global_values[thread_name]["recv_num_tokens_per_expert_list"] = recv_num_tokens_per_expert_list
token_all_num = sum(recv_num_tokens_per_expert_list)
# Note(ZKK):
@@ -317,9 +404,88 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
# so here we manually del a var as soon as it's not used.
# 4. Compute ffn
if token_all_num > 0:
if self.ep_prefill_runner.num_worst_tokens > 0:
token_split_factor = 2 if int(os.getenv("USE_TBO", "0")) == 1 else 1
max_tokens_per_rank = (
layer.fd_config.scheduler_config.max_num_batched_tokens
// layer.fd_config.parallel_config.tensor_parallel_size
// token_split_factor
)
expected_m = max_tokens_per_rank
logger.debug(f"max_tokens_per_rank {max_tokens_per_rank}")
permute_input, permute_scale, permuted_indice_map, token_nums_per_expert = (
call_prefill_permute_to_masked_gemm(
x=recv_x_value,
scale=recv_x_scale,
topk_ids=recv_topk_idx,
num_local_experts=layer.num_local_experts,
max_token_num=layer.ep_size * max_tokens_per_rank,
)
)
up_gate_proj_out = paddle.empty(
[
layer.num_local_experts,
layer.ep_size * max_tokens_per_rank,
layer.moe_intermediate_size * 2,
],
dtype=paddle.bfloat16,
)
m_grouped_fp8_gemm_nt_masked(
(permute_input, permute_scale),
(
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_scale_attrs[0]),
),
up_gate_proj_out,
token_nums_per_expert,
expected_m,
disable_ue8m0_cast=not self.quant_config.deepgemm_scale_ue8m0,
)
act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.fused_mask_swiglu_fp8_quant(
up_gate_proj_out,
token_nums_per_expert,
self.quant_config.weight_block_size[0],
use_ue8m0=self.quant_config.deepgemm_scale_ue8m0,
)
if layer.hidden_size == layer.moe_intermediate_size * 2:
ffn_out = up_gate_proj_out
else:
ffn_out = paddle.empty(
[
layer.num_local_experts,
layer.ep_size * max_tokens_per_rank,
layer.hidden_size,
],
dtype=paddle.bfloat16,
)
m_grouped_fp8_gemm_nt_masked(
(act_out_fp8, scale),
(
getattr(layer, self.added_weight_attrs[1]),
getattr(layer, self.added_scale_attrs[1]),
),
ffn_out,
token_nums_per_expert,
expected_m,
disable_ue8m0_cast=not self.quant_config.deepgemm_scale_ue8m0,
)
tmp_ffn_out = call_depermute_prefill_combine(
x=ffn_out,
indice_map=permuted_indice_map,
topk_weights=recv_topk_weights,
num_worst_tokens=recv_x_value.shape[0],
)
elif token_all_num > 0:
logger.debug(f"token_all_num {token_all_num}")
(recv_x, recv_x_scale) = recv_x
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
@@ -327,14 +493,14 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
permute_input,
permute_scale,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
recv_num_tokens_per_expert_list_padded_cumsum,
_,
_,
dst_weights,
dst_indices,
cumsum_idx_gpu,
_,
m_indices,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
recv_x,
recv_x_value,
recv_x_scale,
recv_topk_idx,
recv_topk_weights,
@@ -353,7 +519,6 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
(token_all_num, getattr(layer, self.added_weight_attrs[0]).shape[1]),
dtype=paddle.bfloat16,
)
# disable_ue8m0_cast is False for SM100
m_grouped_fp8_gemm_nt_contiguous(
(permute_input, permute_scale),
(getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_scale_attrs[0])),
@@ -367,7 +532,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
# down_proj
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, self.quant_config.weight_block_size[0]
ffn_out, self.quant_config.weight_block_size[0], self.quant_config.deepgemm_scale_ue8m0
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous().transpose([1, 0])
else:
@@ -382,7 +547,6 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
(token_all_num, getattr(layer, self.added_weight_attrs[1]).shape[1]),
dtype=paddle.bfloat16,
)
# disable_ue8m0_cast is False for SM100
m_grouped_fp8_gemm_nt_contiguous(
(ffn_in_x, ffn_in_x_scale_tensor),
(getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_scale_attrs[1])),
@@ -405,12 +569,20 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
# 5. EP combine
event = deep_ep.Buffer.capture()
let_another_thread_run()
if self.ep_prefill_runner.num_worst_tokens <= 0:
let_another_thread_run()
global_values[thread_name]["combine_in"] = tmp_ffn_out
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights, event)
if self.ep_prefill_runner.num_worst_tokens > 0:
let_another_thread_run()
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
global_values[thread_name]["combine_out"] = tmp_ffn_out
return tmp_ffn_out
def apply_ep_decode(
@@ -528,7 +700,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128)
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
x, 128, self.quant_config.deepgemm_scale_ue8m0
)
else:
recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
@@ -1236,7 +1236,7 @@ def python_op_fused_moe_kernel_paddle(
from .triton_moe_kernels import fused_moe_kernel_paddle
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0])
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False)
else:
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
x, using_pow2_scale=False, output_scale_transpose=False
@@ -1293,7 +1293,7 @@ def python_op_fused_moe_kernel_paddle(
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config["BLOCK_SIZE_N"]),)
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
intermediate_cache2, quant_config.weight_block_size[0]
intermediate_cache2, quant_config.weight_block_size[0], False
)
else:
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
@@ -326,8 +326,9 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
return linear_out
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
x, self.quant_config.weight_block_size[0]
x, self.quant_config.weight_block_size[0], self.quant_config.deepgemm_scale_ue8m0
)
x_scale_tensor = x_scale_tensor[: x.shape[0], ...]
else:
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
x,
+6
View File
@@ -1090,6 +1090,12 @@ def parse_args():
help="Enable overlap schedule",
)
parser.add_argument(
"--ep_prefill_use_worst_num_tokens",
action="store_true",
help="enable to avoid cpu sync",
)
args = parser.parse_args()
return args
@@ -0,0 +1,334 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import depermute_prefill_combine
def call_depermute_prefill_combine(
x: paddle.Tensor,
indice_map: paddle.Tensor,
topk_weights: paddle.Tensor,
num_worst_tokens: int,
):
"""
Depermute and combine expert outputs back to token-major layout.
Args:
x: Expert outputs [num_local_experts, max_tokens_per_expert, hidden].
indice_map: Flat index tensor [num_worst_tokens, topk] (int32).
topk_weights: Combination weights [num_worst_tokens, topk] (float32).
num_worst_tokens: Number of output tokens to produce.
Returns:
depermuted_x: Combined output [num_worst_tokens, hidden].
"""
results = depermute_prefill_combine(x, indice_map, topk_weights, num_worst_tokens)
return results
class TestDepermutePrefillCombine(unittest.TestCase):
"""
Test cases for depermute_prefill_combine kernel.
"""
def setUp(self):
paddle.seed(2024)
np.random.seed(2024)
paddle.set_device("gpu")
def _compute_reference(
self,
x_np,
indice_map_np,
topk_weights_np,
num_worst_tokens,
max_num_tokens_per_expert,
):
hidden = x_np.shape[2]
topk = indice_map_np.shape[1]
depermuted_x = np.zeros((num_worst_tokens, hidden), dtype=np.float32)
for token_idx in range(num_worst_tokens):
for k in range(topk):
indice = indice_map_np[token_idx, k]
if indice >= 0:
expert_idx = indice // max_num_tokens_per_expert
offset = indice % max_num_tokens_per_expert
weight = topk_weights_np[token_idx, k]
depermuted_x[token_idx, :] += x_np[expert_idx, offset, :] * weight
return depermuted_x
def _run_and_verify(
self,
num_worst_tokens,
num_local_experts,
max_num_tokens_per_expert,
hidden,
topk,
x_dtype=paddle.bfloat16,
sparsity=0.2,
rtol=1e-2,
atol=1e-2,
):
x_np = np.random.randn(num_local_experts, max_num_tokens_per_expert, hidden).astype(np.float32)
if x_dtype == paddle.bfloat16:
x = paddle.to_tensor(x_np).cast(paddle.bfloat16)
elif x_dtype == paddle.float8_e4m3fn:
x_np = np.clip(x_np, -448, 448)
x = paddle.to_tensor(x_np).cast(paddle.float8_e4m3fn)
else:
x = paddle.to_tensor(x_np)
indice_map_np = np.zeros((num_worst_tokens, topk), dtype=np.int32)
for token_idx in range(num_worst_tokens):
used_positions = {}
has_valid_index = False
for k in range(topk):
if k == topk - 1 and not has_valid_index:
should_be_invalid = False
else:
should_be_invalid = np.random.rand() < sparsity
if should_be_invalid:
indice_map_np[token_idx, k] = -1
else:
expert_idx = np.random.randint(0, num_local_experts)
if expert_idx not in used_positions:
used_positions[expert_idx] = []
offset = np.random.randint(0, max_num_tokens_per_expert)
attempts = 0
while offset in used_positions.get(expert_idx, []) and attempts < 10:
offset = np.random.randint(0, max_num_tokens_per_expert)
attempts += 1
used_positions[expert_idx].append(offset)
indice_map_np[token_idx, k] = expert_idx * max_num_tokens_per_expert + offset
has_valid_index = True
indice_map = paddle.to_tensor(indice_map_np).cast(paddle.int32)
topk_weights_np = np.random.rand(num_worst_tokens, topk).astype(np.float32)
row_sums = topk_weights_np.sum(axis=1, keepdims=True)
topk_weights_np = topk_weights_np / (row_sums + 1e-6)
topk_weights_np[indice_map_np == -1] = 0.0
topk_weights = paddle.to_tensor(topk_weights_np).cast(paddle.float32)
depermuted_x = call_depermute_prefill_combine(
x=x,
indice_map=indice_map,
topk_weights=topk_weights,
num_worst_tokens=num_worst_tokens,
)
x_ref_np = x.cast(paddle.float32).numpy()
expected = self._compute_reference(
x_np=x_ref_np,
indice_map_np=indice_map_np,
topk_weights_np=topk_weights_np,
num_worst_tokens=num_worst_tokens,
max_num_tokens_per_expert=max_num_tokens_per_expert,
)
result = depermuted_x.cast(paddle.float32).numpy()
self.assertEqual(result.shape, (num_worst_tokens, hidden))
np.testing.assert_allclose(result, expected, rtol=rtol, atol=atol, err_msg="Depermuted output mismatch")
return True
def test_basic_topk4(self):
self._run_and_verify(
num_worst_tokens=64,
num_local_experts=8,
max_num_tokens_per_expert=128,
hidden=7168,
topk=4,
sparsity=0.2,
)
def test_basic_topk8(self):
self._run_and_verify(
num_worst_tokens=64,
num_local_experts=8,
max_num_tokens_per_expert=128,
hidden=7168,
topk=8,
sparsity=0.2,
)
def test_small_tokens(self):
self._run_and_verify(
num_worst_tokens=4,
num_local_experts=4,
max_num_tokens_per_expert=32,
hidden=1024,
topk=4,
sparsity=0.1,
)
def test_large_tokens(self):
self._run_and_verify(
num_worst_tokens=512,
num_local_experts=16,
max_num_tokens_per_expert=256,
hidden=4096,
topk=4,
sparsity=0.3,
)
def test_high_sparsity(self):
self._run_and_verify(
num_worst_tokens=128,
num_local_experts=8,
max_num_tokens_per_expert=64,
hidden=2048,
topk=4,
sparsity=0.7,
)
def test_no_sparsity(self):
self._run_and_verify(
num_worst_tokens=64,
num_local_experts=8,
max_num_tokens_per_expert=128,
hidden=2048,
topk=4,
sparsity=0.0,
)
def test_single_expert(self):
self._run_and_verify(
num_worst_tokens=32,
num_local_experts=1,
max_num_tokens_per_expert=64,
hidden=1024,
topk=4,
sparsity=0.0,
)
def test_many_experts(self):
self._run_and_verify(
num_worst_tokens=128,
num_local_experts=32,
max_num_tokens_per_expert=64,
hidden=2048,
topk=8,
sparsity=0.3,
)
def test_small_hidden(self):
self._run_and_verify(
num_worst_tokens=64,
num_local_experts=8,
max_num_tokens_per_expert=64,
hidden=256,
topk=4,
sparsity=0.2,
)
def test_large_hidden(self):
self._run_and_verify(
num_worst_tokens=32,
num_local_experts=8,
max_num_tokens_per_expert=64,
hidden=14336,
topk=4,
sparsity=0.2,
)
def test_all_minus_one(self):
num_worst_tokens = 32
num_local_experts = 4
max_num_tokens_per_expert = 64
hidden = 1024
topk = 4
x_np = np.random.randn(num_local_experts, max_num_tokens_per_expert, hidden).astype(np.float32)
x = paddle.to_tensor(x_np).cast(paddle.bfloat16)
indice_map = paddle.full([num_worst_tokens, topk], -1, dtype=paddle.int32)
topk_weights = paddle.zeros([num_worst_tokens, topk], dtype=paddle.float32)
depermuted_x = call_depermute_prefill_combine(
x=x,
indice_map=indice_map,
topk_weights=topk_weights,
num_worst_tokens=num_worst_tokens,
)
result = depermuted_x.cast(paddle.float32).numpy()
self.assertEqual(result.shape, (num_worst_tokens, hidden))
def test_single_token(self):
self._run_and_verify(
num_worst_tokens=1,
num_local_experts=4,
max_num_tokens_per_expert=32,
hidden=1024,
topk=4,
sparsity=0.0,
)
def test_uniform_weights(self):
num_worst_tokens = 64
num_local_experts = 8
max_num_tokens_per_expert = 64
hidden = 2048
topk = 4
x_np = np.random.randn(num_local_experts, max_num_tokens_per_expert, hidden).astype(np.float32)
x = paddle.to_tensor(x_np).cast(paddle.bfloat16)
indice_map_np = np.zeros((num_worst_tokens, topk), dtype=np.int32)
for token_idx in range(num_worst_tokens):
for k in range(topk):
expert_idx = k % num_local_experts
offset = token_idx % max_num_tokens_per_expert
indice_map_np[token_idx, k] = expert_idx * max_num_tokens_per_expert + offset
indice_map = paddle.to_tensor(indice_map_np).cast(paddle.int32)
topk_weights_np = np.ones((num_worst_tokens, topk), dtype=np.float32) / topk
topk_weights = paddle.to_tensor(topk_weights_np)
depermuted_x = call_depermute_prefill_combine(
x=x,
indice_map=indice_map,
topk_weights=topk_weights,
num_worst_tokens=num_worst_tokens,
)
x_ref_np = x.cast(paddle.float32).numpy()
expected = self._compute_reference(
x_np=x_ref_np,
indice_map_np=indice_map_np,
topk_weights_np=topk_weights_np,
num_worst_tokens=num_worst_tokens,
max_num_tokens_per_expert=max_num_tokens_per_expert,
)
result = depermuted_x.cast(paddle.float32).numpy()
np.testing.assert_allclose(result, expected, rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
unittest.main()
+2 -2
View File
@@ -89,7 +89,7 @@ class TestPerTokenQuant(unittest.TestCase):
def test_per_token_quant(self):
paddle_output, paddle_output_scale = per_token_quant_paddle(self.input_tensor, self.block_size)
output, output_scale = per_token_quant(self.input_tensor, self.block_size)
output, output_scale = per_token_quant(self.input_tensor, self.block_size, False)
np.testing.assert_allclose(paddle_output_scale.numpy(), output_scale.numpy(), rtol=1e-6)
@@ -139,7 +139,7 @@ class TestPerTokenQuantPadding(TestPerTokenQuant):
paddle_output, paddle_output_scale = per_token_quant_padding_paddle(
self.input_tensor, self.block_size, self.dtype
)
output, output_scale = per_token_quant_padding(self.input_tensor, self.block_size)
output, output_scale = per_token_quant_padding(self.input_tensor, self.block_size, False)
self.assertEqual(paddle_output_scale.shape, output_scale.shape)
np.testing.assert_allclose(
@@ -0,0 +1,347 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import prefill_permute_to_masked_gemm
def call_prefill_permute_to_masked_gemm(
x: paddle.Tensor,
scale: paddle.Tensor,
topk_ids: paddle.Tensor,
num_local_experts: int,
max_token_num: int,
):
"""
Permute input tokens and scales from token-major to expert-major layout
for MoE masked GEMM operations.
Args:
x: Input hidden states [num_tokens, hidden].
scale: Input scales [num_tokens, hidden_scale].
topk_ids: Expert routing indices [num_tokens, topk] (int64 or int32).
num_local_experts: Number of local experts on this device.
max_token_num: Maximum tokens per expert buffer.
Returns:
tuple: (permute_x, permute_scale, permuted_indice_map, token_nums_per_expert)
"""
if topk_ids.dtype != paddle.int64:
topk_ids = topk_ids.cast(paddle.int64)
results = prefill_permute_to_masked_gemm(x, scale, topk_ids, num_local_experts, max_token_num)
return results[0], results[1], results[2], results[3]
class TestPrefillPermuteToMaskedGemm(unittest.TestCase):
"""
Test cases for prefill_permute_to_masked_gemm kernel.
"""
def setUp(self):
paddle.seed(2024)
np.random.seed(2024)
def _get_expected_tokens_per_expert(self, x, scale, topk_ids, num_local_experts):
num_tokens = x.shape[0]
_, topk = topk_ids.shape
expert_to_tokens = {i: [] for i in range(num_local_experts)}
token_nums_per_expert = np.zeros(num_local_experts, dtype=np.int32)
for token_idx in range(num_tokens):
for k in range(topk):
expert_idx = topk_ids[token_idx, k]
if expert_idx != -1:
expert_to_tokens[expert_idx].append((x[token_idx, :].copy(), scale[token_idx, :].copy()))
token_nums_per_expert[expert_idx] += 1
return expert_to_tokens, token_nums_per_expert
def _run_and_verify(
self,
num_tokens,
hidden_size,
hidden_scale,
num_local_experts,
max_token_num,
topk,
x_dtype=paddle.float8_e4m3fn,
scale_dtype=paddle.int32,
sparsity=0.3,
):
if x_dtype == paddle.float8_e4m3fn:
x_np = np.random.randn(num_tokens, hidden_size).astype(np.float32)
x_np = np.clip(x_np, -448, 448)
x = paddle.to_tensor(x_np).cast(paddle.float8_e4m3fn)
elif x_dtype == paddle.bfloat16:
x_np = np.random.randn(num_tokens, hidden_size).astype(np.float32)
x = paddle.to_tensor(x_np).cast(paddle.bfloat16)
else:
x_np = np.random.randn(num_tokens, hidden_size).astype(np.float32)
x = paddle.to_tensor(x_np)
scale_np = np.random.rand(num_tokens, hidden_scale).astype(np.float32)
scale = paddle.to_tensor(scale_np).cast(scale_dtype).contiguous()
topk_ids_np = np.zeros((num_tokens, topk), dtype=np.int64)
for i in range(num_tokens):
experts = np.random.choice(num_local_experts, size=min(topk, num_local_experts), replace=False)
if len(experts) < topk:
topk_ids_np[i, : len(experts)] = experts
topk_ids_np[i, len(experts) :] = -1
else:
topk_ids_np[i, :] = experts
mask = np.random.rand(num_tokens, topk) < sparsity
topk_ids_np[mask] = -1
topk_ids = paddle.to_tensor(topk_ids_np).cast(paddle.int64)
permute_x, permute_scale, permuted_indice_map, token_nums_per_expert = call_prefill_permute_to_masked_gemm(
x=x,
scale=scale,
topk_ids=topk_ids,
num_local_experts=num_local_experts,
max_token_num=max_token_num,
)
permute_x_result = permute_x.cast(paddle.float32).numpy()
permute_scale_result = permute_scale.numpy()
permuted_indice_map_result = permuted_indice_map.numpy()
token_nums_result = token_nums_per_expert.numpy().flatten()
x_ref_np = x.cast(paddle.float32).numpy()
scale_ref_np = scale.numpy()
expert_to_tokens, token_nums_ref = self._get_expected_tokens_per_expert(
x=x_ref_np,
scale=scale_ref_np,
topk_ids=topk_ids_np,
num_local_experts=num_local_experts,
)
np.testing.assert_array_equal(
token_nums_result,
token_nums_ref,
err_msg=f"Token counts mismatch: kernel={token_nums_result}, ref={token_nums_ref}",
)
for expert_idx in range(num_local_experts):
num_tokens_for_expert = token_nums_ref[expert_idx]
if num_tokens_for_expert > 0:
kernel_x_rows = permute_x_result[expert_idx, :num_tokens_for_expert, :]
kernel_scale_rows = permute_scale_result[expert_idx, :num_tokens_for_expert, :]
expected_tokens = expert_to_tokens[expert_idx]
expected_x_rows = np.array([t[0] for t in expected_tokens])
expected_scale_rows = np.array([t[1] for t in expected_tokens])
kernel_x_sums = np.sort(np.sum(kernel_x_rows, axis=1))
expected_x_sums = np.sort(np.sum(expected_x_rows, axis=1))
np.testing.assert_allclose(
kernel_x_sums,
expected_x_sums,
rtol=1e-2,
atol=1e-1,
err_msg=f"Expert {expert_idx}: permute_x row sums mismatch",
)
kernel_scale_sums = np.sort(np.sum(kernel_scale_rows, axis=1))
expected_scale_sums = np.sort(np.sum(expected_scale_rows, axis=1))
np.testing.assert_allclose(
kernel_scale_sums,
expected_scale_sums,
rtol=1e-2,
atol=1e-2,
err_msg=f"Expert {expert_idx}: permute_scale row sums mismatch",
)
x_ref_np = x.cast(paddle.float32).numpy()
for token_idx in range(num_tokens):
for expert_slot in range(topk):
permuted_idx = permuted_indice_map_result[token_idx, expert_slot]
if permuted_idx >= 0:
expert_idx = permuted_idx // max_token_num
offset = permuted_idx % max_token_num
permuted_data = permute_x_result[expert_idx, offset, :]
original_data = x_ref_np[token_idx, :]
np.testing.assert_allclose(
permuted_data,
original_data,
rtol=1e-2,
atol=1e-1,
err_msg=f"Token {token_idx}: permuted_indice_map points to wrong data",
)
return True
def test_basic_topk4(self):
self._run_and_verify(
num_tokens=64,
hidden_size=7168,
hidden_scale=56,
num_local_experts=8,
max_token_num=128,
topk=4,
sparsity=0.2,
)
def test_basic_topk8(self):
self._run_and_verify(
num_tokens=64,
hidden_size=7168,
hidden_scale=56,
num_local_experts=8,
max_token_num=128,
topk=8,
sparsity=0.2,
)
def test_small_tokens(self):
self._run_and_verify(
num_tokens=4, hidden_size=1024, hidden_scale=8, num_local_experts=4, max_token_num=32, topk=4, sparsity=0.1
)
def test_large_tokens(self):
self._run_and_verify(
num_tokens=512,
hidden_size=4096,
hidden_scale=32,
num_local_experts=16,
max_token_num=256,
topk=4,
sparsity=0.3,
)
def test_high_sparsity(self):
self._run_and_verify(
num_tokens=128,
hidden_size=2048,
hidden_scale=16,
num_local_experts=8,
max_token_num=64,
topk=4,
sparsity=0.7,
)
def test_no_sparsity(self):
self._run_and_verify(
num_tokens=64,
hidden_size=2048,
hidden_scale=16,
num_local_experts=8,
max_token_num=128,
topk=4,
sparsity=0.0,
)
def test_single_expert(self):
self._run_and_verify(
num_tokens=32,
hidden_size=1024,
hidden_scale=8,
num_local_experts=1,
max_token_num=64,
topk=4,
sparsity=0.0,
)
def test_many_experts(self):
self._run_and_verify(
num_tokens=128,
hidden_size=2048,
hidden_scale=16,
num_local_experts=32,
max_token_num=64,
topk=8,
sparsity=0.3,
)
def test_bfloat16_input(self):
self._run_and_verify(
num_tokens=64,
hidden_size=2048,
hidden_scale=16,
num_local_experts=8,
max_token_num=128,
topk=4,
x_dtype=paddle.bfloat16,
sparsity=0.2,
)
def test_very_large_tokens(self):
self._run_and_verify(
num_tokens=65536,
hidden_size=7168,
hidden_scale=56,
num_local_experts=20,
max_token_num=16384,
topk=4,
sparsity=0.3,
)
def test_very_large_tokens_with_fp32_scale(self):
self._run_and_verify(
num_tokens=65536,
hidden_size=7168,
hidden_scale=56,
num_local_experts=20,
max_token_num=16384,
topk=4,
sparsity=0.3,
scale_dtype=paddle.float32,
)
def test_all_minus_one(self):
num_tokens = 32
hidden_size = 1024
hidden_scale = 8
num_local_experts = 4
max_token_num = 64
topk = 4
x_np = np.random.randn(num_tokens, hidden_size).astype(np.float32)
x_np = np.clip(x_np, -448, 448)
x = paddle.to_tensor(x_np).cast(paddle.float8_e4m3fn)
scale_np = np.random.rand(num_tokens, hidden_scale).astype(np.float32)
scale = paddle.to_tensor(scale_np)
topk_ids = paddle.full([num_tokens, topk], -1, dtype=paddle.int32)
permute_x, permute_scale, permuted_indice_map, token_nums_per_expert = call_prefill_permute_to_masked_gemm(
x=x,
scale=scale,
topk_ids=topk_ids,
num_local_experts=num_local_experts,
max_token_num=max_token_num,
)
token_nums_result = token_nums_per_expert.numpy().flatten()
expected = np.zeros(num_local_experts, dtype=np.int32)
np.testing.assert_array_equal(token_nums_result, expected)
self.assertEqual(permuted_indice_map.shape, [num_tokens, topk])
if __name__ == "__main__":
unittest.main()