mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support EP prefill with num_worst_tokens (#6574)
* support num worst tokens * support num worst tokens * fix build error * support num worst tokens: fix errors * support num worst tokens: fix feild * support num worst tokens: delete requiements * replace permute and depermute op by pure cuda * replace permute and depermute op by pure cuda * fix ci * fix op * fix nan * fix code style --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -313,9 +313,11 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
const int token_nums_this_rank_padded);
|
||||
|
||||
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
|
||||
const int block_size);
|
||||
const int block_size,
|
||||
const bool use_ue8m0);
|
||||
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
|
||||
const int block_size);
|
||||
const int block_size,
|
||||
const bool use_ue8m0);
|
||||
|
||||
std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
|
||||
paddle::Tensor& input,
|
||||
@@ -1175,6 +1177,19 @@ std::vector<paddle::Tensor> get_attn_mask_q(
|
||||
const paddle::optional<paddle::Tensor>& attn_mask_kv,
|
||||
const int kv_token_num);
|
||||
|
||||
std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& scale,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const int num_local_experts,
|
||||
const int max_token_num);
|
||||
|
||||
std::vector<paddle::Tensor> DepermutePrefillCombine(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& indice_map,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const int num_worst_tokens);
|
||||
|
||||
void RadixTopkRaggedTransform(
|
||||
paddle::Tensor& input,
|
||||
paddle::Tensor& output_indices,
|
||||
@@ -1367,12 +1382,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
&PerTokenQuant,
|
||||
py::arg("input"),
|
||||
py::arg("block_size"),
|
||||
py::arg("use_ue8m0"),
|
||||
"per token per block quant");
|
||||
|
||||
m.def("per_token_quant_padding",
|
||||
&PerTokenQuantPadding,
|
||||
py::arg("input"),
|
||||
py::arg("block_size"),
|
||||
py::arg("use_ue8m0"),
|
||||
"per token per block quant and padding transpose scale");
|
||||
|
||||
m.def("fused_mask_swiglu_fp8_quant",
|
||||
@@ -1826,6 +1843,22 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("custom_numpy_to_tensor",
|
||||
&CustomNumpyToTensor,
|
||||
"custom_numpy_to_tensor function");
|
||||
m.def("prefill_permute_to_masked_gemm",
|
||||
&PrefillPermuteToMaskedGemm,
|
||||
py::arg("x"),
|
||||
py::arg("scale"),
|
||||
py::arg("topk_ids"),
|
||||
py::arg("num_local_experts"),
|
||||
py::arg("max_token_num"),
|
||||
"Prefill permute to masked GEMM for MoE");
|
||||
|
||||
m.def("depermute_prefill_combine",
|
||||
&DepermutePrefillCombine,
|
||||
py::arg("x"),
|
||||
py::arg("indice_map"),
|
||||
py::arg("topk_weights"),
|
||||
py::arg("num_worst_tokens"),
|
||||
"Depermute and combine expert outputs for MoE prefill");
|
||||
|
||||
m.def("radix_topk_ragged_transform",
|
||||
&RadixTopkRaggedTransform,
|
||||
|
||||
@@ -20,10 +20,18 @@ __host__ __device__ __forceinline__ int ceil_div(int x, int y) {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
__host__ __device__ __forceinline__ int64_t ceil_div(int64_t x, int64_t y) {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
__host__ __device__ __forceinline__ int align(int x, int y) {
|
||||
return ceil_div(x, y) * y;
|
||||
}
|
||||
|
||||
__host__ __device__ __forceinline__ int64_t align(int64_t x, int64_t y) {
|
||||
return ceil_div(x, y) * y;
|
||||
}
|
||||
|
||||
#ifndef BOOL_SWITCH
|
||||
#define BOOL_SWITCH(cond, name, ...) \
|
||||
if (cond) { \
|
||||
@@ -41,10 +49,10 @@ __global__ void fused_swiglu_fp8_quant_kernel(
|
||||
const index_t* __restrict__ token_nums_per_expert,
|
||||
phi::dtype::float8_e4m3fn* __restrict__ out_fp8,
|
||||
ScaleT* __restrict__ out_scale,
|
||||
int group_num,
|
||||
int group_size,
|
||||
int hidden_size,
|
||||
int hidden_size_scale,
|
||||
int64_t group_num,
|
||||
int64_t group_size,
|
||||
int64_t hidden_size,
|
||||
int64_t hidden_size_scale,
|
||||
bool use_finegrained_range) {
|
||||
constexpr int BLOCK = 128;
|
||||
|
||||
@@ -53,7 +61,7 @@ __global__ void fused_swiglu_fp8_quant_kernel(
|
||||
int warp = tid >> 5;
|
||||
int num_warps = blockDim.x >> 5;
|
||||
|
||||
int block_id = static_cast<int64_t>(blockIdx.x);
|
||||
int64_t block_id = static_cast<int64_t>(blockIdx.x);
|
||||
|
||||
using VecBF16 = AlignedVector<T, 4>;
|
||||
VecBF16 x1_vec, x2_vec;
|
||||
@@ -62,13 +70,13 @@ __global__ void fused_swiglu_fp8_quant_kernel(
|
||||
|
||||
while (true) {
|
||||
// ================= token mapping =================
|
||||
int expert = -1;
|
||||
int token_in_expert = -1;
|
||||
int64_t expert = -1;
|
||||
int64_t token_in_expert = -1;
|
||||
|
||||
if (lane == 0) {
|
||||
int cumsum = 0;
|
||||
for (int i = 0; i < group_num; ++i) {
|
||||
int cnt = token_nums_per_expert[i];
|
||||
int64_t cumsum = 0;
|
||||
for (int64_t i = 0; i < group_num; ++i) {
|
||||
int64_t cnt = static_cast<int64_t>(token_nums_per_expert[i]);
|
||||
if (block_id >= cumsum && block_id < cumsum + cnt) {
|
||||
expert = i;
|
||||
token_in_expert = block_id - cumsum;
|
||||
@@ -84,17 +92,17 @@ __global__ void fused_swiglu_fp8_quant_kernel(
|
||||
if (expert < 0 || token_in_expert >= group_size) break;
|
||||
|
||||
// ================= base pointers =================
|
||||
int token = expert * group_size + token_in_expert;
|
||||
int64_t token = expert * group_size + token_in_expert;
|
||||
|
||||
const T* in = input + token * hidden_size * 2;
|
||||
|
||||
auto* out = out_fp8 + token * hidden_size;
|
||||
|
||||
int num_iters = hidden_size / BLOCK;
|
||||
int64_t num_iters = hidden_size / BLOCK;
|
||||
|
||||
// ================= main loop =================
|
||||
for (int iter = warp; iter < num_iters; iter += num_warps) {
|
||||
int base = iter * BLOCK + lane * 4;
|
||||
for (int64_t iter = warp; iter < num_iters; iter += num_warps) {
|
||||
int64_t base = iter * BLOCK + lane * 4;
|
||||
|
||||
// vec load
|
||||
Load(in + base, &x1_vec);
|
||||
@@ -141,20 +149,20 @@ __global__ void fused_swiglu_fp8_quant_kernel(
|
||||
const int exp = (__float_as_int(scale) >> 23) & 0xFF;
|
||||
|
||||
// 2. pack information
|
||||
const int pack_idx = iter >> 2; // iter / 4
|
||||
const int byte_idx = iter & 3; // iter % 4
|
||||
const int64_t pack_idx = iter >> 2; // iter / 4
|
||||
const int64_t byte_idx = iter & 3; // iter % 4
|
||||
|
||||
// 3. layout parameters
|
||||
const int pack_num = ceil_div(hidden_size_scale, 4);
|
||||
const int token_stride = align(group_size, 4);
|
||||
const int64_t pack_num = ceil_div(hidden_size_scale, (int64_t)4);
|
||||
const int64_t token_stride = align(group_size, (int64_t)4);
|
||||
|
||||
// 4. base pointer (int32 pack)
|
||||
auto* scale_pack = reinterpret_cast<int32_t*>(out_scale);
|
||||
|
||||
// 5. column-major offset:
|
||||
// [expert][pack][token]
|
||||
const int base_idx = expert * pack_num * token_stride +
|
||||
pack_idx * token_stride + token_in_expert;
|
||||
const int64_t base_idx = expert * pack_num * token_stride +
|
||||
pack_idx * token_stride + token_in_expert;
|
||||
// 6. write one byte into pack
|
||||
reinterpret_cast<uint8_t*>(&scale_pack[base_idx])[byte_idx] =
|
||||
static_cast<uint8_t>(exp);
|
||||
@@ -184,11 +192,11 @@ std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
|
||||
const int block_size,
|
||||
const bool use_ue8m0) {
|
||||
auto dim = input.dims();
|
||||
const int group_num = token_nums_per_expert.shape()[0];
|
||||
const int group_size = dim[1];
|
||||
const int hidden_size = dim[2] / 2;
|
||||
const int hidden_size_scale = hidden_size / block_size;
|
||||
const int token_num = group_num * group_size;
|
||||
const int64_t group_num = token_nums_per_expert.shape()[0];
|
||||
const int64_t group_size = dim[1];
|
||||
const int64_t hidden_size = dim[2] / 2;
|
||||
const int64_t hidden_size_scale = hidden_size / block_size;
|
||||
const int64_t token_num = group_num * group_size;
|
||||
|
||||
auto out_fp8 = GetEmptyTensor({group_num, group_size, hidden_size},
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
@@ -200,21 +208,22 @@ std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
|
||||
paddle::DataType::FLOAT32,
|
||||
input.place());
|
||||
if (use_ue8m0) {
|
||||
int hidden_size_scale_pack = ceil_div(hidden_size_scale, 4);
|
||||
out_scale = GetEmptyTensor({group_num, group_size, hidden_size_scale_pack},
|
||||
{hidden_size_scale_pack * align(group_size, 4),
|
||||
1,
|
||||
align(group_size, 4)},
|
||||
paddle::DataType::INT32,
|
||||
input.place());
|
||||
int64_t hidden_size_scale_pack = ceil_div(hidden_size_scale, (int64_t)4);
|
||||
int64_t group_size_aligned = align(group_size, (int64_t)4);
|
||||
out_scale = GetEmptyTensor(
|
||||
{group_num, group_size, hidden_size_scale_pack},
|
||||
{hidden_size_scale_pack * group_size_aligned, 1, group_size_aligned},
|
||||
paddle::DataType::INT32,
|
||||
input.place());
|
||||
}
|
||||
|
||||
int sm_count = 0;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, 0);
|
||||
|
||||
constexpr int BLOCKS_PER_SM = 2;
|
||||
int gridx = std::min(sm_count * BLOCKS_PER_SM, token_num);
|
||||
int blockx = std::min(1024, hidden_size / 128 * 32);
|
||||
int gridx =
|
||||
std::min(static_cast<int64_t>(sm_count * BLOCKS_PER_SM), token_num);
|
||||
int blockx = std::min(1024L, hidden_size / 128 * 32);
|
||||
|
||||
bool use_finegrained_range = false;
|
||||
if (auto* env = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"))
|
||||
|
||||
@@ -243,6 +243,13 @@ class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
|
||||
typedef __nv_fp8_e4m3 DataType;
|
||||
typedef paddle::float8_e4m3fn data_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::INT32> {
|
||||
public:
|
||||
typedef int32_t DataType;
|
||||
typedef int32_t data_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T, int Size>
|
||||
|
||||
@@ -0,0 +1,218 @@
|
||||
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
constexpr int DEPERMUTE_BLOCK_THREADS = 512;
|
||||
|
||||
// Depermute and combine expert outputs back to token-major layout.
|
||||
//
|
||||
// For each output token, this kernel:
|
||||
// 1. Loads indice_map and topk_weights from shared memory
|
||||
// 2. For each valid expert slot (indice >= 0), reads the expert output row,
|
||||
// scales by topk_weight, and accumulates in float32
|
||||
// 3. Writes the combined result back in the original dtype
|
||||
//
|
||||
// x: [num_experts, max_tokens_per_expert, hidden] - expert outputs
|
||||
// indice_map: [num_worst_tokens, topk] int32 - flat indices (expert_idx * M +
|
||||
// offset) topk_weights: [num_worst_tokens, topk] float32 - combination weights
|
||||
// depermuted_x: [num_worst_tokens, hidden] - output (same dtype as x)
|
||||
|
||||
template <typename T, int VecSize, int TOP_K>
|
||||
__global__ void DepermutePrefillCombineKernel(
|
||||
T* __restrict__ depermuted_x,
|
||||
const T* __restrict__ x,
|
||||
const int32_t* __restrict__ indice_map,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int num_worst_tokens,
|
||||
const int hidden,
|
||||
const int max_tokens_per_expert) {
|
||||
__shared__ int32_t smem_indices[TOP_K];
|
||||
__shared__ float smem_weights[TOP_K];
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int num_vecs = hidden / VecSize;
|
||||
|
||||
for (int token_idx = blockIdx.x; token_idx < num_worst_tokens;
|
||||
token_idx += gridDim.x) {
|
||||
// Thread 0 loads indice_map, thread 32 loads topk_weights
|
||||
if (tidx < TOP_K) {
|
||||
smem_indices[tidx] = indice_map[token_idx * TOP_K + tidx];
|
||||
}
|
||||
if (tidx >= TOP_K && tidx < 2 * TOP_K) {
|
||||
int k = tidx - TOP_K;
|
||||
smem_weights[k] = topk_weights[token_idx * TOP_K + k];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Check if any expert slot is valid
|
||||
bool need_store = false;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < TOP_K; k++) {
|
||||
if (smem_indices[k] >= 0) {
|
||||
need_store = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (need_store) {
|
||||
// Each thread processes a subset of hidden vectors
|
||||
for (int v = tidx; v < num_vecs; v += DEPERMUTE_BLOCK_THREADS) {
|
||||
// Initialize accumulator in float32
|
||||
float acc[VecSize];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
acc[i] = 0.0f;
|
||||
}
|
||||
|
||||
// Accumulate weighted contributions from each expert
|
||||
for (int k = 0; k < TOP_K; k++) {
|
||||
int32_t indice = smem_indices[k];
|
||||
if (indice >= 0) {
|
||||
float weight = smem_weights[k];
|
||||
int64_t expert_idx =
|
||||
static_cast<int64_t>(indice) / max_tokens_per_expert;
|
||||
int64_t offset =
|
||||
static_cast<int64_t>(indice) % max_tokens_per_expert;
|
||||
|
||||
const T* src = x + expert_idx * max_tokens_per_expert * hidden +
|
||||
offset * hidden;
|
||||
|
||||
AlignedVector<T, VecSize> vec;
|
||||
Load<T, VecSize>(src + v * VecSize, &vec);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
acc[i] += static_cast<float>(vec[i]) * weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cast back and store
|
||||
AlignedVector<T, VecSize> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(acc[i]);
|
||||
}
|
||||
Store<T, VecSize>(out_vec,
|
||||
depermuted_x +
|
||||
static_cast<int64_t>(token_idx) * hidden +
|
||||
v * VecSize);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D, int TOP_K>
|
||||
std::vector<paddle::Tensor> DepermutePrefillCombineDispatch(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& indice_map,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const int num_worst_tokens) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const int hidden = x.shape()[2];
|
||||
const int max_tokens_per_expert = x.shape()[1];
|
||||
|
||||
auto place = x.place();
|
||||
auto stream = x.stream();
|
||||
|
||||
auto depermuted_x =
|
||||
GetEmptyTensor({num_worst_tokens, hidden}, x.dtype(), place);
|
||||
|
||||
constexpr int VecSize = 16 / sizeof(DataType_);
|
||||
|
||||
int dev;
|
||||
cudaGetDevice(&dev);
|
||||
int sm_count;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
|
||||
int num_blocks = min(sm_count * 2, num_worst_tokens);
|
||||
|
||||
DepermutePrefillCombineKernel<DataType_, VecSize, TOP_K>
|
||||
<<<num_blocks, DEPERMUTE_BLOCK_THREADS, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(depermuted_x.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(x.data<data_t>()),
|
||||
indice_map.data<int32_t>(),
|
||||
topk_weights.data<float>(),
|
||||
num_worst_tokens,
|
||||
hidden,
|
||||
max_tokens_per_expert);
|
||||
|
||||
return {depermuted_x};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> DepermutePrefillCombine(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& indice_map,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const int num_worst_tokens) {
|
||||
const int topk = indice_map.shape()[1];
|
||||
|
||||
#define DISPATCH_TOPK(DTYPE, TOPK_VAL) \
|
||||
case TOPK_VAL: \
|
||||
return DepermutePrefillCombineDispatch<DTYPE, TOPK_VAL>( \
|
||||
x, indice_map, topk_weights, num_worst_tokens);
|
||||
|
||||
switch (x.dtype()) {
|
||||
case paddle::DataType::FLOAT8_E4M3FN: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 4)
|
||||
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
}
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 4)
|
||||
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
}
|
||||
}
|
||||
default:
|
||||
PD_THROW("Unsupported dtype, must be float8_e4m3fn or bfloat16");
|
||||
}
|
||||
|
||||
#undef DISPATCH_TOPK
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> DepermutePrefillCombineInferShape(
|
||||
const std::vector<int64_t>& x_shape,
|
||||
const std::vector<int64_t>& indice_map_shape,
|
||||
const std::vector<int64_t>& topk_weights_shape,
|
||||
const int num_worst_tokens) {
|
||||
int64_t hidden = x_shape[2];
|
||||
return {{num_worst_tokens, hidden}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> DepermutePrefillCombineInferDtype(
|
||||
const paddle::DataType& x_dtype,
|
||||
const paddle::DataType& indice_map_dtype,
|
||||
const paddle::DataType& topk_weights_dtype,
|
||||
const int num_worst_tokens) {
|
||||
return {x_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(depermute_prefill_combine)
|
||||
.Inputs({"x", "indice_map", "topk_weights"})
|
||||
.Outputs({"depermuted_x"})
|
||||
.Attrs({"num_worst_tokens: int"})
|
||||
.SetKernelFn(PD_KERNEL(DepermutePrefillCombine))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(DepermutePrefillCombineInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(DepermutePrefillCombineInferDtype));
|
||||
@@ -0,0 +1,283 @@
|
||||
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
constexpr int BLOCK_THREADS = 512;
|
||||
|
||||
template <typename T, typename ScaleT, int VecSize, int TOP_K>
|
||||
__global__ void PrefillPermuteToMaskedGemmKernel(
|
||||
T* __restrict__ permute_x,
|
||||
ScaleT* __restrict__ permute_scale,
|
||||
int32_t* __restrict__ permuted_indice_map,
|
||||
int32_t* __restrict__ token_nums_per_expert,
|
||||
const T* __restrict__ x,
|
||||
const ScaleT* __restrict__ scale,
|
||||
const int64_t* __restrict__ topk_ids,
|
||||
const int num_tokens,
|
||||
const int hidden,
|
||||
const int hidden_scale,
|
||||
const int max_tokens_per_expert) {
|
||||
__shared__ int32_t smem_offset;
|
||||
__shared__ int64_t smem_topk_ids[TOP_K];
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int x_num_vecs = hidden / VecSize;
|
||||
constexpr int ScaleVecSize = 16 / sizeof(float); // 4
|
||||
const int scale_num_vecs = hidden_scale / ScaleVecSize;
|
||||
|
||||
for (int token_idx = blockIdx.x; token_idx < num_tokens;
|
||||
token_idx += gridDim.x) {
|
||||
if (tidx < TOP_K) {
|
||||
smem_topk_ids[tidx] = topk_ids[token_idx * TOP_K + tidx];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int slot = 0; slot < TOP_K; slot++) {
|
||||
int64_t expert_idx = smem_topk_ids[slot];
|
||||
if (expert_idx != -1) {
|
||||
if (tidx == 0) {
|
||||
smem_offset = atomicAdd(&token_nums_per_expert[expert_idx], 1);
|
||||
permuted_indice_map[token_idx * TOP_K + slot] = static_cast<int32_t>(
|
||||
expert_idx * max_tokens_per_expert + smem_offset);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int offset = smem_offset;
|
||||
|
||||
// Vectorized copy of x[token_idx, :] -> permute_x[expert_idx, offset,
|
||||
// :]
|
||||
const T* src_x = x + static_cast<int64_t>(token_idx) * hidden;
|
||||
T* dst_x =
|
||||
permute_x +
|
||||
static_cast<int64_t>(expert_idx) * max_tokens_per_expert * hidden +
|
||||
static_cast<int64_t>(offset) * hidden;
|
||||
|
||||
AlignedVector<T, VecSize> vec_x;
|
||||
for (int v = tidx; v < x_num_vecs; v += BLOCK_THREADS) {
|
||||
Load<T, VecSize>(src_x + v * VecSize, &vec_x);
|
||||
Store<T, VecSize>(vec_x, dst_x + v * VecSize);
|
||||
}
|
||||
|
||||
// Copy scale[token_idx, :] -> permute_scale with transposed layout
|
||||
// Physical layout is [E, S, M], accessed as [E, M, S] via strides [S*M,
|
||||
// 1, M] So permute_scale[expert_idx, offset, s] -> physical addr:
|
||||
// expert_idx*(S*M) + offset + s*M
|
||||
const ScaleT* src_scale =
|
||||
scale + static_cast<int64_t>(token_idx) * hidden_scale;
|
||||
ScaleT* dst_scale_base = permute_scale +
|
||||
static_cast<int64_t>(expert_idx) *
|
||||
hidden_scale * max_tokens_per_expert +
|
||||
offset;
|
||||
|
||||
for (int s = tidx; s < hidden_scale; s += BLOCK_THREADS) {
|
||||
dst_scale_base[static_cast<int64_t>(s) * max_tokens_per_expert] =
|
||||
src_scale[s];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D, paddle::DataType ScaleD, int TOP_K>
|
||||
std::vector<paddle::Tensor> PrefillPermuteToMaskedGemmDispatch(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& scale,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const int num_local_experts,
|
||||
const int max_token_num) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef PDTraits<ScaleD> scale_traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
typedef typename scale_traits_::DataType ScaleDataType_;
|
||||
typedef typename scale_traits_::data_t scale_data_t;
|
||||
|
||||
const int num_tokens = x.shape()[0];
|
||||
const int hidden = x.shape()[1];
|
||||
const int hidden_scale = scale.shape()[1];
|
||||
const int topk = topk_ids.shape()[1];
|
||||
|
||||
auto place = x.place();
|
||||
auto stream = x.stream();
|
||||
|
||||
auto permute_x = GetEmptyTensor(
|
||||
{num_local_experts, max_token_num, hidden}, x.dtype(), place);
|
||||
|
||||
// Allocate permute_scale with transposed physical layout [E, S, M]
|
||||
// but logical shape [E, M, S] with strides [S*M, 1, M]
|
||||
// This matches the CuteDSL version's paddle.empty([E, S, M]).transpose((0, 2,
|
||||
// 1))
|
||||
auto permute_scale =
|
||||
GetEmptyTensor({num_local_experts, max_token_num, hidden_scale},
|
||||
{static_cast<int64_t>(hidden_scale) * max_token_num,
|
||||
1,
|
||||
static_cast<int64_t>(max_token_num)},
|
||||
ScaleD,
|
||||
place);
|
||||
|
||||
auto permuted_indice_map =
|
||||
GetEmptyTensor({num_tokens, topk}, paddle::DataType::INT32, place);
|
||||
auto token_nums_per_expert =
|
||||
GetEmptyTensor({num_local_experts, 1}, paddle::DataType::INT32, place);
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(token_nums_per_expert.data<int32_t>(),
|
||||
0,
|
||||
num_local_experts * sizeof(int32_t),
|
||||
stream));
|
||||
// memset 0xFF for int32 produces -1
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(permuted_indice_map.data<int32_t>(),
|
||||
0xFF,
|
||||
num_tokens * topk * sizeof(int32_t),
|
||||
stream));
|
||||
|
||||
constexpr int VecSize = 16 / sizeof(DataType_);
|
||||
|
||||
int dev;
|
||||
cudaGetDevice(&dev);
|
||||
int sm_count;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
|
||||
int num_blocks = sm_count * 2;
|
||||
|
||||
PrefillPermuteToMaskedGemmKernel<DataType_, ScaleDataType_, VecSize, TOP_K>
|
||||
<<<num_blocks, BLOCK_THREADS, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(permute_x.data<data_t>()),
|
||||
reinterpret_cast<ScaleDataType_*>(
|
||||
permute_scale.template data<scale_data_t>()),
|
||||
permuted_indice_map.data<int32_t>(),
|
||||
token_nums_per_expert.data<int32_t>(),
|
||||
reinterpret_cast<const DataType_*>(x.data<data_t>()),
|
||||
reinterpret_cast<const ScaleDataType_*>(
|
||||
scale.template data<scale_data_t>()),
|
||||
topk_ids.data<int64_t>(),
|
||||
num_tokens,
|
||||
hidden,
|
||||
hidden_scale,
|
||||
max_token_num);
|
||||
|
||||
return {permute_x, permute_scale, permuted_indice_map, token_nums_per_expert};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& scale,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const int num_local_experts,
|
||||
const int max_token_num) {
|
||||
const int topk = topk_ids.shape()[1];
|
||||
|
||||
#define DISPATCH_TOPK(DTYPE, SCALE_DTYPE, TOPK_VAL) \
|
||||
case TOPK_VAL: \
|
||||
return PrefillPermuteToMaskedGemmDispatch<DTYPE, SCALE_DTYPE, TOPK_VAL>( \
|
||||
x, scale, topk_ids, num_local_experts, max_token_num);
|
||||
|
||||
switch (x.dtype()) {
|
||||
case paddle::DataType::FLOAT8_E4M3FN: {
|
||||
switch (scale.dtype()) {
|
||||
case paddle::DataType::FLOAT32: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32, 4)
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
}
|
||||
}
|
||||
case paddle::DataType::INT32: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32, 4)
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
switch (scale.dtype()) {
|
||||
case paddle::DataType::FLOAT32: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 4)
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
}
|
||||
}
|
||||
case paddle::DataType::INT32: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::BFLOAT16, paddle::DataType::INT32, 4)
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::BFLOAT16, paddle::DataType::INT32, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
PD_THROW("Unsupported dtype, must be float8_e4m3fn or bfloat16");
|
||||
}
|
||||
|
||||
#undef DISPATCH_TOPK
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> PrefillPermuteToMaskedGemmInferShape(
|
||||
const std::vector<int64_t>& x_shape,
|
||||
const std::vector<int64_t>& scale_shape,
|
||||
const std::vector<int64_t>& topk_ids_shape,
|
||||
const int num_local_experts,
|
||||
const int max_token_num) {
|
||||
int64_t num_tokens = x_shape[0];
|
||||
int64_t hidden = x_shape[1];
|
||||
int64_t hidden_scale = scale_shape[1];
|
||||
int64_t topk = topk_ids_shape[1];
|
||||
|
||||
return {
|
||||
{num_local_experts, max_token_num, hidden},
|
||||
{num_local_experts, max_token_num, hidden_scale},
|
||||
{num_tokens, topk},
|
||||
{num_local_experts, 1},
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> PrefillPermuteToMaskedGemmInferDtype(
|
||||
const paddle::DataType& x_dtype,
|
||||
const paddle::DataType& scale_dtype,
|
||||
const paddle::DataType& topk_ids_dtype,
|
||||
const int num_local_experts,
|
||||
const int max_token_num) {
|
||||
return {
|
||||
x_dtype, scale_dtype, paddle::DataType::INT32, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(prefill_permute_to_masked_gemm)
|
||||
.Inputs({"x", "scale", "topk_ids"})
|
||||
.Outputs({"permute_x",
|
||||
"permute_scale",
|
||||
"permuted_indice_map",
|
||||
"token_nums_per_expert"})
|
||||
.Attrs({"num_local_experts: int", "max_token_num: int"})
|
||||
.SetKernelFn(PD_KERNEL(PrefillPermuteToMaskedGemm))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(PrefillPermuteToMaskedGemmInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(PrefillPermuteToMaskedGemmInferDtype));
|
||||
@@ -16,11 +16,19 @@
|
||||
|
||||
constexpr float epsilon = 1e-10;
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ __forceinline__ int ceil_div(int x, int y) {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
__host__ __device__ __forceinline__ int align(int x, int y) {
|
||||
return ceil_div(x, y) * y;
|
||||
}
|
||||
|
||||
template <typename T, typename ScaleT, bool UseUE8M0>
|
||||
__global__ void quant_per_token_per_block(
|
||||
const T *input,
|
||||
phi::dtype::float8_e4m3fn *quanted_res,
|
||||
float *quanted_scale,
|
||||
ScaleT *quanted_scale,
|
||||
const int token_num,
|
||||
const int hidden_size,
|
||||
const int hidden_size_scale,
|
||||
@@ -38,10 +46,11 @@ __global__ void quant_per_token_per_block(
|
||||
AlignedVector<float, NUM_PER_THREADS> load_vec_float;
|
||||
AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec;
|
||||
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
|
||||
const T *input_now = input + token_idx * hidden_size;
|
||||
const T *input_now = input + static_cast<int64_t>(token_idx) * hidden_size;
|
||||
phi::dtype::float8_e4m3fn *quanted_res_now =
|
||||
quanted_res + token_idx * hidden_size;
|
||||
float *quanted_scale_now = quanted_scale + token_idx * hidden_size_scale;
|
||||
quanted_res + static_cast<int64_t>(token_idx) * hidden_size;
|
||||
float *quanted_scale_now = reinterpret_cast<float *>(quanted_scale) +
|
||||
token_idx * hidden_size_scale;
|
||||
// deal a block per warp
|
||||
for (int iter = warp_id; iter < end_iter; iter += num_warp) {
|
||||
const int start_offset = iter * 128;
|
||||
@@ -83,11 +92,22 @@ __global__ void quant_per_token_per_block(
|
||||
}
|
||||
|
||||
float scale_to_store = max_value_thread / MAX_VALUE;
|
||||
|
||||
// quant
|
||||
if constexpr (UseUE8M0) {
|
||||
scale_to_store =
|
||||
exp2f(ceilf(log2f(fmaxf(scale_to_store, epsilon) + 5e-7f)));
|
||||
#pragma unroll
|
||||
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
|
||||
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
|
||||
load_vec_float[vid] * MAX_VALUE / max_value_thread);
|
||||
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
|
||||
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
|
||||
load_vec_float[vid] / scale_to_store);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
|
||||
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
|
||||
load_vec_float[vid] * MAX_VALUE / max_value_thread);
|
||||
}
|
||||
}
|
||||
// store
|
||||
if (is_valid_data)
|
||||
@@ -95,14 +115,26 @@ __global__ void quant_per_token_per_block(
|
||||
res_vec,
|
||||
quanted_res_now + start_offset + lane_id * NUM_PER_THREADS);
|
||||
if (lane_id == 0) {
|
||||
quanted_scale_now[iter] = scale_to_store;
|
||||
if constexpr (UseUE8M0) {
|
||||
int exp = (reinterpret_cast<int &>(scale_to_store) >> 23) & 0xFF;
|
||||
const int pack_idx = iter >> 2;
|
||||
const int byte_idx = iter & 3;
|
||||
const int pack_num = ceil_div(hidden_size_scale, 4);
|
||||
int32_t *scale_now = quanted_scale;
|
||||
const int base_idx = token_idx * pack_num + pack_idx;
|
||||
reinterpret_cast<uint8_t *>(&scale_now[base_idx])[byte_idx] =
|
||||
static_cast<uint8_t>(exp);
|
||||
} else {
|
||||
quanted_scale_now[iter] = scale_to_store;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
|
||||
const int block_size) {
|
||||
const int block_size,
|
||||
const bool use_ue8m0) {
|
||||
auto input_dim = input.dims();
|
||||
const int token_num = input_dim[0];
|
||||
const int hidden_size = input_dim[1];
|
||||
@@ -112,8 +144,7 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
|
||||
|
||||
auto quanted_x = GetEmptyTensor(
|
||||
{token_num, hidden_size}, paddle::DataType::FLOAT8_E4M3FN, input.place());
|
||||
auto quanted_scale = GetEmptyTensor(
|
||||
{token_num, hidden_size_scale}, paddle::DataType::FLOAT32, input.place());
|
||||
|
||||
const int gridx = min(132 * 8, token_num);
|
||||
const int blockx = min(1024, hidden_size / 128 * 32);
|
||||
|
||||
@@ -123,31 +154,70 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
|
||||
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
|
||||
}
|
||||
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::bfloat16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<float>(),
|
||||
token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::float16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<float>(),
|
||||
token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for PerTokenQuant");
|
||||
if (use_ue8m0) {
|
||||
auto quanted_scale =
|
||||
GetEmptyTensor({token_num, ceil_div(hidden_size_scale, 4)},
|
||||
paddle::DataType::INT32,
|
||||
input.place());
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
quant_per_token_per_block<paddle::bfloat16, int32_t, true>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::bfloat16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<int32_t>(),
|
||||
token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
quant_per_token_per_block<paddle::float16, int32_t, true>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::float16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<int32_t>(),
|
||||
token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for PerTokenQuant");
|
||||
}
|
||||
return {quanted_x, quanted_scale};
|
||||
} else {
|
||||
auto quanted_scale = GetEmptyTensor({token_num, hidden_size_scale},
|
||||
paddle::DataType::FLOAT32,
|
||||
input.place());
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
quant_per_token_per_block<paddle::bfloat16, float, false>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::bfloat16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<float>(),
|
||||
token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
quant_per_token_per_block<paddle::float16, float, false>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::float16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<float>(),
|
||||
token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for PerTokenQuant");
|
||||
}
|
||||
return {quanted_x, quanted_scale};
|
||||
}
|
||||
return {quanted_x, quanted_scale};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> PerTokenQuantInferShape(
|
||||
@@ -155,19 +225,26 @@ std::vector<std::vector<int64_t>> PerTokenQuantInferShape(
|
||||
const int token_num = input_shape[0];
|
||||
const int hidden_size = input_shape[1];
|
||||
const int hidden_size_scale = (hidden_size + block_size - 1) / block_size;
|
||||
if (GetSMVersion() >= 100) {
|
||||
return {{token_num, hidden_size},
|
||||
{token_num, ceil_div(hidden_size_scale, 4)}};
|
||||
}
|
||||
return {{token_num, hidden_size}, {token_num, hidden_size_scale}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> PerTokenQuantInferDtype(
|
||||
paddle::DataType input_dtype, const int block_size) {
|
||||
if (GetSMVersion() >= 100) {
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32};
|
||||
}
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename ScaleT, bool UseUE8M0>
|
||||
__global__ void quant_per_token_per_block_padding(
|
||||
const T *input,
|
||||
phi::dtype::float8_e4m3fn *quanted_res,
|
||||
float *quanted_scale,
|
||||
ScaleT *quanted_scale,
|
||||
const int token_num,
|
||||
const int padded_token_num,
|
||||
const int hidden_size,
|
||||
@@ -185,13 +262,11 @@ __global__ void quant_per_token_per_block_padding(
|
||||
AlignedVector<float, NUM_PER_THREADS> load_vec_float;
|
||||
AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec;
|
||||
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
|
||||
const T *input_now = input + token_idx * hidden_size;
|
||||
const T *input_now = input + static_cast<int64_t>(token_idx) * hidden_size;
|
||||
phi::dtype::float8_e4m3fn *quanted_res_now =
|
||||
quanted_res + token_idx * hidden_size;
|
||||
quanted_res + static_cast<int64_t>(token_idx) * hidden_size;
|
||||
// deal a block per warp
|
||||
for (int iter = warp_id; iter < end_iter; iter += num_warp) {
|
||||
float *quanted_scale_now =
|
||||
quanted_scale + iter * padded_token_num + token_idx;
|
||||
const int start_offset = iter * 128;
|
||||
Load<T, NUM_PER_THREADS>(
|
||||
input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec);
|
||||
@@ -222,26 +297,58 @@ __global__ void quant_per_token_per_block_padding(
|
||||
}
|
||||
|
||||
float scale_to_store = max_value_thread / MAX_VALUE;
|
||||
|
||||
// quant
|
||||
if constexpr (UseUE8M0) {
|
||||
scale_to_store =
|
||||
exp2f(ceilf(log2f(fmaxf(scale_to_store, epsilon) + 5e-7f)));
|
||||
#pragma unroll
|
||||
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
|
||||
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
|
||||
load_vec_float[vid] * MAX_VALUE / max_value_thread);
|
||||
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
|
||||
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
|
||||
load_vec_float[vid] / scale_to_store);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
|
||||
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
|
||||
load_vec_float[vid] * MAX_VALUE / max_value_thread);
|
||||
}
|
||||
}
|
||||
// store
|
||||
Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>(
|
||||
res_vec, quanted_res_now + start_offset + lane_id * NUM_PER_THREADS);
|
||||
if (lane_id == 0) {
|
||||
*quanted_scale_now = scale_to_store;
|
||||
if constexpr (UseUE8M0) {
|
||||
// exp
|
||||
int exp = (reinterpret_cast<int &>(scale_to_store) >> 23) & 0xFF;
|
||||
|
||||
const int pack_idx = iter >> 2;
|
||||
const int byte_idx = iter & 3;
|
||||
|
||||
// pack
|
||||
const int pack_num = align(hidden_size_scale, 4) >> 2;
|
||||
|
||||
// column-major base index
|
||||
int32_t *scale_now = quanted_scale;
|
||||
const int base_idx = token_idx + pack_idx * padded_token_num;
|
||||
|
||||
// ---------------- store exp ----------------
|
||||
reinterpret_cast<uint8_t *>(&scale_now[base_idx])[byte_idx] =
|
||||
static_cast<uint8_t>(exp);
|
||||
} else {
|
||||
float *scale_now =
|
||||
quanted_scale + iter * padded_token_num + token_idx;
|
||||
*scale_now = scale_to_store;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
|
||||
const int block_size) {
|
||||
const int block_size,
|
||||
const bool use_ue8m0) {
|
||||
using ScaleDtype = float;
|
||||
|
||||
auto input_dim = input.dims();
|
||||
const int token_num = input_dim[0];
|
||||
const int hidden_size = input_dim[1];
|
||||
@@ -256,13 +363,11 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
|
||||
|
||||
const int tma_alignment_bytes = 16;
|
||||
const int tma_alignment_elements = tma_alignment_bytes / sizeof(ScaleDtype);
|
||||
|
||||
const int padded_token_num =
|
||||
((token_num + tma_alignment_elements - 1) / tma_alignment_elements) *
|
||||
tma_alignment_elements;
|
||||
auto quanted_scale = GetEmptyTensor({padded_token_num, hidden_size_scale},
|
||||
{1, padded_token_num},
|
||||
paddle::DataType::FLOAT32,
|
||||
input.place());
|
||||
|
||||
const int gridx = min(132 * 8, token_num);
|
||||
const int blockx = min(1024, hidden_size / 128 * 32);
|
||||
|
||||
@@ -271,34 +376,76 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
|
||||
if (env_var) {
|
||||
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
|
||||
}
|
||||
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::bfloat16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<ScaleDtype>(),
|
||||
token_num,
|
||||
padded_token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::float16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<ScaleDtype>(),
|
||||
token_num,
|
||||
padded_token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for PerTokenQuant");
|
||||
if (use_ue8m0) {
|
||||
auto quanted_scale =
|
||||
GetEmptyTensor({padded_token_num, ceil_div(hidden_size_scale, 4)},
|
||||
{1, padded_token_num},
|
||||
paddle::DataType::INT32,
|
||||
input.place());
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
quant_per_token_per_block_padding<paddle::bfloat16, int32_t, true>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::bfloat16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<int32_t>(),
|
||||
token_num,
|
||||
padded_token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
quant_per_token_per_block_padding<paddle::float16, int32_t, true>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::float16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<int32_t>(),
|
||||
token_num,
|
||||
padded_token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for PerTokenQuant");
|
||||
}
|
||||
return {quanted_x, quanted_scale};
|
||||
} else {
|
||||
auto quanted_scale = GetEmptyTensor({padded_token_num, hidden_size_scale},
|
||||
{1, padded_token_num},
|
||||
paddle::DataType::FLOAT32,
|
||||
input.place());
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
quant_per_token_per_block_padding<paddle::bfloat16, float, false>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::bfloat16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<float>(),
|
||||
token_num,
|
||||
padded_token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
quant_per_token_per_block_padding<paddle::float16, float, false>
|
||||
<<<gridx, blockx, 0, input.stream()>>>(
|
||||
input.data<paddle::float16>(),
|
||||
quanted_x.data<phi::dtype::float8_e4m3fn>(),
|
||||
quanted_scale.data<float>(),
|
||||
token_num,
|
||||
padded_token_num,
|
||||
hidden_size,
|
||||
hidden_size_scale,
|
||||
use_finegrained_range);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for PerTokenQuant");
|
||||
}
|
||||
return {quanted_x, quanted_scale};
|
||||
}
|
||||
return {quanted_x, quanted_scale};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> PerTokenQuantPaddingInferShape(
|
||||
@@ -314,19 +461,25 @@ std::vector<std::vector<int64_t>> PerTokenQuantPaddingInferShape(
|
||||
const int padded_token_num =
|
||||
((token_num + tma_alignment_elements - 1) / tma_alignment_elements) *
|
||||
tma_alignment_elements;
|
||||
|
||||
if (GetSMVersion() >= 100) {
|
||||
return {{token_num, hidden_size},
|
||||
{padded_token_num, ceil_div(hidden_size_scale, 4)}};
|
||||
}
|
||||
return {{token_num, hidden_size}, {padded_token_num, hidden_size_scale}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> PerTokenQuantPaddingInferDtype(
|
||||
paddle::DataType input_dtype) {
|
||||
if (GetSMVersion() >= 100) {
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::INT32};
|
||||
}
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(per_token_quant)
|
||||
.Inputs({"input"})
|
||||
.Outputs({"output", "output_scale"})
|
||||
.Attrs({"block_size: int"})
|
||||
.Attrs({"block_size: int", "use_ue8m0: bool"})
|
||||
.SetKernelFn(PD_KERNEL(PerTokenQuant))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantInferDtype));
|
||||
@@ -334,7 +487,7 @@ PD_BUILD_STATIC_OP(per_token_quant)
|
||||
PD_BUILD_STATIC_OP(per_token_quant_padding)
|
||||
.Inputs({"input"})
|
||||
.Outputs({"output", "output_scale"})
|
||||
.Attrs({"block_size: int"})
|
||||
.Attrs({"block_size: int", "use_ue8m0: bool"})
|
||||
.SetKernelFn(PD_KERNEL(PerTokenQuantPadding))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantPaddingInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantPaddingInferDtype));
|
||||
|
||||
@@ -643,6 +643,8 @@ class ParallelConfig:
|
||||
self.disable_sequence_parallel_moe: bool = False
|
||||
# shutdown comm group if worker idle
|
||||
self.shutdown_comm_group_if_worker_idle: bool = None
|
||||
# ep_prefill_use_worst_num_tokens
|
||||
self.ep_prefill_use_worst_num_tokens: bool = False
|
||||
|
||||
self.pod_ip: str = None
|
||||
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
|
||||
|
||||
@@ -546,6 +546,11 @@ class EngineArgs:
|
||||
Flag to enable entropy output. Default is False (disabled).
|
||||
"""
|
||||
|
||||
ep_prefill_use_worst_num_tokens: bool = False
|
||||
"""
|
||||
Flag to enable prefill_use_worst_num_tokens. Default is False (disabled).
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -1060,6 +1065,12 @@ class EngineArgs:
|
||||
default=EngineArgs.shutdown_comm_group_if_worker_idle,
|
||||
help="Shutdown communication group when worker is idle.",
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--ep-prefill-use-worst-num-tokens",
|
||||
action="store_true",
|
||||
default=EngineArgs.ep_prefill_use_worst_num_tokens,
|
||||
help="Enable prefill use worst num tokens for EP.",
|
||||
)
|
||||
|
||||
# Load group
|
||||
load_group = parser.add_argument_group("Load Configuration")
|
||||
|
||||
@@ -643,6 +643,7 @@ class LLMEngine:
|
||||
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
|
||||
"shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle,
|
||||
"enable_entropy": self.cfg.model_config.enable_entropy,
|
||||
"ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens,
|
||||
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
|
||||
}
|
||||
for worker_flag, value in worker_store_true_flag.items():
|
||||
|
||||
@@ -865,6 +865,7 @@ class RequestMetrics:
|
||||
llm_engine_recv_req_timestamp: Optional[float] = None
|
||||
llm_engine_send_req_to_engine_timestamp: Optional[float] = None
|
||||
llm_engine_recv_latest_token_timestamp: Optional[float] = None
|
||||
llm_engine_recv_token_timestamp: Optional[float] = None
|
||||
|
||||
speculate_metrics: Optional[SpeculateMetrics] = None
|
||||
|
||||
@@ -933,6 +934,7 @@ class RequestMetrics:
|
||||
# for compatibility with old metrics
|
||||
self.llm_engine_recv_req_timestamp = self.engine_get_req_time
|
||||
self.llm_engine_send_req_to_engine_timestamp = self.inference_start_time
|
||||
self.llm_engine_recv_token_timestamp = self.engine_recv_first_token_time
|
||||
|
||||
def get(self, key: str, default_value=None):
|
||||
if hasattr(self, key):
|
||||
|
||||
@@ -319,7 +319,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
elif metadata._dtype == "float32":
|
||||
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||
|
||||
forward_meta.attention_metadata = metadata
|
||||
self.attention_metadata = metadata
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
@@ -332,7 +332,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
metadata = forward_meta.attention_metadata
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
|
||||
@@ -340,6 +340,14 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer.layer_id + self.start_layer_index,
|
||||
)
|
||||
|
||||
if int(os.getenv("USE_TBO", "0")) == 1:
|
||||
if hasattr(forward_meta, "tbo_microbatch_id"):
|
||||
# here we only let the last microbatch invoke cache kv transfer!
|
||||
if forward_meta.tbo_microbatch_id == 0:
|
||||
os.environ["FLAGS_fmt_write_cache_completed_signal"] = "0"
|
||||
elif forward_meta.tbo_microbatch_id == 1:
|
||||
os.environ["FLAGS_fmt_write_cache_completed_signal"] = "1"
|
||||
|
||||
norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
|
||||
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
@@ -369,15 +377,15 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
if forward_meta.max_len_tensor_cpu[1].item() > 0:
|
||||
|
||||
metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
|
||||
metadata.max_len_tensor_cpu_decoder[1] = 0
|
||||
forward_meta.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
|
||||
forward_meta.max_len_tensor_cpu_decoder[1] = 0
|
||||
|
||||
(
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
metadata.kv_token_num_cpu,
|
||||
forward_meta.cu_seqlens_k,
|
||||
forward_meta.pre_cache_batch_ids,
|
||||
forward_meta.pre_cache_tile_ids_per_batch,
|
||||
forward_meta.pre_cache_num_blocks_cpu,
|
||||
forward_meta.kv_token_num_cpu,
|
||||
) = pre_cache_len_concat(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
@@ -386,12 +394,14 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.block_size,
|
||||
)
|
||||
if FLASH_ATTN_VERSION == 4 or forward_meta.attn_mask_offsets is not None:
|
||||
metadata.attn_mask_q = get_attn_mask_q(
|
||||
forward_meta.attn_mask_q = get_attn_mask_q(
|
||||
cu_seqlens_q=forward_meta.cu_seqlens_q,
|
||||
cu_seqlens_k=metadata.cu_seqlens_k,
|
||||
cu_seqlens_k=forward_meta.cu_seqlens_k,
|
||||
attn_mask_kv=forward_meta.attn_mask_offsets,
|
||||
kv_token_num=metadata.kv_token_num_cpu[0].item(),
|
||||
kv_token_num=forward_meta.kv_token_num_cpu[0].item(),
|
||||
)
|
||||
else:
|
||||
forward_meta.attn_mask_q = None
|
||||
|
||||
use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0
|
||||
|
||||
@@ -401,7 +411,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
forward_meta.cu_seqlens_k,
|
||||
forward_meta.rotary_embs,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_encoder,
|
||||
@@ -411,9 +421,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
forward_meta.pre_cache_batch_ids,
|
||||
forward_meta.pre_cache_tile_ids_per_batch,
|
||||
forward_meta.pre_cache_num_blocks_cpu,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
@@ -423,7 +433,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
metadata.kv_token_num_cpu[0].item(),
|
||||
forward_meta.kv_token_num_cpu[0].item(),
|
||||
self.max_seq_len,
|
||||
getattr(layer, "rms_norm_eps", 1e-6),
|
||||
layer.use_neox_rotary_style,
|
||||
@@ -435,11 +445,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
forward_meta.cu_seqlens_q[: metadata.cu_seqlens_k.shape[0]],
|
||||
metadata.cu_seqlens_k,
|
||||
forward_meta.cu_seqlens_q[: forward_meta.cu_seqlens_k.shape[0]],
|
||||
forward_meta.cu_seqlens_k,
|
||||
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
|
||||
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
|
||||
attn_mask_q=metadata.attn_mask_q,
|
||||
attn_mask_q=forward_meta.attn_mask_q,
|
||||
causal=self.causal,
|
||||
num_heads=self.num_heads,
|
||||
kv_num_heads=self.kv_num_heads,
|
||||
@@ -465,7 +475,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.rotary_embs,
|
||||
forward_meta.attn_mask,
|
||||
layer.qkv_bias,
|
||||
|
||||
@@ -164,7 +164,7 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
elif metadata._dtype == "float32":
|
||||
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||
|
||||
forward_meta.attention_metadata = metadata
|
||||
self.attention_metadata = metadata
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
@@ -177,7 +177,7 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
metadata = forward_meta.attention_metadata
|
||||
metadata = self.attention_metadata
|
||||
|
||||
norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
|
||||
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
|
||||
@@ -58,12 +58,9 @@ def load_deep_ep() -> ModuleType:
|
||||
return deep_ep
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"import deep_ep failed! FD_USE_PFCC_DEEP_EP=%s. type=%s, err=%s",
|
||||
envs.FD_USE_PFCC_DEEP_EP,
|
||||
type(e).__name__,
|
||||
e,
|
||||
f"import deep_ep failed! FD_USE_PFCC_DEEP_EP={envs.FD_USE_PFCC_DEEP_EP}. type={type(e).__name__}, err={e}"
|
||||
)
|
||||
logger.error("Traceback:\n%s", traceback.format_exc())
|
||||
logger.error(f"Traceback:{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -586,6 +583,7 @@ class EPPrefillRunner(EPRunner):
|
||||
moe_phase: MoEPhase = MoEPhase("prefill"),
|
||||
ep_group=None,
|
||||
use_internode_ll_two_stage: bool = False,
|
||||
prefill_num_worst_tokens: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
top_k,
|
||||
@@ -600,6 +598,8 @@ class EPPrefillRunner(EPRunner):
|
||||
ep_group=ep_group,
|
||||
use_internode_ll_two_stage=use_internode_ll_two_stage,
|
||||
)
|
||||
self.num_worst_tokens = prefill_num_worst_tokens
|
||||
logger.info(f"prefill_num_worst_tokens {prefill_num_worst_tokens}")
|
||||
|
||||
def set_allocate_on_comm_stream(allocate_on_comm_stream: bool = False):
|
||||
if EPPrefillRunner.allocate_on_comm_stream == allocate_on_comm_stream:
|
||||
@@ -650,6 +650,8 @@ class EPPrefillRunner(EPRunner):
|
||||
"expert_alignment": expert_alignment,
|
||||
"allocate_on_comm_stream": EPPrefillRunner.allocate_on_comm_stream,
|
||||
"previous_event": event,
|
||||
"num_worst_tokens": self.num_worst_tokens,
|
||||
"skip_x_record_stream": self.num_worst_tokens > 0,
|
||||
}
|
||||
return buffer.dispatch(**dispatch_args)
|
||||
|
||||
@@ -672,6 +674,7 @@ class EPPrefillRunner(EPRunner):
|
||||
"topk_weights": recv_topk_weights,
|
||||
"previous_event": event,
|
||||
"allocate_on_comm_stream": EPPrefillRunner.allocate_on_comm_stream,
|
||||
"skip_x_record_stream": self.num_worst_tokens > 0,
|
||||
}
|
||||
fused_moe_out, _, event = buffer.combine(**combine_args)
|
||||
return fused_moe_out, event
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
@@ -92,6 +93,18 @@ class MoEMethodBase(QuantMethodBase):
|
||||
splitwise_role = config.scheduler_config.splitwise_role
|
||||
load_strategy = config.load_config.load_strategy
|
||||
|
||||
if config.parallel_config.ep_prefill_use_worst_num_tokens:
|
||||
token_split_factor = 2 if int(os.getenv("USE_TBO", "0")) == 1 else 1
|
||||
prefill_num_worst_tokens = (
|
||||
config.scheduler_config.max_num_batched_tokens
|
||||
// config.parallel_config.tensor_parallel_size
|
||||
* layer.ep_size
|
||||
* layer.top_k
|
||||
// token_split_factor
|
||||
)
|
||||
else:
|
||||
prefill_num_worst_tokens = 0
|
||||
|
||||
# For "mixed" splitwise role: conditionally initialize both or none
|
||||
if splitwise_role == "mixed":
|
||||
if load_strategy == "meta":
|
||||
@@ -102,6 +115,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
self.ep_prefill_runner = self.EPPrefillRunner(
|
||||
**common_args,
|
||||
use_internode_ll_two_stage=layer.fd_config.parallel_config.use_internode_ll_two_stage,
|
||||
prefill_num_worst_tokens=prefill_num_worst_tokens,
|
||||
)
|
||||
self.ep_decoder_runner = self.EPDecoderRunner(
|
||||
**common_args,
|
||||
@@ -119,6 +133,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
self.ep_prefill_runner = self.EPPrefillRunner(
|
||||
**common_args,
|
||||
use_internode_ll_two_stage=layer.fd_config.parallel_config.use_internode_ll_two_stage,
|
||||
prefill_num_worst_tokens=prefill_num_worst_tokens,
|
||||
)
|
||||
else:
|
||||
self.ep_decoder_runner = self.EPDecoderRunner(
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
@@ -24,7 +26,11 @@ import fastdeploy
|
||||
from fastdeploy.model_executor.layers.moe.ep import deep_ep
|
||||
from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
count_tokens_per_expert_func,
|
||||
depermute_prefill_combine,
|
||||
prefill_permute_to_masked_gemm,
|
||||
)
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import register_custom_python_op
|
||||
from fastdeploy.worker.tbo import let_another_thread_run
|
||||
@@ -43,6 +49,60 @@ else:
|
||||
m_grouped_fp8_gemm_nt_contiguous = None
|
||||
m_grouped_fp8_gemm_nt_masked = None
|
||||
|
||||
global_values = {}
|
||||
|
||||
|
||||
def call_prefill_permute_to_masked_gemm(
|
||||
x: paddle.Tensor,
|
||||
scale: paddle.Tensor,
|
||||
topk_ids: paddle.Tensor,
|
||||
num_local_experts: int,
|
||||
max_token_num: int,
|
||||
):
|
||||
"""
|
||||
Permute input tokens and scales from token-major to expert-major layout
|
||||
for MoE masked GEMM operations.
|
||||
|
||||
Args:
|
||||
x: Input hidden states [num_tokens, hidden].
|
||||
scale: Input scales [num_tokens, hidden_scale].
|
||||
topk_ids: Expert routing indices [num_tokens, topk] (int64 or int32).
|
||||
num_local_experts: Number of local experts on this device.
|
||||
max_token_num: Maximum tokens per expert buffer.
|
||||
|
||||
Returns:
|
||||
tuple: (permute_x, permute_scale, permuted_indice_map, token_nums_per_expert)
|
||||
"""
|
||||
if topk_ids.dtype != paddle.int64:
|
||||
topk_ids = topk_ids.cast(paddle.int64)
|
||||
|
||||
results = prefill_permute_to_masked_gemm(x, scale, topk_ids, num_local_experts, max_token_num)
|
||||
|
||||
return results[0], results[1], results[2], results[3]
|
||||
|
||||
|
||||
def call_depermute_prefill_combine(
|
||||
x: paddle.Tensor,
|
||||
indice_map: paddle.Tensor,
|
||||
topk_weights: paddle.Tensor,
|
||||
num_worst_tokens: int,
|
||||
):
|
||||
"""
|
||||
Depermute and combine expert outputs back to token-major layout.
|
||||
|
||||
Args:
|
||||
x: Expert outputs [num_local_experts, max_tokens_per_expert, hidden].
|
||||
indice_map: Flat index tensor [num_worst_tokens, topk] (int32).
|
||||
topk_weights: Combination weights [num_worst_tokens, topk] (float32).
|
||||
num_worst_tokens: Number of output tokens to produce.
|
||||
|
||||
Returns:
|
||||
depermuted_x: Combined output [num_worst_tokens, hidden].
|
||||
"""
|
||||
results = depermute_prefill_combine(x, indice_map, topk_weights, num_worst_tokens)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def m_grouped_fp8_gemm_nt_contiguous_custom_python_op_infermeta(
|
||||
permute_input: "paddle.static.MetaTensor",
|
||||
@@ -108,7 +168,7 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
|
||||
# down_proj
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
ffn_out, quant_config_weight_block_size_0
|
||||
ffn_out, quant_config_weight_block_size_0, not disable_ue8m0_cast
|
||||
)
|
||||
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous()
|
||||
@@ -277,7 +337,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
# 2. Dynamic compute blockwise quantization scales
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
x, self.quant_config.weight_block_size[0]
|
||||
x, self.quant_config.weight_block_size[0], self.quant_config.deepgemm_scale_ue8m0
|
||||
)
|
||||
else:
|
||||
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
@@ -293,7 +353,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
)
|
||||
|
||||
event = deep_ep.Buffer.capture()
|
||||
let_another_thread_run()
|
||||
|
||||
if self.ep_prefill_runner.num_worst_tokens <= 0:
|
||||
let_another_thread_run()
|
||||
# 3. EP Dispatch
|
||||
(
|
||||
recv_x,
|
||||
@@ -306,9 +368,34 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128, previous_event=event
|
||||
)
|
||||
|
||||
if self.ep_prefill_runner.num_worst_tokens > 0:
|
||||
let_another_thread_run()
|
||||
|
||||
thread_name = threading.current_thread().name
|
||||
|
||||
if self.ep_prefill_runner.ep_engine.async_finish:
|
||||
event.current_stream_wait()
|
||||
|
||||
global global_values
|
||||
|
||||
if thread_name not in global_values:
|
||||
global_values[thread_name] = {}
|
||||
|
||||
(recv_x_value, recv_x_scale) = recv_x
|
||||
(recv_x_value, recv_x_scale) = recv_x
|
||||
|
||||
global_values[thread_name]["x"] = x
|
||||
global_values[thread_name]["topk_idx"] = topk_idx
|
||||
global_values[thread_name]["topk_weights"] = topk_weights
|
||||
global_values[thread_name]["x_scale_tensor"] = x_scale_tensor
|
||||
|
||||
global_values[thread_name]["recv_x_value"] = recv_x_value
|
||||
global_values[thread_name]["recv_x_scale"] = recv_x_scale
|
||||
global_values[thread_name]["recv_topk_idx"] = recv_topk_idx
|
||||
global_values[thread_name]["recv_topk_weights"] = recv_topk_weights
|
||||
global_values[thread_name]["handle"] = handle
|
||||
global_values[thread_name]["recv_num_tokens_per_expert_list"] = recv_num_tokens_per_expert_list
|
||||
|
||||
token_all_num = sum(recv_num_tokens_per_expert_list)
|
||||
|
||||
# Note(ZKK):
|
||||
@@ -317,9 +404,88 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
# so here we manually del a var as soon as it's not used.
|
||||
|
||||
# 4. Compute ffn
|
||||
if token_all_num > 0:
|
||||
if self.ep_prefill_runner.num_worst_tokens > 0:
|
||||
token_split_factor = 2 if int(os.getenv("USE_TBO", "0")) == 1 else 1
|
||||
max_tokens_per_rank = (
|
||||
layer.fd_config.scheduler_config.max_num_batched_tokens
|
||||
// layer.fd_config.parallel_config.tensor_parallel_size
|
||||
// token_split_factor
|
||||
)
|
||||
expected_m = max_tokens_per_rank
|
||||
|
||||
logger.debug(f"max_tokens_per_rank {max_tokens_per_rank}")
|
||||
|
||||
permute_input, permute_scale, permuted_indice_map, token_nums_per_expert = (
|
||||
call_prefill_permute_to_masked_gemm(
|
||||
x=recv_x_value,
|
||||
scale=recv_x_scale,
|
||||
topk_ids=recv_topk_idx,
|
||||
num_local_experts=layer.num_local_experts,
|
||||
max_token_num=layer.ep_size * max_tokens_per_rank,
|
||||
)
|
||||
)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[
|
||||
layer.num_local_experts,
|
||||
layer.ep_size * max_tokens_per_rank,
|
||||
layer.moe_intermediate_size * 2,
|
||||
],
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
|
||||
m_grouped_fp8_gemm_nt_masked(
|
||||
(permute_input, permute_scale),
|
||||
(
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
getattr(layer, self.added_scale_attrs[0]),
|
||||
),
|
||||
up_gate_proj_out,
|
||||
token_nums_per_expert,
|
||||
expected_m,
|
||||
disable_ue8m0_cast=not self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
|
||||
act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.fused_mask_swiglu_fp8_quant(
|
||||
up_gate_proj_out,
|
||||
token_nums_per_expert,
|
||||
self.quant_config.weight_block_size[0],
|
||||
use_ue8m0=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
|
||||
if layer.hidden_size == layer.moe_intermediate_size * 2:
|
||||
ffn_out = up_gate_proj_out
|
||||
else:
|
||||
ffn_out = paddle.empty(
|
||||
[
|
||||
layer.num_local_experts,
|
||||
layer.ep_size * max_tokens_per_rank,
|
||||
layer.hidden_size,
|
||||
],
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
|
||||
m_grouped_fp8_gemm_nt_masked(
|
||||
(act_out_fp8, scale),
|
||||
(
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
getattr(layer, self.added_scale_attrs[1]),
|
||||
),
|
||||
ffn_out,
|
||||
token_nums_per_expert,
|
||||
expected_m,
|
||||
disable_ue8m0_cast=not self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
|
||||
tmp_ffn_out = call_depermute_prefill_combine(
|
||||
x=ffn_out,
|
||||
indice_map=permuted_indice_map,
|
||||
topk_weights=recv_topk_weights,
|
||||
num_worst_tokens=recv_x_value.shape[0],
|
||||
)
|
||||
|
||||
elif token_all_num > 0:
|
||||
logger.debug(f"token_all_num {token_all_num}")
|
||||
(recv_x, recv_x_scale) = recv_x
|
||||
|
||||
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
|
||||
|
||||
@@ -327,14 +493,14 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
permute_input,
|
||||
permute_scale,
|
||||
permute_indices_per_token,
|
||||
recv_num_tokens_per_expert_list_cumsum,
|
||||
recv_num_tokens_per_expert_list_padded_cumsum,
|
||||
_,
|
||||
_,
|
||||
dst_weights,
|
||||
dst_indices,
|
||||
cumsum_idx_gpu,
|
||||
_,
|
||||
m_indices,
|
||||
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
|
||||
recv_x,
|
||||
recv_x_value,
|
||||
recv_x_scale,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
@@ -353,7 +519,6 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
(token_all_num, getattr(layer, self.added_weight_attrs[0]).shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
# disable_ue8m0_cast is False for SM100
|
||||
m_grouped_fp8_gemm_nt_contiguous(
|
||||
(permute_input, permute_scale),
|
||||
(getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_scale_attrs[0])),
|
||||
@@ -367,7 +532,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
# down_proj
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
ffn_out, self.quant_config.weight_block_size[0]
|
||||
ffn_out, self.quant_config.weight_block_size[0], self.quant_config.deepgemm_scale_ue8m0
|
||||
)
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous().transpose([1, 0])
|
||||
else:
|
||||
@@ -382,7 +547,6 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
(token_all_num, getattr(layer, self.added_weight_attrs[1]).shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
# disable_ue8m0_cast is False for SM100
|
||||
m_grouped_fp8_gemm_nt_contiguous(
|
||||
(ffn_in_x, ffn_in_x_scale_tensor),
|
||||
(getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_scale_attrs[1])),
|
||||
@@ -405,12 +569,20 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
# 5. EP combine
|
||||
event = deep_ep.Buffer.capture()
|
||||
let_another_thread_run()
|
||||
if self.ep_prefill_runner.num_worst_tokens <= 0:
|
||||
let_another_thread_run()
|
||||
|
||||
global_values[thread_name]["combine_in"] = tmp_ffn_out
|
||||
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights, event)
|
||||
|
||||
if self.ep_prefill_runner.num_worst_tokens > 0:
|
||||
let_another_thread_run()
|
||||
|
||||
if self.ep_prefill_runner.ep_engine.async_finish:
|
||||
event.current_stream_wait()
|
||||
|
||||
global_values[thread_name]["combine_out"] = tmp_ffn_out
|
||||
|
||||
return tmp_ffn_out
|
||||
|
||||
def apply_ep_decode(
|
||||
@@ -528,7 +700,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
|
||||
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128)
|
||||
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
x, 128, self.quant_config.deepgemm_scale_ue8m0
|
||||
)
|
||||
else:
|
||||
|
||||
recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
|
||||
@@ -1236,7 +1236,7 @@ def python_op_fused_moe_kernel_paddle(
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0])
|
||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False)
|
||||
else:
|
||||
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x, using_pow2_scale=False, output_scale_transpose=False
|
||||
@@ -1293,7 +1293,7 @@ def python_op_fused_moe_kernel_paddle(
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config["BLOCK_SIZE_N"]),)
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
intermediate_cache2, quant_config.weight_block_size[0]
|
||||
intermediate_cache2, quant_config.weight_block_size[0], False
|
||||
)
|
||||
else:
|
||||
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
|
||||
@@ -326,8 +326,9 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
return linear_out
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
|
||||
x, self.quant_config.weight_block_size[0]
|
||||
x, self.quant_config.weight_block_size[0], self.quant_config.deepgemm_scale_ue8m0
|
||||
)
|
||||
x_scale_tensor = x_scale_tensor[: x.shape[0], ...]
|
||||
else:
|
||||
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x,
|
||||
|
||||
@@ -1090,6 +1090,12 @@ def parse_args():
|
||||
help="Enable overlap schedule",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ep_prefill_use_worst_num_tokens",
|
||||
action="store_true",
|
||||
help="enable to avoid cpu sync",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import depermute_prefill_combine
|
||||
|
||||
|
||||
def call_depermute_prefill_combine(
|
||||
x: paddle.Tensor,
|
||||
indice_map: paddle.Tensor,
|
||||
topk_weights: paddle.Tensor,
|
||||
num_worst_tokens: int,
|
||||
):
|
||||
"""
|
||||
Depermute and combine expert outputs back to token-major layout.
|
||||
|
||||
Args:
|
||||
x: Expert outputs [num_local_experts, max_tokens_per_expert, hidden].
|
||||
indice_map: Flat index tensor [num_worst_tokens, topk] (int32).
|
||||
topk_weights: Combination weights [num_worst_tokens, topk] (float32).
|
||||
num_worst_tokens: Number of output tokens to produce.
|
||||
|
||||
Returns:
|
||||
depermuted_x: Combined output [num_worst_tokens, hidden].
|
||||
"""
|
||||
results = depermute_prefill_combine(x, indice_map, topk_weights, num_worst_tokens)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class TestDepermutePrefillCombine(unittest.TestCase):
|
||||
"""
|
||||
Test cases for depermute_prefill_combine kernel.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
paddle.seed(2024)
|
||||
np.random.seed(2024)
|
||||
paddle.set_device("gpu")
|
||||
|
||||
def _compute_reference(
|
||||
self,
|
||||
x_np,
|
||||
indice_map_np,
|
||||
topk_weights_np,
|
||||
num_worst_tokens,
|
||||
max_num_tokens_per_expert,
|
||||
):
|
||||
hidden = x_np.shape[2]
|
||||
topk = indice_map_np.shape[1]
|
||||
depermuted_x = np.zeros((num_worst_tokens, hidden), dtype=np.float32)
|
||||
|
||||
for token_idx in range(num_worst_tokens):
|
||||
for k in range(topk):
|
||||
indice = indice_map_np[token_idx, k]
|
||||
if indice >= 0:
|
||||
expert_idx = indice // max_num_tokens_per_expert
|
||||
offset = indice % max_num_tokens_per_expert
|
||||
weight = topk_weights_np[token_idx, k]
|
||||
depermuted_x[token_idx, :] += x_np[expert_idx, offset, :] * weight
|
||||
|
||||
return depermuted_x
|
||||
|
||||
def _run_and_verify(
|
||||
self,
|
||||
num_worst_tokens,
|
||||
num_local_experts,
|
||||
max_num_tokens_per_expert,
|
||||
hidden,
|
||||
topk,
|
||||
x_dtype=paddle.bfloat16,
|
||||
sparsity=0.2,
|
||||
rtol=1e-2,
|
||||
atol=1e-2,
|
||||
):
|
||||
x_np = np.random.randn(num_local_experts, max_num_tokens_per_expert, hidden).astype(np.float32)
|
||||
if x_dtype == paddle.bfloat16:
|
||||
x = paddle.to_tensor(x_np).cast(paddle.bfloat16)
|
||||
elif x_dtype == paddle.float8_e4m3fn:
|
||||
x_np = np.clip(x_np, -448, 448)
|
||||
x = paddle.to_tensor(x_np).cast(paddle.float8_e4m3fn)
|
||||
else:
|
||||
x = paddle.to_tensor(x_np)
|
||||
|
||||
indice_map_np = np.zeros((num_worst_tokens, topk), dtype=np.int32)
|
||||
for token_idx in range(num_worst_tokens):
|
||||
used_positions = {}
|
||||
has_valid_index = False
|
||||
for k in range(topk):
|
||||
if k == topk - 1 and not has_valid_index:
|
||||
should_be_invalid = False
|
||||
else:
|
||||
should_be_invalid = np.random.rand() < sparsity
|
||||
|
||||
if should_be_invalid:
|
||||
indice_map_np[token_idx, k] = -1
|
||||
else:
|
||||
expert_idx = np.random.randint(0, num_local_experts)
|
||||
if expert_idx not in used_positions:
|
||||
used_positions[expert_idx] = []
|
||||
offset = np.random.randint(0, max_num_tokens_per_expert)
|
||||
attempts = 0
|
||||
while offset in used_positions.get(expert_idx, []) and attempts < 10:
|
||||
offset = np.random.randint(0, max_num_tokens_per_expert)
|
||||
attempts += 1
|
||||
used_positions[expert_idx].append(offset)
|
||||
indice_map_np[token_idx, k] = expert_idx * max_num_tokens_per_expert + offset
|
||||
has_valid_index = True
|
||||
|
||||
indice_map = paddle.to_tensor(indice_map_np).cast(paddle.int32)
|
||||
|
||||
topk_weights_np = np.random.rand(num_worst_tokens, topk).astype(np.float32)
|
||||
row_sums = topk_weights_np.sum(axis=1, keepdims=True)
|
||||
topk_weights_np = topk_weights_np / (row_sums + 1e-6)
|
||||
topk_weights_np[indice_map_np == -1] = 0.0
|
||||
topk_weights = paddle.to_tensor(topk_weights_np).cast(paddle.float32)
|
||||
|
||||
depermuted_x = call_depermute_prefill_combine(
|
||||
x=x,
|
||||
indice_map=indice_map,
|
||||
topk_weights=topk_weights,
|
||||
num_worst_tokens=num_worst_tokens,
|
||||
)
|
||||
|
||||
x_ref_np = x.cast(paddle.float32).numpy()
|
||||
expected = self._compute_reference(
|
||||
x_np=x_ref_np,
|
||||
indice_map_np=indice_map_np,
|
||||
topk_weights_np=topk_weights_np,
|
||||
num_worst_tokens=num_worst_tokens,
|
||||
max_num_tokens_per_expert=max_num_tokens_per_expert,
|
||||
)
|
||||
|
||||
result = depermuted_x.cast(paddle.float32).numpy()
|
||||
|
||||
self.assertEqual(result.shape, (num_worst_tokens, hidden))
|
||||
|
||||
np.testing.assert_allclose(result, expected, rtol=rtol, atol=atol, err_msg="Depermuted output mismatch")
|
||||
|
||||
return True
|
||||
|
||||
def test_basic_topk4(self):
|
||||
self._run_and_verify(
|
||||
num_worst_tokens=64,
|
||||
num_local_experts=8,
|
||||
max_num_tokens_per_expert=128,
|
||||
hidden=7168,
|
||||
topk=4,
|
||||
sparsity=0.2,
|
||||
)
|
||||
|
||||
def test_basic_topk8(self):
|
||||
self._run_and_verify(
|
||||
num_worst_tokens=64,
|
||||
num_local_experts=8,
|
||||
max_num_tokens_per_expert=128,
|
||||
hidden=7168,
|
||||
topk=8,
|
||||
sparsity=0.2,
|
||||
)
|
||||
|
||||
def test_small_tokens(self):
|
||||
self._run_and_verify(
|
||||
num_worst_tokens=4,
|
||||
num_local_experts=4,
|
||||
max_num_tokens_per_expert=32,
|
||||
hidden=1024,
|
||||
topk=4,
|
||||
sparsity=0.1,
|
||||
)
|
||||
|
||||
def test_large_tokens(self):
|
||||
self._run_and_verify(
|
||||
num_worst_tokens=512,
|
||||
num_local_experts=16,
|
||||
max_num_tokens_per_expert=256,
|
||||
hidden=4096,
|
||||
topk=4,
|
||||
sparsity=0.3,
|
||||
)
|
||||
|
||||
def test_high_sparsity(self):
|
||||
self._run_and_verify(
|
||||
num_worst_tokens=128,
|
||||
num_local_experts=8,
|
||||
max_num_tokens_per_expert=64,
|
||||
hidden=2048,
|
||||
topk=4,
|
||||
sparsity=0.7,
|
||||
)
|
||||
|
||||
def test_no_sparsity(self):
|
||||
self._run_and_verify(
|
||||
num_worst_tokens=64,
|
||||
num_local_experts=8,
|
||||
max_num_tokens_per_expert=128,
|
||||
hidden=2048,
|
||||
topk=4,
|
||||
sparsity=0.0,
|
||||
)
|
||||
|
||||
def test_single_expert(self):
|
||||
self._run_and_verify(
|
||||
num_worst_tokens=32,
|
||||
num_local_experts=1,
|
||||
max_num_tokens_per_expert=64,
|
||||
hidden=1024,
|
||||
topk=4,
|
||||
sparsity=0.0,
|
||||
)
|
||||
|
||||
def test_many_experts(self):
|
||||
self._run_and_verify(
|
||||
num_worst_tokens=128,
|
||||
num_local_experts=32,
|
||||
max_num_tokens_per_expert=64,
|
||||
hidden=2048,
|
||||
topk=8,
|
||||
sparsity=0.3,
|
||||
)
|
||||
|
||||
def test_small_hidden(self):
|
||||
self._run_and_verify(
|
||||
num_worst_tokens=64,
|
||||
num_local_experts=8,
|
||||
max_num_tokens_per_expert=64,
|
||||
hidden=256,
|
||||
topk=4,
|
||||
sparsity=0.2,
|
||||
)
|
||||
|
||||
def test_large_hidden(self):
|
||||
self._run_and_verify(
|
||||
num_worst_tokens=32,
|
||||
num_local_experts=8,
|
||||
max_num_tokens_per_expert=64,
|
||||
hidden=14336,
|
||||
topk=4,
|
||||
sparsity=0.2,
|
||||
)
|
||||
|
||||
def test_all_minus_one(self):
|
||||
num_worst_tokens = 32
|
||||
num_local_experts = 4
|
||||
max_num_tokens_per_expert = 64
|
||||
hidden = 1024
|
||||
topk = 4
|
||||
|
||||
x_np = np.random.randn(num_local_experts, max_num_tokens_per_expert, hidden).astype(np.float32)
|
||||
x = paddle.to_tensor(x_np).cast(paddle.bfloat16)
|
||||
|
||||
indice_map = paddle.full([num_worst_tokens, topk], -1, dtype=paddle.int32)
|
||||
topk_weights = paddle.zeros([num_worst_tokens, topk], dtype=paddle.float32)
|
||||
|
||||
depermuted_x = call_depermute_prefill_combine(
|
||||
x=x,
|
||||
indice_map=indice_map,
|
||||
topk_weights=topk_weights,
|
||||
num_worst_tokens=num_worst_tokens,
|
||||
)
|
||||
|
||||
result = depermuted_x.cast(paddle.float32).numpy()
|
||||
self.assertEqual(result.shape, (num_worst_tokens, hidden))
|
||||
|
||||
def test_single_token(self):
|
||||
self._run_and_verify(
|
||||
num_worst_tokens=1,
|
||||
num_local_experts=4,
|
||||
max_num_tokens_per_expert=32,
|
||||
hidden=1024,
|
||||
topk=4,
|
||||
sparsity=0.0,
|
||||
)
|
||||
|
||||
def test_uniform_weights(self):
|
||||
num_worst_tokens = 64
|
||||
num_local_experts = 8
|
||||
max_num_tokens_per_expert = 64
|
||||
hidden = 2048
|
||||
topk = 4
|
||||
|
||||
x_np = np.random.randn(num_local_experts, max_num_tokens_per_expert, hidden).astype(np.float32)
|
||||
x = paddle.to_tensor(x_np).cast(paddle.bfloat16)
|
||||
|
||||
indice_map_np = np.zeros((num_worst_tokens, topk), dtype=np.int32)
|
||||
for token_idx in range(num_worst_tokens):
|
||||
for k in range(topk):
|
||||
expert_idx = k % num_local_experts
|
||||
offset = token_idx % max_num_tokens_per_expert
|
||||
indice_map_np[token_idx, k] = expert_idx * max_num_tokens_per_expert + offset
|
||||
indice_map = paddle.to_tensor(indice_map_np).cast(paddle.int32)
|
||||
|
||||
topk_weights_np = np.ones((num_worst_tokens, topk), dtype=np.float32) / topk
|
||||
topk_weights = paddle.to_tensor(topk_weights_np)
|
||||
|
||||
depermuted_x = call_depermute_prefill_combine(
|
||||
x=x,
|
||||
indice_map=indice_map,
|
||||
topk_weights=topk_weights,
|
||||
num_worst_tokens=num_worst_tokens,
|
||||
)
|
||||
|
||||
x_ref_np = x.cast(paddle.float32).numpy()
|
||||
expected = self._compute_reference(
|
||||
x_np=x_ref_np,
|
||||
indice_map_np=indice_map_np,
|
||||
topk_weights_np=topk_weights_np,
|
||||
num_worst_tokens=num_worst_tokens,
|
||||
max_num_tokens_per_expert=max_num_tokens_per_expert,
|
||||
)
|
||||
|
||||
result = depermuted_x.cast(paddle.float32).numpy()
|
||||
np.testing.assert_allclose(result, expected, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -89,7 +89,7 @@ class TestPerTokenQuant(unittest.TestCase):
|
||||
|
||||
def test_per_token_quant(self):
|
||||
paddle_output, paddle_output_scale = per_token_quant_paddle(self.input_tensor, self.block_size)
|
||||
output, output_scale = per_token_quant(self.input_tensor, self.block_size)
|
||||
output, output_scale = per_token_quant(self.input_tensor, self.block_size, False)
|
||||
|
||||
np.testing.assert_allclose(paddle_output_scale.numpy(), output_scale.numpy(), rtol=1e-6)
|
||||
|
||||
@@ -139,7 +139,7 @@ class TestPerTokenQuantPadding(TestPerTokenQuant):
|
||||
paddle_output, paddle_output_scale = per_token_quant_padding_paddle(
|
||||
self.input_tensor, self.block_size, self.dtype
|
||||
)
|
||||
output, output_scale = per_token_quant_padding(self.input_tensor, self.block_size)
|
||||
output, output_scale = per_token_quant_padding(self.input_tensor, self.block_size, False)
|
||||
|
||||
self.assertEqual(paddle_output_scale.shape, output_scale.shape)
|
||||
np.testing.assert_allclose(
|
||||
|
||||
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import prefill_permute_to_masked_gemm
|
||||
|
||||
|
||||
def call_prefill_permute_to_masked_gemm(
|
||||
x: paddle.Tensor,
|
||||
scale: paddle.Tensor,
|
||||
topk_ids: paddle.Tensor,
|
||||
num_local_experts: int,
|
||||
max_token_num: int,
|
||||
):
|
||||
"""
|
||||
Permute input tokens and scales from token-major to expert-major layout
|
||||
for MoE masked GEMM operations.
|
||||
|
||||
Args:
|
||||
x: Input hidden states [num_tokens, hidden].
|
||||
scale: Input scales [num_tokens, hidden_scale].
|
||||
topk_ids: Expert routing indices [num_tokens, topk] (int64 or int32).
|
||||
num_local_experts: Number of local experts on this device.
|
||||
max_token_num: Maximum tokens per expert buffer.
|
||||
|
||||
Returns:
|
||||
tuple: (permute_x, permute_scale, permuted_indice_map, token_nums_per_expert)
|
||||
"""
|
||||
if topk_ids.dtype != paddle.int64:
|
||||
topk_ids = topk_ids.cast(paddle.int64)
|
||||
|
||||
results = prefill_permute_to_masked_gemm(x, scale, topk_ids, num_local_experts, max_token_num)
|
||||
|
||||
return results[0], results[1], results[2], results[3]
|
||||
|
||||
|
||||
class TestPrefillPermuteToMaskedGemm(unittest.TestCase):
|
||||
"""
|
||||
Test cases for prefill_permute_to_masked_gemm kernel.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
paddle.seed(2024)
|
||||
np.random.seed(2024)
|
||||
|
||||
def _get_expected_tokens_per_expert(self, x, scale, topk_ids, num_local_experts):
|
||||
num_tokens = x.shape[0]
|
||||
_, topk = topk_ids.shape
|
||||
|
||||
expert_to_tokens = {i: [] for i in range(num_local_experts)}
|
||||
token_nums_per_expert = np.zeros(num_local_experts, dtype=np.int32)
|
||||
|
||||
for token_idx in range(num_tokens):
|
||||
for k in range(topk):
|
||||
expert_idx = topk_ids[token_idx, k]
|
||||
if expert_idx != -1:
|
||||
expert_to_tokens[expert_idx].append((x[token_idx, :].copy(), scale[token_idx, :].copy()))
|
||||
token_nums_per_expert[expert_idx] += 1
|
||||
|
||||
return expert_to_tokens, token_nums_per_expert
|
||||
|
||||
def _run_and_verify(
|
||||
self,
|
||||
num_tokens,
|
||||
hidden_size,
|
||||
hidden_scale,
|
||||
num_local_experts,
|
||||
max_token_num,
|
||||
topk,
|
||||
x_dtype=paddle.float8_e4m3fn,
|
||||
scale_dtype=paddle.int32,
|
||||
sparsity=0.3,
|
||||
):
|
||||
|
||||
if x_dtype == paddle.float8_e4m3fn:
|
||||
x_np = np.random.randn(num_tokens, hidden_size).astype(np.float32)
|
||||
x_np = np.clip(x_np, -448, 448)
|
||||
x = paddle.to_tensor(x_np).cast(paddle.float8_e4m3fn)
|
||||
elif x_dtype == paddle.bfloat16:
|
||||
x_np = np.random.randn(num_tokens, hidden_size).astype(np.float32)
|
||||
x = paddle.to_tensor(x_np).cast(paddle.bfloat16)
|
||||
else:
|
||||
x_np = np.random.randn(num_tokens, hidden_size).astype(np.float32)
|
||||
x = paddle.to_tensor(x_np)
|
||||
|
||||
scale_np = np.random.rand(num_tokens, hidden_scale).astype(np.float32)
|
||||
scale = paddle.to_tensor(scale_np).cast(scale_dtype).contiguous()
|
||||
|
||||
topk_ids_np = np.zeros((num_tokens, topk), dtype=np.int64)
|
||||
for i in range(num_tokens):
|
||||
experts = np.random.choice(num_local_experts, size=min(topk, num_local_experts), replace=False)
|
||||
if len(experts) < topk:
|
||||
topk_ids_np[i, : len(experts)] = experts
|
||||
topk_ids_np[i, len(experts) :] = -1
|
||||
else:
|
||||
topk_ids_np[i, :] = experts
|
||||
mask = np.random.rand(num_tokens, topk) < sparsity
|
||||
topk_ids_np[mask] = -1
|
||||
topk_ids = paddle.to_tensor(topk_ids_np).cast(paddle.int64)
|
||||
|
||||
permute_x, permute_scale, permuted_indice_map, token_nums_per_expert = call_prefill_permute_to_masked_gemm(
|
||||
x=x,
|
||||
scale=scale,
|
||||
topk_ids=topk_ids,
|
||||
num_local_experts=num_local_experts,
|
||||
max_token_num=max_token_num,
|
||||
)
|
||||
|
||||
permute_x_result = permute_x.cast(paddle.float32).numpy()
|
||||
permute_scale_result = permute_scale.numpy()
|
||||
permuted_indice_map_result = permuted_indice_map.numpy()
|
||||
token_nums_result = token_nums_per_expert.numpy().flatten()
|
||||
|
||||
x_ref_np = x.cast(paddle.float32).numpy()
|
||||
scale_ref_np = scale.numpy()
|
||||
|
||||
expert_to_tokens, token_nums_ref = self._get_expected_tokens_per_expert(
|
||||
x=x_ref_np,
|
||||
scale=scale_ref_np,
|
||||
topk_ids=topk_ids_np,
|
||||
num_local_experts=num_local_experts,
|
||||
)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
token_nums_result,
|
||||
token_nums_ref,
|
||||
err_msg=f"Token counts mismatch: kernel={token_nums_result}, ref={token_nums_ref}",
|
||||
)
|
||||
|
||||
for expert_idx in range(num_local_experts):
|
||||
num_tokens_for_expert = token_nums_ref[expert_idx]
|
||||
if num_tokens_for_expert > 0:
|
||||
kernel_x_rows = permute_x_result[expert_idx, :num_tokens_for_expert, :]
|
||||
kernel_scale_rows = permute_scale_result[expert_idx, :num_tokens_for_expert, :]
|
||||
|
||||
expected_tokens = expert_to_tokens[expert_idx]
|
||||
expected_x_rows = np.array([t[0] for t in expected_tokens])
|
||||
expected_scale_rows = np.array([t[1] for t in expected_tokens])
|
||||
|
||||
kernel_x_sums = np.sort(np.sum(kernel_x_rows, axis=1))
|
||||
expected_x_sums = np.sort(np.sum(expected_x_rows, axis=1))
|
||||
|
||||
np.testing.assert_allclose(
|
||||
kernel_x_sums,
|
||||
expected_x_sums,
|
||||
rtol=1e-2,
|
||||
atol=1e-1,
|
||||
err_msg=f"Expert {expert_idx}: permute_x row sums mismatch",
|
||||
)
|
||||
|
||||
kernel_scale_sums = np.sort(np.sum(kernel_scale_rows, axis=1))
|
||||
expected_scale_sums = np.sort(np.sum(expected_scale_rows, axis=1))
|
||||
|
||||
np.testing.assert_allclose(
|
||||
kernel_scale_sums,
|
||||
expected_scale_sums,
|
||||
rtol=1e-2,
|
||||
atol=1e-2,
|
||||
err_msg=f"Expert {expert_idx}: permute_scale row sums mismatch",
|
||||
)
|
||||
|
||||
x_ref_np = x.cast(paddle.float32).numpy()
|
||||
for token_idx in range(num_tokens):
|
||||
for expert_slot in range(topk):
|
||||
permuted_idx = permuted_indice_map_result[token_idx, expert_slot]
|
||||
if permuted_idx >= 0:
|
||||
expert_idx = permuted_idx // max_token_num
|
||||
offset = permuted_idx % max_token_num
|
||||
permuted_data = permute_x_result[expert_idx, offset, :]
|
||||
original_data = x_ref_np[token_idx, :]
|
||||
np.testing.assert_allclose(
|
||||
permuted_data,
|
||||
original_data,
|
||||
rtol=1e-2,
|
||||
atol=1e-1,
|
||||
err_msg=f"Token {token_idx}: permuted_indice_map points to wrong data",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def test_basic_topk4(self):
|
||||
self._run_and_verify(
|
||||
num_tokens=64,
|
||||
hidden_size=7168,
|
||||
hidden_scale=56,
|
||||
num_local_experts=8,
|
||||
max_token_num=128,
|
||||
topk=4,
|
||||
sparsity=0.2,
|
||||
)
|
||||
|
||||
def test_basic_topk8(self):
|
||||
self._run_and_verify(
|
||||
num_tokens=64,
|
||||
hidden_size=7168,
|
||||
hidden_scale=56,
|
||||
num_local_experts=8,
|
||||
max_token_num=128,
|
||||
topk=8,
|
||||
sparsity=0.2,
|
||||
)
|
||||
|
||||
def test_small_tokens(self):
|
||||
self._run_and_verify(
|
||||
num_tokens=4, hidden_size=1024, hidden_scale=8, num_local_experts=4, max_token_num=32, topk=4, sparsity=0.1
|
||||
)
|
||||
|
||||
def test_large_tokens(self):
|
||||
self._run_and_verify(
|
||||
num_tokens=512,
|
||||
hidden_size=4096,
|
||||
hidden_scale=32,
|
||||
num_local_experts=16,
|
||||
max_token_num=256,
|
||||
topk=4,
|
||||
sparsity=0.3,
|
||||
)
|
||||
|
||||
def test_high_sparsity(self):
|
||||
self._run_and_verify(
|
||||
num_tokens=128,
|
||||
hidden_size=2048,
|
||||
hidden_scale=16,
|
||||
num_local_experts=8,
|
||||
max_token_num=64,
|
||||
topk=4,
|
||||
sparsity=0.7,
|
||||
)
|
||||
|
||||
def test_no_sparsity(self):
|
||||
self._run_and_verify(
|
||||
num_tokens=64,
|
||||
hidden_size=2048,
|
||||
hidden_scale=16,
|
||||
num_local_experts=8,
|
||||
max_token_num=128,
|
||||
topk=4,
|
||||
sparsity=0.0,
|
||||
)
|
||||
|
||||
def test_single_expert(self):
|
||||
self._run_and_verify(
|
||||
num_tokens=32,
|
||||
hidden_size=1024,
|
||||
hidden_scale=8,
|
||||
num_local_experts=1,
|
||||
max_token_num=64,
|
||||
topk=4,
|
||||
sparsity=0.0,
|
||||
)
|
||||
|
||||
def test_many_experts(self):
|
||||
self._run_and_verify(
|
||||
num_tokens=128,
|
||||
hidden_size=2048,
|
||||
hidden_scale=16,
|
||||
num_local_experts=32,
|
||||
max_token_num=64,
|
||||
topk=8,
|
||||
sparsity=0.3,
|
||||
)
|
||||
|
||||
def test_bfloat16_input(self):
|
||||
self._run_and_verify(
|
||||
num_tokens=64,
|
||||
hidden_size=2048,
|
||||
hidden_scale=16,
|
||||
num_local_experts=8,
|
||||
max_token_num=128,
|
||||
topk=4,
|
||||
x_dtype=paddle.bfloat16,
|
||||
sparsity=0.2,
|
||||
)
|
||||
|
||||
def test_very_large_tokens(self):
|
||||
self._run_and_verify(
|
||||
num_tokens=65536,
|
||||
hidden_size=7168,
|
||||
hidden_scale=56,
|
||||
num_local_experts=20,
|
||||
max_token_num=16384,
|
||||
topk=4,
|
||||
sparsity=0.3,
|
||||
)
|
||||
|
||||
def test_very_large_tokens_with_fp32_scale(self):
|
||||
self._run_and_verify(
|
||||
num_tokens=65536,
|
||||
hidden_size=7168,
|
||||
hidden_scale=56,
|
||||
num_local_experts=20,
|
||||
max_token_num=16384,
|
||||
topk=4,
|
||||
sparsity=0.3,
|
||||
scale_dtype=paddle.float32,
|
||||
)
|
||||
|
||||
def test_all_minus_one(self):
|
||||
num_tokens = 32
|
||||
hidden_size = 1024
|
||||
hidden_scale = 8
|
||||
num_local_experts = 4
|
||||
max_token_num = 64
|
||||
topk = 4
|
||||
|
||||
x_np = np.random.randn(num_tokens, hidden_size).astype(np.float32)
|
||||
x_np = np.clip(x_np, -448, 448)
|
||||
x = paddle.to_tensor(x_np).cast(paddle.float8_e4m3fn)
|
||||
|
||||
scale_np = np.random.rand(num_tokens, hidden_scale).astype(np.float32)
|
||||
scale = paddle.to_tensor(scale_np)
|
||||
|
||||
topk_ids = paddle.full([num_tokens, topk], -1, dtype=paddle.int32)
|
||||
|
||||
permute_x, permute_scale, permuted_indice_map, token_nums_per_expert = call_prefill_permute_to_masked_gemm(
|
||||
x=x,
|
||||
scale=scale,
|
||||
topk_ids=topk_ids,
|
||||
num_local_experts=num_local_experts,
|
||||
max_token_num=max_token_num,
|
||||
)
|
||||
|
||||
token_nums_result = token_nums_per_expert.numpy().flatten()
|
||||
expected = np.zeros(num_local_experts, dtype=np.int32)
|
||||
np.testing.assert_array_equal(token_nums_result, expected)
|
||||
self.assertEqual(permuted_indice_map.shape, [num_tokens, topk])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user