mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support EP prefill with num_worst_tokens (#6574)
* support num worst tokens * support num worst tokens * fix build error * support num worst tokens: fix errors * support num worst tokens: fix feild * support num worst tokens: delete requiements * replace permute and depermute op by pure cuda * replace permute and depermute op by pure cuda * fix ci * fix op * fix nan * fix code style --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -313,9 +313,11 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
const int token_nums_this_rank_padded);
|
||||
|
||||
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
|
||||
const int block_size);
|
||||
const int block_size,
|
||||
const bool use_ue8m0);
|
||||
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
|
||||
const int block_size);
|
||||
const int block_size,
|
||||
const bool use_ue8m0);
|
||||
|
||||
std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
|
||||
paddle::Tensor& input,
|
||||
@@ -1175,6 +1177,19 @@ std::vector<paddle::Tensor> get_attn_mask_q(
|
||||
const paddle::optional<paddle::Tensor>& attn_mask_kv,
|
||||
const int kv_token_num);
|
||||
|
||||
std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& scale,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const int num_local_experts,
|
||||
const int max_token_num);
|
||||
|
||||
std::vector<paddle::Tensor> DepermutePrefillCombine(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& indice_map,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const int num_worst_tokens);
|
||||
|
||||
void RadixTopkRaggedTransform(
|
||||
paddle::Tensor& input,
|
||||
paddle::Tensor& output_indices,
|
||||
@@ -1367,12 +1382,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
&PerTokenQuant,
|
||||
py::arg("input"),
|
||||
py::arg("block_size"),
|
||||
py::arg("use_ue8m0"),
|
||||
"per token per block quant");
|
||||
|
||||
m.def("per_token_quant_padding",
|
||||
&PerTokenQuantPadding,
|
||||
py::arg("input"),
|
||||
py::arg("block_size"),
|
||||
py::arg("use_ue8m0"),
|
||||
"per token per block quant and padding transpose scale");
|
||||
|
||||
m.def("fused_mask_swiglu_fp8_quant",
|
||||
@@ -1826,6 +1843,22 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("custom_numpy_to_tensor",
|
||||
&CustomNumpyToTensor,
|
||||
"custom_numpy_to_tensor function");
|
||||
m.def("prefill_permute_to_masked_gemm",
|
||||
&PrefillPermuteToMaskedGemm,
|
||||
py::arg("x"),
|
||||
py::arg("scale"),
|
||||
py::arg("topk_ids"),
|
||||
py::arg("num_local_experts"),
|
||||
py::arg("max_token_num"),
|
||||
"Prefill permute to masked GEMM for MoE");
|
||||
|
||||
m.def("depermute_prefill_combine",
|
||||
&DepermutePrefillCombine,
|
||||
py::arg("x"),
|
||||
py::arg("indice_map"),
|
||||
py::arg("topk_weights"),
|
||||
py::arg("num_worst_tokens"),
|
||||
"Depermute and combine expert outputs for MoE prefill");
|
||||
|
||||
m.def("radix_topk_ragged_transform",
|
||||
&RadixTopkRaggedTransform,
|
||||
|
||||
@@ -20,10 +20,18 @@ __host__ __device__ __forceinline__ int ceil_div(int x, int y) {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
__host__ __device__ __forceinline__ int64_t ceil_div(int64_t x, int64_t y) {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
__host__ __device__ __forceinline__ int align(int x, int y) {
|
||||
return ceil_div(x, y) * y;
|
||||
}
|
||||
|
||||
__host__ __device__ __forceinline__ int64_t align(int64_t x, int64_t y) {
|
||||
return ceil_div(x, y) * y;
|
||||
}
|
||||
|
||||
#ifndef BOOL_SWITCH
|
||||
#define BOOL_SWITCH(cond, name, ...) \
|
||||
if (cond) { \
|
||||
@@ -41,10 +49,10 @@ __global__ void fused_swiglu_fp8_quant_kernel(
|
||||
const index_t* __restrict__ token_nums_per_expert,
|
||||
phi::dtype::float8_e4m3fn* __restrict__ out_fp8,
|
||||
ScaleT* __restrict__ out_scale,
|
||||
int group_num,
|
||||
int group_size,
|
||||
int hidden_size,
|
||||
int hidden_size_scale,
|
||||
int64_t group_num,
|
||||
int64_t group_size,
|
||||
int64_t hidden_size,
|
||||
int64_t hidden_size_scale,
|
||||
bool use_finegrained_range) {
|
||||
constexpr int BLOCK = 128;
|
||||
|
||||
@@ -53,7 +61,7 @@ __global__ void fused_swiglu_fp8_quant_kernel(
|
||||
int warp = tid >> 5;
|
||||
int num_warps = blockDim.x >> 5;
|
||||
|
||||
int block_id = static_cast<int64_t>(blockIdx.x);
|
||||
int64_t block_id = static_cast<int64_t>(blockIdx.x);
|
||||
|
||||
using VecBF16 = AlignedVector<T, 4>;
|
||||
VecBF16 x1_vec, x2_vec;
|
||||
@@ -62,13 +70,13 @@ __global__ void fused_swiglu_fp8_quant_kernel(
|
||||
|
||||
while (true) {
|
||||
// ================= token mapping =================
|
||||
int expert = -1;
|
||||
int token_in_expert = -1;
|
||||
int64_t expert = -1;
|
||||
int64_t token_in_expert = -1;
|
||||
|
||||
if (lane == 0) {
|
||||
int cumsum = 0;
|
||||
for (int i = 0; i < group_num; ++i) {
|
||||
int cnt = token_nums_per_expert[i];
|
||||
int64_t cumsum = 0;
|
||||
for (int64_t i = 0; i < group_num; ++i) {
|
||||
int64_t cnt = static_cast<int64_t>(token_nums_per_expert[i]);
|
||||
if (block_id >= cumsum && block_id < cumsum + cnt) {
|
||||
expert = i;
|
||||
token_in_expert = block_id - cumsum;
|
||||
@@ -84,17 +92,17 @@ __global__ void fused_swiglu_fp8_quant_kernel(
|
||||
if (expert < 0 || token_in_expert >= group_size) break;
|
||||
|
||||
// ================= base pointers =================
|
||||
int token = expert * group_size + token_in_expert;
|
||||
int64_t token = expert * group_size + token_in_expert;
|
||||
|
||||
const T* in = input + token * hidden_size * 2;
|
||||
|
||||
auto* out = out_fp8 + token * hidden_size;
|
||||
|
||||
int num_iters = hidden_size / BLOCK;
|
||||
int64_t num_iters = hidden_size / BLOCK;
|
||||
|
||||
// ================= main loop =================
|
||||
for (int iter = warp; iter < num_iters; iter += num_warps) {
|
||||
int base = iter * BLOCK + lane * 4;
|
||||
for (int64_t iter = warp; iter < num_iters; iter += num_warps) {
|
||||
int64_t base = iter * BLOCK + lane * 4;
|
||||
|
||||
// vec load
|
||||
Load(in + base, &x1_vec);
|
||||
@@ -141,20 +149,20 @@ __global__ void fused_swiglu_fp8_quant_kernel(
|
||||
const int exp = (__float_as_int(scale) >> 23) & 0xFF;
|
||||
|
||||
// 2. pack information
|
||||
const int pack_idx = iter >> 2; // iter / 4
|
||||
const int byte_idx = iter & 3; // iter % 4
|
||||
const int64_t pack_idx = iter >> 2; // iter / 4
|
||||
const int64_t byte_idx = iter & 3; // iter % 4
|
||||
|
||||
// 3. layout parameters
|
||||
const int pack_num = ceil_div(hidden_size_scale, 4);
|
||||
const int token_stride = align(group_size, 4);
|
||||
const int64_t pack_num = ceil_div(hidden_size_scale, (int64_t)4);
|
||||
const int64_t token_stride = align(group_size, (int64_t)4);
|
||||
|
||||
// 4. base pointer (int32 pack)
|
||||
auto* scale_pack = reinterpret_cast<int32_t*>(out_scale);
|
||||
|
||||
// 5. column-major offset:
|
||||
// [expert][pack][token]
|
||||
const int base_idx = expert * pack_num * token_stride +
|
||||
pack_idx * token_stride + token_in_expert;
|
||||
const int64_t base_idx = expert * pack_num * token_stride +
|
||||
pack_idx * token_stride + token_in_expert;
|
||||
// 6. write one byte into pack
|
||||
reinterpret_cast<uint8_t*>(&scale_pack[base_idx])[byte_idx] =
|
||||
static_cast<uint8_t>(exp);
|
||||
@@ -184,11 +192,11 @@ std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
|
||||
const int block_size,
|
||||
const bool use_ue8m0) {
|
||||
auto dim = input.dims();
|
||||
const int group_num = token_nums_per_expert.shape()[0];
|
||||
const int group_size = dim[1];
|
||||
const int hidden_size = dim[2] / 2;
|
||||
const int hidden_size_scale = hidden_size / block_size;
|
||||
const int token_num = group_num * group_size;
|
||||
const int64_t group_num = token_nums_per_expert.shape()[0];
|
||||
const int64_t group_size = dim[1];
|
||||
const int64_t hidden_size = dim[2] / 2;
|
||||
const int64_t hidden_size_scale = hidden_size / block_size;
|
||||
const int64_t token_num = group_num * group_size;
|
||||
|
||||
auto out_fp8 = GetEmptyTensor({group_num, group_size, hidden_size},
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
@@ -200,21 +208,22 @@ std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
|
||||
paddle::DataType::FLOAT32,
|
||||
input.place());
|
||||
if (use_ue8m0) {
|
||||
int hidden_size_scale_pack = ceil_div(hidden_size_scale, 4);
|
||||
out_scale = GetEmptyTensor({group_num, group_size, hidden_size_scale_pack},
|
||||
{hidden_size_scale_pack * align(group_size, 4),
|
||||
1,
|
||||
align(group_size, 4)},
|
||||
paddle::DataType::INT32,
|
||||
input.place());
|
||||
int64_t hidden_size_scale_pack = ceil_div(hidden_size_scale, (int64_t)4);
|
||||
int64_t group_size_aligned = align(group_size, (int64_t)4);
|
||||
out_scale = GetEmptyTensor(
|
||||
{group_num, group_size, hidden_size_scale_pack},
|
||||
{hidden_size_scale_pack * group_size_aligned, 1, group_size_aligned},
|
||||
paddle::DataType::INT32,
|
||||
input.place());
|
||||
}
|
||||
|
||||
int sm_count = 0;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, 0);
|
||||
|
||||
constexpr int BLOCKS_PER_SM = 2;
|
||||
int gridx = std::min(sm_count * BLOCKS_PER_SM, token_num);
|
||||
int blockx = std::min(1024, hidden_size / 128 * 32);
|
||||
int gridx =
|
||||
std::min(static_cast<int64_t>(sm_count * BLOCKS_PER_SM), token_num);
|
||||
int blockx = std::min(1024L, hidden_size / 128 * 32);
|
||||
|
||||
bool use_finegrained_range = false;
|
||||
if (auto* env = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"))
|
||||
|
||||
@@ -243,6 +243,13 @@ class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
|
||||
typedef __nv_fp8_e4m3 DataType;
|
||||
typedef paddle::float8_e4m3fn data_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::INT32> {
|
||||
public:
|
||||
typedef int32_t DataType;
|
||||
typedef int32_t data_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T, int Size>
|
||||
|
||||
@@ -0,0 +1,218 @@
|
||||
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
constexpr int DEPERMUTE_BLOCK_THREADS = 512;
|
||||
|
||||
// Depermute and combine expert outputs back to token-major layout.
|
||||
//
|
||||
// For each output token, this kernel:
|
||||
// 1. Loads indice_map and topk_weights from shared memory
|
||||
// 2. For each valid expert slot (indice >= 0), reads the expert output row,
|
||||
// scales by topk_weight, and accumulates in float32
|
||||
// 3. Writes the combined result back in the original dtype
|
||||
//
|
||||
// x: [num_experts, max_tokens_per_expert, hidden] - expert outputs
|
||||
// indice_map: [num_worst_tokens, topk] int32 - flat indices (expert_idx * M +
|
||||
// offset) topk_weights: [num_worst_tokens, topk] float32 - combination weights
|
||||
// depermuted_x: [num_worst_tokens, hidden] - output (same dtype as x)
|
||||
|
||||
template <typename T, int VecSize, int TOP_K>
|
||||
__global__ void DepermutePrefillCombineKernel(
|
||||
T* __restrict__ depermuted_x,
|
||||
const T* __restrict__ x,
|
||||
const int32_t* __restrict__ indice_map,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int num_worst_tokens,
|
||||
const int hidden,
|
||||
const int max_tokens_per_expert) {
|
||||
__shared__ int32_t smem_indices[TOP_K];
|
||||
__shared__ float smem_weights[TOP_K];
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int num_vecs = hidden / VecSize;
|
||||
|
||||
for (int token_idx = blockIdx.x; token_idx < num_worst_tokens;
|
||||
token_idx += gridDim.x) {
|
||||
// Thread 0 loads indice_map, thread 32 loads topk_weights
|
||||
if (tidx < TOP_K) {
|
||||
smem_indices[tidx] = indice_map[token_idx * TOP_K + tidx];
|
||||
}
|
||||
if (tidx >= TOP_K && tidx < 2 * TOP_K) {
|
||||
int k = tidx - TOP_K;
|
||||
smem_weights[k] = topk_weights[token_idx * TOP_K + k];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Check if any expert slot is valid
|
||||
bool need_store = false;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < TOP_K; k++) {
|
||||
if (smem_indices[k] >= 0) {
|
||||
need_store = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (need_store) {
|
||||
// Each thread processes a subset of hidden vectors
|
||||
for (int v = tidx; v < num_vecs; v += DEPERMUTE_BLOCK_THREADS) {
|
||||
// Initialize accumulator in float32
|
||||
float acc[VecSize];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
acc[i] = 0.0f;
|
||||
}
|
||||
|
||||
// Accumulate weighted contributions from each expert
|
||||
for (int k = 0; k < TOP_K; k++) {
|
||||
int32_t indice = smem_indices[k];
|
||||
if (indice >= 0) {
|
||||
float weight = smem_weights[k];
|
||||
int64_t expert_idx =
|
||||
static_cast<int64_t>(indice) / max_tokens_per_expert;
|
||||
int64_t offset =
|
||||
static_cast<int64_t>(indice) % max_tokens_per_expert;
|
||||
|
||||
const T* src = x + expert_idx * max_tokens_per_expert * hidden +
|
||||
offset * hidden;
|
||||
|
||||
AlignedVector<T, VecSize> vec;
|
||||
Load<T, VecSize>(src + v * VecSize, &vec);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
acc[i] += static_cast<float>(vec[i]) * weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cast back and store
|
||||
AlignedVector<T, VecSize> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(acc[i]);
|
||||
}
|
||||
Store<T, VecSize>(out_vec,
|
||||
depermuted_x +
|
||||
static_cast<int64_t>(token_idx) * hidden +
|
||||
v * VecSize);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D, int TOP_K>
|
||||
std::vector<paddle::Tensor> DepermutePrefillCombineDispatch(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& indice_map,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const int num_worst_tokens) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const int hidden = x.shape()[2];
|
||||
const int max_tokens_per_expert = x.shape()[1];
|
||||
|
||||
auto place = x.place();
|
||||
auto stream = x.stream();
|
||||
|
||||
auto depermuted_x =
|
||||
GetEmptyTensor({num_worst_tokens, hidden}, x.dtype(), place);
|
||||
|
||||
constexpr int VecSize = 16 / sizeof(DataType_);
|
||||
|
||||
int dev;
|
||||
cudaGetDevice(&dev);
|
||||
int sm_count;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
|
||||
int num_blocks = min(sm_count * 2, num_worst_tokens);
|
||||
|
||||
DepermutePrefillCombineKernel<DataType_, VecSize, TOP_K>
|
||||
<<<num_blocks, DEPERMUTE_BLOCK_THREADS, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(depermuted_x.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(x.data<data_t>()),
|
||||
indice_map.data<int32_t>(),
|
||||
topk_weights.data<float>(),
|
||||
num_worst_tokens,
|
||||
hidden,
|
||||
max_tokens_per_expert);
|
||||
|
||||
return {depermuted_x};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> DepermutePrefillCombine(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& indice_map,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const int num_worst_tokens) {
|
||||
const int topk = indice_map.shape()[1];
|
||||
|
||||
#define DISPATCH_TOPK(DTYPE, TOPK_VAL) \
|
||||
case TOPK_VAL: \
|
||||
return DepermutePrefillCombineDispatch<DTYPE, TOPK_VAL>( \
|
||||
x, indice_map, topk_weights, num_worst_tokens);
|
||||
|
||||
switch (x.dtype()) {
|
||||
case paddle::DataType::FLOAT8_E4M3FN: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 4)
|
||||
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
}
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 4)
|
||||
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
}
|
||||
}
|
||||
default:
|
||||
PD_THROW("Unsupported dtype, must be float8_e4m3fn or bfloat16");
|
||||
}
|
||||
|
||||
#undef DISPATCH_TOPK
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> DepermutePrefillCombineInferShape(
|
||||
const std::vector<int64_t>& x_shape,
|
||||
const std::vector<int64_t>& indice_map_shape,
|
||||
const std::vector<int64_t>& topk_weights_shape,
|
||||
const int num_worst_tokens) {
|
||||
int64_t hidden = x_shape[2];
|
||||
return {{num_worst_tokens, hidden}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> DepermutePrefillCombineInferDtype(
|
||||
const paddle::DataType& x_dtype,
|
||||
const paddle::DataType& indice_map_dtype,
|
||||
const paddle::DataType& topk_weights_dtype,
|
||||
const int num_worst_tokens) {
|
||||
return {x_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(depermute_prefill_combine)
|
||||
.Inputs({"x", "indice_map", "topk_weights"})
|
||||
.Outputs({"depermuted_x"})
|
||||
.Attrs({"num_worst_tokens: int"})
|
||||
.SetKernelFn(PD_KERNEL(DepermutePrefillCombine))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(DepermutePrefillCombineInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(DepermutePrefillCombineInferDtype));
|
||||
@@ -0,0 +1,283 @@
|
||||
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
constexpr int BLOCK_THREADS = 512;
|
||||
|
||||
template <typename T, typename ScaleT, int VecSize, int TOP_K>
|
||||
__global__ void PrefillPermuteToMaskedGemmKernel(
|
||||
T* __restrict__ permute_x,
|
||||
ScaleT* __restrict__ permute_scale,
|
||||
int32_t* __restrict__ permuted_indice_map,
|
||||
int32_t* __restrict__ token_nums_per_expert,
|
||||
const T* __restrict__ x,
|
||||
const ScaleT* __restrict__ scale,
|
||||
const int64_t* __restrict__ topk_ids,
|
||||
const int num_tokens,
|
||||
const int hidden,
|
||||
const int hidden_scale,
|
||||
const int max_tokens_per_expert) {
|
||||
__shared__ int32_t smem_offset;
|
||||
__shared__ int64_t smem_topk_ids[TOP_K];
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int x_num_vecs = hidden / VecSize;
|
||||
constexpr int ScaleVecSize = 16 / sizeof(float); // 4
|
||||
const int scale_num_vecs = hidden_scale / ScaleVecSize;
|
||||
|
||||
for (int token_idx = blockIdx.x; token_idx < num_tokens;
|
||||
token_idx += gridDim.x) {
|
||||
if (tidx < TOP_K) {
|
||||
smem_topk_ids[tidx] = topk_ids[token_idx * TOP_K + tidx];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int slot = 0; slot < TOP_K; slot++) {
|
||||
int64_t expert_idx = smem_topk_ids[slot];
|
||||
if (expert_idx != -1) {
|
||||
if (tidx == 0) {
|
||||
smem_offset = atomicAdd(&token_nums_per_expert[expert_idx], 1);
|
||||
permuted_indice_map[token_idx * TOP_K + slot] = static_cast<int32_t>(
|
||||
expert_idx * max_tokens_per_expert + smem_offset);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int offset = smem_offset;
|
||||
|
||||
// Vectorized copy of x[token_idx, :] -> permute_x[expert_idx, offset,
|
||||
// :]
|
||||
const T* src_x = x + static_cast<int64_t>(token_idx) * hidden;
|
||||
T* dst_x =
|
||||
permute_x +
|
||||
static_cast<int64_t>(expert_idx) * max_tokens_per_expert * hidden +
|
||||
static_cast<int64_t>(offset) * hidden;
|
||||
|
||||
AlignedVector<T, VecSize> vec_x;
|
||||
for (int v = tidx; v < x_num_vecs; v += BLOCK_THREADS) {
|
||||
Load<T, VecSize>(src_x + v * VecSize, &vec_x);
|
||||
Store<T, VecSize>(vec_x, dst_x + v * VecSize);
|
||||
}
|
||||
|
||||
// Copy scale[token_idx, :] -> permute_scale with transposed layout
|
||||
// Physical layout is [E, S, M], accessed as [E, M, S] via strides [S*M,
|
||||
// 1, M] So permute_scale[expert_idx, offset, s] -> physical addr:
|
||||
// expert_idx*(S*M) + offset + s*M
|
||||
const ScaleT* src_scale =
|
||||
scale + static_cast<int64_t>(token_idx) * hidden_scale;
|
||||
ScaleT* dst_scale_base = permute_scale +
|
||||
static_cast<int64_t>(expert_idx) *
|
||||
hidden_scale * max_tokens_per_expert +
|
||||
offset;
|
||||
|
||||
for (int s = tidx; s < hidden_scale; s += BLOCK_THREADS) {
|
||||
dst_scale_base[static_cast<int64_t>(s) * max_tokens_per_expert] =
|
||||
src_scale[s];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D, paddle::DataType ScaleD, int TOP_K>
|
||||
std::vector<paddle::Tensor> PrefillPermuteToMaskedGemmDispatch(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& scale,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const int num_local_experts,
|
||||
const int max_token_num) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef PDTraits<ScaleD> scale_traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
typedef typename scale_traits_::DataType ScaleDataType_;
|
||||
typedef typename scale_traits_::data_t scale_data_t;
|
||||
|
||||
const int num_tokens = x.shape()[0];
|
||||
const int hidden = x.shape()[1];
|
||||
const int hidden_scale = scale.shape()[1];
|
||||
const int topk = topk_ids.shape()[1];
|
||||
|
||||
auto place = x.place();
|
||||
auto stream = x.stream();
|
||||
|
||||
auto permute_x = GetEmptyTensor(
|
||||
{num_local_experts, max_token_num, hidden}, x.dtype(), place);
|
||||
|
||||
// Allocate permute_scale with transposed physical layout [E, S, M]
|
||||
// but logical shape [E, M, S] with strides [S*M, 1, M]
|
||||
// This matches the CuteDSL version's paddle.empty([E, S, M]).transpose((0, 2,
|
||||
// 1))
|
||||
auto permute_scale =
|
||||
GetEmptyTensor({num_local_experts, max_token_num, hidden_scale},
|
||||
{static_cast<int64_t>(hidden_scale) * max_token_num,
|
||||
1,
|
||||
static_cast<int64_t>(max_token_num)},
|
||||
ScaleD,
|
||||
place);
|
||||
|
||||
auto permuted_indice_map =
|
||||
GetEmptyTensor({num_tokens, topk}, paddle::DataType::INT32, place);
|
||||
auto token_nums_per_expert =
|
||||
GetEmptyTensor({num_local_experts, 1}, paddle::DataType::INT32, place);
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(token_nums_per_expert.data<int32_t>(),
|
||||
0,
|
||||
num_local_experts * sizeof(int32_t),
|
||||
stream));
|
||||
// memset 0xFF for int32 produces -1
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(permuted_indice_map.data<int32_t>(),
|
||||
0xFF,
|
||||
num_tokens * topk * sizeof(int32_t),
|
||||
stream));
|
||||
|
||||
constexpr int VecSize = 16 / sizeof(DataType_);
|
||||
|
||||
int dev;
|
||||
cudaGetDevice(&dev);
|
||||
int sm_count;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
|
||||
int num_blocks = sm_count * 2;
|
||||
|
||||
PrefillPermuteToMaskedGemmKernel<DataType_, ScaleDataType_, VecSize, TOP_K>
|
||||
<<<num_blocks, BLOCK_THREADS, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(permute_x.data<data_t>()),
|
||||
reinterpret_cast<ScaleDataType_*>(
|
||||
permute_scale.template data<scale_data_t>()),
|
||||
permuted_indice_map.data<int32_t>(),
|
||||
token_nums_per_expert.data<int32_t>(),
|
||||
reinterpret_cast<const DataType_*>(x.data<data_t>()),
|
||||
reinterpret_cast<const ScaleDataType_*>(
|
||||
scale.template data<scale_data_t>()),
|
||||
topk_ids.data<int64_t>(),
|
||||
num_tokens,
|
||||
hidden,
|
||||
hidden_scale,
|
||||
max_token_num);
|
||||
|
||||
return {permute_x, permute_scale, permuted_indice_map, token_nums_per_expert};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& scale,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const int num_local_experts,
|
||||
const int max_token_num) {
|
||||
const int topk = topk_ids.shape()[1];
|
||||
|
||||
#define DISPATCH_TOPK(DTYPE, SCALE_DTYPE, TOPK_VAL) \
|
||||
case TOPK_VAL: \
|
||||
return PrefillPermuteToMaskedGemmDispatch<DTYPE, SCALE_DTYPE, TOPK_VAL>( \
|
||||
x, scale, topk_ids, num_local_experts, max_token_num);
|
||||
|
||||
switch (x.dtype()) {
|
||||
case paddle::DataType::FLOAT8_E4M3FN: {
|
||||
switch (scale.dtype()) {
|
||||
case paddle::DataType::FLOAT32: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32, 4)
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
}
|
||||
}
|
||||
case paddle::DataType::INT32: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32, 4)
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
switch (scale.dtype()) {
|
||||
case paddle::DataType::FLOAT32: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 4)
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
}
|
||||
}
|
||||
case paddle::DataType::INT32: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::BFLOAT16, paddle::DataType::INT32, 4)
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::BFLOAT16, paddle::DataType::INT32, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
PD_THROW("Unsupported dtype, must be float8_e4m3fn or bfloat16");
|
||||
}
|
||||
|
||||
#undef DISPATCH_TOPK
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> PrefillPermuteToMaskedGemmInferShape(
|
||||
const std::vector<int64_t>& x_shape,
|
||||
const std::vector<int64_t>& scale_shape,
|
||||
const std::vector<int64_t>& topk_ids_shape,
|
||||
const int num_local_experts,
|
||||
const int max_token_num) {
|
||||
int64_t num_tokens = x_shape[0];
|
||||
int64_t hidden = x_shape[1];
|
||||
int64_t hidden_scale = scale_shape[1];
|
||||
int64_t topk = topk_ids_shape[1];
|
||||
|
||||
return {
|
||||
{num_local_experts, max_token_num, hidden},
|
||||
{num_local_experts, max_token_num, hidden_scale},
|
||||
{num_tokens, topk},
|
||||
{num_local_experts, 1},
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> PrefillPermuteToMaskedGemmInferDtype(
|
||||
const paddle::DataType& x_dtype,
|
||||
const paddle::DataType& scale_dtype,
|
||||
const paddle::DataType& topk_ids_dtype,
|
||||
const int num_local_experts,
|
||||
const int max_token_num) {
|
||||
return {
|
||||
x_dtype, scale_dtype, paddle::DataType::INT32, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(prefill_permute_to_masked_gemm)
|
||||
.Inputs({"x", "scale", "topk_ids"})
|
||||
.Outputs({"permute_x",
|
||||
"permute_scale",
|
||||
"permuted_indice_map",
|
||||
"token_nums_per_expert"})
|
||||
.Attrs({"num_local_experts: int", "max_token_num: int"})
|
||||
.SetKernelFn(PD_KERNEL(PrefillPermuteToMaskedGemm))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(PrefillPermuteToMaskedGemmInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(PrefillPermuteToMaskedGemmInferDtype));
|
||||
@@ -16,11 +16,19 @@
|
||||
|
||||
constexpr float epsilon = 1e-10;
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ __forceinline__ int ceil_div(int x, int y) {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
__host__ __device__ __forceinline__ int align(int x, int y) {
|
||||
return ceil_div(x, y) * y;
|
||||
}
|
||||
|
||||
template <typename T, typename ScaleT, bool UseUE8M0>
|
||||
__global__ void quant_per_token_per_block(
|
||||
const T *input,
|
||||
phi::dtype::float8_e4m3fn *quanted_res,
|
||||
float *quanted_scale,
|
||||
ScaleT *quanted_scale,
|
||||
const int token_num,
|
||||
const int hidden_size,
|
||||
const int hidden_size_scale,
|
||||
@@ -38,10 +46,11 @@ __global__ void quant_per_token_per_block(
|
||||
AlignedVector<float, NUM_PER_THREADS> load_vec_float;
|
||||
AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec;
|
||||
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
|
||||
const T *input_now = input + token_idx * hidden_size;
|
||||
const T *input_now = input + static_cast<int64_t>(token_idx) * hidden_size;
|
||||
phi::dtype::float8_e4m3fn *quanted_res_now =
|
||||
quanted_res + token_idx * hidden_size;
|
||||
float *quanted_scale_now = quanted_scale + token_idx * hidden_size_scale;
|
||||
quanted_res + static_cast<int64_t>(token_idx) * hidden_size;
|
||||
float *quanted_scale_now = reinterpret_cast<float *>(quanted_scale) +
|
||||
token_idx * hidden_size_scale;
|
||||
// deal a block per warp
|
||||
for (int iter = warp_id; iter < end_iter; iter += num_warp) {
|
||||
const int start_offset = iter * 128;
|
||||
@@ -83,11 +92,22 @@ __global__ void quant_per_token_per_block(
|
||||
}
|
||||
|
||||
float scale_to_store = max_value_thread / MAX_VALUE;
|
||||
|
||||
// quant
|
||||
if constexpr (UseUE8M0) {
|
||||
scale_to_store =
|
||||
exp2f(ceilf(log2f(fmaxf(scale_to_store, epsilon) + 5e-7f)));
|
||||
#pragma unroll
|
||||
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
|
||||
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
|
||||
load_vec_float[vid] * MAX_VALUE / max_value_thread);
|
||||
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
|
||||
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
|
||||
load_vec_float[vid] / scale_to_store);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
|
||||
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
|
||||
load_vec_float[vid] * MAX_VALUE / max_value_thread);
|
||||
}
|
||||
}
|
||||
// store
|
||||
if (is_valid_data)
|
||||
@@ -95,14 +115,26 @@ __global__ void quant_per_token_per_block(
|
||||
res_vec,
|
||||
quanted_res_now + start_offset + lane_id * NUM_PER_THREADS);
|
||||
if (lane_id == 0) {
|
||||
quanted_scale_now[iter] = scale_to_store;
|
||||
if constexpr (UseUE8M0) {
|
||||
int exp = (reinterpret_cast<int &>(scale_to_store) >> 23) & 0xFF;
|
||||
const int pack_idx = iter >> 2;
|
||||
const int byte_idx = iter & 3;
|
||||
const int pack_num = ceil_div(hidden_size_scale, 4);
|
||||
int32_t *scale_now = quanted_scale;
|
||||
const int base_idx = token_idx * pack_num + pack_idx;
|
||||
reinterpret_cast<uint8_t *>(&scale_now[base_idx])[byte_idx] =
|
||||
static_cast<uint8_t>(exp);
|
||||
} else {
|
||||
quanted_scale_now[iter] = scale_to_store;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
|
||||
const int block_size) {
|
||||
const int block_size,
|
||||
const bool use_ue8m0) {
|
||||
auto input_dim = input.dims();
|
||||
const int token_num = input_dim[0];
|
||||
const int hidden_size = input_dim[1];
|
||||
@@ -112,8 +144,7 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
|
||||
|
||||
auto quanted_x = GetEmptyTensor(
|
||||
{token_num, hidden_size}, paddle::DataType::FLOAT8_E4M3FN, input.place());
|
||||
auto quanted_scale = GetEmptyTensor(
|
||||
{token_num, hidden_size_scale}, paddle::DataType::FLOAT32, input.place());
|
||||
|
||||
const int gridx = min(132 * 8, token_num);
|
||||
const int blockx = min(1024, hidden_size / 128 * 32);
|
||||
|
||||
@@ -123,31 +154,70 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
|
||||
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
|
||||
}
|
||||
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::bfloat16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<float>(),
|
||||
token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::float16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<float>(),
|
||||
token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for PerTokenQuant");
|
||||
if (use_ue8m0) {
|
||||
auto quanted_scale =
|
||||
GetEmptyTensor({token_num, ceil_div(hidden_size_scale, 4)},
|
||||
paddle::DataType::INT32,
|
||||
input.place());
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
quant_per_token_per_block<paddle::bfloat16, int32_t, true>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::bfloat16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<int32_t>(),
|
||||
token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
quant_per_token_per_block<paddle::float16, int32_t, true>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::float16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<int32_t>(),
|
||||
token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for PerTokenQuant");
|
||||
}
|
||||
return {quanted_x, quanted_scale};
|
||||
} else {
|
||||
auto quanted_scale = GetEmptyTensor({token_num, hidden_size_scale},
|
||||
paddle::DataType::FLOAT32,
|
||||
input.place());
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
quant_per_token_per_block<paddle::bfloat16, float, false>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::bfloat16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<float>(),
|
||||
token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
quant_per_token_per_block<paddle::float16, float, false>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::float16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<float>(),
|
||||
token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for PerTokenQuant");
|
||||
}
|
||||
return {quanted_x, quanted_scale};
|
||||
}
|
||||
return {quanted_x, quanted_scale};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> PerTokenQuantInferShape(
|
||||
@@ -155,19 +225,26 @@ std::vector<std::vector<int64_t>> PerTokenQuantInferShape(
|
||||
const int token_num = input_shape[0];
|
||||
const int hidden_size = input_shape[1];
|
||||
const int hidden_size_scale = (hidden_size + block_size - 1) / block_size;
|
||||
if (GetSMVersion() >= 100) {
|
||||
return {{token_num, hidden_size},
|
||||
{token_num, ceil_div(hidden_size_scale, 4)}};
|
||||
}
|
||||
return {{token_num, hidden_size}, {token_num, hidden_size_scale}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> PerTokenQuantInferDtype(
|
||||
paddle::DataType input_dtype, const int block_size) {
|
||||
if (GetSMVersion() >= 100) {
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32};
|
||||
}
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename ScaleT, bool UseUE8M0>
|
||||
__global__ void quant_per_token_per_block_padding(
|
||||
const T *input,
|
||||
phi::dtype::float8_e4m3fn *quanted_res,
|
||||
float *quanted_scale,
|
||||
ScaleT *quanted_scale,
|
||||
const int token_num,
|
||||
const int padded_token_num,
|
||||
const int hidden_size,
|
||||
@@ -185,13 +262,11 @@ __global__ void quant_per_token_per_block_padding(
|
||||
AlignedVector<float, NUM_PER_THREADS> load_vec_float;
|
||||
AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec;
|
||||
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
|
||||
const T *input_now = input + token_idx * hidden_size;
|
||||
const T *input_now = input + static_cast<int64_t>(token_idx) * hidden_size;
|
||||
phi::dtype::float8_e4m3fn *quanted_res_now =
|
||||
quanted_res + token_idx * hidden_size;
|
||||
quanted_res + static_cast<int64_t>(token_idx) * hidden_size;
|
||||
// deal a block per warp
|
||||
for (int iter = warp_id; iter < end_iter; iter += num_warp) {
|
||||
float *quanted_scale_now =
|
||||
quanted_scale + iter * padded_token_num + token_idx;
|
||||
const int start_offset = iter * 128;
|
||||
Load<T, NUM_PER_THREADS>(
|
||||
input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec);
|
||||
@@ -222,26 +297,58 @@ __global__ void quant_per_token_per_block_padding(
|
||||
}
|
||||
|
||||
float scale_to_store = max_value_thread / MAX_VALUE;
|
||||
|
||||
// quant
|
||||
if constexpr (UseUE8M0) {
|
||||
scale_to_store =
|
||||
exp2f(ceilf(log2f(fmaxf(scale_to_store, epsilon) + 5e-7f)));
|
||||
#pragma unroll
|
||||
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
|
||||
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
|
||||
load_vec_float[vid] * MAX_VALUE / max_value_thread);
|
||||
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
|
||||
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
|
||||
load_vec_float[vid] / scale_to_store);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
|
||||
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
|
||||
load_vec_float[vid] * MAX_VALUE / max_value_thread);
|
||||
}
|
||||
}
|
||||
// store
|
||||
Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>(
|
||||
res_vec, quanted_res_now + start_offset + lane_id * NUM_PER_THREADS);
|
||||
if (lane_id == 0) {
|
||||
*quanted_scale_now = scale_to_store;
|
||||
if constexpr (UseUE8M0) {
|
||||
// exp
|
||||
int exp = (reinterpret_cast<int &>(scale_to_store) >> 23) & 0xFF;
|
||||
|
||||
const int pack_idx = iter >> 2;
|
||||
const int byte_idx = iter & 3;
|
||||
|
||||
// pack
|
||||
const int pack_num = align(hidden_size_scale, 4) >> 2;
|
||||
|
||||
// column-major base index
|
||||
int32_t *scale_now = quanted_scale;
|
||||
const int base_idx = token_idx + pack_idx * padded_token_num;
|
||||
|
||||
// ---------------- store exp ----------------
|
||||
reinterpret_cast<uint8_t *>(&scale_now[base_idx])[byte_idx] =
|
||||
static_cast<uint8_t>(exp);
|
||||
} else {
|
||||
float *scale_now =
|
||||
quanted_scale + iter * padded_token_num + token_idx;
|
||||
*scale_now = scale_to_store;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
|
||||
const int block_size) {
|
||||
const int block_size,
|
||||
const bool use_ue8m0) {
|
||||
using ScaleDtype = float;
|
||||
|
||||
auto input_dim = input.dims();
|
||||
const int token_num = input_dim[0];
|
||||
const int hidden_size = input_dim[1];
|
||||
@@ -256,13 +363,11 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
|
||||
|
||||
const int tma_alignment_bytes = 16;
|
||||
const int tma_alignment_elements = tma_alignment_bytes / sizeof(ScaleDtype);
|
||||
|
||||
const int padded_token_num =
|
||||
((token_num + tma_alignment_elements - 1) / tma_alignment_elements) *
|
||||
tma_alignment_elements;
|
||||
auto quanted_scale = GetEmptyTensor({padded_token_num, hidden_size_scale},
|
||||
{1, padded_token_num},
|
||||
paddle::DataType::FLOAT32,
|
||||
input.place());
|
||||
|
||||
const int gridx = min(132 * 8, token_num);
|
||||
const int blockx = min(1024, hidden_size / 128 * 32);
|
||||
|
||||
@@ -271,34 +376,76 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
|
||||
if (env_var) {
|
||||
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
|
||||
}
|
||||
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::bfloat16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<ScaleDtype>(),
|
||||
token_num,
|
||||
padded_token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::float16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<ScaleDtype>(),
|
||||
token_num,
|
||||
padded_token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for PerTokenQuant");
|
||||
if (use_ue8m0) {
|
||||
auto quanted_scale =
|
||||
GetEmptyTensor({padded_token_num, ceil_div(hidden_size_scale, 4)},
|
||||
{1, padded_token_num},
|
||||
paddle::DataType::INT32,
|
||||
input.place());
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
quant_per_token_per_block_padding<paddle::bfloat16, int32_t, true>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::bfloat16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<int32_t>(),
|
||||
token_num,
|
||||
padded_token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
quant_per_token_per_block_padding<paddle::float16, int32_t, true>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::float16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<int32_t>(),
|
||||
token_num,
|
||||
padded_token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for PerTokenQuant");
|
||||
}
|
||||
return {quanted_x, quanted_scale};
|
||||
} else {
|
||||
auto quanted_scale = GetEmptyTensor({padded_token_num, hidden_size_scale},
|
||||
{1, padded_token_num},
|
||||
paddle::DataType::FLOAT32,
|
||||
input.place());
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
quant_per_token_per_block_padding<paddle::bfloat16, float, false>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::bfloat16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<float>(),
|
||||
token_num,
|
||||
padded_token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
quant_per_token_per_block_padding<paddle::float16, float, false>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::float16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<float>(),
|
||||
token_num,
|
||||
padded_token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for PerTokenQuant");
|
||||
}
|
||||
return {quanted_x, quanted_scale};
|
||||
}
|
||||
return {quanted_x, quanted_scale};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> PerTokenQuantPaddingInferShape(
|
||||
@@ -314,19 +461,25 @@ std::vector<std::vector<int64_t>> PerTokenQuantPaddingInferShape(
|
||||
const int padded_token_num =
|
||||
((token_num + tma_alignment_elements - 1) / tma_alignment_elements) *
|
||||
tma_alignment_elements;
|
||||
|
||||
if (GetSMVersion() >= 100) {
|
||||
return {{token_num, hidden_size},
|
||||
{padded_token_num, ceil_div(hidden_size_scale, 4)}};
|
||||
}
|
||||
return {{token_num, hidden_size}, {padded_token_num, hidden_size_scale}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> PerTokenQuantPaddingInferDtype(
|
||||
paddle::DataType input_dtype) {
|
||||
if (GetSMVersion() >= 100) {
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32};
|
||||
}
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(per_token_quant)
|
||||
.Inputs({"input"})
|
||||
.Outputs({"output", "output_scale"})
|
||||
.Attrs({"block_size: int"})
|
||||
.Attrs({"block_size: int", "use_ue8m0: bool"})
|
||||
.SetKernelFn(PD_KERNEL(PerTokenQuant))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantInferDtype));
|
||||
@@ -334,7 +487,7 @@ PD_BUILD_STATIC_OP(per_token_quant)
|
||||
PD_BUILD_STATIC_OP(per_token_quant_padding)
|
||||
.Inputs({"input"})
|
||||
.Outputs({"output", "output_scale"})
|
||||
.Attrs({"block_size: int"})
|
||||
.Attrs({"block_size: int", "use_ue8m0: bool"})
|
||||
.SetKernelFn(PD_KERNEL(PerTokenQuantPadding))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantPaddingInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantPaddingInferDtype));
|
||||
|
||||
Reference in New Issue
Block a user