[Feature] Support EP prefill with num_worst_tokens (#6574)

* support num worst tokens

* support num worst tokens

* fix build error

* support num worst tokens: fix errors

* support num worst tokens: fix feild

* support num worst tokens: delete requiements

* replace permute and depermute op by pure cuda

* replace permute and depermute op by pure cuda

* fix ci

* fix op

* fix nan

* fix code style

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
RichardWooSJTU
2026-03-11 17:09:07 +08:00
committed by GitHub
parent 0466c7e8a8
commit 9f0778f991
21 changed files with 1775 additions and 166 deletions
+35 -2
View File
@@ -313,9 +313,11 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const int token_nums_this_rank_padded);
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
const int block_size);
const int block_size,
const bool use_ue8m0);
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
const int block_size);
const int block_size,
const bool use_ue8m0);
std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
paddle::Tensor& input,
@@ -1175,6 +1177,19 @@ std::vector<paddle::Tensor> get_attn_mask_q(
const paddle::optional<paddle::Tensor>& attn_mask_kv,
const int kv_token_num);
std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
const paddle::Tensor& x,
const paddle::Tensor& scale,
const paddle::Tensor& topk_ids,
const int num_local_experts,
const int max_token_num);
std::vector<paddle::Tensor> DepermutePrefillCombine(
const paddle::Tensor& x,
const paddle::Tensor& indice_map,
const paddle::Tensor& topk_weights,
const int num_worst_tokens);
void RadixTopkRaggedTransform(
paddle::Tensor& input,
paddle::Tensor& output_indices,
@@ -1367,12 +1382,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&PerTokenQuant,
py::arg("input"),
py::arg("block_size"),
py::arg("use_ue8m0"),
"per token per block quant");
m.def("per_token_quant_padding",
&PerTokenQuantPadding,
py::arg("input"),
py::arg("block_size"),
py::arg("use_ue8m0"),
"per token per block quant and padding transpose scale");
m.def("fused_mask_swiglu_fp8_quant",
@@ -1826,6 +1843,22 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("custom_numpy_to_tensor",
&CustomNumpyToTensor,
"custom_numpy_to_tensor function");
m.def("prefill_permute_to_masked_gemm",
&PrefillPermuteToMaskedGemm,
py::arg("x"),
py::arg("scale"),
py::arg("topk_ids"),
py::arg("num_local_experts"),
py::arg("max_token_num"),
"Prefill permute to masked GEMM for MoE");
m.def("depermute_prefill_combine",
&DepermutePrefillCombine,
py::arg("x"),
py::arg("indice_map"),
py::arg("topk_weights"),
py::arg("num_worst_tokens"),
"Depermute and combine expert outputs for MoE prefill");
m.def("radix_topk_ragged_transform",
&RadixTopkRaggedTransform,
@@ -20,10 +20,18 @@ __host__ __device__ __forceinline__ int ceil_div(int x, int y) {
return (x + y - 1) / y;
}
__host__ __device__ __forceinline__ int64_t ceil_div(int64_t x, int64_t y) {
return (x + y - 1) / y;
}
__host__ __device__ __forceinline__ int align(int x, int y) {
return ceil_div(x, y) * y;
}
__host__ __device__ __forceinline__ int64_t align(int64_t x, int64_t y) {
return ceil_div(x, y) * y;
}
#ifndef BOOL_SWITCH
#define BOOL_SWITCH(cond, name, ...) \
if (cond) { \
@@ -41,10 +49,10 @@ __global__ void fused_swiglu_fp8_quant_kernel(
const index_t* __restrict__ token_nums_per_expert,
phi::dtype::float8_e4m3fn* __restrict__ out_fp8,
ScaleT* __restrict__ out_scale,
int group_num,
int group_size,
int hidden_size,
int hidden_size_scale,
int64_t group_num,
int64_t group_size,
int64_t hidden_size,
int64_t hidden_size_scale,
bool use_finegrained_range) {
constexpr int BLOCK = 128;
@@ -53,7 +61,7 @@ __global__ void fused_swiglu_fp8_quant_kernel(
int warp = tid >> 5;
int num_warps = blockDim.x >> 5;
int block_id = static_cast<int64_t>(blockIdx.x);
int64_t block_id = static_cast<int64_t>(blockIdx.x);
using VecBF16 = AlignedVector<T, 4>;
VecBF16 x1_vec, x2_vec;
@@ -62,13 +70,13 @@ __global__ void fused_swiglu_fp8_quant_kernel(
while (true) {
// ================= token mapping =================
int expert = -1;
int token_in_expert = -1;
int64_t expert = -1;
int64_t token_in_expert = -1;
if (lane == 0) {
int cumsum = 0;
for (int i = 0; i < group_num; ++i) {
int cnt = token_nums_per_expert[i];
int64_t cumsum = 0;
for (int64_t i = 0; i < group_num; ++i) {
int64_t cnt = static_cast<int64_t>(token_nums_per_expert[i]);
if (block_id >= cumsum && block_id < cumsum + cnt) {
expert = i;
token_in_expert = block_id - cumsum;
@@ -84,17 +92,17 @@ __global__ void fused_swiglu_fp8_quant_kernel(
if (expert < 0 || token_in_expert >= group_size) break;
// ================= base pointers =================
int token = expert * group_size + token_in_expert;
int64_t token = expert * group_size + token_in_expert;
const T* in = input + token * hidden_size * 2;
auto* out = out_fp8 + token * hidden_size;
int num_iters = hidden_size / BLOCK;
int64_t num_iters = hidden_size / BLOCK;
// ================= main loop =================
for (int iter = warp; iter < num_iters; iter += num_warps) {
int base = iter * BLOCK + lane * 4;
for (int64_t iter = warp; iter < num_iters; iter += num_warps) {
int64_t base = iter * BLOCK + lane * 4;
// vec load
Load(in + base, &x1_vec);
@@ -141,20 +149,20 @@ __global__ void fused_swiglu_fp8_quant_kernel(
const int exp = (__float_as_int(scale) >> 23) & 0xFF;
// 2. pack information
const int pack_idx = iter >> 2; // iter / 4
const int byte_idx = iter & 3; // iter % 4
const int64_t pack_idx = iter >> 2; // iter / 4
const int64_t byte_idx = iter & 3; // iter % 4
// 3. layout parameters
const int pack_num = ceil_div(hidden_size_scale, 4);
const int token_stride = align(group_size, 4);
const int64_t pack_num = ceil_div(hidden_size_scale, (int64_t)4);
const int64_t token_stride = align(group_size, (int64_t)4);
// 4. base pointer (int32 pack)
auto* scale_pack = reinterpret_cast<int32_t*>(out_scale);
// 5. column-major offset:
// [expert][pack][token]
const int base_idx = expert * pack_num * token_stride +
pack_idx * token_stride + token_in_expert;
const int64_t base_idx = expert * pack_num * token_stride +
pack_idx * token_stride + token_in_expert;
// 6. write one byte into pack
reinterpret_cast<uint8_t*>(&scale_pack[base_idx])[byte_idx] =
static_cast<uint8_t>(exp);
@@ -184,11 +192,11 @@ std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
const int block_size,
const bool use_ue8m0) {
auto dim = input.dims();
const int group_num = token_nums_per_expert.shape()[0];
const int group_size = dim[1];
const int hidden_size = dim[2] / 2;
const int hidden_size_scale = hidden_size / block_size;
const int token_num = group_num * group_size;
const int64_t group_num = token_nums_per_expert.shape()[0];
const int64_t group_size = dim[1];
const int64_t hidden_size = dim[2] / 2;
const int64_t hidden_size_scale = hidden_size / block_size;
const int64_t token_num = group_num * group_size;
auto out_fp8 = GetEmptyTensor({group_num, group_size, hidden_size},
paddle::DataType::FLOAT8_E4M3FN,
@@ -200,21 +208,22 @@ std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
paddle::DataType::FLOAT32,
input.place());
if (use_ue8m0) {
int hidden_size_scale_pack = ceil_div(hidden_size_scale, 4);
out_scale = GetEmptyTensor({group_num, group_size, hidden_size_scale_pack},
{hidden_size_scale_pack * align(group_size, 4),
1,
align(group_size, 4)},
paddle::DataType::INT32,
input.place());
int64_t hidden_size_scale_pack = ceil_div(hidden_size_scale, (int64_t)4);
int64_t group_size_aligned = align(group_size, (int64_t)4);
out_scale = GetEmptyTensor(
{group_num, group_size, hidden_size_scale_pack},
{hidden_size_scale_pack * group_size_aligned, 1, group_size_aligned},
paddle::DataType::INT32,
input.place());
}
int sm_count = 0;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, 0);
constexpr int BLOCKS_PER_SM = 2;
int gridx = std::min(sm_count * BLOCKS_PER_SM, token_num);
int blockx = std::min(1024, hidden_size / 128 * 32);
int gridx =
std::min(static_cast<int64_t>(sm_count * BLOCKS_PER_SM), token_num);
int blockx = std::min(1024L, hidden_size / 128 * 32);
bool use_finegrained_range = false;
if (auto* env = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"))
+7
View File
@@ -243,6 +243,13 @@ class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
typedef __nv_fp8_e4m3 DataType;
typedef paddle::float8_e4m3fn data_t;
};
template <>
class PDTraits<paddle::DataType::INT32> {
public:
typedef int32_t DataType;
typedef int32_t data_t;
};
#endif
template <typename T, int Size>
@@ -0,0 +1,218 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
constexpr int DEPERMUTE_BLOCK_THREADS = 512;
// Depermute and combine expert outputs back to token-major layout.
//
// For each output token, this kernel:
// 1. Loads indice_map and topk_weights from shared memory
// 2. For each valid expert slot (indice >= 0), reads the expert output row,
// scales by topk_weight, and accumulates in float32
// 3. Writes the combined result back in the original dtype
//
// x: [num_experts, max_tokens_per_expert, hidden] - expert outputs
// indice_map: [num_worst_tokens, topk] int32 - flat indices (expert_idx * M +
// offset) topk_weights: [num_worst_tokens, topk] float32 - combination weights
// depermuted_x: [num_worst_tokens, hidden] - output (same dtype as x)
template <typename T, int VecSize, int TOP_K>
__global__ void DepermutePrefillCombineKernel(
T* __restrict__ depermuted_x,
const T* __restrict__ x,
const int32_t* __restrict__ indice_map,
const float* __restrict__ topk_weights,
const int num_worst_tokens,
const int hidden,
const int max_tokens_per_expert) {
__shared__ int32_t smem_indices[TOP_K];
__shared__ float smem_weights[TOP_K];
const int tidx = threadIdx.x;
const int num_vecs = hidden / VecSize;
for (int token_idx = blockIdx.x; token_idx < num_worst_tokens;
token_idx += gridDim.x) {
// Thread 0 loads indice_map, thread 32 loads topk_weights
if (tidx < TOP_K) {
smem_indices[tidx] = indice_map[token_idx * TOP_K + tidx];
}
if (tidx >= TOP_K && tidx < 2 * TOP_K) {
int k = tidx - TOP_K;
smem_weights[k] = topk_weights[token_idx * TOP_K + k];
}
__syncthreads();
// Check if any expert slot is valid
bool need_store = false;
#pragma unroll
for (int k = 0; k < TOP_K; k++) {
if (smem_indices[k] >= 0) {
need_store = true;
break;
}
}
if (need_store) {
// Each thread processes a subset of hidden vectors
for (int v = tidx; v < num_vecs; v += DEPERMUTE_BLOCK_THREADS) {
// Initialize accumulator in float32
float acc[VecSize];
#pragma unroll
for (int i = 0; i < VecSize; i++) {
acc[i] = 0.0f;
}
// Accumulate weighted contributions from each expert
for (int k = 0; k < TOP_K; k++) {
int32_t indice = smem_indices[k];
if (indice >= 0) {
float weight = smem_weights[k];
int64_t expert_idx =
static_cast<int64_t>(indice) / max_tokens_per_expert;
int64_t offset =
static_cast<int64_t>(indice) % max_tokens_per_expert;
const T* src = x + expert_idx * max_tokens_per_expert * hidden +
offset * hidden;
AlignedVector<T, VecSize> vec;
Load<T, VecSize>(src + v * VecSize, &vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
acc[i] += static_cast<float>(vec[i]) * weight;
}
}
}
// Cast back and store
AlignedVector<T, VecSize> out_vec;
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(acc[i]);
}
Store<T, VecSize>(out_vec,
depermuted_x +
static_cast<int64_t>(token_idx) * hidden +
v * VecSize);
}
}
__syncthreads();
}
}
template <paddle::DataType D, int TOP_K>
std::vector<paddle::Tensor> DepermutePrefillCombineDispatch(
const paddle::Tensor& x,
const paddle::Tensor& indice_map,
const paddle::Tensor& topk_weights,
const int num_worst_tokens) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int hidden = x.shape()[2];
const int max_tokens_per_expert = x.shape()[1];
auto place = x.place();
auto stream = x.stream();
auto depermuted_x =
GetEmptyTensor({num_worst_tokens, hidden}, x.dtype(), place);
constexpr int VecSize = 16 / sizeof(DataType_);
int dev;
cudaGetDevice(&dev);
int sm_count;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
int num_blocks = min(sm_count * 2, num_worst_tokens);
DepermutePrefillCombineKernel<DataType_, VecSize, TOP_K>
<<<num_blocks, DEPERMUTE_BLOCK_THREADS, 0, stream>>>(
reinterpret_cast<DataType_*>(depermuted_x.data<data_t>()),
reinterpret_cast<const DataType_*>(x.data<data_t>()),
indice_map.data<int32_t>(),
topk_weights.data<float>(),
num_worst_tokens,
hidden,
max_tokens_per_expert);
return {depermuted_x};
}
std::vector<paddle::Tensor> DepermutePrefillCombine(
const paddle::Tensor& x,
const paddle::Tensor& indice_map,
const paddle::Tensor& topk_weights,
const int num_worst_tokens) {
const int topk = indice_map.shape()[1];
#define DISPATCH_TOPK(DTYPE, TOPK_VAL) \
case TOPK_VAL: \
return DepermutePrefillCombineDispatch<DTYPE, TOPK_VAL>( \
x, indice_map, topk_weights, num_worst_tokens);
switch (x.dtype()) {
case paddle::DataType::FLOAT8_E4M3FN: {
switch (topk) {
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 4)
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
case paddle::DataType::BFLOAT16: {
switch (topk) {
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 4)
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
default:
PD_THROW("Unsupported dtype, must be float8_e4m3fn or bfloat16");
}
#undef DISPATCH_TOPK
}
std::vector<std::vector<int64_t>> DepermutePrefillCombineInferShape(
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& indice_map_shape,
const std::vector<int64_t>& topk_weights_shape,
const int num_worst_tokens) {
int64_t hidden = x_shape[2];
return {{num_worst_tokens, hidden}};
}
std::vector<paddle::DataType> DepermutePrefillCombineInferDtype(
const paddle::DataType& x_dtype,
const paddle::DataType& indice_map_dtype,
const paddle::DataType& topk_weights_dtype,
const int num_worst_tokens) {
return {x_dtype};
}
PD_BUILD_STATIC_OP(depermute_prefill_combine)
.Inputs({"x", "indice_map", "topk_weights"})
.Outputs({"depermuted_x"})
.Attrs({"num_worst_tokens: int"})
.SetKernelFn(PD_KERNEL(DepermutePrefillCombine))
.SetInferShapeFn(PD_INFER_SHAPE(DepermutePrefillCombineInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(DepermutePrefillCombineInferDtype));
@@ -0,0 +1,283 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
constexpr int BLOCK_THREADS = 512;
template <typename T, typename ScaleT, int VecSize, int TOP_K>
__global__ void PrefillPermuteToMaskedGemmKernel(
T* __restrict__ permute_x,
ScaleT* __restrict__ permute_scale,
int32_t* __restrict__ permuted_indice_map,
int32_t* __restrict__ token_nums_per_expert,
const T* __restrict__ x,
const ScaleT* __restrict__ scale,
const int64_t* __restrict__ topk_ids,
const int num_tokens,
const int hidden,
const int hidden_scale,
const int max_tokens_per_expert) {
__shared__ int32_t smem_offset;
__shared__ int64_t smem_topk_ids[TOP_K];
const int tidx = threadIdx.x;
const int x_num_vecs = hidden / VecSize;
constexpr int ScaleVecSize = 16 / sizeof(float); // 4
const int scale_num_vecs = hidden_scale / ScaleVecSize;
for (int token_idx = blockIdx.x; token_idx < num_tokens;
token_idx += gridDim.x) {
if (tidx < TOP_K) {
smem_topk_ids[tidx] = topk_ids[token_idx * TOP_K + tidx];
}
__syncthreads();
for (int slot = 0; slot < TOP_K; slot++) {
int64_t expert_idx = smem_topk_ids[slot];
if (expert_idx != -1) {
if (tidx == 0) {
smem_offset = atomicAdd(&token_nums_per_expert[expert_idx], 1);
permuted_indice_map[token_idx * TOP_K + slot] = static_cast<int32_t>(
expert_idx * max_tokens_per_expert + smem_offset);
}
__syncthreads();
int offset = smem_offset;
// Vectorized copy of x[token_idx, :] -> permute_x[expert_idx, offset,
// :]
const T* src_x = x + static_cast<int64_t>(token_idx) * hidden;
T* dst_x =
permute_x +
static_cast<int64_t>(expert_idx) * max_tokens_per_expert * hidden +
static_cast<int64_t>(offset) * hidden;
AlignedVector<T, VecSize> vec_x;
for (int v = tidx; v < x_num_vecs; v += BLOCK_THREADS) {
Load<T, VecSize>(src_x + v * VecSize, &vec_x);
Store<T, VecSize>(vec_x, dst_x + v * VecSize);
}
// Copy scale[token_idx, :] -> permute_scale with transposed layout
// Physical layout is [E, S, M], accessed as [E, M, S] via strides [S*M,
// 1, M] So permute_scale[expert_idx, offset, s] -> physical addr:
// expert_idx*(S*M) + offset + s*M
const ScaleT* src_scale =
scale + static_cast<int64_t>(token_idx) * hidden_scale;
ScaleT* dst_scale_base = permute_scale +
static_cast<int64_t>(expert_idx) *
hidden_scale * max_tokens_per_expert +
offset;
for (int s = tidx; s < hidden_scale; s += BLOCK_THREADS) {
dst_scale_base[static_cast<int64_t>(s) * max_tokens_per_expert] =
src_scale[s];
}
__syncthreads();
}
}
}
}
template <paddle::DataType D, paddle::DataType ScaleD, int TOP_K>
std::vector<paddle::Tensor> PrefillPermuteToMaskedGemmDispatch(
const paddle::Tensor& x,
const paddle::Tensor& scale,
const paddle::Tensor& topk_ids,
const int num_local_experts,
const int max_token_num) {
typedef PDTraits<D> traits_;
typedef PDTraits<ScaleD> scale_traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
typedef typename scale_traits_::DataType ScaleDataType_;
typedef typename scale_traits_::data_t scale_data_t;
const int num_tokens = x.shape()[0];
const int hidden = x.shape()[1];
const int hidden_scale = scale.shape()[1];
const int topk = topk_ids.shape()[1];
auto place = x.place();
auto stream = x.stream();
auto permute_x = GetEmptyTensor(
{num_local_experts, max_token_num, hidden}, x.dtype(), place);
// Allocate permute_scale with transposed physical layout [E, S, M]
// but logical shape [E, M, S] with strides [S*M, 1, M]
// This matches the CuteDSL version's paddle.empty([E, S, M]).transpose((0, 2,
// 1))
auto permute_scale =
GetEmptyTensor({num_local_experts, max_token_num, hidden_scale},
{static_cast<int64_t>(hidden_scale) * max_token_num,
1,
static_cast<int64_t>(max_token_num)},
ScaleD,
place);
auto permuted_indice_map =
GetEmptyTensor({num_tokens, topk}, paddle::DataType::INT32, place);
auto token_nums_per_expert =
GetEmptyTensor({num_local_experts, 1}, paddle::DataType::INT32, place);
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(token_nums_per_expert.data<int32_t>(),
0,
num_local_experts * sizeof(int32_t),
stream));
// memset 0xFF for int32 produces -1
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(permuted_indice_map.data<int32_t>(),
0xFF,
num_tokens * topk * sizeof(int32_t),
stream));
constexpr int VecSize = 16 / sizeof(DataType_);
int dev;
cudaGetDevice(&dev);
int sm_count;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
int num_blocks = sm_count * 2;
PrefillPermuteToMaskedGemmKernel<DataType_, ScaleDataType_, VecSize, TOP_K>
<<<num_blocks, BLOCK_THREADS, 0, stream>>>(
reinterpret_cast<DataType_*>(permute_x.data<data_t>()),
reinterpret_cast<ScaleDataType_*>(
permute_scale.template data<scale_data_t>()),
permuted_indice_map.data<int32_t>(),
token_nums_per_expert.data<int32_t>(),
reinterpret_cast<const DataType_*>(x.data<data_t>()),
reinterpret_cast<const ScaleDataType_*>(
scale.template data<scale_data_t>()),
topk_ids.data<int64_t>(),
num_tokens,
hidden,
hidden_scale,
max_token_num);
return {permute_x, permute_scale, permuted_indice_map, token_nums_per_expert};
}
std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
const paddle::Tensor& x,
const paddle::Tensor& scale,
const paddle::Tensor& topk_ids,
const int num_local_experts,
const int max_token_num) {
const int topk = topk_ids.shape()[1];
#define DISPATCH_TOPK(DTYPE, SCALE_DTYPE, TOPK_VAL) \
case TOPK_VAL: \
return PrefillPermuteToMaskedGemmDispatch<DTYPE, SCALE_DTYPE, TOPK_VAL>( \
x, scale, topk_ids, num_local_experts, max_token_num);
switch (x.dtype()) {
case paddle::DataType::FLOAT8_E4M3FN: {
switch (scale.dtype()) {
case paddle::DataType::FLOAT32: {
switch (topk) {
DISPATCH_TOPK(
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32, 4)
DISPATCH_TOPK(
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
case paddle::DataType::INT32: {
switch (topk) {
DISPATCH_TOPK(
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32, 4)
DISPATCH_TOPK(
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
}
}
case paddle::DataType::BFLOAT16: {
switch (scale.dtype()) {
case paddle::DataType::FLOAT32: {
switch (topk) {
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 4)
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
case paddle::DataType::INT32: {
switch (topk) {
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::INT32, 4)
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::INT32, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
}
}
default:
PD_THROW("Unsupported dtype, must be float8_e4m3fn or bfloat16");
}
#undef DISPATCH_TOPK
}
std::vector<std::vector<int64_t>> PrefillPermuteToMaskedGemmInferShape(
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& scale_shape,
const std::vector<int64_t>& topk_ids_shape,
const int num_local_experts,
const int max_token_num) {
int64_t num_tokens = x_shape[0];
int64_t hidden = x_shape[1];
int64_t hidden_scale = scale_shape[1];
int64_t topk = topk_ids_shape[1];
return {
{num_local_experts, max_token_num, hidden},
{num_local_experts, max_token_num, hidden_scale},
{num_tokens, topk},
{num_local_experts, 1},
};
}
std::vector<paddle::DataType> PrefillPermuteToMaskedGemmInferDtype(
const paddle::DataType& x_dtype,
const paddle::DataType& scale_dtype,
const paddle::DataType& topk_ids_dtype,
const int num_local_experts,
const int max_token_num) {
return {
x_dtype, scale_dtype, paddle::DataType::INT32, paddle::DataType::INT32};
}
PD_BUILD_STATIC_OP(prefill_permute_to_masked_gemm)
.Inputs({"x", "scale", "topk_ids"})
.Outputs({"permute_x",
"permute_scale",
"permuted_indice_map",
"token_nums_per_expert"})
.Attrs({"num_local_experts: int", "max_token_num: int"})
.SetKernelFn(PD_KERNEL(PrefillPermuteToMaskedGemm))
.SetInferShapeFn(PD_INFER_SHAPE(PrefillPermuteToMaskedGemmInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PrefillPermuteToMaskedGemmInferDtype));
+235 -82
View File
@@ -16,11 +16,19 @@
constexpr float epsilon = 1e-10;
template <typename T>
__host__ __device__ __forceinline__ int ceil_div(int x, int y) {
return (x + y - 1) / y;
}
__host__ __device__ __forceinline__ int align(int x, int y) {
return ceil_div(x, y) * y;
}
template <typename T, typename ScaleT, bool UseUE8M0>
__global__ void quant_per_token_per_block(
const T *input,
phi::dtype::float8_e4m3fn *quanted_res,
float *quanted_scale,
ScaleT *quanted_scale,
const int token_num,
const int hidden_size,
const int hidden_size_scale,
@@ -38,10 +46,11 @@ __global__ void quant_per_token_per_block(
AlignedVector<float, NUM_PER_THREADS> load_vec_float;
AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec;
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
const T *input_now = input + token_idx * hidden_size;
const T *input_now = input + static_cast<int64_t>(token_idx) * hidden_size;
phi::dtype::float8_e4m3fn *quanted_res_now =
quanted_res + token_idx * hidden_size;
float *quanted_scale_now = quanted_scale + token_idx * hidden_size_scale;
quanted_res + static_cast<int64_t>(token_idx) * hidden_size;
float *quanted_scale_now = reinterpret_cast<float *>(quanted_scale) +
token_idx * hidden_size_scale;
// deal a block per warp
for (int iter = warp_id; iter < end_iter; iter += num_warp) {
const int start_offset = iter * 128;
@@ -83,11 +92,22 @@ __global__ void quant_per_token_per_block(
}
float scale_to_store = max_value_thread / MAX_VALUE;
// quant
if constexpr (UseUE8M0) {
scale_to_store =
exp2f(ceilf(log2f(fmaxf(scale_to_store, epsilon) + 5e-7f)));
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] * MAX_VALUE / max_value_thread);
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] / scale_to_store);
}
} else {
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] * MAX_VALUE / max_value_thread);
}
}
// store
if (is_valid_data)
@@ -95,14 +115,26 @@ __global__ void quant_per_token_per_block(
res_vec,
quanted_res_now + start_offset + lane_id * NUM_PER_THREADS);
if (lane_id == 0) {
quanted_scale_now[iter] = scale_to_store;
if constexpr (UseUE8M0) {
int exp = (reinterpret_cast<int &>(scale_to_store) >> 23) & 0xFF;
const int pack_idx = iter >> 2;
const int byte_idx = iter & 3;
const int pack_num = ceil_div(hidden_size_scale, 4);
int32_t *scale_now = quanted_scale;
const int base_idx = token_idx * pack_num + pack_idx;
reinterpret_cast<uint8_t *>(&scale_now[base_idx])[byte_idx] =
static_cast<uint8_t>(exp);
} else {
quanted_scale_now[iter] = scale_to_store;
}
}
}
}
}
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
const int block_size) {
const int block_size,
const bool use_ue8m0) {
auto input_dim = input.dims();
const int token_num = input_dim[0];
const int hidden_size = input_dim[1];
@@ -112,8 +144,7 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
auto quanted_x = GetEmptyTensor(
{token_num, hidden_size}, paddle::DataType::FLOAT8_E4M3FN, input.place());
auto quanted_scale = GetEmptyTensor(
{token_num, hidden_size_scale}, paddle::DataType::FLOAT32, input.place());
const int gridx = min(132 * 8, token_num);
const int blockx = min(1024, hidden_size / 128 * 32);
@@ -123,31 +154,70 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
}
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
if (use_ue8m0) {
auto quanted_scale =
GetEmptyTensor({token_num, ceil_div(hidden_size_scale, 4)},
paddle::DataType::INT32,
input.place());
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block<paddle::bfloat16, int32_t, true>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<int32_t>(),
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block<paddle::float16, int32_t, true>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<int32_t>(),
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
}
return {quanted_x, quanted_scale};
} else {
auto quanted_scale = GetEmptyTensor({token_num, hidden_size_scale},
paddle::DataType::FLOAT32,
input.place());
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block<paddle::bfloat16, float, false>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block<paddle::float16, float, false>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
}
return {quanted_x, quanted_scale};
}
return {quanted_x, quanted_scale};
}
std::vector<std::vector<int64_t>> PerTokenQuantInferShape(
@@ -155,19 +225,26 @@ std::vector<std::vector<int64_t>> PerTokenQuantInferShape(
const int token_num = input_shape[0];
const int hidden_size = input_shape[1];
const int hidden_size_scale = (hidden_size + block_size - 1) / block_size;
if (GetSMVersion() >= 100) {
return {{token_num, hidden_size},
{token_num, ceil_div(hidden_size_scale, 4)}};
}
return {{token_num, hidden_size}, {token_num, hidden_size_scale}};
}
std::vector<paddle::DataType> PerTokenQuantInferDtype(
paddle::DataType input_dtype, const int block_size) {
if (GetSMVersion() >= 100) {
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32};
}
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32};
}
template <typename T>
template <typename T, typename ScaleT, bool UseUE8M0>
__global__ void quant_per_token_per_block_padding(
const T *input,
phi::dtype::float8_e4m3fn *quanted_res,
float *quanted_scale,
ScaleT *quanted_scale,
const int token_num,
const int padded_token_num,
const int hidden_size,
@@ -185,13 +262,11 @@ __global__ void quant_per_token_per_block_padding(
AlignedVector<float, NUM_PER_THREADS> load_vec_float;
AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec;
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
const T *input_now = input + token_idx * hidden_size;
const T *input_now = input + static_cast<int64_t>(token_idx) * hidden_size;
phi::dtype::float8_e4m3fn *quanted_res_now =
quanted_res + token_idx * hidden_size;
quanted_res + static_cast<int64_t>(token_idx) * hidden_size;
// deal a block per warp
for (int iter = warp_id; iter < end_iter; iter += num_warp) {
float *quanted_scale_now =
quanted_scale + iter * padded_token_num + token_idx;
const int start_offset = iter * 128;
Load<T, NUM_PER_THREADS>(
input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec);
@@ -222,26 +297,58 @@ __global__ void quant_per_token_per_block_padding(
}
float scale_to_store = max_value_thread / MAX_VALUE;
// quant
if constexpr (UseUE8M0) {
scale_to_store =
exp2f(ceilf(log2f(fmaxf(scale_to_store, epsilon) + 5e-7f)));
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] * MAX_VALUE / max_value_thread);
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] / scale_to_store);
}
} else {
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] * MAX_VALUE / max_value_thread);
}
}
// store
Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>(
res_vec, quanted_res_now + start_offset + lane_id * NUM_PER_THREADS);
if (lane_id == 0) {
*quanted_scale_now = scale_to_store;
if constexpr (UseUE8M0) {
// exp
int exp = (reinterpret_cast<int &>(scale_to_store) >> 23) & 0xFF;
const int pack_idx = iter >> 2;
const int byte_idx = iter & 3;
// pack
const int pack_num = align(hidden_size_scale, 4) >> 2;
// column-major base index
int32_t *scale_now = quanted_scale;
const int base_idx = token_idx + pack_idx * padded_token_num;
// ---------------- store exp ----------------
reinterpret_cast<uint8_t *>(&scale_now[base_idx])[byte_idx] =
static_cast<uint8_t>(exp);
} else {
float *scale_now =
quanted_scale + iter * padded_token_num + token_idx;
*scale_now = scale_to_store;
}
}
}
}
}
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
const int block_size) {
const int block_size,
const bool use_ue8m0) {
using ScaleDtype = float;
auto input_dim = input.dims();
const int token_num = input_dim[0];
const int hidden_size = input_dim[1];
@@ -256,13 +363,11 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
const int tma_alignment_bytes = 16;
const int tma_alignment_elements = tma_alignment_bytes / sizeof(ScaleDtype);
const int padded_token_num =
((token_num + tma_alignment_elements - 1) / tma_alignment_elements) *
tma_alignment_elements;
auto quanted_scale = GetEmptyTensor({padded_token_num, hidden_size_scale},
{1, padded_token_num},
paddle::DataType::FLOAT32,
input.place());
const int gridx = min(132 * 8, token_num);
const int blockx = min(1024, hidden_size / 128 * 32);
@@ -271,34 +376,76 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
if (env_var) {
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
}
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<ScaleDtype>(),
token_num,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<ScaleDtype>(),
token_num,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
if (use_ue8m0) {
auto quanted_scale =
GetEmptyTensor({padded_token_num, ceil_div(hidden_size_scale, 4)},
{1, padded_token_num},
paddle::DataType::INT32,
input.place());
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block_padding<paddle::bfloat16, int32_t, true>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<int32_t>(),
token_num,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block_padding<paddle::float16, int32_t, true>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<int32_t>(),
token_num,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
}
return {quanted_x, quanted_scale};
} else {
auto quanted_scale = GetEmptyTensor({padded_token_num, hidden_size_scale},
{1, padded_token_num},
paddle::DataType::FLOAT32,
input.place());
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block_padding<paddle::bfloat16, float, false>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block_padding<paddle::float16, float, false>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
}
return {quanted_x, quanted_scale};
}
return {quanted_x, quanted_scale};
}
std::vector<std::vector<int64_t>> PerTokenQuantPaddingInferShape(
@@ -314,19 +461,25 @@ std::vector<std::vector<int64_t>> PerTokenQuantPaddingInferShape(
const int padded_token_num =
((token_num + tma_alignment_elements - 1) / tma_alignment_elements) *
tma_alignment_elements;
if (GetSMVersion() >= 100) {
return {{token_num, hidden_size},
{padded_token_num, ceil_div(hidden_size_scale, 4)}};
}
return {{token_num, hidden_size}, {padded_token_num, hidden_size_scale}};
}
std::vector<paddle::DataType> PerTokenQuantPaddingInferDtype(
paddle::DataType input_dtype) {
if (GetSMVersion() >= 100) {
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32};
}
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32};
}
PD_BUILD_STATIC_OP(per_token_quant)
.Inputs({"input"})
.Outputs({"output", "output_scale"})
.Attrs({"block_size: int"})
.Attrs({"block_size: int", "use_ue8m0: bool"})
.SetKernelFn(PD_KERNEL(PerTokenQuant))
.SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantInferDtype));
@@ -334,7 +487,7 @@ PD_BUILD_STATIC_OP(per_token_quant)
PD_BUILD_STATIC_OP(per_token_quant_padding)
.Inputs({"input"})
.Outputs({"output", "output_scale"})
.Attrs({"block_size: int"})
.Attrs({"block_size: int", "use_ue8m0: bool"})
.SetKernelFn(PD_KERNEL(PerTokenQuantPadding))
.SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantPaddingInferDtype));