mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix]fix v1 loader lm head fp32 (#5270)
This commit is contained in:
@@ -68,7 +68,9 @@ def load_weights_from_cache(model, weights_iterator):
|
||||
)
|
||||
param.copy_(loaded_weight, False)
|
||||
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
|
||||
model.lm_head.linear.weight.set_value(loaded_weight.transpose([1, 0]))
|
||||
model.lm_head.linear.weight.set_value(
|
||||
loaded_weight.transpose([1, 0]).astype(model.lm_head.linear.weight.dtype)
|
||||
)
|
||||
for _, model_sublayer in model.named_sublayers():
|
||||
if isinstance(model_sublayer, KVBatchLinear):
|
||||
model_sublayer.process_weights_after_loading()
|
||||
|
||||
Reference in New Issue
Block a user