Support Norm before Rope (#6332)

Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
K11OntheBoat
2026-02-05 15:28:52 +08:00
committed by GitHub
parent 29a313a402
commit 116e2aea7a
4 changed files with 27 additions and 12 deletions
@@ -330,6 +330,10 @@ class FlashAttentionBackend(AttentionBackend):
layer.layer_id + self.start_layer_index,
)
norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None
if layer.layer_id == 0:
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
@@ -400,8 +404,8 @@ class FlashAttentionBackend(AttentionBackend):
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
q_norm_weight,
k_norm_weight,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
@@ -466,8 +470,8 @@ class FlashAttentionBackend(AttentionBackend):
layer.linear_smooth,
forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
q_norm_weight,
k_norm_weight,
getattr(layer, "sinks", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,