Support Norm before Rope (#6332)

Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
K11OntheBoat
2026-02-05 15:28:52 +08:00
committed by GitHub
parent 29a313a402
commit 116e2aea7a
4 changed files with 27 additions and 12 deletions
@@ -59,6 +59,7 @@ class Attention(nn.Layer):
linear_smooth: paddle.Tensor = None,
use_neox_rotary_style: bool = False,
use_qk_norm: bool = False,
qk_norm_before_rope: bool = False,
rms_norm_eps: float = 1e-6,
with_sinks: bool = False,
) -> None:
@@ -76,6 +77,7 @@ class Attention(nn.Layer):
linear_shift (Optional[paddle.Tensor], optional): The shift of linear. Defaults to None.
linear_smooth (Optional[paddle.Tensor], optional): The smooth of linear. Defaults to None.
use_qk_norm (bool, optional): Whether to apply rmsnorm on QA after rope. Defaults to False.
qk_norm_before_rope (bool, optional): Whether to apply rmsnorm before rope (e.g., Qwen style). Defaults to False. if True, use_qk_norm should also be True.
rms_norm_eps (float, optional): The epsilon of RMSNorm. Defaults to 1e-6.
Raises:
@@ -124,6 +126,7 @@ class Attention(nn.Layer):
else:
logger.info(f"Attention is running in cache kv {self.quant_method.cache_quant_config.quant_type} mode")
self.use_qk_norm = use_qk_norm
self.qk_norm_before_rope = qk_norm_before_rope
self.rms_norm_eps = rms_norm_eps
if self.use_qk_norm:
self.q_norm_key = f"{self.prefix}.q_norm"