mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend (#7028)
* [BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend * add constexpr and code style clean * add test * fix code style * fix test
This commit is contained in:
@@ -354,6 +354,18 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
|
||||
cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none")
|
||||
if cache_quant_type_str == "block_wise_fp8":
|
||||
cache_k = forward_meta.caches[4 * layer.layer_id]
|
||||
cache_v = forward_meta.caches[4 * layer.layer_id + 1]
|
||||
cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2]
|
||||
cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3]
|
||||
else:
|
||||
cache_k = forward_meta.caches[2 * layer.layer_id]
|
||||
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
|
||||
cache_k_scales = getattr(layer, "cache_k_scale", None)
|
||||
cache_v_scales = getattr(layer, "cache_v_scale", None)
|
||||
|
||||
if layer.layer_id == 0:
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
@@ -410,8 +422,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if use_fa_do_prefill:
|
||||
q, k, v, _ = gqa_rope_write_cache(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
cache_k,
|
||||
cache_v,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.cu_seqlens_k,
|
||||
forward_meta.rotary_embs,
|
||||
@@ -428,8 +440,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.pre_cache_num_blocks_cpu,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
cache_k_scales,
|
||||
cache_v_scales,
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
@@ -460,8 +472,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
res_decoder = append_attention(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
cache_k,
|
||||
cache_v,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -482,8 +494,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.attn_mask,
|
||||
layer.qkv_bias,
|
||||
layer.qkv_scale,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
cache_k_scales,
|
||||
cache_v_scales,
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
|
||||
@@ -187,6 +187,18 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
|
||||
cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none")
|
||||
if cache_quant_type_str == "block_wise_fp8":
|
||||
cache_k = forward_meta.caches[4 * layer.layer_id]
|
||||
cache_v = forward_meta.caches[4 * layer.layer_id + 1]
|
||||
cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2]
|
||||
cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3]
|
||||
else:
|
||||
cache_k = forward_meta.caches[2 * layer.layer_id]
|
||||
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
|
||||
cache_k_scales = getattr(layer, "cache_k_scale", None)
|
||||
cache_v_scales = getattr(layer, "cache_v_scale", None)
|
||||
|
||||
if self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
@@ -249,8 +261,8 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
res_encoder = paddle.zeros([qkv.shape[0], self.num_heads * self.head_dim], dtype=qkv.dtype)
|
||||
q, k, v, _ = gqa_rope_write_cache(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
cache_k,
|
||||
cache_v,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.attn_cu_seqlens_k,
|
||||
forward_meta.rotary_embs,
|
||||
@@ -267,8 +279,8 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
forward_meta.pre_cache_num_blocks_cpu,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
cache_k_scales,
|
||||
cache_v_scales,
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
@@ -309,8 +321,8 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
|
||||
res_decoder = append_attention(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
cache_k,
|
||||
cache_v,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -331,8 +343,8 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
forward_meta.attn_mask,
|
||||
layer.qkv_bias,
|
||||
layer.qkv_scale,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
cache_k_scales,
|
||||
cache_v_scales,
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
|
||||
Reference in New Issue
Block a user