[Optimize] optimize mask_quant & swiglu (#6222)

* optimize mask_quant op speed up 1.5

* fix calculate sequence

* add fused

* rm log

* push kernel code

* add ut

* accuracy ok

* add ue8m0

* add ut

* add merge develop

* rm ut of mask_per_token_quant
This commit is contained in:
fxyfxy777
2026-02-02 13:52:38 +08:00
committed by GitHub
parent 25656455ee
commit 2ada119a38
7 changed files with 555 additions and 452 deletions
+10 -7
View File
@@ -303,10 +303,12 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
const int block_size); const int block_size);
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input, std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
const int block_size); const int block_size);
std::vector<paddle::Tensor> MaskedPerTokenQuant(
std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
paddle::Tensor& input, paddle::Tensor& input,
paddle::Tensor& recv_expert_count, paddle::Tensor& token_nums_per_expert,
const int block_size); const int block_size,
const bool use_ue8m0);
std::vector<paddle::Tensor> EPMoeExpertCombine( std::vector<paddle::Tensor> EPMoeExpertCombine(
const paddle::Tensor& ffn_out, const paddle::Tensor& ffn_out,
@@ -1267,12 +1269,13 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("routed_scaling_factor"), py::arg("routed_scaling_factor"),
"ep moe export combine function"); "ep moe export combine function");
m.def("masked_per_token_quant", m.def("fused_mask_swiglu_fp8_quant",
&MaskedPerTokenQuant, &FusedMaskSwigluFP8Quant,
py::arg("input"), py::arg("input"),
py::arg("recv_expert_count"), py::arg("token_nums_per_expert"),
py::arg("block_size"), py::arg("block_size"),
"per token per block quant"); py::arg("use_ue8m0") = false,
"fused mask swiglu and fp8 quant");
#ifdef ENABLE_MACHETE #ifdef ENABLE_MACHETE
/*machete/machete_mm.cu /*machete/machete_mm.cu
@@ -0,0 +1,248 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
constexpr float kEpsilon = 1e-10;
constexpr float kFP8Max = 448.f;
__host__ __device__ __forceinline__ int ceil_div(int x, int y) {
return (x + y - 1) / y;
}
__host__ __device__ __forceinline__ int align(int x, int y) {
return ceil_div(x, y) * y;
}
#ifndef BOOL_SWITCH
#define BOOL_SWITCH(cond, name, ...) \
if (cond) { \
constexpr bool name = true; \
__VA_ARGS__(); \
} else { \
constexpr bool name = false; \
__VA_ARGS__(); \
}
#endif
template <typename T, typename index_t, typename ScaleT, bool UseUE8M0>
__global__ void fused_swiglu_fp8_quant_kernel(
const T* __restrict__ input, // [group, max_tokens, hidden*2]
const index_t* __restrict__ token_nums_per_expert,
phi::dtype::float8_e4m3fn* __restrict__ out_fp8,
ScaleT* __restrict__ out_scale,
int group_num,
int group_size,
int hidden_size,
int hidden_size_scale,
bool use_finegrained_range) {
constexpr int BLOCK = 128;
int tid = threadIdx.x;
int lane = tid & 31;
int warp = tid >> 5;
int num_warps = blockDim.x >> 5;
int block_id = static_cast<int64_t>(blockIdx.x);
using VecBF16 = AlignedVector<T, 4>;
VecBF16 x1_vec, x2_vec;
using VecFP8 = AlignedVector<phi::dtype::float8_e4m3fn, 4>;
VecFP8 q_vec;
while (true) {
// ================= token mapping =================
int expert = -1;
int token_in_expert = -1;
if (lane == 0) {
int cumsum = 0;
for (int i = 0; i < group_num; ++i) {
int cnt = token_nums_per_expert[i];
if (block_id >= cumsum && block_id < cumsum + cnt) {
expert = i;
token_in_expert = block_id - cumsum;
break;
}
cumsum += cnt;
}
}
expert = __shfl_sync(0xffffffff, expert, 0);
token_in_expert = __shfl_sync(0xffffffff, token_in_expert, 0);
if (expert < 0 || token_in_expert >= group_size) break;
// ================= base pointers =================
int token = expert * group_size + token_in_expert;
const T* in = input + token * hidden_size * 2;
auto* out = out_fp8 + token * hidden_size;
int num_iters = hidden_size / BLOCK;
// ================= main loop =================
for (int iter = warp; iter < num_iters; iter += num_warps) {
int base = iter * BLOCK + lane * 4;
// vec load
Load(in + base, &x1_vec);
Load(in + base + hidden_size, &x2_vec);
float v[4];
float amax = -5e4;
#pragma unroll
for (int i = 0; i < 4; ++i) {
float x1 = static_cast<float>(x1_vec[i]);
float x2 = static_cast<float>(x2_vec[i]);
float y = x2 * x1 / (1.f + expf(-x1));
float y_r = static_cast<float>(
static_cast<T>(y)); // To simulate the data transformation before
// the fusion of swiglu and quant operators
v[i] = y_r;
amax = max(amax, abs(y_r));
}
// ---------- warp reduce amax ----------
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
amax = max(amax, __shfl_down_sync(0xffffffff, amax, offset));
amax = __shfl_sync(0xffffffff, amax, 0);
amax = max(amax, kEpsilon);
if (use_finegrained_range) amax *= 7.f;
float scale = amax / kFP8Max;
// ---------- quantize ----------
if constexpr (UseUE8M0) {
scale = exp2f(ceilf(log2f(fmaxf(scale, kEpsilon))));
#pragma unroll
for (int i = 0; i < 4; ++i) {
float q = v[i] / scale;
q_vec[i] = static_cast<phi::dtype::float8_e4m3fn>(q);
}
// ---------- store scale ----------
if (lane == 0) {
// 1. extract exponent
const int exp = (__float_as_int(scale) >> 23) & 0xFF;
// 2. pack information
const int pack_idx = iter >> 2; // iter / 4
const int byte_idx = iter & 3; // iter % 4
// 3. layout parameters
const int pack_num = ceil_div(hidden_size_scale, 4);
const int token_stride = align(group_size, 4);
// 4. base pointer (int32 pack)
auto* scale_pack = reinterpret_cast<int32_t*>(out_scale);
// 5. column-major offset:
// [expert][pack][token]
const int base_idx = expert * pack_num * token_stride +
pack_idx * token_stride + token_in_expert;
// 6. write one byte into pack
reinterpret_cast<uint8_t*>(&scale_pack[base_idx])[byte_idx] =
static_cast<uint8_t>(exp);
}
} else {
#pragma unroll
for (int i = 0; i < 4; i++) {
float q = v[i] * kFP8Max / amax;
q_vec[i] = static_cast<phi::dtype::float8_e4m3fn>(q);
}
// ---------- store scale ----------
if (lane == 0) {
out_scale[expert * hidden_size_scale * group_size +
iter * group_size + token_in_expert] = scale;
}
}
Store(q_vec, out + base);
}
block_id += gridDim.x;
}
}
std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
paddle::Tensor& input,
paddle::Tensor& token_nums_per_expert,
const int block_size,
const bool use_ue8m0) {
auto dim = input.dims();
const int group_num = token_nums_per_expert.shape()[0];
const int group_size = dim[1];
const int hidden_size = dim[2] / 2;
const int hidden_size_scale = hidden_size / block_size;
const int token_num = group_num * group_size;
auto out_fp8 = GetEmptyTensor({group_num, group_size, hidden_size},
paddle::DataType::FLOAT8_E4M3FN,
input.place());
auto out_scale =
GetEmptyTensor({group_num, group_size, hidden_size_scale},
{hidden_size_scale * group_size, 1, group_size},
paddle::DataType::FLOAT32,
input.place());
if (use_ue8m0) {
int hidden_size_scale_pack = ceil_div(hidden_size_scale, 4);
out_scale = GetEmptyTensor({group_num, group_size, hidden_size_scale_pack},
{hidden_size_scale_pack * align(group_size, 4),
1,
align(group_size, 4)},
paddle::DataType::INT32,
input.place());
}
int sm_count = 0;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, 0);
constexpr int BLOCKS_PER_SM = 2;
int gridx = std::min(sm_count * BLOCKS_PER_SM, token_num);
int blockx = std::min(1024, hidden_size / 128 * 32);
bool use_finegrained_range = false;
if (auto* env = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"))
use_finegrained_range = static_cast<bool>(std::stoi(env));
if (input.dtype() == paddle::DataType::BFLOAT16) {
BOOL_SWITCH(use_ue8m0, UseUE8M0, [&] {
using ScaleT = std::conditional_t<UseUE8M0, int, float>;
fused_swiglu_fp8_quant_kernel<paddle::bfloat16, int, ScaleT, UseUE8M0>
<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
token_nums_per_expert.data<int>(),
out_fp8.data<phi::dtype::float8_e4m3fn>(),
out_scale.data<ScaleT>(),
group_num,
group_size,
hidden_size,
hidden_size_scale,
use_finegrained_range);
});
} else {
PD_THROW("Only BF16 supported");
}
return {out_fp8, out_scale};
}
PD_BUILD_STATIC_OP(fused_mask_swiglu_fp8_quant)
.Inputs({"input", "token_nums_per_expert"})
.Outputs({"out_fp8", "output_scale"})
.Attrs({"block_size: int", "use_ue8m0: bool"})
.SetKernelFn(PD_KERNEL(FusedMaskSwigluFP8Quant));
-172
View File
@@ -1,172 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
constexpr float epsilon = 1e-10;
template <typename T>
__global__ void masked_quant_per_token_per_block(
const T *input,
const int *recv_expert_count,
phi::dtype::float8_e4m3fn *quanted_res,
float *quanted_scale,
const int token_num,
const int hidden_size,
const int hidden_size_scale,
const int num_max_tokens_per_expert,
const bool use_finegrained_range) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const int warp_id = tid / 32;
const int lane_id = tid % 32;
const int num_warp = blockDim.x / 32;
static constexpr int NUM_PER_THREADS = 128 / 32; // 4
static constexpr float MAX_VALUE = 448.f;
const int end_iter = hidden_size / 128; // warp_iter_num
AlignedVector<T, NUM_PER_THREADS> load_vec;
AlignedVector<float, NUM_PER_THREADS> load_vec_float;
AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec;
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
const auto expert_id = token_idx / num_max_tokens_per_expert;
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
auto num_iters_to_next_expert =
(next_expert_start_idx - token_idx - 1) / gridDim.x;
token_idx += num_iters_to_next_expert * gridDim.x;
continue;
}
const T *input_now = input + token_idx * hidden_size;
phi::dtype::float8_e4m3fn *quanted_res_now =
quanted_res + token_idx * hidden_size;
// deal a block per warp
for (int iter = warp_id; iter < end_iter; iter += num_warp) {
float *quanted_scale_now =
quanted_scale +
expert_id * hidden_size_scale * num_max_tokens_per_expert +
iter * num_max_tokens_per_expert + token_idx_in_expert;
const int start_offset = iter * 128;
Load<T, NUM_PER_THREADS>(
input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec);
// get max value per thread
float max_value_thread = -5e4;
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
load_vec_float[vid] = static_cast<float>(load_vec[vid]);
max_value_thread = max(abs(load_vec_float[vid]), max_value_thread);
}
// get max value per warp
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16),
max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8),
max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4),
max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2),
max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1),
max_value_thread);
// broadcast max_value
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
max_value_thread = max(max_value_thread, epsilon);
if (use_finegrained_range) {
max_value_thread *= 7.0f;
}
float scale_to_store = max_value_thread / MAX_VALUE;
// quant
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] * MAX_VALUE / max_value_thread);
}
// store
Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>(
res_vec, quanted_res_now + start_offset + lane_id * NUM_PER_THREADS);
if (lane_id == 0) {
*quanted_scale_now = scale_to_store;
}
}
}
}
std::vector<paddle::Tensor> MaskedPerTokenQuant(
paddle::Tensor &input,
paddle::Tensor &recv_expert_count,
const int block_size) {
auto input_dim = input.dims();
const int num_local_expert = input_dim[0];
const int num_max_tokens_per_expert = input_dim[1];
const int hidden_size = input_dim[2];
const int hidden_size_scale = hidden_size / block_size;
const int token_num = num_local_expert * num_max_tokens_per_expert;
auto quanted_x =
GetEmptyTensor({num_local_expert, num_max_tokens_per_expert, hidden_size},
paddle::DataType::FLOAT8_E4M3FN,
input.place());
auto quanted_scale = GetEmptyTensor(
{num_local_expert, num_max_tokens_per_expert, hidden_size_scale},
{hidden_size_scale * num_max_tokens_per_expert,
1,
num_max_tokens_per_expert},
paddle::DataType::FLOAT32,
input.place());
const int gridx = min(132 * 2, token_num);
const int blockx = min(1024, hidden_size / 128 * 32);
bool use_finegrained_range = false;
char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE");
if (env_var) {
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
}
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
masked_quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::bfloat16>(),
recv_expert_count.data<int>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
hidden_size,
hidden_size_scale,
num_max_tokens_per_expert,
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
masked_quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
input.data<paddle::float16>(),
recv_expert_count.data<int>(),
quanted_x.data<phi::dtype::float8_e4m3fn>(),
quanted_scale.data<float>(),
token_num,
hidden_size,
hidden_size_scale,
num_max_tokens_per_expert,
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
}
return {quanted_x, quanted_scale};
}
PD_BUILD_STATIC_OP(masked_per_token_quant)
.Inputs({"input", "recv_expert_count"})
.Outputs({"output", "output_scale"})
.Attrs({"block_size: int"})
.SetKernelFn(PD_KERNEL(MaskedPerTokenQuant));
+1 -1
View File
@@ -293,7 +293,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/step_system_cache.cu", "gpu_ops/step_system_cache.cu",
"gpu_ops/cpp_extensions.cc", "gpu_ops/cpp_extensions.cc",
"gpu_ops/share_external_data.cu", "gpu_ops/share_external_data.cu",
"gpu_ops/per_token_quant_fp8.cu", "gpu_ops/fused_mask_swiglu_fp8_quant_kernel.cu",
"gpu_ops/update_split_fuse_input.cu", "gpu_ops/update_split_fuse_input.cu",
"gpu_ops/text_image_index_out.cu", "gpu_ops/text_image_index_out.cu",
"gpu_ops/text_image_gather_scatter.cu", "gpu_ops/text_image_gather_scatter.cu",
@@ -413,12 +413,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
expected_m, expected_m,
) )
act_out = fastdeploy.model_executor.ops.gpu.group_swiglu_with_masked(up_gate_proj_out, token_nums_per_expert) act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.fused_mask_swiglu_fp8_quant(
up_gate_proj_out, token_nums_per_expert, use_ue8m0=False
act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.masked_per_token_quant(
act_out,
token_nums_per_expert,
self.quant_config.weight_block_size[0],
) )
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
@@ -0,0 +1,294 @@
import os
import unittest
import numpy as np
import paddle
paddle.seed(2026)
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def align(x: int, y: int) -> int:
return ceil_div(x, y) * y
def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Align x to TMA-required size.
Args:
x: size in elements
element_size: size of each element in bytes
Returns:
Aligned size in elements
"""
kNumTMAAlignmentBytes = 16
assert kNumTMAAlignmentBytes % element_size == 0
return align(x, kNumTMAAlignmentBytes // element_size)
def _get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(
x: paddle.Tensor,
):
assert x.dtype == paddle.float and x.dim() in (2, 3)
ue8m0_tensor = (x.view(paddle.int) >> 23).to(paddle.uint8)
mn, k = x.shape[-2], x.shape[-1]
remove_dim = False
if x.dim() == 2:
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
aligned_mn = get_tma_aligned_size(mn, 4)
aligned_k = align(k, 4)
padded = paddle.zeros((b, aligned_mn, aligned_k), device=x.device, dtype=paddle.uint8)
padded[:, :mn, :k] = ue8m0_tensor
padded = padded.view(-1).view(dtype=paddle.int).view(b, aligned_mn, aligned_k // 4)
transposed = paddle.zeros((b, aligned_k // 4, aligned_mn), device=x.device, dtype=paddle.int).mT
transposed[:, :, :] = padded
aligned_x = transposed[:, :mn, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
def transform_scale_ue8m0(sf, mn, weight_block_size=None):
get_mn_major_tma_aligned_packed_ue8m0_tensor = _get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl
if weight_block_size:
assert weight_block_size == [128, 128]
sf = sf.index_select(-2, paddle.arange(mn, device=sf.device) // 128)
sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
return sf
def ceil_to_ue8m0_paddle(x: paddle.Tensor):
"""
x > 0
return 2 ^ ceil(log2(x))
"""
# log2(x)
log2_x = paddle.log(x) / paddle.log(paddle.to_tensor(2.0, dtype=x.dtype))
# ceil
ceil_log2_x = paddle.ceil(log2_x)
# 2^k
return paddle.pow(paddle.to_tensor(2.0, dtype=x.dtype), ceil_log2_x)
def masked_per_token_quant_ref(input_tensor, recv_expert_count, block_size, use_ue8m0):
"""
Paddle API implementation of masked_per_token_quant
Args:
input_tensor: Input tensor with shape [num_local_expert, num_max_tokens_per_expert, hidden_size]
recv_expert_count: Expert token count tensor with shape [num_local_expert]
block_size: Quantization block size
Returns:
Tuple of (quantized_tensor, scale_tensor)
"""
MAX_VALUE = 448.0
epsilon = 1e-10
# Get dimensions
input_shape = input_tensor.shape
num_local_expert = input_shape[0]
num_max_tokens_per_expert = input_shape[1]
hidden_size = input_shape[2]
# CUDA kernel uses: hidden_size_scale = hidden_size / block_size (integer division)
# This assumes hidden_size is divisible by block_size
hidden_size_scale = hidden_size // block_size
# Check environment variable for fine-grained range
use_finegrained_range = False
env_var = os.getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE")
if env_var:
use_finegrained_range = bool(int(env_var))
# Create mask for valid tokens based on recv_expert_count
token_indices = paddle.arange(num_max_tokens_per_expert, dtype="int32").unsqueeze(
0
) # [1, num_max_tokens_per_expert]
expert_counts = recv_expert_count.unsqueeze(1) # [num_local_expert, 1]
valid_mask = token_indices < expert_counts # [num_local_expert, num_max_tokens_per_expert]
# Reshape input for block-wise processing
# [num_local_expert, num_max_tokens_per_expert, hidden_size_scale, block_size]
reshaped_input = paddle.reshape(
input_tensor, [num_local_expert, num_max_tokens_per_expert, hidden_size_scale, block_size]
).astype("float32")
# Calculate max absolute values per block
max_abs_val = paddle.max(
paddle.abs(reshaped_input), axis=-1, keepdim=True
) # [num_local_expert, num_max_tokens_per_expert, hidden_size_scale, 1]
max_abs_val = paddle.clip(max_abs_val, min=epsilon)
# Apply valid mask - set invalid tokens' max values to epsilon
valid_mask_expanded = valid_mask.unsqueeze(2).unsqueeze(3) # [num_local_expert, num_max_tokens_per_expert, 1, 1]
max_abs_val = paddle.where(valid_mask_expanded, max_abs_val, paddle.to_tensor(epsilon))
# Apply fine-grained range if enabled
if use_finegrained_range:
max_abs_val *= 7.0
# Calculate scale
scale = max_abs_val / MAX_VALUE
if use_ue8m0:
scale = ceil_to_ue8m0_paddle(scale)
# Quantize
quanted_value = reshaped_input / scale
# Convert to float8_e4m3fn and reshape back
quanted_x_reshaped = quanted_value.astype("float8_e4m3fn")
quanted_x = paddle.reshape(quanted_x_reshaped, [num_local_expert, num_max_tokens_per_expert, hidden_size])
# Apply valid mask to quantized output - convert to float32 first, then back to float8_e4m3fn
valid_mask_full = valid_mask.unsqueeze(2) # [num_local_expert, num_max_tokens_per_expert, 1]
quanted_x_float32 = quanted_x.astype("float32")
quanted_x_masked_float32 = paddle.where(valid_mask_full, quanted_x_float32, paddle.zeros_like(quanted_x_float32))
quanted_x = quanted_x_masked_float32.astype("float8_e4m3fn")
# Prepare scale output - squeeze the last dimension
quanted_scale = paddle.squeeze(scale, axis=-1) # [num_local_expert, num_max_tokens_per_expert, hidden_size_scale]
# Apply valid mask to scale
valid_mask_scale = valid_mask.unsqueeze(2) # [num_local_expert, num_max_tokens_per_expert, 1]
quanted_scale = paddle.where(valid_mask_scale, quanted_scale, paddle.zeros_like(quanted_scale))
if use_ue8m0:
quanted_scale = transform_scale_ue8m0(quanted_scale, mn=quanted_x.shape[-2])
return quanted_x, quanted_scale
def run_fused(x, token_nums, block_size, use_ue8m0=False):
import fastdeploy.model_executor.ops.gpu as ops
return ops.fused_mask_swiglu_fp8_quant(x, token_nums, block_size, use_ue8m0)
def run_separate(x, token_nums, block_size, use_ue8m0=False):
"""Run separate operations (FastDeploy non-fused kernels)"""
from fastdeploy.model_executor.ops.gpu import group_swiglu_with_masked
swiglu = group_swiglu_with_masked(x, token_nums)
q, scale = masked_per_token_quant_ref(swiglu, token_nums, block_size, use_ue8m0)
return q, scale
# ------------------------------------------------------------
# Test case
# ------------------------------------------------------------
def benchmark_cuda(fn, warmup=10, repeat=10):
"""
Benchmark a CUDA function using paddle.device.Event
fn: callable with no return dependency on CPU
"""
# warmup
for _ in range(warmup):
fn()
paddle.device.synchronize()
start = paddle.device.Event(enable_timing=True)
end = paddle.device.Event(enable_timing=True)
start.record()
for _ in range(repeat):
fn()
end.record()
end.synchronize()
elapsed_ms = start.elapsed_time(end) # ms
return elapsed_ms / repeat
class TestFusedSwigluFP8Quant(unittest.TestCase):
def setUp(self):
paddle.set_device("gpu")
# 10, 2048, 7168
self.group_num = 10
self.group_size = 2048
self.hidden_dim = 7168
self.block_size = 128
self.x = paddle.randn(
[self.group_num, self.group_size, self.hidden_dim * 2],
dtype="bfloat16",
)
self.token_nums = paddle.to_tensor([50, 51, 50, 50, 50, 50, 50, 49, 51, 51], dtype="int32")
def fused_vs_separate_exact_match(self, use_ue8m0=False):
"""
Test fused kernel vs separate operations - should be exact match
This compares FastDeploy's fused kernel vs FastDeploy's separate kernels
"""
# Run separate operations
q_ref, s_ref = run_separate(self.x, self.token_nums, self.block_size, use_ue8m0)
# Run fused kernel
q_fused, s_fused = run_fused(self.x, self.token_nums, self.block_size, use_ue8m0)
def run_sep():
run_separate(self.x, self.token_nums, self.block_size)
def run_fus():
run_fused(self.x, self.token_nums, self.block_size)
t_sep = benchmark_cuda(run_sep)
t_fus = benchmark_cuda(run_fus)
print("\n====== Fused vs Separate Benchmark ======")
print(f"Separate: {t_sep:.3f} ms")
print(f"Fused : {t_fus:.3f} ms")
print(f"Speedup : {t_sep / t_fus:.2f}x")
# ---------------- valid mask ----------------
arange = paddle.arange(self.group_size, dtype="int32")
valid = arange < self.token_nums.unsqueeze(1) # [G, S]
valid_flat = valid.reshape([-1])
# ---------------- FP8 output ----------------
q_ref_flat = q_ref.reshape([-1, q_ref.shape[-1]]).astype("float32")
q_fused_flat = q_fused.reshape([-1, q_fused.shape[-1]]).astype("float32")
# ---------------- scale ----------------
s_ref_flat = s_ref.reshape([-1, s_ref.shape[-1]])
s_fused_flat = s_fused.reshape([-1, s_fused.shape[-1]])
np.testing.assert_allclose(
s_ref_flat[valid_flat].numpy(),
s_fused_flat[valid_flat].numpy(),
rtol=1e-06,
err_msg="**scale mismatch**",
)
np.testing.assert_allclose(
q_ref_flat[valid_flat].numpy(),
q_fused_flat[valid_flat].numpy(),
equal_nan=True,
rtol=0.5,
err_msg="**quant_x mismatch**",
)
def test_fused(self):
self.fused_vs_separate_exact_match(use_ue8m0=True)
self.fused_vs_separate_exact_match(use_ue8m0=False)
if __name__ == "__main__":
unittest.main()
@@ -1,266 +0,0 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import masked_per_token_quant
def masked_per_token_quant_ref(input_tensor, recv_expert_count, block_size):
"""
Paddle API implementation of masked_per_token_quant
Args:
input_tensor: Input tensor with shape [num_local_expert, num_max_tokens_per_expert, hidden_size]
recv_expert_count: Expert token count tensor with shape [num_local_expert]
block_size: Quantization block size
Returns:
Tuple of (quantized_tensor, scale_tensor)
"""
MAX_VALUE = 448.0
epsilon = 1e-10
# Get dimensions
input_shape = input_tensor.shape
num_local_expert = input_shape[0]
num_max_tokens_per_expert = input_shape[1]
hidden_size = input_shape[2]
# CUDA kernel uses: hidden_size_scale = hidden_size / block_size (integer division)
# This assumes hidden_size is divisible by block_size
hidden_size_scale = hidden_size // block_size
# Check environment variable for fine-grained range
use_finegrained_range = False
env_var = os.getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE")
if env_var:
use_finegrained_range = bool(int(env_var))
# Create mask for valid tokens based on recv_expert_count
token_indices = paddle.arange(num_max_tokens_per_expert, dtype="int32").unsqueeze(
0
) # [1, num_max_tokens_per_expert]
expert_counts = recv_expert_count.unsqueeze(1) # [num_local_expert, 1]
valid_mask = token_indices < expert_counts # [num_local_expert, num_max_tokens_per_expert]
# Reshape input for block-wise processing
# [num_local_expert, num_max_tokens_per_expert, hidden_size_scale, block_size]
reshaped_input = paddle.reshape(
input_tensor, [num_local_expert, num_max_tokens_per_expert, hidden_size_scale, block_size]
).astype("float32")
# Calculate max absolute values per block
max_abs_val = paddle.max(
paddle.abs(reshaped_input), axis=-1, keepdim=True
) # [num_local_expert, num_max_tokens_per_expert, hidden_size_scale, 1]
max_abs_val = paddle.clip(max_abs_val, min=epsilon)
# Apply valid mask - set invalid tokens' max values to epsilon
valid_mask_expanded = valid_mask.unsqueeze(2).unsqueeze(3) # [num_local_expert, num_max_tokens_per_expert, 1, 1]
max_abs_val = paddle.where(valid_mask_expanded, max_abs_val, paddle.to_tensor(epsilon))
# Apply fine-grained range if enabled
if use_finegrained_range:
max_abs_val *= 7.0
# Calculate scale
scale = max_abs_val / MAX_VALUE
# Quantize
quanted_value = reshaped_input / scale
# Convert to float8_e4m3fn and reshape back
quanted_x_reshaped = quanted_value.astype("float8_e4m3fn")
quanted_x = paddle.reshape(quanted_x_reshaped, [num_local_expert, num_max_tokens_per_expert, hidden_size])
# Apply valid mask to quantized output - convert to float32 first, then back to float8_e4m3fn
valid_mask_full = valid_mask.unsqueeze(2) # [num_local_expert, num_max_tokens_per_expert, 1]
quanted_x_float32 = quanted_x.astype("float32")
quanted_x_masked_float32 = paddle.where(valid_mask_full, quanted_x_float32, paddle.zeros_like(quanted_x_float32))
quanted_x = quanted_x_masked_float32.astype("float8_e4m3fn")
# Prepare scale output - squeeze the last dimension
quanted_scale = paddle.squeeze(scale, axis=-1) # [num_local_expert, num_max_tokens_per_expert, hidden_size_scale]
# Apply valid mask to scale
valid_mask_scale = valid_mask.unsqueeze(2) # [num_local_expert, num_max_tokens_per_expert, 1]
quanted_scale = paddle.where(valid_mask_scale, quanted_scale, paddle.zeros_like(quanted_scale))
return quanted_x, quanted_scale
class TestMaskedPerTokenQuant(unittest.TestCase):
def setUp(self) -> None:
paddle.seed(2024)
self.num_local_expert = 2
self.num_max_tokens_per_expert = 4
self.hidden_size = 256
self.block_size = 128
self.dtype = paddle.bfloat16
self.input_tensor = paddle.randn(
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
)
self.recv_expert_count = paddle.to_tensor([3, 2], dtype="int32")
# Get reference results from paddle implementation
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
self.input_tensor, self.recv_expert_count, self.block_size
)
def _mask_invalid_tokens(self, quanted_x, quanted_scale, recv_expert_count):
"""Apply mask to zero out invalid tokens"""
token_indices = paddle.arange(self.num_max_tokens_per_expert, dtype="int32").unsqueeze(0)
expert_counts = recv_expert_count.unsqueeze(1)
valid_mask = token_indices < expert_counts
# Apply mask to quantized values - convert to float32 first
valid_mask_full = valid_mask.unsqueeze(2)
quanted_x_float32 = quanted_x.astype("float32")
quanted_x_masked_float32 = paddle.where(
valid_mask_full, quanted_x_float32, paddle.zeros_like(quanted_x_float32)
)
quanted_x_masked = quanted_x_masked_float32.astype("float8_e4m3fn")
# Apply mask to scale values
valid_mask_scale = valid_mask.unsqueeze(2)
quanted_scale_masked = paddle.where(valid_mask_scale, quanted_scale, paddle.zeros_like(quanted_scale))
return quanted_x_masked, quanted_scale_masked
def test_masked_per_token_quant_basic(self):
"""Test basic functionality against CUDA kernel"""
quanted_x_cuda, quanted_scale_cuda = masked_per_token_quant(
self.input_tensor, self.recv_expert_count, self.block_size
)
quanted_x_cuda_masked, quanted_scale_cuda_masked = self._mask_invalid_tokens(
quanted_x_cuda, quanted_scale_cuda, self.recv_expert_count
)
# Check output shapes
self.assertEqual(quanted_x_cuda.shape, self.quanted_x_ref.shape)
self.assertEqual(quanted_scale_cuda.shape, self.quanted_scale_ref.shape)
# Check dtypes
self.assertEqual(quanted_x_cuda.dtype, paddle.float8_e4m3fn)
self.assertEqual(quanted_scale_cuda.dtype, paddle.float32)
# Compare scale values (using masked versions)
np.testing.assert_allclose(
self.quanted_scale_ref.numpy(), quanted_scale_cuda_masked.numpy(), rtol=1e-5, atol=1e-6
)
# Compare quantized values (convert to float32 for comparison, using masked versions)
quant_diff = paddle.mean(
paddle.abs(quanted_x_cuda_masked.astype("float32") - self.quanted_x_ref.astype("float32"))
) / paddle.mean(paddle.abs(self.quanted_x_ref.astype("float32")) + 1e-9)
diff_val = float(quant_diff.numpy().item())
self.assertLess(diff_val, 0.01, msg="Quantized values should be close")
class TestMaskedPerTokenQuantCase1(TestMaskedPerTokenQuant):
"""Test with float16 input"""
def setUp(self) -> None:
paddle.seed(2024)
self.num_local_expert = 3
self.num_max_tokens_per_expert = 6
self.hidden_size = 512
self.block_size = 128
self.dtype = paddle.float16
self.input_tensor = paddle.randn(
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
)
self.recv_expert_count = paddle.to_tensor([4, 2, 5], dtype="int32")
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
self.input_tensor, self.recv_expert_count, self.block_size
)
class TestMaskedPerTokenQuantCase2(TestMaskedPerTokenQuant):
"""Test with different hidden size"""
def setUp(self) -> None:
paddle.seed(2024)
self.num_local_expert = 4
self.num_max_tokens_per_expert = 8
self.hidden_size = 384 # 3 * 128
self.block_size = 128
self.dtype = paddle.bfloat16
self.input_tensor = paddle.randn(
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
)
self.recv_expert_count = paddle.to_tensor([6, 3, 7, 1], dtype="int32")
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
self.input_tensor, self.recv_expert_count, self.block_size
)
class TestMaskedPerTokenQuantCase3(TestMaskedPerTokenQuant):
"""Test with all experts having max tokens"""
def setUp(self) -> None:
paddle.seed(2024)
self.num_local_expert = 2
self.num_max_tokens_per_expert = 4
self.hidden_size = 256
self.block_size = 128
self.dtype = paddle.bfloat16
self.input_tensor = paddle.randn(
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
)
# All experts use all tokens
self.recv_expert_count = paddle.to_tensor([4, 4], dtype="int32")
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
self.input_tensor, self.recv_expert_count, self.block_size
)
class TestMaskedPerTokenQuantEdgeCases(unittest.TestCase):
"""Test edge cases"""
def test_zero_tokens_expert(self):
"""Test expert with zero tokens"""
paddle.seed(2024)
input_tensor = paddle.randn([2, 4, 256], dtype="bfloat16")
recv_expert_count = paddle.to_tensor([0, 2], dtype="int32") # First expert has no tokens
quanted_x_ref, quanted_scale_ref = masked_per_token_quant_ref(input_tensor, recv_expert_count, 128)
# First expert should be all zeros - convert to float32 for comparison
expert_0_quanted = quanted_x_ref[0].astype("float32")
self.assertTrue(paddle.all(expert_0_quanted == 0), "Expert with zero tokens should be all zero")
self.assertTrue(paddle.all(quanted_scale_ref[0] == 0), "Expert with zero tokens should have zero scales")
# Second expert should have valid values - convert to float32 for comparison
expert_1_quanted = quanted_x_ref[1, :2].astype("float32")
self.assertTrue(paddle.any(expert_1_quanted != 0), "Expert with tokens should have non-zero values")
if __name__ == "__main__":
unittest.main()