diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index b4b284ab87..f2aac041cf 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -193,9 +193,6 @@ class FlashAttentionBackend(AttentionBackend): elif metadata._dtype == "float32": metadata._fuse_kernel_compute_dtype = "fp32" - metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu) - metadata.max_len_tensor_cpu_decoder[1] = 0 - forward_meta.attention_metadata = metadata def forward_mixed( @@ -241,6 +238,10 @@ class FlashAttentionBackend(AttentionBackend): ) if forward_meta.max_len_tensor_cpu[1].item() > 0: + + metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu) + metadata.max_len_tensor_cpu_decoder[1] = 0 + ( metadata.cu_seqlens_k, metadata.pre_cache_batch_ids, @@ -309,7 +310,7 @@ class FlashAttentionBackend(AttentionBackend): qkv, forward_meta.caches[2 * layer.layer_id], forward_meta.caches[2 * layer.layer_id + 1], - self.zero_seq_enc_lens_for_decode if use_fa_do_prefill else forward_meta.seq_lens_encoder, + forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, forward_meta.batch_id_per_token, diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index abc2d1a52c..9cdba35d16 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -166,9 +166,6 @@ class FlashMaskAttentionBackend(AttentionBackend): elif metadata._dtype == "float32": metadata._fuse_kernel_compute_dtype = "fp32" - metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu) - metadata.max_len_tensor_cpu_decoder[1] = 0 - forward_meta.attention_metadata = metadata def forward_mixed( @@ -222,6 +219,10 @@ class FlashMaskAttentionBackend(AttentionBackend): # here we add five members,this is ugly, just for now. if forward_meta.max_len_tensor_cpu[1].item() > 0: + + metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu) + metadata.max_len_tensor_cpu_decoder[1] = 0 + ( forward_meta.attn_cu_seqlens_k, forward_meta.pre_cache_batch_ids, @@ -293,7 +294,7 @@ class FlashMaskAttentionBackend(AttentionBackend): qkv, forward_meta.caches[2 * layer.layer_id], forward_meta.caches[2 * layer.layer_id + 1], - self.zero_seq_enc_lens_for_decode if use_fa_do_prefill else forward_meta.seq_lens_encoder, + forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, forward_meta.batch_id_per_token, diff --git a/fastdeploy/model_executor/models/glm4_mtp.py b/fastdeploy/model_executor/models/glm4_mtp.py index d16632c2b4..c28023202d 100644 --- a/fastdeploy/model_executor/models/glm4_mtp.py +++ b/fastdeploy/model_executor/models/glm4_mtp.py @@ -28,8 +28,6 @@ from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( support_graph_optimization, ) -from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding -from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.models.glm4_moe import Glm4MoeDecoderLayer @@ -119,12 +117,8 @@ class SharedHead(nn.Module): eps=fd_config.model_config.rms_norm_eps, prefix=f"{prefix}.shared_head.norm", ) - self.head = ParallelLMHead( - fd_config, - embedding_dim=fd_config.model_config.hidden_size, - num_embeddings=fd_config.model_config.vocab_size, - prefix=f"{prefix}.shared_head.head", - ) + if fd_config.speculative_config.sharing_model is not None: + self.head = fd_config.speculative_config.sharing_model.lm_head def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: # NOTE(wangyanpeng04): Just for compute logits @@ -216,15 +210,8 @@ class Glm4MTPModel(nn.Layer): assert self.num_mtp_layers == 1, f"Currently only supports single MTP layer, but got {self.num_mtp_layers}" - self.embed_tokens = VocabParallelEmbedding( - fd_config=fd_config, - num_embeddings=fd_config.model_config.vocab_size, - embedding_dim=fd_config.model_config.hidden_size, - params_dtype=paddle.get_default_dtype(), - prefix=( - f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{self.mtp_start_layer_idx}.embed_tokens" - ), - ) + if fd_config.speculative_config.sharing_model is not None: + self.embed_tokens = fd_config.speculative_config.sharing_model.model.embed_tokens self.layers = nn.LayerDict( { diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index a9c2ed027b..cade135508 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -46,7 +46,7 @@ class RolloutModelConfig: enable_chunked_prefill: bool = False, speculative_method: str = None, speculative_max_draft_token_num: int = 1, - speculative_model_name_or_path: str = "", + speculative_model_name_or_path: str = None, speculative_model_quantization: str = "WINT8", max_num_batched_tokens: int = 2048, enable_prefix_caching: bool = False, @@ -96,7 +96,9 @@ class RolloutModelConfig: self.speculative_config = {} self.speculative_config["method"] = speculative_method self.speculative_config["max_draft_token_num"] = speculative_max_draft_token_num - self.speculative_config["model"] = speculative_model_name_or_path + self.speculative_config["model"] = ( + speculative_model_name_or_path if speculative_model_name_or_path is not None else model_name_or_path + ) self.speculative_config["quantization"] = speculative_model_quantization self.max_num_batched_tokens = max_num_batched_tokens self.enable_prefix_caching = enable_prefix_caching diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index e4dd3ace5a..3a0ee13ddd 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -18,8 +18,10 @@ import copy from typing import Dict import paddle +import paddle.distributed as dist from paddle import nn +from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.models.ernie4_5_moe import ( @@ -34,6 +36,10 @@ from fastdeploy.model_executor.models.glm4_moe import ( Glm4MoeForCausalLM, Glm4MoePretrainedModel, ) +from fastdeploy.model_executor.models.glm4_mtp import ( + Glm4MTPForCausalLM, + Glm4MTPPretrainedModel, +) from fastdeploy.model_executor.models.model_base import ModelRegistry from fastdeploy.model_executor.models.qwen2 import ( Qwen2ForCausalLM, @@ -698,12 +704,52 @@ class Glm4MoeForCausalLMRL(Glm4MoeForCausalLM, BaseRLModel): fd_config (FDConfig): Configurations for the LLM model. """ super(Glm4MoeForCausalLMRL, self).__init__(fd_config) + self.speculative_decoding = fd_config.speculative_config.method is not None + self.speculative_method = fd_config.speculative_config.method + + if self.speculative_decoding and self.speculative_method == "mtp": + fd_config.parallel_config.tp_group = None + fd_config.parallel_config.ep_group = None + self.mtp_fd_config = copy.deepcopy(fd_config) + fd_config.parallel_config.tp_group = dist.get_group( + fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET + ) + fd_config.parallel_config.ep_group = dist.get_group( + fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET + ) + self.fd_config.parallel_config.tp_group = dist.get_group( + fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET + ) + self.fd_config.parallel_config.ep_group = dist.get_group( + fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET + ) + self.update_mtp_config(self.mtp_fd_config) + self.mtp_layers = Glm4MTPForCausalLMRL(self.mtp_fd_config) @classmethod def name(self) -> str: """name""" return "Glm4MoeForCausalLMRL" + def update_mtp_config(self, mtp_fd_config): + mtp_fd_config.model_config.architectures[0] = mtp_fd_config.model_config.architectures[0].replace("Moe", "MTP") + mtp_fd_config.speculative_config.sharing_model = None + mtp_fd_config.model_config.num_hidden_layers = 1 + mtp_fd_config.model_config.model = mtp_fd_config.speculative_config.model + if mtp_fd_config.speculative_config.quantization != "": + mtp_fd_config.model_config.quantization = mtp_fd_config.speculative_config.quantization + mtp_fd_config.model_config.start_layer_index = mtp_fd_config.model_config.num_hidden_layers + mtp_fd_config.speculative_config.model_type = "mtp" + + def state_dict(self): + """state_dict""" + main_state_dict = super().state_dict() + state_dict = {k: v for k, v in main_state_dict.items() if not k.startswith("mtp_layers")} + if self.speculative_decoding and self.speculative_method == "mtp": + mtp_state_dict = self.mtp_layers.state_dict() + state_dict.update(mtp_state_dict) + return state_dict + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" if self._mappings_built: @@ -757,9 +803,106 @@ class Glm4MoeForCausalLMRL(Glm4MoeForCausalLM, BaseRLModel): _add_layer_mappings(layer_idx) self._complete_missing_mappings() + + # extra for mtp + if self.speculative_decoding and self.speculative_method == "mtp": + mtp_infer_to_train_mapping = self.mtp_layers.get_name_mappings_to_training(trainer_degree) + self.infer_to_train_mapping.update(mtp_infer_to_train_mapping) + infer_to_train_mapping_copy = copy.deepcopy(self.infer_to_train_mapping) for key in infer_to_train_mapping_copy.keys(): if "mlp.experts.gate_correction_bias" in key: self.infer_to_train_mapping.pop(key) return self.infer_to_train_mapping + + +class Glm4MTPForCausalLMRL(Glm4MTPForCausalLM, BaseRLModel): + """ + Glm4MTPForCausalLMRL + """ + + _get_tensor_parallel_mappings = Glm4MTPPretrainedModel._get_tensor_parallel_mappings + + def __init__(self, fd_config: FDConfig): + """ + Args: + fd_config (FDConfig): Configurations for the LLM model. + """ + super(Glm4MTPForCausalLMRL, self).__init__(fd_config) + + @classmethod + def name(self) -> str: + """name""" + return "Glm4MTPForCausalLMRL" + + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: + """Generate mapping between inference and training parameter for RL(donot delete!).""" + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True + + # Prepare placeholders + place_holders = ["weight"] + + base_name = "model.layers" + + # Helper function to add layer mappings + def _add_layer_mappings(layer_idx: int): + # MTP specific mappings + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.shared_head.head.weight"] = ( + f"{base_name}.{layer_idx}.shared_head.head.weight" + ) + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.shared_head.norm.weight"] = ( + f"{base_name}.{layer_idx}.shared_head.norm.weight" + ) + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.eh_proj.weight"] = ( + f"{base_name}.{layer_idx}.eh_proj.weight" + ) + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.enorm.weight"] = ( + f"{base_name}.{layer_idx}.enorm.weight" + ) + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.hnorm.weight"] = ( + f"{base_name}.{layer_idx}.hnorm.weight" + ) + + # MoE specific mappings + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.weight"] = ( + f"{base_name}.{layer_idx}.mlp.gate.weight" + ) + + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias"] = ( + f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias" + ) + + # MoE experts mappings + for expert_idx in range(self.fd_config.model_config.n_routed_experts): + for ph in place_holders: + # up_gate_proj (up_gate_proj) + up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.experts.up_gate_proj_weight" + if up_gate_proj_key not in self.infer_to_train_mapping: + self.infer_to_train_mapping[up_gate_proj_key] = [] + self.infer_to_train_mapping[up_gate_proj_key].append( + f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}" + ) + + # down_proj (down_proj) + down_proj_key = f"{base_name}.{layer_idx}.mlp.experts.down_proj_weight" + if down_proj_key not in self.infer_to_train_mapping: + self.infer_to_train_mapping[down_proj_key] = [] + self.infer_to_train_mapping[down_proj_key].append( + f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}" + ) + + # Process MoE layers + for layer_idx in range( + self.fd_config.model_config.start_layer_index, + self.fd_config.model_config.start_layer_index + self.fd_config.model_config.num_nextn_predict_layers, + ): + _add_layer_mappings(layer_idx) + + self._complete_missing_mappings() + + return self.infer_to_train_mapping diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index eb5e8c12e4..94813313a3 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1382,12 +1382,6 @@ class GPUModelRunner(ModelRunnerBase): model_loader = get_model_loader(load_config=self.fd_config.load_config) self.model = model_loader.load_model(fd_config=self.fd_config) - # 1.1 Load RL dynamic model - if self.fd_config.load_config.dynamic_load_weight: - from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager - - self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model, self.local_rank) - # 2. Load lora model # 3. Load drafter model(for speculative decoding) @@ -1395,6 +1389,17 @@ class GPUModelRunner(ModelRunnerBase): # 4. Init proposer for speculative method self._init_speculative_proposer() + # Load RL dynamic model + if self.fd_config.load_config.dynamic_load_weight: + from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager + + if self.fd_config.speculative_config.method == "mtp": + self.dynamic_weight_manager = DynamicWeightManager( + self.fd_config, [self.model, self.proposer.model], self.local_rank + ) + else: + self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model, self.local_rank) + def get_model(self) -> nn.Layer: """Get current model""" return self.model diff --git a/tests/ci_use/GLM-45-AIR/baseline_mtp.txt b/tests/ci_use/GLM-45-AIR/baseline_mtp.txt new file mode 100644 index 0000000000..7afa09e9e7 --- /dev/null +++ b/tests/ci_use/GLM-45-AIR/baseline_mtp.txt @@ -0,0 +1,79 @@ +lm_head.linear.weight +lm_head.linear.weight:lm_head.weight +model.embed_tokens.embeddings.weight +model.embed_tokens.embeddings.weight:model.embed_tokens.weight +model.layers.0.eh_proj.linear.weight +model.layers.0.eh_proj.linear.weight:model.layers.0.eh_proj.linear.weight +model.layers.0.enorm.weight +model.layers.0.enorm.weight:model.layers.0.enorm.weight +model.layers.0.hnorm.weight +model.layers.0.hnorm.weight:model.layers.0.hnorm.weight +model.layers.0.input_layernorm.weight +model.layers.0.input_layernorm.weight:model.layers.0.input_layernorm.weight +model.layers.0.mlp.down_proj.weight +model.layers.0.mlp.down_proj.weight:model.layers.0.mlp.down_proj.weight +model.layers.0.mlp.up_gate_proj.weight +model.layers.0.mlp.up_gate_proj.weight:model.layers.0.mlp.up_gate_proj.weight +model.layers.0.mtp_block.input_layernorm.weight +model.layers.0.mtp_block.input_layernorm.weight:model.layers.0.mtp_block.input_layernorm.weight +model.layers.0.mtp_block.mlp.experts.down_proj_weight +model.layers.0.mtp_block.mlp.experts.down_proj_weight:model.layers.0.mtp_block.mlp.experts.down_proj_weight +model.layers.0.mtp_block.mlp.experts.gate_correction_bias +model.layers.0.mtp_block.mlp.experts.up_gate_proj_weight +model.layers.0.mtp_block.mlp.experts.up_gate_proj_weight:model.layers.0.mtp_block.mlp.experts.up_gate_proj_weight +model.layers.0.mtp_block.mlp.gate.e_score_correction_bias +model.layers.0.mtp_block.mlp.gate.e_score_correction_bias:model.layers.0.mtp_block.mlp.gate.e_score_correction_bias +model.layers.0.mtp_block.mlp.gate.weight +model.layers.0.mtp_block.mlp.gate.weight:model.layers.0.mtp_block.mlp.gate.weight +model.layers.0.mtp_block.mlp.shared_experts.down_proj.weight +model.layers.0.mtp_block.mlp.shared_experts.down_proj.weight:model.layers.0.mtp_block.mlp.shared_experts.down_proj.weight +model.layers.0.mtp_block.mlp.shared_experts.up_gate_proj.weight +model.layers.0.mtp_block.mlp.shared_experts.up_gate_proj.weight:model.layers.0.mtp_block.mlp.shared_experts.up_gate_proj.weight +model.layers.0.mtp_block.post_attention_layernorm.weight +model.layers.0.mtp_block.post_attention_layernorm.weight:model.layers.0.mtp_block.post_attention_layernorm.weight +model.layers.0.mtp_block.self_attn.o_proj.weight +model.layers.0.mtp_block.self_attn.o_proj.weight:model.layers.0.mtp_block.self_attn.o_proj.weight +model.layers.0.mtp_block.self_attn.qkv_proj.bias +model.layers.0.mtp_block.self_attn.qkv_proj.bias:model.layers.0.mtp_block.self_attn.qkv_proj.bias +model.layers.0.mtp_block.self_attn.qkv_proj.weight +model.layers.0.mtp_block.self_attn.qkv_proj.weight:model.layers.0.mtp_block.self_attn.qkv_proj.weight +model.layers.0.post_attention_layernorm.weight +model.layers.0.post_attention_layernorm.weight:model.layers.0.post_attention_layernorm.weight +model.layers.0.self_attn.o_proj.weight +model.layers.0.self_attn.o_proj.weight:model.layers.0.self_attn.o_proj.weight +model.layers.0.self_attn.qkv_proj.bias +model.layers.0.self_attn.qkv_proj.bias:model.layers.0.self_attn.qkv_proj.bias +model.layers.0.self_attn.qkv_proj.weight +model.layers.0.self_attn.qkv_proj.weight:model.layers.0.self_attn.qkv_proj.weight +model.layers.0.shared_head.norm.weight +model.layers.0.shared_head.norm.weight:model.layers.0.shared_head.norm.weight +model.layers.1.eh_proj.weight:model.layers.1.eh_proj.weight +model.layers.1.enorm.weight:model.layers.1.enorm.weight +model.layers.1.hnorm.weight:model.layers.1.hnorm.weight +model.layers.1.input_layernorm.weight +model.layers.1.input_layernorm.weight:model.layers.1.input_layernorm.weight +model.layers.1.mlp.experts.down_proj_weight +model.layers.1.mlp.experts.down_proj_weight:['model.layers.1.mlp.experts.0.down_proj.weight', 'model.layers.1.mlp.experts.1.down_proj.weight', 'model.layers.1.mlp.experts.2.down_proj.weight', 'model.layers.1.mlp.experts.3.down_proj.weight', 'model.layers.1.mlp.experts.4.down_proj.weight', 'model.layers.1.mlp.experts.5.down_proj.weight', 'model.layers.1.mlp.experts.6.down_proj.weight', 'model.layers.1.mlp.experts.7.down_proj.weight', 'model.layers.1.mlp.experts.8.down_proj.weight', 'model.layers.1.mlp.experts.9.down_proj.weight', 'model.layers.1.mlp.experts.10.down_proj.weight', 'model.layers.1.mlp.experts.11.down_proj.weight', 'model.layers.1.mlp.experts.12.down_proj.weight', 'model.layers.1.mlp.experts.13.down_proj.weight', 'model.layers.1.mlp.experts.14.down_proj.weight', 'model.layers.1.mlp.experts.15.down_proj.weight', 'model.layers.1.mlp.experts.16.down_proj.weight', 'model.layers.1.mlp.experts.17.down_proj.weight', 'model.layers.1.mlp.experts.18.down_proj.weight', 'model.layers.1.mlp.experts.19.down_proj.weight', 'model.layers.1.mlp.experts.20.down_proj.weight', 'model.layers.1.mlp.experts.21.down_proj.weight', 'model.layers.1.mlp.experts.22.down_proj.weight', 'model.layers.1.mlp.experts.23.down_proj.weight', 'model.layers.1.mlp.experts.24.down_proj.weight', 'model.layers.1.mlp.experts.25.down_proj.weight', 'model.layers.1.mlp.experts.26.down_proj.weight', 'model.layers.1.mlp.experts.27.down_proj.weight', 'model.layers.1.mlp.experts.28.down_proj.weight', 'model.layers.1.mlp.experts.29.down_proj.weight', 'model.layers.1.mlp.experts.30.down_proj.weight', 'model.layers.1.mlp.experts.31.down_proj.weight', 'model.layers.1.mlp.experts.32.down_proj.weight', 'model.layers.1.mlp.experts.33.down_proj.weight', 'model.layers.1.mlp.experts.34.down_proj.weight', 'model.layers.1.mlp.experts.35.down_proj.weight', 'model.layers.1.mlp.experts.36.down_proj.weight', 'model.layers.1.mlp.experts.37.down_proj.weight', 'model.layers.1.mlp.experts.38.down_proj.weight', 'model.layers.1.mlp.experts.39.down_proj.weight', 'model.layers.1.mlp.experts.40.down_proj.weight', 'model.layers.1.mlp.experts.41.down_proj.weight', 'model.layers.1.mlp.experts.42.down_proj.weight', 'model.layers.1.mlp.experts.43.down_proj.weight', 'model.layers.1.mlp.experts.44.down_proj.weight', 'model.layers.1.mlp.experts.45.down_proj.weight', 'model.layers.1.mlp.experts.46.down_proj.weight', 'model.layers.1.mlp.experts.47.down_proj.weight', 'model.layers.1.mlp.experts.48.down_proj.weight', 'model.layers.1.mlp.experts.49.down_proj.weight', 'model.layers.1.mlp.experts.50.down_proj.weight', 'model.layers.1.mlp.experts.51.down_proj.weight', 'model.layers.1.mlp.experts.52.down_proj.weight', 'model.layers.1.mlp.experts.53.down_proj.weight', 'model.layers.1.mlp.experts.54.down_proj.weight', 'model.layers.1.mlp.experts.55.down_proj.weight', 'model.layers.1.mlp.experts.56.down_proj.weight', 'model.layers.1.mlp.experts.57.down_proj.weight', 'model.layers.1.mlp.experts.58.down_proj.weight', 'model.layers.1.mlp.experts.59.down_proj.weight', 'model.layers.1.mlp.experts.60.down_proj.weight', 'model.layers.1.mlp.experts.61.down_proj.weight', 'model.layers.1.mlp.experts.62.down_proj.weight', 'model.layers.1.mlp.experts.63.down_proj.weight', 'model.layers.1.mlp.experts.64.down_proj.weight', 'model.layers.1.mlp.experts.65.down_proj.weight', 'model.layers.1.mlp.experts.66.down_proj.weight', 'model.layers.1.mlp.experts.67.down_proj.weight', 'model.layers.1.mlp.experts.68.down_proj.weight', 'model.layers.1.mlp.experts.69.down_proj.weight', 'model.layers.1.mlp.experts.70.down_proj.weight', 'model.layers.1.mlp.experts.71.down_proj.weight', 'model.layers.1.mlp.experts.72.down_proj.weight', 'model.layers.1.mlp.experts.73.down_proj.weight', 'model.layers.1.mlp.experts.74.down_proj.weight', 'model.layers.1.mlp.experts.75.down_proj.weight', 'model.layers.1.mlp.experts.76.down_proj.weight', 'model.layers.1.mlp.experts.77.down_proj.weight', 'model.layers.1.mlp.experts.78.down_proj.weight', 'model.layers.1.mlp.experts.79.down_proj.weight', 'model.layers.1.mlp.experts.80.down_proj.weight', 'model.layers.1.mlp.experts.81.down_proj.weight', 'model.layers.1.mlp.experts.82.down_proj.weight', 'model.layers.1.mlp.experts.83.down_proj.weight', 'model.layers.1.mlp.experts.84.down_proj.weight', 'model.layers.1.mlp.experts.85.down_proj.weight', 'model.layers.1.mlp.experts.86.down_proj.weight', 'model.layers.1.mlp.experts.87.down_proj.weight', 'model.layers.1.mlp.experts.88.down_proj.weight', 'model.layers.1.mlp.experts.89.down_proj.weight', 'model.layers.1.mlp.experts.90.down_proj.weight', 'model.layers.1.mlp.experts.91.down_proj.weight', 'model.layers.1.mlp.experts.92.down_proj.weight', 'model.layers.1.mlp.experts.93.down_proj.weight', 'model.layers.1.mlp.experts.94.down_proj.weight', 'model.layers.1.mlp.experts.95.down_proj.weight', 'model.layers.1.mlp.experts.96.down_proj.weight', 'model.layers.1.mlp.experts.97.down_proj.weight', 'model.layers.1.mlp.experts.98.down_proj.weight', 'model.layers.1.mlp.experts.99.down_proj.weight', 'model.layers.1.mlp.experts.100.down_proj.weight', 'model.layers.1.mlp.experts.101.down_proj.weight', 'model.layers.1.mlp.experts.102.down_proj.weight', 'model.layers.1.mlp.experts.103.down_proj.weight', 'model.layers.1.mlp.experts.104.down_proj.weight', 'model.layers.1.mlp.experts.105.down_proj.weight', 'model.layers.1.mlp.experts.106.down_proj.weight', 'model.layers.1.mlp.experts.107.down_proj.weight', 'model.layers.1.mlp.experts.108.down_proj.weight', 'model.layers.1.mlp.experts.109.down_proj.weight', 'model.layers.1.mlp.experts.110.down_proj.weight', 'model.layers.1.mlp.experts.111.down_proj.weight', 'model.layers.1.mlp.experts.112.down_proj.weight', 'model.layers.1.mlp.experts.113.down_proj.weight', 'model.layers.1.mlp.experts.114.down_proj.weight', 'model.layers.1.mlp.experts.115.down_proj.weight', 'model.layers.1.mlp.experts.116.down_proj.weight', 'model.layers.1.mlp.experts.117.down_proj.weight', 'model.layers.1.mlp.experts.118.down_proj.weight', 'model.layers.1.mlp.experts.119.down_proj.weight', 'model.layers.1.mlp.experts.120.down_proj.weight', 'model.layers.1.mlp.experts.121.down_proj.weight', 'model.layers.1.mlp.experts.122.down_proj.weight', 'model.layers.1.mlp.experts.123.down_proj.weight', 'model.layers.1.mlp.experts.124.down_proj.weight', 'model.layers.1.mlp.experts.125.down_proj.weight', 'model.layers.1.mlp.experts.126.down_proj.weight', 'model.layers.1.mlp.experts.127.down_proj.weight'] +model.layers.1.mlp.experts.gate_correction_bias +model.layers.1.mlp.experts.up_gate_proj_weight +model.layers.1.mlp.experts.up_gate_proj_weight:['model.layers.1.mlp.experts.0.up_gate_proj.weight', 'model.layers.1.mlp.experts.1.up_gate_proj.weight', 'model.layers.1.mlp.experts.2.up_gate_proj.weight', 'model.layers.1.mlp.experts.3.up_gate_proj.weight', 'model.layers.1.mlp.experts.4.up_gate_proj.weight', 'model.layers.1.mlp.experts.5.up_gate_proj.weight', 'model.layers.1.mlp.experts.6.up_gate_proj.weight', 'model.layers.1.mlp.experts.7.up_gate_proj.weight', 'model.layers.1.mlp.experts.8.up_gate_proj.weight', 'model.layers.1.mlp.experts.9.up_gate_proj.weight', 'model.layers.1.mlp.experts.10.up_gate_proj.weight', 'model.layers.1.mlp.experts.11.up_gate_proj.weight', 'model.layers.1.mlp.experts.12.up_gate_proj.weight', 'model.layers.1.mlp.experts.13.up_gate_proj.weight', 'model.layers.1.mlp.experts.14.up_gate_proj.weight', 'model.layers.1.mlp.experts.15.up_gate_proj.weight', 'model.layers.1.mlp.experts.16.up_gate_proj.weight', 'model.layers.1.mlp.experts.17.up_gate_proj.weight', 'model.layers.1.mlp.experts.18.up_gate_proj.weight', 'model.layers.1.mlp.experts.19.up_gate_proj.weight', 'model.layers.1.mlp.experts.20.up_gate_proj.weight', 'model.layers.1.mlp.experts.21.up_gate_proj.weight', 'model.layers.1.mlp.experts.22.up_gate_proj.weight', 'model.layers.1.mlp.experts.23.up_gate_proj.weight', 'model.layers.1.mlp.experts.24.up_gate_proj.weight', 'model.layers.1.mlp.experts.25.up_gate_proj.weight', 'model.layers.1.mlp.experts.26.up_gate_proj.weight', 'model.layers.1.mlp.experts.27.up_gate_proj.weight', 'model.layers.1.mlp.experts.28.up_gate_proj.weight', 'model.layers.1.mlp.experts.29.up_gate_proj.weight', 'model.layers.1.mlp.experts.30.up_gate_proj.weight', 'model.layers.1.mlp.experts.31.up_gate_proj.weight', 'model.layers.1.mlp.experts.32.up_gate_proj.weight', 'model.layers.1.mlp.experts.33.up_gate_proj.weight', 'model.layers.1.mlp.experts.34.up_gate_proj.weight', 'model.layers.1.mlp.experts.35.up_gate_proj.weight', 'model.layers.1.mlp.experts.36.up_gate_proj.weight', 'model.layers.1.mlp.experts.37.up_gate_proj.weight', 'model.layers.1.mlp.experts.38.up_gate_proj.weight', 'model.layers.1.mlp.experts.39.up_gate_proj.weight', 'model.layers.1.mlp.experts.40.up_gate_proj.weight', 'model.layers.1.mlp.experts.41.up_gate_proj.weight', 'model.layers.1.mlp.experts.42.up_gate_proj.weight', 'model.layers.1.mlp.experts.43.up_gate_proj.weight', 'model.layers.1.mlp.experts.44.up_gate_proj.weight', 'model.layers.1.mlp.experts.45.up_gate_proj.weight', 'model.layers.1.mlp.experts.46.up_gate_proj.weight', 'model.layers.1.mlp.experts.47.up_gate_proj.weight', 'model.layers.1.mlp.experts.48.up_gate_proj.weight', 'model.layers.1.mlp.experts.49.up_gate_proj.weight', 'model.layers.1.mlp.experts.50.up_gate_proj.weight', 'model.layers.1.mlp.experts.51.up_gate_proj.weight', 'model.layers.1.mlp.experts.52.up_gate_proj.weight', 'model.layers.1.mlp.experts.53.up_gate_proj.weight', 'model.layers.1.mlp.experts.54.up_gate_proj.weight', 'model.layers.1.mlp.experts.55.up_gate_proj.weight', 'model.layers.1.mlp.experts.56.up_gate_proj.weight', 'model.layers.1.mlp.experts.57.up_gate_proj.weight', 'model.layers.1.mlp.experts.58.up_gate_proj.weight', 'model.layers.1.mlp.experts.59.up_gate_proj.weight', 'model.layers.1.mlp.experts.60.up_gate_proj.weight', 'model.layers.1.mlp.experts.61.up_gate_proj.weight', 'model.layers.1.mlp.experts.62.up_gate_proj.weight', 'model.layers.1.mlp.experts.63.up_gate_proj.weight', 'model.layers.1.mlp.experts.64.up_gate_proj.weight', 'model.layers.1.mlp.experts.65.up_gate_proj.weight', 'model.layers.1.mlp.experts.66.up_gate_proj.weight', 'model.layers.1.mlp.experts.67.up_gate_proj.weight', 'model.layers.1.mlp.experts.68.up_gate_proj.weight', 'model.layers.1.mlp.experts.69.up_gate_proj.weight', 'model.layers.1.mlp.experts.70.up_gate_proj.weight', 'model.layers.1.mlp.experts.71.up_gate_proj.weight', 'model.layers.1.mlp.experts.72.up_gate_proj.weight', 'model.layers.1.mlp.experts.73.up_gate_proj.weight', 'model.layers.1.mlp.experts.74.up_gate_proj.weight', 'model.layers.1.mlp.experts.75.up_gate_proj.weight', 'model.layers.1.mlp.experts.76.up_gate_proj.weight', 'model.layers.1.mlp.experts.77.up_gate_proj.weight', 'model.layers.1.mlp.experts.78.up_gate_proj.weight', 'model.layers.1.mlp.experts.79.up_gate_proj.weight', 'model.layers.1.mlp.experts.80.up_gate_proj.weight', 'model.layers.1.mlp.experts.81.up_gate_proj.weight', 'model.layers.1.mlp.experts.82.up_gate_proj.weight', 'model.layers.1.mlp.experts.83.up_gate_proj.weight', 'model.layers.1.mlp.experts.84.up_gate_proj.weight', 'model.layers.1.mlp.experts.85.up_gate_proj.weight', 'model.layers.1.mlp.experts.86.up_gate_proj.weight', 'model.layers.1.mlp.experts.87.up_gate_proj.weight', 'model.layers.1.mlp.experts.88.up_gate_proj.weight', 'model.layers.1.mlp.experts.89.up_gate_proj.weight', 'model.layers.1.mlp.experts.90.up_gate_proj.weight', 'model.layers.1.mlp.experts.91.up_gate_proj.weight', 'model.layers.1.mlp.experts.92.up_gate_proj.weight', 'model.layers.1.mlp.experts.93.up_gate_proj.weight', 'model.layers.1.mlp.experts.94.up_gate_proj.weight', 'model.layers.1.mlp.experts.95.up_gate_proj.weight', 'model.layers.1.mlp.experts.96.up_gate_proj.weight', 'model.layers.1.mlp.experts.97.up_gate_proj.weight', 'model.layers.1.mlp.experts.98.up_gate_proj.weight', 'model.layers.1.mlp.experts.99.up_gate_proj.weight', 'model.layers.1.mlp.experts.100.up_gate_proj.weight', 'model.layers.1.mlp.experts.101.up_gate_proj.weight', 'model.layers.1.mlp.experts.102.up_gate_proj.weight', 'model.layers.1.mlp.experts.103.up_gate_proj.weight', 'model.layers.1.mlp.experts.104.up_gate_proj.weight', 'model.layers.1.mlp.experts.105.up_gate_proj.weight', 'model.layers.1.mlp.experts.106.up_gate_proj.weight', 'model.layers.1.mlp.experts.107.up_gate_proj.weight', 'model.layers.1.mlp.experts.108.up_gate_proj.weight', 'model.layers.1.mlp.experts.109.up_gate_proj.weight', 'model.layers.1.mlp.experts.110.up_gate_proj.weight', 'model.layers.1.mlp.experts.111.up_gate_proj.weight', 'model.layers.1.mlp.experts.112.up_gate_proj.weight', 'model.layers.1.mlp.experts.113.up_gate_proj.weight', 'model.layers.1.mlp.experts.114.up_gate_proj.weight', 'model.layers.1.mlp.experts.115.up_gate_proj.weight', 'model.layers.1.mlp.experts.116.up_gate_proj.weight', 'model.layers.1.mlp.experts.117.up_gate_proj.weight', 'model.layers.1.mlp.experts.118.up_gate_proj.weight', 'model.layers.1.mlp.experts.119.up_gate_proj.weight', 'model.layers.1.mlp.experts.120.up_gate_proj.weight', 'model.layers.1.mlp.experts.121.up_gate_proj.weight', 'model.layers.1.mlp.experts.122.up_gate_proj.weight', 'model.layers.1.mlp.experts.123.up_gate_proj.weight', 'model.layers.1.mlp.experts.124.up_gate_proj.weight', 'model.layers.1.mlp.experts.125.up_gate_proj.weight', 'model.layers.1.mlp.experts.126.up_gate_proj.weight', 'model.layers.1.mlp.experts.127.up_gate_proj.weight'] +model.layers.1.mlp.gate.e_score_correction_bias +model.layers.1.mlp.gate.e_score_correction_bias:model.layers.1.mlp.gate.e_score_correction_bias +model.layers.1.mlp.gate.weight +model.layers.1.mlp.gate.weight:model.layers.1.mlp.gate.weight +model.layers.1.mlp.shared_experts.down_proj.weight +model.layers.1.mlp.shared_experts.down_proj.weight:model.layers.1.mlp.shared_experts.down_proj.weight +model.layers.1.mlp.shared_experts.up_gate_proj.weight +model.layers.1.mlp.shared_experts.up_gate_proj.weight:model.layers.1.mlp.shared_experts.up_gate_proj.weight +model.layers.1.post_attention_layernorm.weight +model.layers.1.post_attention_layernorm.weight:model.layers.1.post_attention_layernorm.weight +model.layers.1.self_attn.o_proj.weight +model.layers.1.self_attn.o_proj.weight:model.layers.1.self_attn.o_proj.weight +model.layers.1.self_attn.qkv_proj.bias +model.layers.1.self_attn.qkv_proj.bias:model.layers.1.self_attn.qkv_proj.bias +model.layers.1.self_attn.qkv_proj.weight +model.layers.1.self_attn.qkv_proj.weight:model.layers.1.self_attn.qkv_proj.weight +model.layers.1.shared_head.head.weight:model.layers.1.shared_head.head.weight +model.layers.1.shared_head.norm.weight:model.layers.1.shared_head.norm.weight +model.norm.weight +model.norm.weight:model.norm.weight diff --git a/tests/ci_use/GLM-45-AIR/test_rollout_model.py b/tests/ci_use/GLM-45-AIR/test_rollout_model.py index 71c94f66ac..02e2df6210 100644 --- a/tests/ci_use/GLM-45-AIR/test_rollout_model.py +++ b/tests/ci_use/GLM-45-AIR/test_rollout_model.py @@ -64,3 +64,53 @@ def test_rollout_model_with_distributed_launch(): print(stderr) assert return_code in (0, 250), f"Process exited with code {return_code}" + + +def test_rollout_model_with_distributed_launch_mtp(): + """ + test_rollout_model + """ + current_dir = os.path.dirname(os.path.abspath(__file__)) + utils_dir = os.path.join(os.path.dirname(current_dir), "utils") + rollout_script = os.path.join(utils_dir, "rollout_model.py") + baseline_path = os.path.join(current_dir, "baseline_mtp.txt") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "GLM-4.5-Air-Fake") + else: + model_path = "./GLM-4.5-Air-Fake" + print(f"model_path = {model_path}") + + command = [ + sys.executable, + "-m", + "paddle.distributed.launch", + "--gpus", + "0,1", + rollout_script, + "--model_path", + model_path, + "--baseline_path", + baseline_path, + "--enable_speculative_decoding", + ] + + print(f"Executing command: {' '.join(command)}") + + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + try: + stdout, stderr = process.communicate(timeout=300) + return_code = process.returncode + except subprocess.TimeoutExpired: + process.kill() + stdout, stderr = process.communicate() + return_code = -1 + + print("\n" + "=" * 50 + " STDOUT " + "=" * 50) + print(stdout) + print("\n" + "=" * 50 + " STDERR " + "=" * 50) + print(stderr) + + assert return_code in (0, 250), f"Process exited with code {return_code}" diff --git a/tests/ci_use/utils/rollout_model.py b/tests/ci_use/utils/rollout_model.py index dc3b4aef36..ca3c89c5e2 100644 --- a/tests/ci_use/utils/rollout_model.py +++ b/tests/ci_use/utils/rollout_model.py @@ -26,6 +26,9 @@ parser.add_argument("--model_path", type=str, required=True, help="Path to the m parser.add_argument("--baseline_path", type=str, required=True, help="Path to the baseline path") parser.add_argument("--quantization", type=str, default=None, help="Quantization") parser.add_argument("--enable_mm", action="store_true", required=False, help="Flags to enable multi-modal model") +parser.add_argument( + "--enable_speculative_decoding", action="store_true", required=False, help="Flags to enable speculative decoding" +) args = parser.parse_args() # base result @@ -42,6 +45,9 @@ init_kwargs = { } if args.enable_mm: init_kwargs["enable_mm"] = True +if args.enable_speculative_decoding: + init_kwargs["speculative_method"] = "mtp" + init_kwargs["num_nextn_predict_layers"] = 1 rollout_config = RolloutModelConfig(**init_kwargs) diff --git a/tests/rl/test_rollout_model.py b/tests/rl/test_rollout_model.py index 4d720bcc2a..c5a034759b 100644 --- a/tests/rl/test_rollout_model.py +++ b/tests/rl/test_rollout_model.py @@ -205,6 +205,7 @@ def test_glm4moe_mapping_removes_gate_correction(): "model.layers.0.mlp.experts.gate_correction_bias", ], ) + dummy.speculative_decoding = False mappings = dummy.get_name_mappings_to_training() # Cover gate/experts aggregation and dropping gate_correction_bias assert "model.layers.0.mlp.experts.up_gate_proj_weight" in mappings