[BugFix]fix v1 loader lm head fp32 (#5270)

This commit is contained in:
chen
2025-11-27 20:12:56 +08:00
committed by GitHub
parent b52ec268f7
commit 35f85baf09
8 changed files with 24 additions and 9 deletions
@@ -600,7 +600,9 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
self.lm_head.linear.weight.set_value(
self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype)
)
def compute_logits(self, hidden_states: paddle.Tensor):
logits = self.lm_head(hidden_states)