mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
Support Norm before Rope (#6332)
Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
@@ -280,6 +280,10 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
|
||||
sliding_window = layer.sliding_window
|
||||
|
||||
norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
|
||||
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
|
||||
if self.rope_3d:
|
||||
assert len(forward_meta.rotary_embs.shape) == 6
|
||||
else:
|
||||
@@ -402,8 +406,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
layer.linear_smooth,
|
||||
forward_meta.attn_mask_offsets,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
getattr(layer, "q_norm_weight", None),
|
||||
getattr(layer, "k_norm_weight", None),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
getattr(layer, "sinks", None),
|
||||
getattr(layer, "rms_norm_eps", 1e-6),
|
||||
metadata._fuse_kernel_compute_dtype,
|
||||
@@ -458,8 +462,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
layer.linear_smooth,
|
||||
forward_meta.attn_mask_offsets,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
getattr(layer, "q_norm_weight", None),
|
||||
getattr(layer, "k_norm_weight", None),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
getattr(layer, "sinks", None),
|
||||
getattr(layer, "rms_norm_eps", 1e-6),
|
||||
metadata._fuse_kernel_compute_dtype,
|
||||
|
||||
@@ -59,6 +59,7 @@ class Attention(nn.Layer):
|
||||
linear_smooth: paddle.Tensor = None,
|
||||
use_neox_rotary_style: bool = False,
|
||||
use_qk_norm: bool = False,
|
||||
qk_norm_before_rope: bool = False,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
with_sinks: bool = False,
|
||||
) -> None:
|
||||
@@ -76,6 +77,7 @@ class Attention(nn.Layer):
|
||||
linear_shift (Optional[paddle.Tensor], optional): The shift of linear. Defaults to None.
|
||||
linear_smooth (Optional[paddle.Tensor], optional): The smooth of linear. Defaults to None.
|
||||
use_qk_norm (bool, optional): Whether to apply rmsnorm on QA after rope. Defaults to False.
|
||||
qk_norm_before_rope (bool, optional): Whether to apply rmsnorm before rope (e.g., Qwen style). Defaults to False. if True, use_qk_norm should also be True.
|
||||
rms_norm_eps (float, optional): The epsilon of RMSNorm. Defaults to 1e-6.
|
||||
|
||||
Raises:
|
||||
@@ -124,6 +126,7 @@ class Attention(nn.Layer):
|
||||
else:
|
||||
logger.info(f"Attention is running in cache kv {self.quant_method.cache_quant_config.quant_type} mode")
|
||||
self.use_qk_norm = use_qk_norm
|
||||
self.qk_norm_before_rope = qk_norm_before_rope
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
if self.use_qk_norm:
|
||||
self.q_norm_key = f"{self.prefix}.q_norm"
|
||||
|
||||
@@ -330,6 +330,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer.layer_id + self.start_layer_index,
|
||||
)
|
||||
|
||||
norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
|
||||
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
|
||||
if layer.layer_id == 0:
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
@@ -400,8 +404,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
getattr(layer, "q_norm_weight", None),
|
||||
getattr(layer, "k_norm_weight", None),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
@@ -466,8 +470,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer.linear_smooth,
|
||||
forward_meta.attn_mask_offsets,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
getattr(layer, "q_norm_weight", None),
|
||||
getattr(layer, "k_norm_weight", None),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
getattr(layer, "sinks", None),
|
||||
getattr(layer, "rms_norm_eps", 1e-6),
|
||||
metadata._fuse_kernel_compute_dtype,
|
||||
|
||||
@@ -181,6 +181,10 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
):
|
||||
metadata = forward_meta.attention_metadata
|
||||
|
||||
norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
|
||||
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None
|
||||
|
||||
if self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
@@ -259,8 +263,8 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
forward_meta.pre_cache_batch_ids,
|
||||
forward_meta.pre_cache_tile_ids_per_batch,
|
||||
forward_meta.pre_cache_num_blocks_cpu,
|
||||
getattr(layer, "q_norm_weight", None),
|
||||
getattr(layer, "k_norm_weight", None),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
@@ -324,8 +328,8 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
layer.linear_smooth,
|
||||
forward_meta.attn_mask_offsets,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
getattr(layer, "q_norm_weight", None),
|
||||
getattr(layer, "k_norm_weight", None),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
getattr(layer, "sinks", None),
|
||||
getattr(layer, "rms_norm_eps", 1e-6),
|
||||
metadata._fuse_kernel_compute_dtype,
|
||||
|
||||
Reference in New Issue
Block a user