[BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend (#7028)

* [BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend

* add constexpr and code style clean

* add test

* fix code style

* fix test
This commit is contained in:
Longzhi Wang
2026-03-30 11:17:15 +08:00
committed by GitHub
parent 7a20eaebe8
commit 2eea6fa97a
5 changed files with 1031 additions and 82 deletions
@@ -615,11 +615,14 @@ template <typename T,
uint32_t HEAD_DIM,
uint32_t BLOCK_SIZE,
uint32_t NUM_WARPS = 4,
bool IS_FP8 = false>
bool IS_FP8 = false,
bool dynamic_quant = false>
__global__ void append_cache_kv_c8(const CacheT *__restrict__ cache_k,
const CacheT *__restrict__ cache_v,
T *__restrict__ k_out,
T *__restrict__ v_out,
const T *__restrict__ cache_k_quant_scales,
const T *__restrict__ cache_v_quant_scales,
const T *__restrict__ cache_k_dequant_scales,
const T *__restrict__ cache_v_dequant_scales,
const int *__restrict__ seq_lens_this_time,
@@ -653,6 +656,19 @@ __global__ void append_cache_kv_c8(const CacheT *__restrict__ cache_k,
cache_k + block_id * block_stride + kv_head_idx * kv_h_stride;
const CacheT *cur_cache_v =
cache_v + block_id * block_stride + kv_head_idx * kv_h_stride;
const T *cur_cache_k_scales;
const T *cur_cache_v_scales;
T cache_k_scale = 0;
T cache_v_scale = 0;
if constexpr (dynamic_quant) {
cur_cache_k_scales = cache_k_quant_scales +
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE;
cur_cache_v_scales = cache_v_quant_scales +
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE;
} else {
cache_k_scale = cache_k_dequant_scales[kv_head_idx];
cache_v_scale = cache_v_dequant_scales[kv_head_idx];
}
// k_out v_out idx
uint32_t kv_t_stride = kv_num_heads * HEAD_DIM;
@@ -663,8 +679,6 @@ __global__ void append_cache_kv_c8(const CacheT *__restrict__ cache_k,
uint32_t k_frag[4], v_frag[4], frag_dq[4];
T *frag_dq_T = reinterpret_cast<T *>(frag_dq);
T cache_k_scale = cache_k_dequant_scales[kv_head_idx];
T cache_v_scale = cache_v_dequant_scales[kv_head_idx];
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / num_elems_per_128b<CacheT>();
@@ -723,23 +737,29 @@ __global__ void append_cache_kv_c8(const CacheT *__restrict__ cache_k,
T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride +
kv_head_idx * HEAD_DIM + col_idx;
T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride;
T cache_k_scale_0 = cache_k_scale;
T cache_k_scale_1 = cache_k_scale;
if constexpr (dynamic_quant) {
cache_k_scale_0 = cur_cache_k_scales[row_idx];
cache_k_scale_1 = cur_cache_k_scales[row_idx + 8];
}
if (row_idx < end_idx) {
convert_c8<T, IS_FP8>(frag_dq_T,
k_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T
k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale;
k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale;
k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale;
k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale;
k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale_0;
k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale_0;
k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale_0;
k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale_0;
}
if (row_idx + 8 < end_idx) {
convert_c8<T, IS_FP8>(frag_dq_T + 4,
k_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T
k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale;
k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale;
k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale;
k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale;
k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale_1;
k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale_1;
k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale_1;
k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale_1;
}
col_idx += 16;
}
@@ -798,25 +818,36 @@ __global__ void append_cache_kv_c8(const CacheT *__restrict__ cache_k,
T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride +
kv_head_idx * HEAD_DIM + dim_idx;
T *v_tile_ptr1 = v_tile_ptr0 + 8;
T cache_v_scale_0 = cache_v_scale;
T cache_v_scale_1 = cache_v_scale;
T cache_v_scale_2 = cache_v_scale;
T cache_v_scale_3 = cache_v_scale;
if constexpr (dynamic_quant) {
cache_v_scale_0 = cur_cache_v_scales[kv_idx];
cache_v_scale_1 = cur_cache_v_scales[kv_idx + 1];
cache_v_scale_2 = cur_cache_v_scales[kv_idx + 2];
cache_v_scale_3 = cur_cache_v_scales[kv_idx + 3];
}
convert_c8<T, IS_FP8>(frag_dq_T,
v_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T
convert_c8<T, IS_FP8>(frag_dq_T + 4,
v_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T
if (kv_idx < end_idx) {
v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale;
v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale;
v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale_0;
v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale_0;
}
if (kv_idx + 1 < end_idx) {
v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale;
v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale;
v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale_1;
v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale_1;
}
if (kv_idx + 8 < end_idx) {
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale;
v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale;
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale_2;
v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale_2;
}
if (kv_idx + 9 < end_idx) {
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale;
v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale;
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale_3;
v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale_3;
}
kv_idx += 16;
}
@@ -1154,25 +1185,28 @@ __global__ void append_cache_kv_c4(const CacheT *__restrict__ cache_k,
}
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
void AppendCacheKV(const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::Tensor &cache_k_dequant_scales,
const paddle::Tensor &cache_v_dequant_scales,
const paddle::Tensor &cache_k_zp,
const paddle::Tensor &cache_v_zp,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &cu_seqlens_k,
const paddle::Tensor &block_tables,
const paddle::Tensor &cache_batch_ids,
const paddle::Tensor &cache_tile_ids_per_batch,
const paddle::Tensor &cache_num_blocks_x,
const int max_blocks_per_seq,
const int kv_num_heads,
const std::string &cache_quant_type,
paddle::Tensor *k_out,
paddle::Tensor *v_out,
const cudaStream_t &stream) {
void AppendCacheKV(
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
const paddle::Tensor &cache_k_zp,
const paddle::Tensor &cache_v_zp,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &cu_seqlens_k,
const paddle::Tensor &block_tables,
const paddle::Tensor &cache_batch_ids,
const paddle::Tensor &cache_tile_ids_per_batch,
const paddle::Tensor &cache_num_blocks_x,
const int max_blocks_per_seq,
const int kv_num_heads,
const std::string &cache_quant_type,
paddle::Tensor *k_out,
paddle::Tensor *v_out,
const cudaStream_t &stream) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
constexpr int NUM_WARPS = 4;
int block_num = cache_num_blocks_x.data<int>()[0];
@@ -1206,7 +1240,8 @@ void AppendCacheKV(const paddle::Tensor &cache_k,
max_blocks_per_seq,
kv_num_heads);
} else if (cache_quant_type == "cache_int8" ||
cache_quant_type == "cache_fp8") {
cache_quant_type == "cache_fp8" ||
cache_quant_type == "block_wise_fp8") {
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;
auto kernel_func = append_cache_kv_c8<NV_TYPE,
@@ -1214,6 +1249,7 @@ void AppendCacheKV(const paddle::Tensor &cache_k,
HEAD_DIM,
BLOCK_SIZE,
NUM_WARPS,
false,
false>;
if (cache_quant_type == "cache_fp8") {
kernel_func = append_cache_kv_c8<NV_TYPE,
@@ -1221,33 +1257,51 @@ void AppendCacheKV(const paddle::Tensor &cache_k,
HEAD_DIM,
BLOCK_SIZE,
NUM_WARPS,
true,
false>;
} else if (cache_quant_type == "block_wise_fp8") {
kernel_func = append_cache_kv_c8<NV_TYPE,
uint8_t,
HEAD_DIM,
BLOCK_SIZE,
NUM_WARPS,
true,
true>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
launchWithPdlWhenEnabled(kernel_func,
grids,
blocks,
smem_size,
stream,
cache_k.data<uint8_t>(),
cache_v.data<uint8_t>(),
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(
cache_k_dequant_scales.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(
cache_v_dequant_scales.data<T>())),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_k.data<int>(),
block_tables.data<int>(),
cache_batch_ids.data<int>(),
cache_tile_ids_per_batch.data<int>(),
max_blocks_per_seq,
kv_num_heads);
launchWithPdlWhenEnabled(
kernel_func,
grids,
blocks,
smem_size,
stream,
cache_k.data<uint8_t>(),
cache_v.data<uint8_t>(),
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
cache_k_quant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
cache_k_quant_scales.get().data<T>()))
: nullptr,
cache_v_quant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
cache_v_quant_scales.get().data<T>()))
: nullptr,
cache_k_dequant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
cache_k_dequant_scales.get().data<T>()))
: nullptr,
cache_v_dequant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
cache_v_dequant_scales.get().data<T>()))
: nullptr,
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_k.data<int>(),
block_tables.data<int>(),
cache_batch_ids.data<int>(),
cache_tile_ids_per_batch.data<int>(),
max_blocks_per_seq,
kv_num_heads);
} else if (cache_quant_type == "cache_int4_zp") {
const uint32_t smem_size =
BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) + 4 * HEAD_DIM * sizeof(T);
@@ -1270,9 +1324,9 @@ void AppendCacheKV(const paddle::Tensor &cache_k,
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
reinterpret_cast<NV_TYPE *>(
const_cast<T *>(cache_k_dequant_scales.data<T>())),
const_cast<T *>(cache_k_dequant_scales.get().data<T>())),
reinterpret_cast<NV_TYPE *>(
const_cast<T *>(cache_v_dequant_scales.data<T>())),
const_cast<T *>(cache_v_dequant_scales.get().data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_zp.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_zp.data<T>())),
seq_lens_this_time.data<int>(),
@@ -1457,8 +1511,10 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
if (token_num < kv_token_num) {
AppendCacheKV<data_t, 128, 64>(key_cache,
value_cache,
cache_k_dequant_scales.get(),
cache_v_dequant_scales.get(),
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp.get(),
cache_v_zp.get(),
seq_lens_this_time,