mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend (#7028)
* [BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend * add constexpr and code style clean * add test * fix code style * fix test
This commit is contained in:
@@ -615,11 +615,14 @@ template <typename T,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t NUM_WARPS = 4,
|
||||
bool IS_FP8 = false>
|
||||
bool IS_FP8 = false,
|
||||
bool dynamic_quant = false>
|
||||
__global__ void append_cache_kv_c8(const CacheT *__restrict__ cache_k,
|
||||
const CacheT *__restrict__ cache_v,
|
||||
T *__restrict__ k_out,
|
||||
T *__restrict__ v_out,
|
||||
const T *__restrict__ cache_k_quant_scales,
|
||||
const T *__restrict__ cache_v_quant_scales,
|
||||
const T *__restrict__ cache_k_dequant_scales,
|
||||
const T *__restrict__ cache_v_dequant_scales,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
@@ -653,6 +656,19 @@ __global__ void append_cache_kv_c8(const CacheT *__restrict__ cache_k,
|
||||
cache_k + block_id * block_stride + kv_head_idx * kv_h_stride;
|
||||
const CacheT *cur_cache_v =
|
||||
cache_v + block_id * block_stride + kv_head_idx * kv_h_stride;
|
||||
const T *cur_cache_k_scales;
|
||||
const T *cur_cache_v_scales;
|
||||
T cache_k_scale = 0;
|
||||
T cache_v_scale = 0;
|
||||
if constexpr (dynamic_quant) {
|
||||
cur_cache_k_scales = cache_k_quant_scales +
|
||||
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE;
|
||||
cur_cache_v_scales = cache_v_quant_scales +
|
||||
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE;
|
||||
} else {
|
||||
cache_k_scale = cache_k_dequant_scales[kv_head_idx];
|
||||
cache_v_scale = cache_v_dequant_scales[kv_head_idx];
|
||||
}
|
||||
|
||||
// k_out v_out idx
|
||||
uint32_t kv_t_stride = kv_num_heads * HEAD_DIM;
|
||||
@@ -663,8 +679,6 @@ __global__ void append_cache_kv_c8(const CacheT *__restrict__ cache_k,
|
||||
|
||||
uint32_t k_frag[4], v_frag[4], frag_dq[4];
|
||||
T *frag_dq_T = reinterpret_cast<T *>(frag_dq);
|
||||
T cache_k_scale = cache_k_dequant_scales[kv_head_idx];
|
||||
T cache_v_scale = cache_v_dequant_scales[kv_head_idx];
|
||||
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM / num_elems_per_128b<CacheT>();
|
||||
@@ -723,23 +737,29 @@ __global__ void append_cache_kv_c8(const CacheT *__restrict__ cache_k,
|
||||
T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride +
|
||||
kv_head_idx * HEAD_DIM + col_idx;
|
||||
T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride;
|
||||
T cache_k_scale_0 = cache_k_scale;
|
||||
T cache_k_scale_1 = cache_k_scale;
|
||||
if constexpr (dynamic_quant) {
|
||||
cache_k_scale_0 = cur_cache_k_scales[row_idx];
|
||||
cache_k_scale_1 = cur_cache_k_scales[row_idx + 8];
|
||||
}
|
||||
|
||||
if (row_idx < end_idx) {
|
||||
convert_c8<T, IS_FP8>(frag_dq_T,
|
||||
k_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T
|
||||
k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale;
|
||||
k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale;
|
||||
k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale;
|
||||
k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale;
|
||||
k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale_0;
|
||||
k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale_0;
|
||||
k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale_0;
|
||||
k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale_0;
|
||||
}
|
||||
|
||||
if (row_idx + 8 < end_idx) {
|
||||
convert_c8<T, IS_FP8>(frag_dq_T + 4,
|
||||
k_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T
|
||||
k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale;
|
||||
k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale;
|
||||
k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale;
|
||||
k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale;
|
||||
k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale_1;
|
||||
k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale_1;
|
||||
k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale_1;
|
||||
k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale_1;
|
||||
}
|
||||
col_idx += 16;
|
||||
}
|
||||
@@ -798,25 +818,36 @@ __global__ void append_cache_kv_c8(const CacheT *__restrict__ cache_k,
|
||||
T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride +
|
||||
kv_head_idx * HEAD_DIM + dim_idx;
|
||||
T *v_tile_ptr1 = v_tile_ptr0 + 8;
|
||||
T cache_v_scale_0 = cache_v_scale;
|
||||
T cache_v_scale_1 = cache_v_scale;
|
||||
T cache_v_scale_2 = cache_v_scale;
|
||||
T cache_v_scale_3 = cache_v_scale;
|
||||
|
||||
if constexpr (dynamic_quant) {
|
||||
cache_v_scale_0 = cur_cache_v_scales[kv_idx];
|
||||
cache_v_scale_1 = cur_cache_v_scales[kv_idx + 1];
|
||||
cache_v_scale_2 = cur_cache_v_scales[kv_idx + 2];
|
||||
cache_v_scale_3 = cur_cache_v_scales[kv_idx + 3];
|
||||
}
|
||||
convert_c8<T, IS_FP8>(frag_dq_T,
|
||||
v_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T
|
||||
convert_c8<T, IS_FP8>(frag_dq_T + 4,
|
||||
v_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T
|
||||
if (kv_idx < end_idx) {
|
||||
v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale;
|
||||
v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale;
|
||||
v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale_0;
|
||||
v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale_0;
|
||||
}
|
||||
if (kv_idx + 1 < end_idx) {
|
||||
v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale;
|
||||
v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale;
|
||||
v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale_1;
|
||||
v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale_1;
|
||||
}
|
||||
if (kv_idx + 8 < end_idx) {
|
||||
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale;
|
||||
v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale;
|
||||
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale_2;
|
||||
v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale_2;
|
||||
}
|
||||
if (kv_idx + 9 < end_idx) {
|
||||
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale;
|
||||
v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale;
|
||||
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale_3;
|
||||
v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale_3;
|
||||
}
|
||||
kv_idx += 16;
|
||||
}
|
||||
@@ -1154,25 +1185,28 @@ __global__ void append_cache_kv_c4(const CacheT *__restrict__ cache_k,
|
||||
}
|
||||
|
||||
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
|
||||
void AppendCacheKV(const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::Tensor &cache_k_dequant_scales,
|
||||
const paddle::Tensor &cache_v_dequant_scales,
|
||||
const paddle::Tensor &cache_k_zp,
|
||||
const paddle::Tensor &cache_v_zp,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &cu_seqlens_k,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &cache_batch_ids,
|
||||
const paddle::Tensor &cache_tile_ids_per_batch,
|
||||
const paddle::Tensor &cache_num_blocks_x,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const std::string &cache_quant_type,
|
||||
paddle::Tensor *k_out,
|
||||
paddle::Tensor *v_out,
|
||||
const cudaStream_t &stream) {
|
||||
void AppendCacheKV(
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
|
||||
const paddle::Tensor &cache_k_zp,
|
||||
const paddle::Tensor &cache_v_zp,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &cu_seqlens_k,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &cache_batch_ids,
|
||||
const paddle::Tensor &cache_tile_ids_per_batch,
|
||||
const paddle::Tensor &cache_num_blocks_x,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const std::string &cache_quant_type,
|
||||
paddle::Tensor *k_out,
|
||||
paddle::Tensor *v_out,
|
||||
const cudaStream_t &stream) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
constexpr int NUM_WARPS = 4;
|
||||
int block_num = cache_num_blocks_x.data<int>()[0];
|
||||
@@ -1206,7 +1240,8 @@ void AppendCacheKV(const paddle::Tensor &cache_k,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads);
|
||||
} else if (cache_quant_type == "cache_int8" ||
|
||||
cache_quant_type == "cache_fp8") {
|
||||
cache_quant_type == "cache_fp8" ||
|
||||
cache_quant_type == "block_wise_fp8") {
|
||||
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
|
||||
auto kernel_func = append_cache_kv_c8<NV_TYPE,
|
||||
@@ -1214,6 +1249,7 @@ void AppendCacheKV(const paddle::Tensor &cache_k,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
NUM_WARPS,
|
||||
false,
|
||||
false>;
|
||||
if (cache_quant_type == "cache_fp8") {
|
||||
kernel_func = append_cache_kv_c8<NV_TYPE,
|
||||
@@ -1221,33 +1257,51 @@ void AppendCacheKV(const paddle::Tensor &cache_k,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
NUM_WARPS,
|
||||
true,
|
||||
false>;
|
||||
} else if (cache_quant_type == "block_wise_fp8") {
|
||||
kernel_func = append_cache_kv_c8<NV_TYPE,
|
||||
uint8_t,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
NUM_WARPS,
|
||||
true,
|
||||
true>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
launchWithPdlWhenEnabled(kernel_func,
|
||||
grids,
|
||||
blocks,
|
||||
smem_size,
|
||||
stream,
|
||||
cache_k.data<uint8_t>(),
|
||||
cache_v.data<uint8_t>(),
|
||||
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
|
||||
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
cache_k_dequant_scales.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
cache_v_dequant_scales.data<T>())),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
block_tables.data<int>(),
|
||||
cache_batch_ids.data<int>(),
|
||||
cache_tile_ids_per_batch.data<int>(),
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads);
|
||||
launchWithPdlWhenEnabled(
|
||||
kernel_func,
|
||||
grids,
|
||||
blocks,
|
||||
smem_size,
|
||||
stream,
|
||||
cache_k.data<uint8_t>(),
|
||||
cache_v.data<uint8_t>(),
|
||||
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
|
||||
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
|
||||
cache_k_quant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
cache_k_quant_scales.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_quant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
cache_v_quant_scales.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_dequant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
cache_k_dequant_scales.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_dequant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
cache_v_dequant_scales.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
block_tables.data<int>(),
|
||||
cache_batch_ids.data<int>(),
|
||||
cache_tile_ids_per_batch.data<int>(),
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads);
|
||||
} else if (cache_quant_type == "cache_int4_zp") {
|
||||
const uint32_t smem_size =
|
||||
BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) + 4 * HEAD_DIM * sizeof(T);
|
||||
@@ -1270,9 +1324,9 @@ void AppendCacheKV(const paddle::Tensor &cache_k,
|
||||
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
|
||||
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
|
||||
reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(cache_k_dequant_scales.data<T>())),
|
||||
const_cast<T *>(cache_k_dequant_scales.get().data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(cache_v_dequant_scales.data<T>())),
|
||||
const_cast<T *>(cache_v_dequant_scales.get().data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_zp.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_zp.data<T>())),
|
||||
seq_lens_this_time.data<int>(),
|
||||
@@ -1457,8 +1511,10 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
if (token_num < kv_token_num) {
|
||||
AppendCacheKV<data_t, 128, 64>(key_cache,
|
||||
value_cache,
|
||||
cache_k_dequant_scales.get(),
|
||||
cache_v_dequant_scales.get(),
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp.get(),
|
||||
cache_v_zp.get(),
|
||||
seq_lens_this_time,
|
||||
|
||||
Reference in New Issue
Block a user