[BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend (#7028)

* [BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend

* add constexpr and code style clean

* add test

* fix code style

* fix test
This commit is contained in:
Longzhi Wang
2026-03-30 11:17:15 +08:00
committed by GitHub
parent 7a20eaebe8
commit 2eea6fa97a
5 changed files with 1031 additions and 82 deletions
@@ -354,6 +354,18 @@ class FlashAttentionBackend(AttentionBackend):
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None
cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none")
if cache_quant_type_str == "block_wise_fp8":
cache_k = forward_meta.caches[4 * layer.layer_id]
cache_v = forward_meta.caches[4 * layer.layer_id + 1]
cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2]
cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3]
else:
cache_k = forward_meta.caches[2 * layer.layer_id]
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
cache_k_scales = getattr(layer, "cache_k_scale", None)
cache_v_scales = getattr(layer, "cache_v_scale", None)
if layer.layer_id == 0:
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
@@ -410,8 +422,8 @@ class FlashAttentionBackend(AttentionBackend):
if use_fa_do_prefill:
q, k, v, _ = gqa_rope_write_cache(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
cache_k,
cache_v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
forward_meta.rotary_embs,
@@ -428,8 +440,8 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.pre_cache_num_blocks_cpu,
q_norm_weight,
k_norm_weight,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
cache_k_scales,
cache_v_scales,
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
@@ -460,8 +472,8 @@ class FlashAttentionBackend(AttentionBackend):
res_decoder = append_attention(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
cache_k,
cache_v,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
@@ -482,8 +494,8 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
cache_k_scales,
cache_v_scales,
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
@@ -187,6 +187,18 @@ class FlashMaskAttentionBackend(AttentionBackend):
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None
cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none")
if cache_quant_type_str == "block_wise_fp8":
cache_k = forward_meta.caches[4 * layer.layer_id]
cache_v = forward_meta.caches[4 * layer.layer_id + 1]
cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2]
cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3]
else:
cache_k = forward_meta.caches[2 * layer.layer_id]
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
cache_k_scales = getattr(layer, "cache_k_scale", None)
cache_v_scales = getattr(layer, "cache_v_scale", None)
if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
@@ -249,8 +261,8 @@ class FlashMaskAttentionBackend(AttentionBackend):
res_encoder = paddle.zeros([qkv.shape[0], self.num_heads * self.head_dim], dtype=qkv.dtype)
q, k, v, _ = gqa_rope_write_cache(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
cache_k,
cache_v,
forward_meta.cu_seqlens_q,
forward_meta.attn_cu_seqlens_k,
forward_meta.rotary_embs,
@@ -267,8 +279,8 @@ class FlashMaskAttentionBackend(AttentionBackend):
forward_meta.pre_cache_num_blocks_cpu,
q_norm_weight,
k_norm_weight,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
cache_k_scales,
cache_v_scales,
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
@@ -309,8 +321,8 @@ class FlashMaskAttentionBackend(AttentionBackend):
res_decoder = append_attention(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
cache_k,
cache_v,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
@@ -331,8 +343,8 @@ class FlashMaskAttentionBackend(AttentionBackend):
forward_meta.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
cache_k_scales,
cache_v_scales,
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),