[Optimization] optimize fused_swiglu_fp8_quant_kernel (#7007)

* use sharemem

* B card test

* fix acc error
This commit is contained in:
fxyfxy777
2026-03-27 16:10:16 +08:00
committed by GitHub
parent 6693bcd0e4
commit 8ff8236a6f
@@ -55,41 +55,86 @@ __global__ void fused_swiglu_fp8_quant_kernel(
int64_t hidden_size_scale,
bool use_finegrained_range) {
constexpr int BLOCK = 128;
constexpr int VEC_SIZE = 8; // 8 × bf16 = 16 bytes = 128-bit load
int tid = threadIdx.x;
int lane = tid & 31;
int warp = tid >> 5;
int num_warps = blockDim.x >> 5;
int64_t block_id = static_cast<int64_t>(blockIdx.x);
// Build prefix-sum + per-expert token offset lookup in shared memory.
// Layout: [0..group_num] = cumsum, [group_num+1..2*group_num] =
// expert_of[token_range]
extern __shared__ int smem[];
int* smem_cumsum = smem;
// Build a flat lookup table: for each cumsum bucket, store the expert index.
// Since group_num is small (typically 20-64), this is very compact.
int* smem_expert_lut = smem + group_num + 1;
using VecBF16 = AlignedVector<T, 4>;
if (tid == 0) {
smem_cumsum[0] = 0;
for (int i = 0; i < group_num; ++i) {
smem_cumsum[i + 1] =
smem_cumsum[i] + static_cast<int>(token_nums_per_expert[i]);
}
}
__syncthreads();
int total_tokens = smem_cumsum[group_num];
using VecBF16 = AlignedVector<T, VEC_SIZE>;
VecBF16 x1_vec, x2_vec;
using VecFP8 = AlignedVector<phi::dtype::float8_e4m3fn, 4>;
using VecFP8 = AlignedVector<phi::dtype::float8_e4m3fn, VEC_SIZE>;
VecFP8 q_vec;
while (true) {
// ================= token mapping =================
int64_t expert = -1;
int64_t token_in_expert = -1;
// Pre-compute scale constants outside loop
const float inv_fp8_max = 1.f / kFP8Max;
// Each warp tracks its current expert to avoid repeated binary search.
// When block_id moves to the next token, we check if it's still in the
// same expert range (which is the common case for sequential iteration).
int cached_expert = -1;
int cached_cumsum_lo = 0;
int cached_cumsum_hi = 0;
for (int64_t block_id = static_cast<int64_t>(blockIdx.x);
block_id < total_tokens;
block_id += gridDim.x) {
// ================= token mapping with cached expert =============
int64_t expert, token_in_expert;
if (lane == 0) {
int64_t cumsum = 0;
for (int64_t i = 0; i < group_num; ++i) {
int64_t cnt = static_cast<int64_t>(token_nums_per_expert[i]);
if (block_id >= cumsum && block_id < cumsum + cnt) {
expert = i;
token_in_expert = block_id - cumsum;
break;
int bid = static_cast<int>(block_id);
// Fast path: check if still in same expert range
if (bid >= cached_cumsum_lo && bid < cached_cumsum_hi) {
expert = cached_expert;
token_in_expert = bid - cached_cumsum_lo;
} else {
// Binary search fallback
int lo = 0, hi = static_cast<int>(group_num) + 1;
while (lo < hi) {
int mid = (lo + hi) >> 1;
if (smem_cumsum[mid] <= bid)
lo = mid + 1;
else
hi = mid;
}
cumsum += cnt;
expert = static_cast<int64_t>(lo - 1);
token_in_expert = bid - static_cast<int64_t>(smem_cumsum[lo - 1]);
// Cache for next iteration
cached_expert = static_cast<int>(expert);
cached_cumsum_lo = smem_cumsum[lo - 1];
cached_cumsum_hi = smem_cumsum[lo]; // lo is already the upper bound
}
}
expert = __shfl_sync(0xffffffff, expert, 0);
token_in_expert = __shfl_sync(0xffffffff, token_in_expert, 0);
if (expert < 0 || token_in_expert >= group_size) break;
// Also broadcast cache values so all lanes in the warp have them
// (only lane 0 updates, but we need consistency for the next iteration
// check)
cached_expert = __shfl_sync(0xffffffff, cached_expert, 0);
cached_cumsum_lo = __shfl_sync(0xffffffff, cached_cumsum_lo, 0);
cached_cumsum_hi = __shfl_sync(0xffffffff, cached_cumsum_hi, 0);
// ================= base pointers =================
int64_t token = expert * group_size + token_in_expert;
@@ -98,91 +143,215 @@ __global__ void fused_swiglu_fp8_quant_kernel(
auto* out = out_fp8 + token * hidden_size;
// With VEC_SIZE=8, each lane processes 8 elements, 32 lanes process 256
// elements. We need to process BLOCK=128 elements per scale group. Each
// warp iteration: 32 lanes × 8 elements = 256 elements = 2 scale groups. So
// we process 2 scale groups per warp iteration.
int64_t num_iters = hidden_size / BLOCK;
// ================= main loop =================
for (int64_t iter = warp; iter < num_iters; iter += num_warps) {
int64_t base = iter * BLOCK + lane * 4;
// Process 2 scale groups (2 × 128 = 256 elements) per warp iteration
for (int64_t iter_pair = warp; iter_pair < num_iters / 2;
iter_pair += num_warps) {
int64_t iter0 = iter_pair * 2;
int64_t base = iter0 * BLOCK + lane * VEC_SIZE;
// vec load
// 128-bit vectorized load: 8 × bf16 = 16 bytes
Load(in + base, &x1_vec);
Load(in + base + hidden_size, &x2_vec);
float v[4];
float amax = -5e4;
float v[VEC_SIZE];
float amax0 = 0.f;
float amax1 = 0.f;
#pragma unroll
for (int i = 0; i < 4; ++i) {
for (int i = 0; i < VEC_SIZE; ++i) {
float x1 = static_cast<float>(x1_vec[i]);
float x2 = static_cast<float>(x2_vec[i]);
// SwiGLU: x2 * silu(x1) = x2 * x1 / (1 + exp(-x1))
float y = x2 * x1 / (1.f + expf(-x1));
float y_r = static_cast<float>(
static_cast<T>(y)); // To simulate the data transformation before
// the fusion of swiglu and quant operators
static_cast<T>(y)); // bf16 round-trip to match reference
v[i] = y_r;
amax = max(amax, abs(y_r));
// Split amax for two scale groups:
// Elements 0..3 belong to scale group iter0, elements 4..7 to iter0+1
if (i < 4) {
amax0 = fmaxf(amax0, fabsf(y_r));
} else {
amax1 = fmaxf(amax1, fabsf(y_r));
}
}
// ---------- warp reduce amax ----------
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
amax = max(amax, __shfl_down_sync(0xffffffff, amax, offset));
// ---------- warp reduce amax for group 0 (lanes 0-15 contribute lower
// half) ---------- All lanes have amax0 from elements [0..3], but we need
// to split by 128-element boundary. lane * 8 + [0..3] maps to elements in
// range [lane*8 .. lane*8+3] A 128-element group covers lanes where
// (lane*8)/128 is the same. 128/8 = 16 lanes per group. So lanes 0-15 →
// group iter0, lanes 16-31 → group iter0+1. Merge: lanes 0-15 have group0
// in amax0 and group1 doesn't exist for them (amax1 from elements 4-7 =
// group0 still since 16*8=128 > lane*8+7 for lane<16... wait, let me
// reconsider).
amax = __shfl_sync(0xffffffff, amax, 0);
amax = max(amax, kEpsilon);
// Actually: lane L processes elements at offset base + [0..7] = iter0*128
// + L*8 + [0..7] For L in [0..15]: offsets are in [iter0*128 .. iter0*128
// + 127] → scale group iter0 For L in [16..31]: offsets are in [iter0*128
// + 128 .. iter0*128 + 255] → scale group iter0+1 So: lanes 0-15 all
// contribute to amax of group iter0, lanes 16-31 to group iter0+1.
// Combine amax0 and amax1 per lane (both belong to same group for that
// lane)
float my_amax = fmaxf(amax0, amax1);
// Half-warp reduce for each group
// Lanes 0-15 reduce among themselves, lanes 16-31 reduce among themselves
#pragma unroll
for (int offset = 8; offset > 0; offset >>= 1)
my_amax = fmaxf(my_amax, __shfl_xor_sync(0xffffffff, my_amax, offset));
// Now lane 0 has amax for group iter0, lane 16 has amax for group iter0+1
float group0_amax = __shfl_sync(0xffffffff, my_amax, 0);
float group1_amax = __shfl_sync(0xffffffff, my_amax, 16);
// Select the correct amax for this lane's group
float amax = (lane < 16) ? group0_amax : group1_amax;
amax = fmaxf(amax, kEpsilon);
if (use_finegrained_range) amax *= 7.f;
float scale = amax / kFP8Max;
float scale = amax * inv_fp8_max;
int64_t my_iter = iter0 + (lane >= 16 ? 1 : 0);
// ---------- quantize ----------
if constexpr (UseUE8M0) {
scale = exp2f(ceilf(log2f(fmaxf(scale, kEpsilon))));
float ue8m0_scale = exp2f(ceilf(log2f(fmaxf(scale, kEpsilon))));
float inv_scale = __frcp_rn(ue8m0_scale);
#pragma unroll
for (int i = 0; i < 4; ++i) {
float q = v[i] / scale;
q_vec[i] = static_cast<phi::dtype::float8_e4m3fn>(q);
for (int i = 0; i < VEC_SIZE; ++i) {
q_vec[i] = static_cast<phi::dtype::float8_e4m3fn>(v[i] * inv_scale);
}
// ---------- store scale ----------
// ---------- store scale (lane 0 writes both groups to avoid race)
// ----------
if (lane == 0) {
// 1. extract exponent
const int exp = (__float_as_int(scale) >> 23) & 0xFF;
// 2. pack information
const int64_t pack_idx = iter >> 2; // iter / 4
const int64_t byte_idx = iter & 3; // iter % 4
// 3. layout parameters
const int64_t pack_num = ceil_div(hidden_size_scale, (int64_t)4);
const int64_t token_stride = align(group_size, (int64_t)4);
// 4. base pointer (int32 pack)
auto* scale_pack = reinterpret_cast<int32_t*>(out_scale);
// 5. column-major offset:
// [expert][pack][token]
const int64_t base_idx = expert * pack_num * token_stride +
pack_idx * token_stride + token_in_expert;
// 6. write one byte into pack
reinterpret_cast<uint8_t*>(&scale_pack[base_idx])[byte_idx] =
static_cast<uint8_t>(exp);
// Group 0 scale (from lane 0's own value)
float s0 =
exp2f(ceilf(log2f(fmaxf(group0_amax * inv_fp8_max, kEpsilon))));
const int exp0 = (__float_as_int(s0) >> 23) & 0xFF;
const int64_t pack_idx0 = iter0 >> 2;
const int64_t byte_idx0 = iter0 & 3;
const int64_t base_idx0 = expert * pack_num * token_stride +
pack_idx0 * token_stride + token_in_expert;
reinterpret_cast<uint8_t*>(&scale_pack[base_idx0])[byte_idx0] =
static_cast<uint8_t>(exp0);
// Group 1 scale (from lane 16's value, broadcast earlier)
int64_t iter1 = iter0 + 1;
float s1 =
exp2f(ceilf(log2f(fmaxf(group1_amax * inv_fp8_max, kEpsilon))));
const int exp1 = (__float_as_int(s1) >> 23) & 0xFF;
const int64_t pack_idx1 = iter1 >> 2;
const int64_t byte_idx1 = iter1 & 3;
const int64_t base_idx1 = expert * pack_num * token_stride +
pack_idx1 * token_stride + token_in_expert;
reinterpret_cast<uint8_t*>(&scale_pack[base_idx1])[byte_idx1] =
static_cast<uint8_t>(exp1);
}
} else {
float inv_amax = __frcp_rn(amax);
#pragma unroll
for (int i = 0; i < 4; i++) {
float q = v[i] * kFP8Max / amax;
for (int i = 0; i < VEC_SIZE; i++) {
float q = v[i] * kFP8Max * inv_amax;
q_vec[i] = static_cast<phi::dtype::float8_e4m3fn>(q);
}
// ---------- store scale ----------
if (lane == 0) {
if (lane == 0 || lane == 16) {
out_scale[expert * hidden_size_scale * group_size +
iter * group_size + token_in_expert] = scale;
my_iter * group_size + token_in_expert] = scale;
}
}
Store(q_vec, out + base);
}
block_id += gridDim.x;
// Handle remainder if num_iters is odd
if (num_iters & 1) {
int64_t iter = num_iters - 1;
// Only the last warp handles this
if (warp == (num_iters / 2) % num_warps ||
num_iters / 2 < static_cast<int64_t>(num_warps)) {
// Fall back to vec4 for the remainder
using VecBF16_4 = AlignedVector<T, 4>;
using VecFP8_4 = AlignedVector<phi::dtype::float8_e4m3fn, 4>;
VecBF16_4 rx1, rx2;
VecFP8_4 rq;
int64_t rbase = iter * BLOCK + lane * 4;
if (rbase < hidden_size) {
Load(in + rbase, &rx1);
Load(in + rbase + hidden_size, &rx2);
float rv[4];
float ramax = 0.f;
#pragma unroll
for (int i = 0; i < 4; ++i) {
float x1 = static_cast<float>(rx1[i]);
float x2 = static_cast<float>(rx2[i]);
float y = x2 * x1 / (1.f + expf(-x1));
float y_r = static_cast<float>(
static_cast<T>(y)); // bf16 round-trip to match reference
rv[i] = y_r;
ramax = fmaxf(ramax, fabsf(y_r));
}
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
ramax = fmaxf(ramax, __shfl_down_sync(0xffffffff, ramax, offset));
ramax = __shfl_sync(0xffffffff, ramax, 0);
ramax = fmaxf(ramax, kEpsilon);
if (use_finegrained_range) ramax *= 7.f;
float rscale = ramax * inv_fp8_max;
if constexpr (UseUE8M0) {
float s = exp2f(ceilf(log2f(fmaxf(rscale, kEpsilon))));
float inv_s = __frcp_rn(s);
#pragma unroll
for (int i = 0; i < 4; ++i) {
rq[i] = static_cast<phi::dtype::float8_e4m3fn>(rv[i] * inv_s);
}
if (lane == 0) {
const int exp = (__float_as_int(s) >> 23) & 0xFF;
const int64_t pack_idx = iter >> 2;
const int64_t byte_idx = iter & 3;
const int64_t pack_num = ceil_div(hidden_size_scale, (int64_t)4);
const int64_t token_stride = align(group_size, (int64_t)4);
auto* scale_pack = reinterpret_cast<int32_t*>(out_scale);
const int64_t base_idx = expert * pack_num * token_stride +
pack_idx * token_stride +
token_in_expert;
reinterpret_cast<uint8_t*>(&scale_pack[base_idx])[byte_idx] =
static_cast<uint8_t>(exp);
}
} else {
float inv_ramax = __frcp_rn(ramax);
#pragma unroll
for (int i = 0; i < 4; i++) {
rq[i] = static_cast<phi::dtype::float8_e4m3fn>(rv[i] * kFP8Max *
inv_ramax);
}
if (lane == 0) {
out_scale[expert * hidden_size_scale * group_size +
iter * group_size + token_in_expert] = rscale;
}
}
Store(rq, out + rbase);
}
}
}
}
}
@@ -220,10 +389,11 @@ std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
int sm_count = 0;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, 0);
constexpr int BLOCKS_PER_SM = 2;
constexpr int BLOCKS_PER_SM = 3;
int blockx = std::min(512L, hidden_size / 128 * 32);
int gridx =
std::min(static_cast<int64_t>(sm_count * BLOCKS_PER_SM), token_num);
int blockx = std::min(1024L, hidden_size / 128 * 32);
int smem_bytes = (group_num + 1) * sizeof(int);
bool use_finegrained_range = false;
if (auto* env = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"))
@@ -233,7 +403,7 @@ std::vector<paddle::Tensor> FusedMaskSwigluFP8Quant(
BOOL_SWITCH(use_ue8m0, UseUE8M0, [&] {
using ScaleT = std::conditional_t<UseUE8M0, int, float>;
fused_swiglu_fp8_quant_kernel<paddle::bfloat16, int, ScaleT, UseUE8M0>
<<<gridx, blockx, 0, input.stream()>>>(
<<<gridx, blockx, smem_bytes, input.stream()>>>(
input.data<paddle::bfloat16>(),
token_nums_per_expert.data<int>(),
out_fp8.data<phi::dtype::float8_e4m3fn>(),