Files
FastDeploy/fastdeploy/model_executor/layers/attention/flash_attn_backend.py
T
Longzhi Wang 2eea6fa97a [BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend (#7028)
* [BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend

* add constexpr and code style clean

* add test

* fix code style

* fix test
2026-03-30 11:17:15 +08:00

543 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
import paddle
from paddle.nn.functional.flash_attention import flash_attn_unpadded
from paddleformers.utils.log import logger
try:
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
except Exception as e:
logger.debug(f"flash_attention_v3_varlen not available: {e}")
flash_attention_v3_varlen = None
try:
from paddle.nn.functional.flash_attention import flashmask_attention
except Exception as e:
logger.debug(f"flashmask_attention not available: {e}")
flashmask_attention = None
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
get_attn_mask_q,
get_block_shape_and_split_kv_block,
gqa_rope_write_cache,
init_kv_signal_per_query,
init_signal_layerwise,
open_shm_and_get_meta_signal,
pre_cache_len_concat,
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.model_executor.utils import get_sm_version
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
import os
from fastdeploy import envs
from fastdeploy.platforms import current_platform
flashmask_attention_v4 = None
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
else:
merge_prefill_decode_output = None
from fastdeploy.spec_decode import SpecMethod
FLASH_ATTN_VERSION = None
def init_flash_attn_version():
"""
init_flash_attn_version
"""
if current_platform.is_cuda():
global FLASH_ATTN_VERSION
sm_version = get_sm_version()
if sm_version >= 100:
try:
paddle.compat.enable_torch_proxy(scope={"cutlass"})
from flash_mask.cute.interface import flashmask_attention as fa4
global flashmask_attention_v4
flashmask_attention_v4 = fa4
FLASH_ATTN_VERSION = 4
logger.info("The current platform supports Flash Attention V4.")
except ImportError:
logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.")
if FLASH_ATTN_VERSION is None:
if sm_version >= 89 and any(num >= 89 for num in paddle.version.cuda_archs()):
FLASH_ATTN_VERSION = 3
logger.info("The current platform supports Flash Attention V3.")
else:
FLASH_ATTN_VERSION = 2
logger.info("The current platform only support Flash Attention V2.")
else:
logger.info("Only support CUDA version flash attention.")
def _is_deterministic_mode():
"""Check if FD_DETERMINISTIC_MODE is enabled."""
return envs.FD_DETERMINISTIC_MODE
init_flash_attn_version()
def flash_attn_func(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: Optional[paddle.Tensor] = None,
cu_seqlens_k: Optional[paddle.Tensor] = None,
max_seqlen_q: Optional[paddle.Tensor] = None,
max_seqlen_k: Optional[paddle.Tensor] = None,
attn_mask_q: Optional[paddle.Tensor] = None,
causal: bool = True,
num_heads: int = None,
kv_num_heads: int = None,
head_dim: int = 128,
version: Optional[int] = None,
):
if FLASH_ATTN_VERSION is None:
init_flash_attn_version()
if version is None:
version = FLASH_ATTN_VERSION
if version == 4:
assert (
flashmask_attention_v4 is not None
), "Cannot import flashmask_attention from flash_mask.cute.interface, please install it first"
assert attn_mask_q is not None, "FA4 requires attn_mask_q"
assert num_heads is not None
assert kv_num_heads is not None
original_flash_attn_version = paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])[
"FLAGS_flash_attn_version"
]
with paddle.no_grad():
try:
paddle.set_flags({"FLAGS_flash_attn_version": 4})
out = flashmask_attention_v4(
q.reshape([1, -1, num_heads, head_dim]),
k.reshape([1, -1, kv_num_heads, head_dim]),
v.reshape([1, -1, kv_num_heads, head_dim]),
startend_row_indices=attn_mask_q,
causal=False,
return_softmax_lse=True,
training=True,
)
finally:
paddle.set_flags({"FLAGS_flash_attn_version": original_flash_attn_version})
return out
if attn_mask_q is not None:
assert flashmask_attention is not None
out = flashmask_attention(
q.reshape([1, -1, num_heads, head_dim]),
k.reshape([1, -1, kv_num_heads, head_dim]),
v.reshape([1, -1, kv_num_heads, head_dim]),
startend_row_indices=attn_mask_q,
causal=False,
)
else:
if version == 3:
out = flash_attention_v3_varlen(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=causal,
)
else:
out = flash_attn_unpadded(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=causal,
scale=head_dim**-0.5,
training=False,
)
return out
@dataclass
class FlashAttentionMetadata(AttentionMetadata):
"""
FlashAttentionMetadata
"""
cu_seqlens_k: paddle.Tensor = None
pre_cache_batch_ids = None
pre_cache_tile_ids_per_batch = None
pre_cache_num_blocks_cpu = None
kv_token_num_cpu = None
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
_fuse_kernel_compute_dtype: str = "bf16"
_dtype: paddle.dtype = paddle.bfloat16
max_len_tensor_cpu_decoder: paddle.Tensor = None
attn_mask_q: paddle.Tensor = None
class FlashAttentionBackend(AttentionBackend):
"""
FlashAttentionBackend backend implementation
"""
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: FlashAttentionMetadata
def __init__(
self,
fd_config: FDConfig,
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
):
"""
FlashAttentionBackend __init__
"""
super().__init__()
self.max_seq_len = fd_config.model_config.max_model_len
self.causal = getattr(fd_config.model_config, "causal", True)
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim = fd_config.model_config.head_dim
self.attn_outputsize_tp = self.num_heads * self.head_dim
self.block_size = fd_config.cache_config.block_size
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
self.speculative_method = fd_config.speculative_config.method
self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
self.start_layer_index: int = fd_config.model_config.start_layer_index
self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
# Note(ZKK): here must be consistent with append_attn_backend.py
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024))
if FLASH_ATTN_VERSION is None:
init_flash_attn_version()
def get_attention_meta(self):
"""get_attention_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
kv_cache_quant_type: str = None,
):
"""
Calculate kv cache shape
"""
key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
key_cache_shape[-1] = self.head_dim // 2
value_cache_shape = key_cache_shape
return key_cache_shape, value_cache_shape
def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata = FlashAttentionMetadata()
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.pd_disaggregation_mode == "per_chunk":
if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run:
init_kv_signal_per_query(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_decoder,
self.rank,
self.num_layers + self.num_layers_draft_model,
)
elif self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag
)
if metadata._dtype == "bfloat16":
metadata._fuse_kernel_compute_dtype = "bf16"
elif metadata._dtype == "float16":
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
self.attention_metadata = metadata
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
):
metadata = self.attention_metadata
if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index,
)
if int(os.getenv("USE_TBO", "0")) == 1:
if hasattr(forward_meta, "tbo_microbatch_id"):
# here we only let the last microbatch invoke cache kv transfer
if forward_meta.tbo_microbatch_id == 0:
os.environ["FLAGS_fmt_write_cache_completed_signal"] = "0"
elif forward_meta.tbo_microbatch_id == 1:
os.environ["FLAGS_fmt_write_cache_completed_signal"] = "1"
norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None
cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none")
if cache_quant_type_str == "block_wise_fp8":
cache_k = forward_meta.caches[4 * layer.layer_id]
cache_v = forward_meta.caches[4 * layer.layer_id + 1]
cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2]
cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3]
else:
cache_k = forward_meta.caches[2 * layer.layer_id]
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
cache_k_scales = getattr(layer, "cache_k_scale", None)
cache_v_scales = getattr(layer, "cache_v_scale", None)
if layer.layer_id == 0:
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
)
if forward_meta.max_len_tensor_cpu[1].item() > 0:
forward_meta.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
forward_meta.max_len_tensor_cpu_decoder[1] = 0
(
forward_meta.cu_seqlens_k,
forward_meta.pre_cache_batch_ids,
forward_meta.pre_cache_tile_ids_per_batch,
forward_meta.pre_cache_num_blocks_cpu,
forward_meta.kv_token_num_cpu,
) = pre_cache_len_concat(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.max_len_tensor_cpu[2],
self.block_size,
)
if FLASH_ATTN_VERSION == 4 or forward_meta.attn_mask_offsets is not None:
forward_meta.attn_mask_q = get_attn_mask_q(
cu_seqlens_q=forward_meta.cu_seqlens_q,
cu_seqlens_k=forward_meta.cu_seqlens_k,
attn_mask_kv=forward_meta.attn_mask_offsets,
kv_token_num=forward_meta.kv_token_num_cpu[0].item(),
)
else:
forward_meta.attn_mask_q = None
use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0
if use_fa_do_prefill:
q, k, v, _ = gqa_rope_write_cache(
qkv,
cache_k,
cache_v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
forward_meta.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
forward_meta.block_tables,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.pre_cache_batch_ids,
forward_meta.pre_cache_tile_ids_per_batch,
forward_meta.pre_cache_num_blocks_cpu,
q_norm_weight,
k_norm_weight,
cache_k_scales,
cache_v_scales,
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
metadata.kv_signal_data_list[layer.layer_id],
forward_meta.kv_token_num_cpu[0].item(),
self.max_seq_len,
getattr(layer, "rms_norm_eps", 1e-6),
layer.use_neox_rotary_style,
getattr(layer, "cache_quant_type_str", "none"),
self.rope_3d,
)
res_encoder = flash_attn_func(
q,
k,
v,
forward_meta.cu_seqlens_q[: forward_meta.cu_seqlens_k.shape[0]],
forward_meta.cu_seqlens_k,
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
attn_mask_q=forward_meta.attn_mask_q,
causal=self.causal,
num_heads=self.num_heads,
kv_num_heads=self.kv_num_heads,
head_dim=self.head_dim,
)[0].reshape([-1, self.attn_outputsize_tp])
res_decoder = append_attention(
qkv,
cache_k,
cache_v,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
forward_meta.block_tables,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu,
forward_meta.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
cache_k_scales,
cache_v_scales,
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id],
q_norm_weight,
k_norm_weight,
getattr(layer, "sinks", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"),
layer.use_neox_rotary_style,
self.rope_3d,
self.max_seq_len,
getattr(layer, "quant_max_bound", 0.0),
getattr(layer, "quant_min_bound", 0.0),
getattr(layer, "out_scale", -1.0),
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.max_partition_size,
self.max_seq_len,
self.speculate_max_draft_token_num + 1,
self.causal,
self.speculative_method is not None,
)
if use_fa_do_prefill:
merge_prefill_decode_output(
res_encoder,
res_decoder,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
self.num_heads,
self.head_dim,
self.speculate_max_draft_token_num + 1,
)
return res_encoder
else:
return res_decoder