mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL]Resolve shape mismatch problems in RL-related modules (#5032)
* RL fix * update
This commit is contained in:
@@ -62,10 +62,13 @@ def load_weights_from_cache(model, weights_iterator):
|
||||
logger.info(f"{loaded_weight_name} is not in model parameters.")
|
||||
continue
|
||||
param = params_dict[loaded_weight_name]
|
||||
if param.shape != loaded_weight.shape:
|
||||
raise ValueError(
|
||||
f"Shape mismatch between loaded weight {loaded_weight_name}: {loaded_weight.shape}, expected shape: {param.shape}"
|
||||
)
|
||||
param.copy_(loaded_weight, False)
|
||||
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
|
||||
model.lm_head.linear.weight.set_value(loaded_weight)
|
||||
model.lm_head.process_weights_after_loading()
|
||||
model.lm_head.linear.weight.set_value(loaded_weight.transpose([1, 0]))
|
||||
for _, model_sublayer in model.named_sublayers():
|
||||
if isinstance(model_sublayer, KVBatchLinear):
|
||||
model_sublayer.process_weights_after_loading()
|
||||
@@ -107,7 +110,6 @@ def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
|
||||
|
||||
weight_cache_context = multi_switch_config_context(
|
||||
(fd_config.quant_config, "is_checkpoint_bf16", False),
|
||||
(fd_config.model_config, "model_format", "paddle"),
|
||||
)
|
||||
|
||||
return enable_cache, weight_cache_dir, weight_cache_context
|
||||
|
||||
Reference in New Issue
Block a user