Add with_output version AppendAttention (#3302)

* get use_output from fd_config

* add clear TODO description

* add mask_offset para to align with develop

* fix bug

* fix use_output logic

* fix sot bug
This commit is contained in:
Liumengyuan
2025-08-28 17:10:18 +08:00
committed by GitHub
parent 94ded434bd
commit e93d4cfcdd
8 changed files with 1366 additions and 96 deletions
@@ -24,6 +24,7 @@ import paddle
from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
append_attention_with_output,
get_block_shape_and_split_kv_block,
init_kv_signal_per_query,
init_signal_layerwise,
@@ -122,6 +123,7 @@ class AppendAttentionBackend(AttentionBackend):
fd_config.parallel_config.expert_parallel_rank = 0
self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.use_output = not fd_config.graph_opt_config.full_cuda_graph
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
@@ -229,58 +231,149 @@ class AppendAttentionBackend(AttentionBackend):
layer.layer_id + self.start_layer_index,
)
res = append_attention(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
metadata.rotary_embs,
metadata.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
metadata.mask_offset,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"),
layer.use_neox_rotary_style,
self.rope_3d,
self.max_seq_len,
getattr(layer, "quant_max_bound", 0.0),
getattr(layer, "quant_min_bound", 0.0),
getattr(layer, "out_scale", -1.0),
self.encoder_block_shape_q,
self.decoder_block_shape_q,
metadata.max_partition_size,
metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1,
self.causal,
self.speculative_method is not None,
)[0]
if self.use_output:
quant_max_bound = getattr(layer, "quant_max_bound", 0.0)
cache_quant_type = getattr(layer, "cache_quant_type_str", "none")
compute_type = metadata._fuse_kernel_compute_dtype
out_scale = getattr(layer, "out_scale", -1.0)
# 1. get output datatype
qkv_dtype = qkv.dtype
if qkv_dtype == paddle.float16:
D_type = paddle.float16
elif qkv_dtype == paddle.bfloat16:
D_type = paddle.bfloat16
elif qkv_dtype == paddle.int32:
if compute_type == "bf16":
D_type = paddle.bfloat16
elif compute_type == "fp16":
D_type = paddle.float16
else:
raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16'].")
else:
raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16', 'int32'].")
# 2.Extract related parameters
token_nums = qkv.shape[0]
head_dims = self.head_dim if cache_quant_type != "cache_int4_zp" else self.head_dim * 2
q_num_heads = self.num_heads
# 3. generate output tensor of different dtypes
if out_scale > 0.0:
if abs(quant_max_bound - 127) < 0.000001:
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="int8").to(qkv.place)
elif abs(quant_max_bound - 448) < 0.000001:
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="float8_e4m3fn").to(qkv.place)
else:
raise NotImplementedError("Only supported attr of quant_max_bound in ['127', '448'].")
else:
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type).to(qkv.place)
append_attention_with_output(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
res,
metadata.rotary_embs,
metadata.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
metadata.mask_offset,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"),
layer.use_neox_rotary_style,
self.rope_3d,
self.max_seq_len,
getattr(layer, "quant_max_bound", 0.0),
getattr(layer, "quant_min_bound", 0.0),
getattr(layer, "out_scale", -1.0),
self.encoder_block_shape_q,
self.decoder_block_shape_q,
metadata.max_partition_size,
metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1,
self.causal,
self.speculative_method is not None,
)
else:
res = append_attention(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
metadata.rotary_embs,
metadata.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
metadata.mask_offset,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"),
layer.use_neox_rotary_style,
self.rope_3d,
self.max_seq_len,
getattr(layer, "quant_max_bound", 0.0),
getattr(layer, "quant_min_bound", 0.0),
getattr(layer, "out_scale", -1.0),
self.encoder_block_shape_q,
self.decoder_block_shape_q,
metadata.max_partition_size,
metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1,
self.causal,
self.speculative_method is not None,
)
return res
@@ -378,7 +378,7 @@ class FlashAttentionBackend(AttentionBackend):
self.speculate_max_draft_token_num + 1,
self.causal,
self.speculative_method is not None,
)[0]
)
if metadata.max_len_tensor_cpu[1] > 0:
merge_prefill_decode_output(
@@ -14,7 +14,7 @@
# limitations under the License.
"""
from .append_attention import append_attention
from .append_attention import append_attention, append_attention_with_output
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
from .gqa_rope_write_cache import gqa_rope_write_cache
from .init_kv_signal_per_query import init_kv_signal_per_query
@@ -25,6 +25,7 @@ from .pre_cache_len_concat import pre_cache_len_concat
__all__ = [
"get_block_shape_and_split_kv_block",
"append_attention",
"append_attention_with_output",
"open_shm_and_get_meta_signal",
"init_signal_layerwise",
"gqa_rope_write_cache",
@@ -24,6 +24,9 @@ if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
append_attention as append_attention_gpu,
)
from fastdeploy.model_executor.ops.gpu import (
append_attention_with_output as append_attention_with_output_gpu,
)
def append_attention(
@@ -141,3 +144,124 @@ def append_attention(
return out
else:
raise NotImplementedError
# TODO: (mengyuan) merge w/o output version append attention after
# finishing developing sub-graph cudagraph capture to reduce
# compilation volume
def append_attention_with_output(
qkv: paddle.Tensor,
key_cache: paddle.Tensor,
value_cache: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
batch_id_per_token: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
block_tables: paddle.Tensor,
encoder_batch_ids: paddle.Tensor,
encoder_tile_ids_per_batch: paddle.Tensor,
encoder_num_blocks: paddle.Tensor,
kv_batch_ids: paddle.Tensor,
kv_tile_ids_per_batch: paddle.Tensor,
kv_num_blocks: paddle.Tensor,
decoder_batch_ids: paddle.Tensor,
decoder_tile_ids_per_batch: paddle.Tensor,
decoder_num_blocks: paddle.Tensor,
set_max_lengths: paddle.Tensor,
max_len_kv: paddle.Tensor,
out: paddle.tensor, # attention output
rotary_embs: Optional[paddle.Tensor] = None,
attn_mask: Optional[paddle.Tensor] = None,
qkv_bias: Optional[paddle.Tensor] = None,
qkv_scale: Optional[paddle.Tensor] = None,
k_quant_scale: Optional[paddle.Tensor] = None,
v_quant_scale: Optional[paddle.Tensor] = None,
k_dequant_scale: Optional[paddle.Tensor] = None,
v_dequant_scale: Optional[paddle.Tensor] = None,
cache_k_zp: Optional[paddle.Tensor] = None,
cache_v_zp: Optional[paddle.Tensor] = None,
linear_shift: Optional[paddle.Tensor] = None,
linear_smooth: Optional[paddle.Tensor] = None,
mask_offset: Optional[paddle.Tensor] = None,
kv_signal_data: Optional[paddle.Tensor] = None,
q_norm_weight: Optional[paddle.Tensor] = None,
k_norm_weight: Optional[paddle.Tensor] = None,
rms_norm_eps: float = 1e-6,
compute_type: str = "bf16",
cache_quant_type: str = "none",
use_neox_rotary_style: bool = False,
rope_3d: bool = False,
max_input_length: int = 0,
quant_max_bound: float = 0.0,
quant_min_bound: float = 0.0,
out_linear_in_scale: float = -1.0,
encoder_block_shape_q: int = 64,
decoder_block_shape_q: int = 16,
max_partition_size: int = 32768,
encoder_max_partition_size: int = 32768,
speculate_max_draft_token_num: int = 1,
causal: bool = True,
speculate_decoder: bool = False,
) -> None:
"""
append_attention
"""
if current_platform.is_cuda():
append_attention_with_output_gpu(
qkv,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
max_len_kv,
out,
rotary_embs,
attn_mask,
qkv_bias,
qkv_scale,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
cache_k_zp,
cache_v_zp,
linear_shift,
linear_smooth,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
compute_type,
cache_quant_type,
use_neox_rotary_style,
rope_3d,
max_input_length,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder,
)
else:
raise NotImplementedError