""" # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from __future__ import annotations import re from functools import partial from typing import Dict import paddle from paddle import nn from paddleformers.transformers import PretrainedModel from paddleformers.utils.log import logger from fastdeploy.config import FDConfig from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( support_graph_optimization, ) from fastdeploy.model_executor.layers.activation import SiluAndMul from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.linear import ( MergedColumnParallelLinear, ReplicatedLinear, RowParallelLinear, ) from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.models.model_base import ( ModelCategory, ModelForCasualLM, ModelRegistry, ) from fastdeploy.model_executor.models.qwen3 import Qwen3Attention class Qwen3MoeBlock(nn.Layer): def __init__( self, fd_config: FDConfig, layer_id: int, prefix: str = "", ) -> None: super().__init__() weight_key_map = { "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", } self.experts = FusedMoE( fd_config, hidden_size=fd_config.model_config.hidden_size, moe_intermediate_size=fd_config.model_config.moe_intermediate_size, num_experts=fd_config.model_config.num_experts, top_k=fd_config.model_config.num_experts_per_tok, layer_idx=layer_id, weight_key_map=weight_key_map, ) self.gate = ReplicatedLinear( fd_config=fd_config, prefix=f"{prefix}.gate", input_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.num_experts, with_bias=False, skip_quant=True, weight_dtype=("float32" if fd_config.model_config.moe_gate_fp32 else ""), ) def forward(self, x, forward_meta): return self.experts(x, self.gate, forward_meta) def load_state_dict(self, state_dict): """ """ self.gate.load_state_dict(state_dict) self.experts.load_state_dict(state_dict) class Qwen3MLP(nn.Layer): """ """ def __init__( self, fd_config: FDConfig, prefix: str = "", ) -> None: super().__init__() self.up_gate_proj = MergedColumnParallelLinear( fd_config, prefix=f"{prefix}.up_gate_proj", input_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.intermediate_size * 2, with_bias=False, activation=fd_config.model_config.hidden_act, ) self.down_proj = RowParallelLinear( fd_config, prefix=f"{prefix}.down_proj", input_size=fd_config.model_config.intermediate_size, output_size=fd_config.model_config.hidden_size, with_bias=False, ) self.act_fn = SiluAndMul( fd_config, bias=getattr(self.up_gate_proj, "bias", None), act_method=fd_config.model_config.hidden_act, ) def load_state_dict(self, state_dict): """ """ self.up_gate_proj.load_state_dict(state_dict) self.down_proj.load_state_dict(state_dict) def forward(self, x, forward_meta): """ """ gate_up_out = self.up_gate_proj(x) act_out = self.act_fn(gate_up_out) down_out = self.down_proj(act_out) return down_out class Qwen3DecoderLayer(nn.Layer): """ """ def __init__( self, fd_config: FDConfig, prefix: str = "", ) -> None: super().__init__() layer_id = int(prefix.split(sep=".")[-1]) self.self_attn = Qwen3Attention( fd_config=fd_config, layer_id=layer_id, prefix=f"{prefix}.self_attn", ) mlp_only_layers = ( [] if not hasattr(fd_config.model_config, "mlp_only_layers") else fd_config.model_config.mlp_only_layers ) if (layer_id not in mlp_only_layers) and ( fd_config.model_config.num_experts > 0 and (layer_id + 1) % fd_config.model_config.decoder_sparse_step == 0 ): self.mlp = Qwen3MoeBlock(fd_config, layer_id, prefix=f"{prefix}.mlp") else: self.mlp = Qwen3MLP( fd_config, prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm( fd_config, hidden_size=fd_config.model_config.hidden_size, eps=fd_config.model_config.rms_norm_eps, prefix=f"{prefix}.input_layernorm", layer_id=layer_id, ) self.post_attention_layernorm = RMSNorm( fd_config, hidden_size=fd_config.model_config.hidden_size, eps=fd_config.model_config.rms_norm_eps, prefix=f"{prefix}.post_attention_layernorm", layer_id=layer_id, ) def load_state_dict(self, state_dict): """ """ self.self_attn.load_state_dict(state_dict) self.mlp.load_state_dict(state_dict) self.input_layernorm.load_state_dict(state_dict) self.post_attention_layernorm.load_state_dict(state_dict) def forward( self, forward_meta: ForwardMeta, hidden_states: paddle.Tensor, residual: paddle.Tensor = None, ): """ """ hidden_states, residual = self.input_layernorm( hidden_states, residual_input=residual, forward_meta=forward_meta ) hidden_states = self.self_attn( hidden_states=hidden_states, forward_meta=forward_meta, ) # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states, forward_meta) return hidden_states, residual @support_graph_optimization class Qwen3MoeModel(nn.Layer): """ """ def __init__( self, fd_config: FDConfig = None, ): """ Initializer for the Qwen2Model class. Args: """ super().__init__() self.num_layers = fd_config.model_config.num_hidden_layers fd_config.model_config.pretrained_config.prefix_name = "model" self.embed_tokens = VocabParallelEmbedding( fd_config, num_embeddings=fd_config.model_config.vocab_size, embedding_dim=fd_config.model_config.hidden_size, params_dtype=paddle.get_default_dtype, prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"), ) self.layers = nn.LayerList( [ Qwen3DecoderLayer( fd_config, prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}", ) for i in range(self.num_layers) ] ) self.norm = RMSNorm( fd_config, hidden_size=fd_config.model_config.hidden_size, eps=fd_config.model_config.rms_norm_eps, prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm", ) def load_state_dict(self, state_dict): """ Load model parameters from a given state dictionary. Args: state_dict (dict[str, np.ndarray | paddle.Tensor]): A dictionary containing model parameters, where keys are parameter names and values are NumPy arrays or PaddlePaddle tensors. """ self.embed_tokens.load_state_dict(state_dict) self.norm.load_state_dict(state_dict) for i in range(self.num_layers): logger.info(f"Start load layer {i}") self.layers[i].load_state_dict(state_dict) def forward( self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta, ): hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) residual = None for i in range(self.num_layers): hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual) out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0] if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe: out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0]) return out @ModelRegistry.register_model_class( architecture="Qwen3MoeForCausalLM", module_name="qwen3moe", category=ModelCategory.TEXT_GENERATION, primary_use=ModelCategory.TEXT_GENERATION, ) class Qwen3MoeForCausalLM(ModelForCasualLM): """ Qwen3MoeForCausalLM """ def __init__(self, fd_config: FDConfig): """ Args: fd_config (FDConfig): Configurations for the LLM model. """ super(Qwen3MoeForCausalLM, self).__init__(fd_config) self.model = Qwen3MoeModel(fd_config) self.ori_vocab_size = fd_config.model_config.ori_vocab_size self.lm_head = ParallelLMHead( fd_config, embedding_dim=fd_config.model_config.hidden_size, num_embeddings=fd_config.model_config.vocab_size, prefix="lm_head", ) @classmethod def name(self): """ """ return "Qwen3MoeForCausalLM" def get_expert_mapping( self, ) -> list[tuple[str, str, int, str]]: # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( num_experts=self.fd_config.model_config.num_experts, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", param_gate_up_proj_name="experts.up_gate_proj_", param_down_proj_name="experts.down_proj_", ) @paddle.no_grad() def load_weights(self, weights_iterator) -> None: """ Load model parameters from a given weights_iterator object. Args: weights_iterator (Iterator): An iterator yielding (name, weight) pairs. """ from fastdeploy.model_executor.utils import ( default_weight_loader, process_weights_after_loading, ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("up_gate_proj", "gate_proj", "gate"), ("up_gate_proj", "up_proj", "up"), ("embed_tokens.embeddings", "embed_tokens", None), ("lm_head.linear", "lm_head", None), ("qk_norm.q_norm", "q_norm", None), ("qk_norm.k_norm", "k_norm", None), ] expert_params_mapping = self.get_expert_mapping() params_dict = dict(self.named_parameters()) process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config) for loaded_weight_name, loaded_weight in weights_iterator: logger.debug(f"Loading weight: {loaded_weight_name}") for param_name, weight_name, shard_id in stacked_params_mapping: # weight_name是存储在 model.safetensors.index.json 中的key的部分字段! if weight_name not in loaded_weight_name: continue # 专家权重需要特殊处理,先continue! if "mlp.experts" in loaded_weight_name: continue # 这里需要将 weight_name部分替换成param_name! model_param_name = loaded_weight_name.replace(weight_name, param_name) if model_param_name not in params_dict: continue param = params_dict[model_param_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in loaded_weight_name: continue model_param_name = loaded_weight_name.replace(weight_name, param_name) if model_param_name not in params_dict: continue param = params_dict[model_param_name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id) break else: model_param_name = loaded_weight_name if model_param_name not in params_dict: continue param = params_dict[model_param_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight) # 这个代码的作用是将 param_name 换成 sublayer_name,从而找到layer! model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) process_weights_after_loading_fn(model_sublayer_name, param) @paddle.no_grad() def set_state_dict(self, state_dict): """ Load model parameters from a given state dictionary. Args: state_dict (dict[str, np.ndarray | paddle.Tensor]): A dictionary containing model parameters, where keys are parameter names and values are NumPy arrays or PaddlePaddle tensors. """ self.model.load_state_dict(state_dict) self.lm_head.load_state_dict(state_dict) def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None): """ """ logits = self.lm_head(hidden_states) logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits def empty_input_forward(self, forward_meta): """ empty_input_forward """ fake_hidden_states = paddle.empty( shape=[0, self.fd_config.model_config.hidden_size], dtype=paddle.get_default_dtype(), ) for i in range( self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.num_hidden_layers, ): self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate, forward_meta) def forward( self, inputs: Dict, forward_meta: ForwardMeta, ): ids_remove_padding = inputs["ids_remove_padding"] hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) return hidden_states def clear_grpah_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" self.model.clear_grpah_opt_backend(fd_config=self.fd_config) class Qwen3MoePretrainedModel(PretrainedModel): """ Qwen3MoePretrainedModel """ config_class = FDConfig def _init_weight(self, layer): """ _init_weight """ return None @classmethod def arch_name(self): return "Qwen3MoeForCausalLM" @classmethod def _get_tensor_parallel_mappings(cls, config, is_split=True): # TODO not support TP split now, next PR will support TP. from paddleformers.transformers.conversion_utils import split_or_merge_func fn = split_or_merge_func( is_split=is_split, tensor_model_parallel_size=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, ) def get_tensor_parallel_split_mappings(num_layers, num_experts): final_actions = {} base_actions = { "lm_head.weight": partial(fn, is_column=True), # Row Linear "embed_tokens.weight": partial(fn, is_column=False), "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), } # Column Linear config.fuse_attention_qkv = False if config.fuse_attention_qkv: base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) else: base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) # if we have enough num_key_value_heads to split, then split it. if config.num_key_value_heads % config.tensor_model_parallel_size == 0: base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True) for key, action in base_actions.items(): if "layers.0." in key: for i in range(num_layers): final_actions[key.replace("layers.0.", f"layers.{i}.")] = action final_actions[key] = action base_actions = { "layers.0.mlp.experts.0.gate_proj.weight": partial(fn, is_column=True), "layers.0.mlp.experts.0.down_proj.weight": partial(fn, is_column=False), "layers.0.mlp.experts.0.up_proj.weight": partial(fn, is_column=True), } for key, action in base_actions.items(): for i in range(num_layers): newkey = key.replace("layers.0.", f"layers.{i}.") for j in range(num_experts): newkey2 = newkey.replace("experts.0.", f"experts.{j}.") final_actions[newkey2] = action return final_actions num_experts = 0 if isinstance(config.num_experts, list): num_experts = sum(config.num_experts) elif isinstance(config.num_experts, int): num_experts = config.num_experts else: raise ValueError(f"Not support type of num_experts [{type(config.num_experts)}]") mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers, num_experts) return mappings