[RL]Resolve shape mismatch problems in RL-related modules (#5032)

* RL fix

* update
This commit is contained in:
bukejiyu
2025-11-19 11:12:48 +08:00
committed by GitHub
parent 4694ed2a43
commit a82f25ea7b
12 changed files with 61 additions and 87 deletions
+1 -1
View File
@@ -320,7 +320,7 @@ class Qwen3ForCausalLM(ModelForCasualLM):
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings and not is_pooling_model:
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight)
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
@paddle.no_grad()
def set_state_dict(self, state_dict):