[Cherry-Pick][Loader]Fix bug in MTP weight loading #5744 (#5746)

* fix mtp

* fix

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
bukejiyu
2025-12-25 14:21:23 +08:00
committed by GitHub
parent 47ffaa41b1
commit e627e13808
4 changed files with 23 additions and 46 deletions
@@ -625,7 +625,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
if getattr(self, "tie_word_embeddings", False):
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
def compute_logits(self, hidden_states: paddle.Tensor):
@@ -16,7 +16,6 @@
from __future__ import annotations
import re
from functools import partial
from typing import Dict, Union
@@ -356,7 +355,6 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
self.lm_head = fd_config.speculative_config.sharing_model.lm_head
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
@classmethod
def name(self):
@@ -374,11 +372,6 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.ernie.load_state_dict(state_dict)
# if self.tie_word_embeddings:
# self.lm_head.linear.weight.set_value(
# self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
# else:
# self.lm_head.load_state_dict(state_dict)
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
@@ -388,45 +381,22 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
from fastdeploy.model_executor.models.ernie4_5_moe import (
Ernie4_5_MoeForCausalLM,
)
from fastdeploy.model_executor.utils import remap_weight_keys
all_param_mapping = [
# (param_name, weight_name, expert_id, shard_id)
("embed_tokens.embeddings", "embed_tokens", None, None),
("lm_head.linear", "lm_head", None, None),
("enorm", "mtp_emb_norm.0", None, None),
("hnorm", "mtp_hidden_norm.0", None, None),
("eh_proj.linear", "mtp_linear_proj.0", None, None),
]
params_dict = dict(self.named_parameters())
shard_id = None
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
param = params_dict[model_param_name]
shard_id = shard_id
break
else:
if loaded_weight_name not in params_dict.keys():
continue
model_param_name = loaded_weight_name
param = params_dict[loaded_weight_name]
# Get weight loader from parameter and set weight
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(
r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name
)
process_weights_after_loading_fn(model_sublayer_name, param)
Ernie4_5_MoeForCausalLM.load_weights(
self,
remap_weight_keys(
weights_iterator,
{
"mtp_emb_norm.0": "enorm",
"mtp_hidden_norm.0": "hnorm",
"mtp_linear_proj.0": "eh_proj.linear",
},
),
)
def compute_logits(self, hidden_states: paddle.Tensor):
"""
@@ -724,7 +724,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name
)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
if getattr(self, "tie_word_embeddings", False):
# because we use lazy guard and is not initialized by default
if not self.lm_head.linear.weight._is_initialized():
self.lm_head.linear.weight.initialize()
+7
View File
@@ -166,6 +166,13 @@ class WeightsMapper:
return self._map_name(weight_name)
def remap_weight_keys(weights_iterator, mapper: dict):
return (
(next((key.replace(k, v) for k, v in mapper.items() if k in key), key), value)
for key, value in weights_iterator
)
def process_weights_before_loading(
*, skip_prefixes: Optional[List[str]] = None, mapper: Optional[WeightsMapper] = None
):