diff --git a/fastdeploy/model_executor/models/glm4_mtp.py b/fastdeploy/model_executor/models/glm4_mtp.py index d16632c2b4..c28023202d 100644 --- a/fastdeploy/model_executor/models/glm4_mtp.py +++ b/fastdeploy/model_executor/models/glm4_mtp.py @@ -28,8 +28,6 @@ from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( support_graph_optimization, ) -from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding -from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.models.glm4_moe import Glm4MoeDecoderLayer @@ -119,12 +117,8 @@ class SharedHead(nn.Module): eps=fd_config.model_config.rms_norm_eps, prefix=f"{prefix}.shared_head.norm", ) - self.head = ParallelLMHead( - fd_config, - embedding_dim=fd_config.model_config.hidden_size, - num_embeddings=fd_config.model_config.vocab_size, - prefix=f"{prefix}.shared_head.head", - ) + if fd_config.speculative_config.sharing_model is not None: + self.head = fd_config.speculative_config.sharing_model.lm_head def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: # NOTE(wangyanpeng04): Just for compute logits @@ -216,15 +210,8 @@ class Glm4MTPModel(nn.Layer): assert self.num_mtp_layers == 1, f"Currently only supports single MTP layer, but got {self.num_mtp_layers}" - self.embed_tokens = VocabParallelEmbedding( - fd_config=fd_config, - num_embeddings=fd_config.model_config.vocab_size, - embedding_dim=fd_config.model_config.hidden_size, - params_dtype=paddle.get_default_dtype(), - prefix=( - f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{self.mtp_start_layer_idx}.embed_tokens" - ), - ) + if fd_config.speculative_config.sharing_model is not None: + self.embed_tokens = fd_config.speculative_config.sharing_model.model.embed_tokens self.layers = nn.LayerDict( { diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index e4dd3ace5a..27c7b54236 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -18,8 +18,10 @@ import copy from typing import Dict import paddle +import paddle.distributed as dist from paddle import nn +from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.models.ernie4_5_moe import ( @@ -34,6 +36,10 @@ from fastdeploy.model_executor.models.glm4_moe import ( Glm4MoeForCausalLM, Glm4MoePretrainedModel, ) +from fastdeploy.model_executor.models.glm4_mtp import ( + Glm4MTPForCausalLM, + Glm4MTPPretrainedModel, +) from fastdeploy.model_executor.models.model_base import ModelRegistry from fastdeploy.model_executor.models.qwen2 import ( Qwen2ForCausalLM, @@ -698,12 +704,42 @@ class Glm4MoeForCausalLMRL(Glm4MoeForCausalLM, BaseRLModel): fd_config (FDConfig): Configurations for the LLM model. """ super(Glm4MoeForCausalLMRL, self).__init__(fd_config) + self.num_nextn_predict_layers = fd_config.model_config.num_nextn_predict_layers + + if self.num_nextn_predict_layers > 0: + fd_config.parallel_config.tp_group = None + fd_config.parallel_config.ep_group = None + self.mtp_fd_config = copy.deepcopy(fd_config) + fd_config.parallel_config.tp_group = dist.get_group( + fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET + ) + fd_config.parallel_config.ep_group = dist.get_group( + fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET + ) + self.fd_config.parallel_config.tp_group = dist.get_group( + fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET + ) + self.fd_config.parallel_config.ep_group = dist.get_group( + fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET + ) + self.update_mtp_config(self.mtp_fd_config) + self.mtp_layers = Glm4MTPForCausalLMRL(self.mtp_fd_config) @classmethod def name(self) -> str: """name""" return "Glm4MoeForCausalLMRL" + def update_mtp_config(self, mtp_fd_config): + mtp_fd_config.model_config.architectures[0] = mtp_fd_config.model_config.architectures[0].replace("Moe", "MTP") + mtp_fd_config.speculative_config.sharing_model = None + mtp_fd_config.model_config.start_layer_index = mtp_fd_config.model_config.num_hidden_layers + mtp_fd_config.model_config.num_hidden_layers = 1 + mtp_fd_config.model_config.model = mtp_fd_config.speculative_config.model + if mtp_fd_config.speculative_config.quantization != "": + mtp_fd_config.model_config.quantization = mtp_fd_config.speculative_config.quantization + mtp_fd_config.speculative_config.model_type = "mtp" + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" if self._mappings_built: @@ -757,9 +793,106 @@ class Glm4MoeForCausalLMRL(Glm4MoeForCausalLM, BaseRLModel): _add_layer_mappings(layer_idx) self._complete_missing_mappings() + + # extra for mtp + if self.num_nextn_predict_layers > 0: + mtp_infer_to_train_mapping = self.mtp_layers.get_name_mappings_to_training(trainer_degree) + self.infer_to_train_mapping.update(mtp_infer_to_train_mapping) + infer_to_train_mapping_copy = copy.deepcopy(self.infer_to_train_mapping) for key in infer_to_train_mapping_copy.keys(): if "mlp.experts.gate_correction_bias" in key: self.infer_to_train_mapping.pop(key) return self.infer_to_train_mapping + + +class Glm4MTPForCausalLMRL(Glm4MTPForCausalLM, BaseRLModel): + """ + Glm4MTPForCausalLMRL + """ + + _get_tensor_parallel_mappings = Glm4MTPPretrainedModel._get_tensor_parallel_mappings + + def __init__(self, fd_config: FDConfig): + """ + Args: + fd_config (FDConfig): Configurations for the LLM model. + """ + super(Glm4MTPForCausalLMRL, self).__init__(fd_config) + + @classmethod + def name(self) -> str: + """name""" + return "Glm4MTPForCausalLMRL" + + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: + """Generate mapping between inference and training parameter for RL(donot delete!).""" + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True + + # Prepare placeholders + place_holders = ["weight"] + + base_name = "model.layers" + + # Helper function to add layer mappings + def _add_layer_mappings(layer_idx: int): + # MTP specific mappings + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.shared_head.head.weight"] = ( + f"{base_name}.{layer_idx}.shared_head.head.weight" + ) + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.shared_head.norm.weight"] = ( + f"{base_name}.{layer_idx}.shared_head.norm.weight" + ) + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.eh_proj.weight"] = ( + f"{base_name}.{layer_idx}.eh_proj.weight" + ) + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.enorm.weight"] = ( + f"{base_name}.{layer_idx}.enorm.weight" + ) + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.hnorm.weight"] = ( + f"{base_name}.{layer_idx}.hnorm.weight" + ) + + # MoE specific mappings + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.weight"] = ( + f"{base_name}.{layer_idx}.mlp.gate.weight" + ) + + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias"] = ( + f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias" + ) + + # MoE experts mappings + for expert_idx in range(self.fd_config.model_config.n_routed_experts): + for ph in place_holders: + # up_gate_proj (up_gate_proj) + up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.experts.up_gate_proj_weight" + if up_gate_proj_key not in self.infer_to_train_mapping: + self.infer_to_train_mapping[up_gate_proj_key] = [] + self.infer_to_train_mapping[up_gate_proj_key].append( + f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}" + ) + + # down_proj (down_proj) + down_proj_key = f"{base_name}.{layer_idx}.mlp.experts.down_proj_weight" + if down_proj_key not in self.infer_to_train_mapping: + self.infer_to_train_mapping[down_proj_key] = [] + self.infer_to_train_mapping[down_proj_key].append( + f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}" + ) + + # Process MoE layers + for layer_idx in range( + self.fd_config.model_config.start_layer_index, + self.fd_config.model_config.start_layer_index + self.fd_config.model_config.num_nextn_predict_layers, + ): + _add_layer_mappings(layer_idx) + + self._complete_missing_mappings() + + return self.infer_to_train_mapping diff --git a/tests/ci_use/GLM-45-AIR/baseline.txt b/tests/ci_use/GLM-45-AIR/baseline.txt index bddb29fdac..4ebb05f0ce 100644 --- a/tests/ci_use/GLM-45-AIR/baseline.txt +++ b/tests/ci_use/GLM-45-AIR/baseline.txt @@ -2,12 +2,26 @@ lm_head.linear.weight lm_head.linear.weight:lm_head.weight model.embed_tokens.embeddings.weight model.embed_tokens.embeddings.weight:model.embed_tokens.weight +model.layers.0.eh_proj.linear.weight:model.layers.0.eh_proj.linear.weight +model.layers.0.enorm.weight:model.layers.0.enorm.weight +model.layers.0.hnorm.weight:model.layers.0.hnorm.weight model.layers.0.input_layernorm.weight model.layers.0.input_layernorm.weight:model.layers.0.input_layernorm.weight model.layers.0.mlp.down_proj.weight model.layers.0.mlp.down_proj.weight:model.layers.0.mlp.down_proj.weight model.layers.0.mlp.up_gate_proj.weight model.layers.0.mlp.up_gate_proj.weight:model.layers.0.mlp.up_gate_proj.weight +model.layers.0.mtp_block.input_layernorm.weight:model.layers.0.mtp_block.input_layernorm.weight +model.layers.0.mtp_block.mlp.experts.down_proj_weight:model.layers.0.mtp_block.mlp.experts.down_proj_weight +model.layers.0.mtp_block.mlp.experts.up_gate_proj_weight:model.layers.0.mtp_block.mlp.experts.up_gate_proj_weight +model.layers.0.mtp_block.mlp.gate.e_score_correction_bias:model.layers.0.mtp_block.mlp.gate.e_score_correction_bias +model.layers.0.mtp_block.mlp.gate.weight:model.layers.0.mtp_block.mlp.gate.weight +model.layers.0.mtp_block.mlp.shared_experts.down_proj.weight:model.layers.0.mtp_block.mlp.shared_experts.down_proj.weight +model.layers.0.mtp_block.mlp.shared_experts.up_gate_proj.weight:model.layers.0.mtp_block.mlp.shared_experts.up_gate_proj.weight +model.layers.0.mtp_block.post_attention_layernorm.weight:model.layers.0.mtp_block.post_attention_layernorm.weight +model.layers.0.mtp_block.self_attn.o_proj.weight:model.layers.0.mtp_block.self_attn.o_proj.weight +model.layers.0.mtp_block.self_attn.qkv_proj.bias:model.layers.0.mtp_block.self_attn.qkv_proj.bias +model.layers.0.mtp_block.self_attn.qkv_proj.weight:model.layers.0.mtp_block.self_attn.qkv_proj.weight model.layers.0.post_attention_layernorm.weight model.layers.0.post_attention_layernorm.weight:model.layers.0.post_attention_layernorm.weight model.layers.0.self_attn.o_proj.weight @@ -16,6 +30,7 @@ model.layers.0.self_attn.qkv_proj.bias model.layers.0.self_attn.qkv_proj.bias:model.layers.0.self_attn.qkv_proj.bias model.layers.0.self_attn.qkv_proj.weight model.layers.0.self_attn.qkv_proj.weight:model.layers.0.self_attn.qkv_proj.weight +model.layers.0.shared_head.norm.weight:model.layers.0.shared_head.norm.weight model.layers.1.input_layernorm.weight model.layers.1.input_layernorm.weight:model.layers.1.input_layernorm.weight model.layers.1.mlp.experts.down_proj_weight @@ -39,5 +54,45 @@ model.layers.1.self_attn.qkv_proj.bias model.layers.1.self_attn.qkv_proj.bias:model.layers.1.self_attn.qkv_proj.bias model.layers.1.self_attn.qkv_proj.weight model.layers.1.self_attn.qkv_proj.weight:model.layers.1.self_attn.qkv_proj.weight +model.layers.2.eh_proj.weight:model.layers.2.eh_proj.weight +model.layers.2.enorm.weight:model.layers.2.enorm.weight +model.layers.2.hnorm.weight:model.layers.2.hnorm.weight +model.layers.2.mlp.experts.down_proj_weight:['model.layers.2.mlp.experts.0.down_proj.weight', 'model.layers.2.mlp.experts.1.down_proj.weight', 'model.layers.2.mlp.experts.2.down_proj.weight', 'model.layers.2.mlp.experts.3.down_proj.weight', 'model.layers.2.mlp.experts.4.down_proj.weight', 'model.layers.2.mlp.experts.5.down_proj.weight', 'model.layers.2.mlp.experts.6.down_proj.weight', 'model.layers.2.mlp.experts.7.down_proj.weight', 'model.layers.2.mlp.experts.8.down_proj.weight', 'model.layers.2.mlp.experts.9.down_proj.weight', 'model.layers.2.mlp.experts.10.down_proj.weight', 'model.layers.2.mlp.experts.11.down_proj.weight', 'model.layers.2.mlp.experts.12.down_proj.weight', 'model.layers.2.mlp.experts.13.down_proj.weight', 'model.layers.2.mlp.experts.14.down_proj.weight', 'model.layers.2.mlp.experts.15.down_proj.weight', 'model.layers.2.mlp.experts.16.down_proj.weight', 'model.layers.2.mlp.experts.17.down_proj.weight', 'model.layers.2.mlp.experts.18.down_proj.weight', 'model.layers.2.mlp.experts.19.down_proj.weight', 'model.layers.2.mlp.experts.20.down_proj.weight', 'model.layers.2.mlp.experts.21.down_proj.weight', 'model.layers.2.mlp.experts.22.down_proj.weight', 'model.layers.2.mlp.experts.23.down_proj.weight', 'model.layers.2.mlp.experts.24.down_proj.weight', 'model.layers.2.mlp.experts.25.down_proj.weight', 'model.layers.2.mlp.experts.26.down_proj.weight', 'model.layers.2.mlp.experts.27.down_proj.weight', 'model.layers.2.mlp.experts.28.down_proj.weight', 'model.layers.2.mlp.experts.29.down_proj.weight', 'model.layers.2.mlp.experts.30.down_proj.weight', 'model.layers.2.mlp.experts.31.down_proj.weight', 'model.layers.2.mlp.experts.32.down_proj.weight', 'model.layers.2.mlp.experts.33.down_proj.weight', 'model.layers.2.mlp.experts.34.down_proj.weight', 'model.layers.2.mlp.experts.35.down_proj.weight', 'model.layers.2.mlp.experts.36.down_proj.weight', 'model.layers.2.mlp.experts.37.down_proj.weight', 'model.layers.2.mlp.experts.38.down_proj.weight', 'model.layers.2.mlp.experts.39.down_proj.weight', 'model.layers.2.mlp.experts.40.down_proj.weight', 'model.layers.2.mlp.experts.41.down_proj.weight', 'model.layers.2.mlp.experts.42.down_proj.weight', 'model.layers.2.mlp.experts.43.down_proj.weight', 'model.layers.2.mlp.experts.44.down_proj.weight', 'model.layers.2.mlp.experts.45.down_proj.weight', 'model.layers.2.mlp.experts.46.down_proj.weight', 'model.layers.2.mlp.experts.47.down_proj.weight', 'model.layers.2.mlp.experts.48.down_proj.weight', 'model.layers.2.mlp.experts.49.down_proj.weight', 'model.layers.2.mlp.experts.50.down_proj.weight', 'model.layers.2.mlp.experts.51.down_proj.weight', 'model.layers.2.mlp.experts.52.down_proj.weight', 'model.layers.2.mlp.experts.53.down_proj.weight', 'model.layers.2.mlp.experts.54.down_proj.weight', 'model.layers.2.mlp.experts.55.down_proj.weight', 'model.layers.2.mlp.experts.56.down_proj.weight', 'model.layers.2.mlp.experts.57.down_proj.weight', 'model.layers.2.mlp.experts.58.down_proj.weight', 'model.layers.2.mlp.experts.59.down_proj.weight', 'model.layers.2.mlp.experts.60.down_proj.weight', 'model.layers.2.mlp.experts.61.down_proj.weight', 'model.layers.2.mlp.experts.62.down_proj.weight', 'model.layers.2.mlp.experts.63.down_proj.weight', 'model.layers.2.mlp.experts.64.down_proj.weight', 'model.layers.2.mlp.experts.65.down_proj.weight', 'model.layers.2.mlp.experts.66.down_proj.weight', 'model.layers.2.mlp.experts.67.down_proj.weight', 'model.layers.2.mlp.experts.68.down_proj.weight', 'model.layers.2.mlp.experts.69.down_proj.weight', 'model.layers.2.mlp.experts.70.down_proj.weight', 'model.layers.2.mlp.experts.71.down_proj.weight', 'model.layers.2.mlp.experts.72.down_proj.weight', 'model.layers.2.mlp.experts.73.down_proj.weight', 'model.layers.2.mlp.experts.74.down_proj.weight', 'model.layers.2.mlp.experts.75.down_proj.weight', 'model.layers.2.mlp.experts.76.down_proj.weight', 'model.layers.2.mlp.experts.77.down_proj.weight', 'model.layers.2.mlp.experts.78.down_proj.weight', 'model.layers.2.mlp.experts.79.down_proj.weight', 'model.layers.2.mlp.experts.80.down_proj.weight', 'model.layers.2.mlp.experts.81.down_proj.weight', 'model.layers.2.mlp.experts.82.down_proj.weight', 'model.layers.2.mlp.experts.83.down_proj.weight', 'model.layers.2.mlp.experts.84.down_proj.weight', 'model.layers.2.mlp.experts.85.down_proj.weight', 'model.layers.2.mlp.experts.86.down_proj.weight', 'model.layers.2.mlp.experts.87.down_proj.weight', 'model.layers.2.mlp.experts.88.down_proj.weight', 'model.layers.2.mlp.experts.89.down_proj.weight', 'model.layers.2.mlp.experts.90.down_proj.weight', 'model.layers.2.mlp.experts.91.down_proj.weight', 'model.layers.2.mlp.experts.92.down_proj.weight', 'model.layers.2.mlp.experts.93.down_proj.weight', 'model.layers.2.mlp.experts.94.down_proj.weight', 'model.layers.2.mlp.experts.95.down_proj.weight', 'model.layers.2.mlp.experts.96.down_proj.weight', 'model.layers.2.mlp.experts.97.down_proj.weight', 'model.layers.2.mlp.experts.98.down_proj.weight', 'model.layers.2.mlp.experts.99.down_proj.weight', 'model.layers.2.mlp.experts.100.down_proj.weight', 'model.layers.2.mlp.experts.101.down_proj.weight', 'model.layers.2.mlp.experts.102.down_proj.weight', 'model.layers.2.mlp.experts.103.down_proj.weight', 'model.layers.2.mlp.experts.104.down_proj.weight', 'model.layers.2.mlp.experts.105.down_proj.weight', 'model.layers.2.mlp.experts.106.down_proj.weight', 'model.layers.2.mlp.experts.107.down_proj.weight', 'model.layers.2.mlp.experts.108.down_proj.weight', 'model.layers.2.mlp.experts.109.down_proj.weight', 'model.layers.2.mlp.experts.110.down_proj.weight', 'model.layers.2.mlp.experts.111.down_proj.weight', 'model.layers.2.mlp.experts.112.down_proj.weight', 'model.layers.2.mlp.experts.113.down_proj.weight', 'model.layers.2.mlp.experts.114.down_proj.weight', 'model.layers.2.mlp.experts.115.down_proj.weight', 'model.layers.2.mlp.experts.116.down_proj.weight', 'model.layers.2.mlp.experts.117.down_proj.weight', 'model.layers.2.mlp.experts.118.down_proj.weight', 'model.layers.2.mlp.experts.119.down_proj.weight', 'model.layers.2.mlp.experts.120.down_proj.weight', 'model.layers.2.mlp.experts.121.down_proj.weight', 'model.layers.2.mlp.experts.122.down_proj.weight', 'model.layers.2.mlp.experts.123.down_proj.weight', 'model.layers.2.mlp.experts.124.down_proj.weight', 'model.layers.2.mlp.experts.125.down_proj.weight', 'model.layers.2.mlp.experts.126.down_proj.weight', 'model.layers.2.mlp.experts.127.down_proj.weight'] +model.layers.2.mlp.experts.up_gate_proj_weight:['model.layers.2.mlp.experts.0.up_gate_proj.weight', 'model.layers.2.mlp.experts.1.up_gate_proj.weight', 'model.layers.2.mlp.experts.2.up_gate_proj.weight', 'model.layers.2.mlp.experts.3.up_gate_proj.weight', 'model.layers.2.mlp.experts.4.up_gate_proj.weight', 'model.layers.2.mlp.experts.5.up_gate_proj.weight', 'model.layers.2.mlp.experts.6.up_gate_proj.weight', 'model.layers.2.mlp.experts.7.up_gate_proj.weight', 'model.layers.2.mlp.experts.8.up_gate_proj.weight', 'model.layers.2.mlp.experts.9.up_gate_proj.weight', 'model.layers.2.mlp.experts.10.up_gate_proj.weight', 'model.layers.2.mlp.experts.11.up_gate_proj.weight', 'model.layers.2.mlp.experts.12.up_gate_proj.weight', 'model.layers.2.mlp.experts.13.up_gate_proj.weight', 'model.layers.2.mlp.experts.14.up_gate_proj.weight', 'model.layers.2.mlp.experts.15.up_gate_proj.weight', 'model.layers.2.mlp.experts.16.up_gate_proj.weight', 'model.layers.2.mlp.experts.17.up_gate_proj.weight', 'model.layers.2.mlp.experts.18.up_gate_proj.weight', 'model.layers.2.mlp.experts.19.up_gate_proj.weight', 'model.layers.2.mlp.experts.20.up_gate_proj.weight', 'model.layers.2.mlp.experts.21.up_gate_proj.weight', 'model.layers.2.mlp.experts.22.up_gate_proj.weight', 'model.layers.2.mlp.experts.23.up_gate_proj.weight', 'model.layers.2.mlp.experts.24.up_gate_proj.weight', 'model.layers.2.mlp.experts.25.up_gate_proj.weight', 'model.layers.2.mlp.experts.26.up_gate_proj.weight', 'model.layers.2.mlp.experts.27.up_gate_proj.weight', 'model.layers.2.mlp.experts.28.up_gate_proj.weight', 'model.layers.2.mlp.experts.29.up_gate_proj.weight', 'model.layers.2.mlp.experts.30.up_gate_proj.weight', 'model.layers.2.mlp.experts.31.up_gate_proj.weight', 'model.layers.2.mlp.experts.32.up_gate_proj.weight', 'model.layers.2.mlp.experts.33.up_gate_proj.weight', 'model.layers.2.mlp.experts.34.up_gate_proj.weight', 'model.layers.2.mlp.experts.35.up_gate_proj.weight', 'model.layers.2.mlp.experts.36.up_gate_proj.weight', 'model.layers.2.mlp.experts.37.up_gate_proj.weight', 'model.layers.2.mlp.experts.38.up_gate_proj.weight', 'model.layers.2.mlp.experts.39.up_gate_proj.weight', 'model.layers.2.mlp.experts.40.up_gate_proj.weight', 'model.layers.2.mlp.experts.41.up_gate_proj.weight', 'model.layers.2.mlp.experts.42.up_gate_proj.weight', 'model.layers.2.mlp.experts.43.up_gate_proj.weight', 'model.layers.2.mlp.experts.44.up_gate_proj.weight', 'model.layers.2.mlp.experts.45.up_gate_proj.weight', 'model.layers.2.mlp.experts.46.up_gate_proj.weight', 'model.layers.2.mlp.experts.47.up_gate_proj.weight', 'model.layers.2.mlp.experts.48.up_gate_proj.weight', 'model.layers.2.mlp.experts.49.up_gate_proj.weight', 'model.layers.2.mlp.experts.50.up_gate_proj.weight', 'model.layers.2.mlp.experts.51.up_gate_proj.weight', 'model.layers.2.mlp.experts.52.up_gate_proj.weight', 'model.layers.2.mlp.experts.53.up_gate_proj.weight', 'model.layers.2.mlp.experts.54.up_gate_proj.weight', 'model.layers.2.mlp.experts.55.up_gate_proj.weight', 'model.layers.2.mlp.experts.56.up_gate_proj.weight', 'model.layers.2.mlp.experts.57.up_gate_proj.weight', 'model.layers.2.mlp.experts.58.up_gate_proj.weight', 'model.layers.2.mlp.experts.59.up_gate_proj.weight', 'model.layers.2.mlp.experts.60.up_gate_proj.weight', 'model.layers.2.mlp.experts.61.up_gate_proj.weight', 'model.layers.2.mlp.experts.62.up_gate_proj.weight', 'model.layers.2.mlp.experts.63.up_gate_proj.weight', 'model.layers.2.mlp.experts.64.up_gate_proj.weight', 'model.layers.2.mlp.experts.65.up_gate_proj.weight', 'model.layers.2.mlp.experts.66.up_gate_proj.weight', 'model.layers.2.mlp.experts.67.up_gate_proj.weight', 'model.layers.2.mlp.experts.68.up_gate_proj.weight', 'model.layers.2.mlp.experts.69.up_gate_proj.weight', 'model.layers.2.mlp.experts.70.up_gate_proj.weight', 'model.layers.2.mlp.experts.71.up_gate_proj.weight', 'model.layers.2.mlp.experts.72.up_gate_proj.weight', 'model.layers.2.mlp.experts.73.up_gate_proj.weight', 'model.layers.2.mlp.experts.74.up_gate_proj.weight', 'model.layers.2.mlp.experts.75.up_gate_proj.weight', 'model.layers.2.mlp.experts.76.up_gate_proj.weight', 'model.layers.2.mlp.experts.77.up_gate_proj.weight', 'model.layers.2.mlp.experts.78.up_gate_proj.weight', 'model.layers.2.mlp.experts.79.up_gate_proj.weight', 'model.layers.2.mlp.experts.80.up_gate_proj.weight', 'model.layers.2.mlp.experts.81.up_gate_proj.weight', 'model.layers.2.mlp.experts.82.up_gate_proj.weight', 'model.layers.2.mlp.experts.83.up_gate_proj.weight', 'model.layers.2.mlp.experts.84.up_gate_proj.weight', 'model.layers.2.mlp.experts.85.up_gate_proj.weight', 'model.layers.2.mlp.experts.86.up_gate_proj.weight', 'model.layers.2.mlp.experts.87.up_gate_proj.weight', 'model.layers.2.mlp.experts.88.up_gate_proj.weight', 'model.layers.2.mlp.experts.89.up_gate_proj.weight', 'model.layers.2.mlp.experts.90.up_gate_proj.weight', 'model.layers.2.mlp.experts.91.up_gate_proj.weight', 'model.layers.2.mlp.experts.92.up_gate_proj.weight', 'model.layers.2.mlp.experts.93.up_gate_proj.weight', 'model.layers.2.mlp.experts.94.up_gate_proj.weight', 'model.layers.2.mlp.experts.95.up_gate_proj.weight', 'model.layers.2.mlp.experts.96.up_gate_proj.weight', 'model.layers.2.mlp.experts.97.up_gate_proj.weight', 'model.layers.2.mlp.experts.98.up_gate_proj.weight', 'model.layers.2.mlp.experts.99.up_gate_proj.weight', 'model.layers.2.mlp.experts.100.up_gate_proj.weight', 'model.layers.2.mlp.experts.101.up_gate_proj.weight', 'model.layers.2.mlp.experts.102.up_gate_proj.weight', 'model.layers.2.mlp.experts.103.up_gate_proj.weight', 'model.layers.2.mlp.experts.104.up_gate_proj.weight', 'model.layers.2.mlp.experts.105.up_gate_proj.weight', 'model.layers.2.mlp.experts.106.up_gate_proj.weight', 'model.layers.2.mlp.experts.107.up_gate_proj.weight', 'model.layers.2.mlp.experts.108.up_gate_proj.weight', 'model.layers.2.mlp.experts.109.up_gate_proj.weight', 'model.layers.2.mlp.experts.110.up_gate_proj.weight', 'model.layers.2.mlp.experts.111.up_gate_proj.weight', 'model.layers.2.mlp.experts.112.up_gate_proj.weight', 'model.layers.2.mlp.experts.113.up_gate_proj.weight', 'model.layers.2.mlp.experts.114.up_gate_proj.weight', 'model.layers.2.mlp.experts.115.up_gate_proj.weight', 'model.layers.2.mlp.experts.116.up_gate_proj.weight', 'model.layers.2.mlp.experts.117.up_gate_proj.weight', 'model.layers.2.mlp.experts.118.up_gate_proj.weight', 'model.layers.2.mlp.experts.119.up_gate_proj.weight', 'model.layers.2.mlp.experts.120.up_gate_proj.weight', 'model.layers.2.mlp.experts.121.up_gate_proj.weight', 'model.layers.2.mlp.experts.122.up_gate_proj.weight', 'model.layers.2.mlp.experts.123.up_gate_proj.weight', 'model.layers.2.mlp.experts.124.up_gate_proj.weight', 'model.layers.2.mlp.experts.125.up_gate_proj.weight', 'model.layers.2.mlp.experts.126.up_gate_proj.weight', 'model.layers.2.mlp.experts.127.up_gate_proj.weight'] +model.layers.2.mlp.gate.e_score_correction_bias:model.layers.2.mlp.gate.e_score_correction_bias +model.layers.2.mlp.gate.weight:model.layers.2.mlp.gate.weight +model.layers.2.shared_head.head.weight:model.layers.2.shared_head.head.weight +model.layers.2.shared_head.norm.weight:model.layers.2.shared_head.norm.weight model.norm.weight model.norm.weight:model.norm.weight +mtp_layers.model.layers.0.eh_proj.linear.weight +mtp_layers.model.layers.0.eh_proj.linear.weight:mtp_layers.model.layers.0.eh_proj.linear.weight +mtp_layers.model.layers.0.enorm.weight +mtp_layers.model.layers.0.enorm.weight:mtp_layers.model.layers.0.enorm.weight +mtp_layers.model.layers.0.hnorm.weight +mtp_layers.model.layers.0.hnorm.weight:mtp_layers.model.layers.0.hnorm.weight +mtp_layers.model.layers.0.mtp_block.input_layernorm.weight +mtp_layers.model.layers.0.mtp_block.input_layernorm.weight:mtp_layers.model.layers.0.mtp_block.input_layernorm.weight +mtp_layers.model.layers.0.mtp_block.mlp.experts.down_proj_weight +mtp_layers.model.layers.0.mtp_block.mlp.experts.down_proj_weight:mtp_layers.model.layers.0.mtp_block.mlp.experts.down_proj_weight +mtp_layers.model.layers.0.mtp_block.mlp.experts.gate_correction_bias +mtp_layers.model.layers.0.mtp_block.mlp.experts.up_gate_proj_weight +mtp_layers.model.layers.0.mtp_block.mlp.experts.up_gate_proj_weight:mtp_layers.model.layers.0.mtp_block.mlp.experts.up_gate_proj_weight +mtp_layers.model.layers.0.mtp_block.mlp.gate.e_score_correction_bias +mtp_layers.model.layers.0.mtp_block.mlp.gate.e_score_correction_bias:mtp_layers.model.layers.0.mtp_block.mlp.gate.e_score_correction_bias +mtp_layers.model.layers.0.mtp_block.mlp.gate.weight +mtp_layers.model.layers.0.mtp_block.mlp.gate.weight:mtp_layers.model.layers.0.mtp_block.mlp.gate.weight +mtp_layers.model.layers.0.mtp_block.mlp.shared_experts.down_proj.weight +mtp_layers.model.layers.0.mtp_block.mlp.shared_experts.down_proj.weight:mtp_layers.model.layers.0.mtp_block.mlp.shared_experts.down_proj.weight +mtp_layers.model.layers.0.mtp_block.mlp.shared_experts.up_gate_proj.weight +mtp_layers.model.layers.0.mtp_block.mlp.shared_experts.up_gate_proj.weight:mtp_layers.model.layers.0.mtp_block.mlp.shared_experts.up_gate_proj.weight +mtp_layers.model.layers.0.mtp_block.post_attention_layernorm.weight +mtp_layers.model.layers.0.mtp_block.post_attention_layernorm.weight:mtp_layers.model.layers.0.mtp_block.post_attention_layernorm.weight +mtp_layers.model.layers.0.mtp_block.self_attn.o_proj.weight +mtp_layers.model.layers.0.mtp_block.self_attn.o_proj.weight:mtp_layers.model.layers.0.mtp_block.self_attn.o_proj.weight +mtp_layers.model.layers.0.mtp_block.self_attn.qkv_proj.bias +mtp_layers.model.layers.0.mtp_block.self_attn.qkv_proj.bias:mtp_layers.model.layers.0.mtp_block.self_attn.qkv_proj.bias +mtp_layers.model.layers.0.mtp_block.self_attn.qkv_proj.weight +mtp_layers.model.layers.0.mtp_block.self_attn.qkv_proj.weight:mtp_layers.model.layers.0.mtp_block.self_attn.qkv_proj.weight +mtp_layers.model.layers.0.shared_head.norm.weight +mtp_layers.model.layers.0.shared_head.norm.weight:mtp_layers.model.layers.0.shared_head.norm.weight diff --git a/tests/rl/test_rollout_model.py b/tests/rl/test_rollout_model.py index 4d720bcc2a..589362192b 100644 --- a/tests/rl/test_rollout_model.py +++ b/tests/rl/test_rollout_model.py @@ -205,6 +205,7 @@ def test_glm4moe_mapping_removes_gate_correction(): "model.layers.0.mlp.experts.gate_correction_bias", ], ) + dummy.num_nextn_predict_layers = 0 mappings = dummy.get_name_mappings_to_training() # Cover gate/experts aggregation and dropping gate_correction_bias assert "model.layers.0.mlp.experts.up_gate_proj_weight" in mappings