Files
FastDeploy/fastdeploy/model_executor/models/qwen3.py
T
GoldPancake a498720a75 [RL] Add clear_graph_opt_backend for glm4_mtp (#7378)
* add clear_grpah func

* fix spell
2026-04-15 19:44:15 +08:00

411 lines
14 KiB
Python

"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import re
from functools import partial
from typing import Dict
import paddle
from paddle import nn
from paddleformers.transformers import PretrainedModel
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
)
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.normalization import QKRMSNorm, RMSNorm
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP
from fastdeploy.transformer_utils.config import get_pooling_config
class Qwen3MLP(Qwen2MLP):
""" """
pass
class Qwen3Attention(nn.Layer):
""" """
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None:
super().__init__()
self.fd_config = fd_config
self.head_dim = fd_config.model_config.head_dim
tp_size = fd_config.parallel_config.tensor_parallel_size
num_kv_heads_replicas = max(1, tp_size // fd_config.model_config.num_key_value_heads)
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // tp_size
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // tp_size
self.qkv_proj = QKVParallelLinear(fd_config, prefix=f"{prefix}.qkv_proj", with_bias=False)
self.o_proj = RowParallelLinear(
fd_config,
prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.head_dim * fd_config.model_config.num_attention_heads,
output_size=fd_config.model_config.hidden_size,
layer_id=layer_id,
)
self.attn = Attention(
fd_config,
layer_id=layer_id,
prefix=prefix,
use_neox_rotary_style=True,
)
self.qk_norm = QKRMSNorm(
fd_config,
head_dim=self.head_dim,
q_size=self.q_size,
kv_size=self.kv_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=prefix,
begin_norm_axis=2,
)
def load_state_dict(self, state_dict):
""" """
self.qkv_proj.load_state_dict(state_dict)
self.o_proj.load_state_dict(state_dict)
self.qk_norm.load_state_dict(state_dict)
self.attn.load_state_dict(state_dict)
def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
):
""" """
qkv_out = self.qkv_proj(hidden_states)
qkv_out = self.qk_norm(qkv_out, forward_meta)
atten_out = self.attn(
qkv=qkv_out,
forward_meta=forward_meta,
)
output = self.o_proj(atten_out)
return output
class Qwen3DecoderLayer(Qwen2DecoderLayer):
""" """
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
) -> None:
super().__init__(fd_config, prefix)
layer_id = int(prefix.split(sep=".")[-1])
self.self_attn = Qwen3Attention(fd_config=fd_config, layer_id=layer_id, prefix=f"{prefix}.self_attn")
@support_graph_optimization
class Qwen3Model(nn.Layer):
""" """
def __init__(
self,
fd_config: FDConfig = None,
):
"""
Initializer for the Qwen3Model class.
Args:
"""
super().__init__()
self.num_layers = fd_config.model_config.num_hidden_layers
fd_config.model_config.pretrained_config.prefix_name = "model"
self.embed_tokens = VocabParallelEmbedding(
fd_config=fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
params_dtype=paddle.get_default_dtype,
prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"),
)
self.layers = nn.LayerList(
[
Qwen3DecoderLayer(
fd_config=fd_config,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
)
for i in range(self.num_layers)
]
)
self.norm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
)
def load_state_dict(self, state_dict):
"""
Load model parameters from a given state dictionary.
Args:
state_dict (dict[str, np.ndarray | paddle.Tensor]):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.embed_tokens.load_state_dict(state_dict)
self.norm.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
self.layers[i].load_state_dict(state_dict)
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
out = self.norm(hidden_states, residual)[0]
return out
@ModelRegistry.register_model_class(
architecture="Qwen3ForCausalLM",
module_name="qwen3",
category=[ModelCategory.TEXT_GENERATION],
primary_use=ModelCategory.TEXT_GENERATION,
)
class Qwen3ForCausalLM(ModelForCasualLM):
"""
Qwen3ForCausalLM
"""
def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Qwen3ForCausalLM, self).__init__(fd_config)
self.fd_config = fd_config
self.model = Qwen3Model(fd_config=fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
self.lm_head = ParallelLMHead(
fd_config=fd_config,
embedding_dim=fd_config.model_config.hidden_size,
num_embeddings=fd_config.model_config.vocab_size,
prefix="lm_head",
)
@classmethod
def name(self):
""" """
return "Qwen3ForCausalLM"
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)
is_pooling_model = hasattr(self, "is_pooling_model") and self.is_pooling_model
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("up_gate_proj", "gate_proj", "gate"),
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("qk_norm.q_norm", "q_norm", None),
("qk_norm.k_norm", "k_norm", None),
]
params_dict = dict(self.named_parameters())
model_path = self.fd_config.model_config.model
revision = self.fd_config.model_config.revision
if is_pooling_model and get_pooling_config(model_path, revision):
params_dict = {
param_name[6:] if param_name.startswith("model.") else param_name: param
for param_name, param in params_dict.items()
}
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config)
for loaded_weight_name, loaded_weight in weights_iterator:
logger.debug(f"Loading weight: {loaded_weight_name}")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight, shard_id)
break
else:
model_param_name = loaded_weight_name
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings and not is_pooling_model:
self.lm_head.linear.weight.set_value(
self.model.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype)
)
@paddle.no_grad()
def set_state_dict(self, state_dict):
"""
Load model parameters from a given state dictionary.
Args:
state_dict (dict[str, np.ndarray | paddle.Tensor]):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.model.load_state_dict(state_dict)
if self.tie_word_embeddings:
self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight})
else:
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None):
""" """
logits = self.lm_head(hidden_states)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits
def forward(
self,
inputs: Dict,
forward_meta: ForwardMeta,
):
ids_remove_padding = inputs["ids_remove_padding"]
hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
return hidden_states
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
class Qwen3PretrainedModel(PretrainedModel):
"""
Qwen3PretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
@classmethod
def arch_name(self):
return "Qwen3ForCausalLM"
@classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True):
from paddleformers.transformers.conversion_utils import split_or_merge_func
fn = split_or_merge_func(
is_split=is_split,
tensor_model_parallel_size=config.tensor_model_parallel_size,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
)
def get_tensor_parallel_split_mappings(num_layers):
final_actions = {}
base_actions = {
# Row Linear
"lm_head.weight": partial(fn, is_column=True),
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
"layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
}
# Column Linear
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
# if we have enough num_key_value_heads to split, then split it.
if config.num_key_value_heads % config.tensor_model_parallel_size == 0:
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)
for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
final_actions[key] = action
return final_actions
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
return mappings