[BugFix]fix v1 loader lm head fp32 (#5270)

This commit is contained in:
chen
2025-11-27 20:12:56 +08:00
committed by GitHub
parent b52ec268f7
commit 35f85baf09
8 changed files with 24 additions and 9 deletions
+3 -1
View File
@@ -376,7 +376,9 @@ class Qwen2ForCausalLM(ModelForCasualLM):
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.qwen2.embed_tokens.embeddings.weight.transpose([1, 0]))
self.lm_head.linear.weight.set_value(
self.qwen2.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype)
)
@classmethod
def name(self):