[Optimization] Accelerate Qwen3 QK RMSNorm via Fused Triton Kernel (#5880)

* qk rmsnorm fused

* inplace

* glm

* fix

* add qknorm layer

* fix

* update

* fix qwen3 vl

* update rl baseline

* fix qwen3 vl moe

* test

* fix qwen vl moe rl

* fix
This commit is contained in:
sunxin
2026-01-12 21:10:21 +08:00
committed by GitHub
parent 1aa7e82924
commit 2533836dbb
12 changed files with 733 additions and 387 deletions
+12 -17
View File
@@ -41,7 +41,7 @@ from fastdeploy.model_executor.layers.linear import (
)
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.layers.normalization import QKRMSNorm, RMSNorm
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
@@ -205,18 +205,13 @@ class Glm4MoeAttention(nn.Layer):
rms_norm_eps=fd_config.model_config.rms_norm_eps,
)
if self.use_qk_norm:
self.q_norm = RMSNorm(
self.qk_norm = QKRMSNorm(
fd_config,
hidden_size=self.head_dim,
head_dim=self.head_dim,
q_size=self.q_size,
kv_size=self.kv_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.q_norm",
begin_norm_axis=2,
)
self.k_norm = RMSNorm(
fd_config,
hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.k_norm",
prefix=prefix,
begin_norm_axis=2,
)
@@ -227,13 +222,8 @@ class Glm4MoeAttention(nn.Layer):
):
""" """
qkv_out = self.qkv_proj(hidden_states)
if self.use_qk_norm:
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)
q = self.q_norm(q.reshape([-1, self.num_heads, self.head_dim]))[0].reshape(q.shape)
k = self.k_norm(k.reshape([-1, self.num_kv_heads, self.head_dim]))[0].reshape(k.shape)
qkv_out = paddle.concat([q, k, v], axis=-1)
qkv_out = self.qk_norm(qkv_out)
atten_out = self.attn(
qkv=qkv_out,
forward_meta=forward_meta,
@@ -435,6 +425,11 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
("lm_head.linear", "lm_head", None),
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
]
if self.fd_config.model_config.use_qk_norm:
stacked_params_mapping.append(("qk_norm.q_norm", "q_norm", None))
stacked_params_mapping.append(("qk_norm.k_norm", "k_norm", None))
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
num_experts=self.fd_config.model_config.n_routed_experts,
+14 -31
View File
@@ -33,7 +33,7 @@ from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.layers.normalization import QKRMSNorm, RMSNorm
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
@@ -57,6 +57,10 @@ class Qwen3Attention(nn.Layer):
self.fd_config = fd_config
self.head_dim = fd_config.model_config.head_dim
tp_size = fd_config.parallel_config.tensor_parallel_size
num_kv_heads_replicas = max(1, tp_size // fd_config.model_config.num_key_value_heads)
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // tp_size
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // tp_size
self.qkv_proj = QKVParallelLinear(fd_config, prefix=f"{prefix}.qkv_proj", with_bias=False)
@@ -75,32 +79,21 @@ class Qwen3Attention(nn.Layer):
use_neox_rotary_style=True,
)
self.q_norm = RMSNorm(
self.qk_norm = QKRMSNorm(
fd_config,
hidden_size=self.head_dim,
head_dim=self.head_dim,
q_size=self.q_size,
kv_size=self.kv_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.q_norm",
prefix=prefix,
begin_norm_axis=2,
)
self.k_norm = RMSNorm(
fd_config,
hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.k_norm",
begin_norm_axis=2,
)
tp_size = fd_config.parallel_config.tensor_parallel_size
num_kv_heads_replicas = max(1, tp_size // fd_config.model_config.num_key_value_heads)
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // tp_size
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // tp_size
def load_state_dict(self, state_dict):
""" """
self.qkv_proj.load_state_dict(state_dict)
self.o_proj.load_state_dict(state_dict)
self.q_norm.load_state_dict(state_dict)
self.k_norm.load_state_dict(state_dict)
self.qk_norm.load_state_dict(state_dict)
self.attn.load_state_dict(state_dict)
def forward(
@@ -110,19 +103,7 @@ class Qwen3Attention(nn.Layer):
):
""" """
qkv_out = self.qkv_proj(hidden_states)
# origin_qkv_out = qkv_out
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)
q_by_head = q.reshape([*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim])
q_by_head = self.q_norm(q_by_head)[0]
q = q_by_head.reshape(q.shape)
k_by_head = k.reshape([*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim])
k_by_head = self.k_norm(k_by_head)[0]
k = k_by_head.reshape(k.shape)
qkv_out = paddle.concat([q, k, v], axis=-1)
qkv_out = self.qk_norm(qkv_out)
atten_out = self.attn(
qkv=qkv_out,
forward_meta=forward_meta,
@@ -280,6 +261,8 @@ class Qwen3ForCausalLM(ModelForCasualLM):
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("qk_norm.q_norm", "q_norm", None),
("qk_norm.k_norm", "k_norm", None),
]
params_dict = dict(self.named_parameters())
@@ -207,6 +207,8 @@ class Qwen3VLForConditionalGeneration(ModelForCasualLM):
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("visual", "model.visual", None),
("qk_norm.q_norm", "q_norm", None),
("qk_norm.k_norm", "k_norm", None),
]
params_dict = dict(self.named_parameters())
@@ -227,6 +227,8 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("visual", "model.visual", None),
("qk_norm.q_norm", "q_norm", None),
("qk_norm.k_norm", "k_norm", None),
]
expert_params_mapping = self.get_expert_mapping() # Not actually used
@@ -358,6 +358,8 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("qk_norm.q_norm", "q_norm", None),
("qk_norm.k_norm", "k_norm", None),
]
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters())