mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization] Accelerate Qwen3 QK RMSNorm via Fused Triton Kernel (#5880)
* qk rmsnorm fused * inplace * glm * fix * add qknorm layer * fix * update * fix qwen3 vl * update rl baseline * fix qwen3 vl moe * test * fix qwen vl moe rl * fix
This commit is contained in:
@@ -41,7 +41,7 @@ from fastdeploy.model_executor.layers.linear import (
|
||||
)
|
||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.layers.normalization import QKRMSNorm, RMSNorm
|
||||
from fastdeploy.model_executor.models.model_base import (
|
||||
ModelCategory,
|
||||
ModelForCasualLM,
|
||||
@@ -205,18 +205,13 @@ class Glm4MoeAttention(nn.Layer):
|
||||
rms_norm_eps=fd_config.model_config.rms_norm_eps,
|
||||
)
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = RMSNorm(
|
||||
self.qk_norm = QKRMSNorm(
|
||||
fd_config,
|
||||
hidden_size=self.head_dim,
|
||||
head_dim=self.head_dim,
|
||||
q_size=self.q_size,
|
||||
kv_size=self.kv_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.q_norm",
|
||||
begin_norm_axis=2,
|
||||
)
|
||||
self.k_norm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=self.head_dim,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.k_norm",
|
||||
prefix=prefix,
|
||||
begin_norm_axis=2,
|
||||
)
|
||||
|
||||
@@ -227,13 +222,8 @@ class Glm4MoeAttention(nn.Layer):
|
||||
):
|
||||
""" """
|
||||
qkv_out = self.qkv_proj(hidden_states)
|
||||
|
||||
if self.use_qk_norm:
|
||||
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)
|
||||
q = self.q_norm(q.reshape([-1, self.num_heads, self.head_dim]))[0].reshape(q.shape)
|
||||
k = self.k_norm(k.reshape([-1, self.num_kv_heads, self.head_dim]))[0].reshape(k.shape)
|
||||
qkv_out = paddle.concat([q, k, v], axis=-1)
|
||||
|
||||
qkv_out = self.qk_norm(qkv_out)
|
||||
atten_out = self.attn(
|
||||
qkv=qkv_out,
|
||||
forward_meta=forward_meta,
|
||||
@@ -435,6 +425,11 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
|
||||
("lm_head.linear", "lm_head", None),
|
||||
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
|
||||
]
|
||||
|
||||
if self.fd_config.model_config.use_qk_norm:
|
||||
stacked_params_mapping.append(("qk_norm.q_norm", "q_norm", None))
|
||||
stacked_params_mapping.append(("qk_norm.k_norm", "k_norm", None))
|
||||
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
num_experts=self.fd_config.model_config.n_routed_experts,
|
||||
|
||||
@@ -33,7 +33,7 @@ from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.layers.normalization import QKRMSNorm, RMSNorm
|
||||
from fastdeploy.model_executor.models.model_base import (
|
||||
ModelCategory,
|
||||
ModelForCasualLM,
|
||||
@@ -57,6 +57,10 @@ class Qwen3Attention(nn.Layer):
|
||||
|
||||
self.fd_config = fd_config
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
num_kv_heads_replicas = max(1, tp_size // fd_config.model_config.num_key_value_heads)
|
||||
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // tp_size
|
||||
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // tp_size
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(fd_config, prefix=f"{prefix}.qkv_proj", with_bias=False)
|
||||
|
||||
@@ -75,32 +79,21 @@ class Qwen3Attention(nn.Layer):
|
||||
use_neox_rotary_style=True,
|
||||
)
|
||||
|
||||
self.q_norm = RMSNorm(
|
||||
self.qk_norm = QKRMSNorm(
|
||||
fd_config,
|
||||
hidden_size=self.head_dim,
|
||||
head_dim=self.head_dim,
|
||||
q_size=self.q_size,
|
||||
kv_size=self.kv_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.q_norm",
|
||||
prefix=prefix,
|
||||
begin_norm_axis=2,
|
||||
)
|
||||
self.k_norm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=self.head_dim,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.k_norm",
|
||||
begin_norm_axis=2,
|
||||
)
|
||||
|
||||
tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
num_kv_heads_replicas = max(1, tp_size // fd_config.model_config.num_key_value_heads)
|
||||
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // tp_size
|
||||
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // tp_size
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
""" """
|
||||
self.qkv_proj.load_state_dict(state_dict)
|
||||
self.o_proj.load_state_dict(state_dict)
|
||||
self.q_norm.load_state_dict(state_dict)
|
||||
self.k_norm.load_state_dict(state_dict)
|
||||
self.qk_norm.load_state_dict(state_dict)
|
||||
self.attn.load_state_dict(state_dict)
|
||||
|
||||
def forward(
|
||||
@@ -110,19 +103,7 @@ class Qwen3Attention(nn.Layer):
|
||||
):
|
||||
""" """
|
||||
qkv_out = self.qkv_proj(hidden_states)
|
||||
# origin_qkv_out = qkv_out
|
||||
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)
|
||||
|
||||
q_by_head = q.reshape([*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim])
|
||||
q_by_head = self.q_norm(q_by_head)[0]
|
||||
q = q_by_head.reshape(q.shape)
|
||||
|
||||
k_by_head = k.reshape([*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim])
|
||||
k_by_head = self.k_norm(k_by_head)[0]
|
||||
k = k_by_head.reshape(k.shape)
|
||||
|
||||
qkv_out = paddle.concat([q, k, v], axis=-1)
|
||||
|
||||
qkv_out = self.qk_norm(qkv_out)
|
||||
atten_out = self.attn(
|
||||
qkv=qkv_out,
|
||||
forward_meta=forward_meta,
|
||||
@@ -280,6 +261,8 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
("up_gate_proj", "up_proj", "up"),
|
||||
("embed_tokens.embeddings", "embed_tokens", None),
|
||||
("lm_head.linear", "lm_head", None),
|
||||
("qk_norm.q_norm", "q_norm", None),
|
||||
("qk_norm.k_norm", "k_norm", None),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
@@ -207,6 +207,8 @@ class Qwen3VLForConditionalGeneration(ModelForCasualLM):
|
||||
("embed_tokens.embeddings", "embed_tokens", None),
|
||||
("lm_head.linear", "lm_head", None),
|
||||
("visual", "model.visual", None),
|
||||
("qk_norm.q_norm", "q_norm", None),
|
||||
("qk_norm.k_norm", "k_norm", None),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
@@ -227,6 +227,8 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
("embed_tokens.embeddings", "embed_tokens", None),
|
||||
("lm_head.linear", "lm_head", None),
|
||||
("visual", "model.visual", None),
|
||||
("qk_norm.q_norm", "q_norm", None),
|
||||
("qk_norm.k_norm", "k_norm", None),
|
||||
]
|
||||
|
||||
expert_params_mapping = self.get_expert_mapping() # Not actually used
|
||||
|
||||
@@ -358,6 +358,8 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
||||
("up_gate_proj", "up_proj", "up"),
|
||||
("embed_tokens.embeddings", "embed_tokens", None),
|
||||
("lm_head.linear", "lm_head", None),
|
||||
("qk_norm.q_norm", "q_norm", None),
|
||||
("qk_norm.k_norm", "k_norm", None),
|
||||
]
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
Reference in New Issue
Block a user