[RL]Resolve shape mismatch problems in RL-related modules (#5032)

* RL fix

* update
This commit is contained in:
bukejiyu
2025-11-19 11:12:48 +08:00
committed by GitHub
parent 4694ed2a43
commit a82f25ea7b
12 changed files with 61 additions and 87 deletions
@@ -62,10 +62,13 @@ def load_weights_from_cache(model, weights_iterator):
logger.info(f"{loaded_weight_name} is not in model parameters.")
continue
param = params_dict[loaded_weight_name]
if param.shape != loaded_weight.shape:
raise ValueError(
f"Shape mismatch between loaded weight {loaded_weight_name}: {loaded_weight.shape}, expected shape: {param.shape}"
)
param.copy_(loaded_weight, False)
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
model.lm_head.linear.weight.set_value(loaded_weight)
model.lm_head.process_weights_after_loading()
model.lm_head.linear.weight.set_value(loaded_weight.transpose([1, 0]))
for _, model_sublayer in model.named_sublayers():
if isinstance(model_sublayer, KVBatchLinear):
model_sublayer.process_weights_after_loading()
@@ -107,7 +110,6 @@ def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
weight_cache_context = multi_switch_config_context(
(fd_config.quant_config, "is_checkpoint_bf16", False),
(fd_config.model_config, "model_format", "paddle"),
)
return enable_cache, weight_cache_dir, weight_cache_context