[BugFix]fix v1 loader lm head fp32 (#5270)

This commit is contained in:
chen
2025-11-27 20:12:56 +08:00
committed by GitHub
parent b52ec268f7
commit 35f85baf09
8 changed files with 24 additions and 9 deletions
@@ -68,7 +68,9 @@ def load_weights_from_cache(model, weights_iterator):
)
param.copy_(loaded_weight, False)
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
model.lm_head.linear.weight.set_value(loaded_weight.transpose([1, 0]))
model.lm_head.linear.weight.set_value(
loaded_weight.transpose([1, 0]).astype(model.lm_head.linear.weight.dtype)
)
for _, model_sublayer in model.named_sublayers():
if isinstance(model_sublayer, KVBatchLinear):
model_sublayer.process_weights_after_loading()