[RL]Resolve shape mismatch problems in RL-related modules (#5032)

* RL fix

* update
This commit is contained in:
bukejiyu
2025-11-19 11:12:48 +08:00
committed by GitHub
parent 4694ed2a43
commit a82f25ea7b
12 changed files with 61 additions and 87 deletions
+3 -2
View File
@@ -959,7 +959,7 @@ class KVBatchLinear(nn.Layer):
# Split num_attention_heads when using TP inference.
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.fd_config = fd_config
self.kv_b_proj = kv_b_proj
self.weight_dtype = self._helper.get_default_dtype()
@@ -968,7 +968,8 @@ class KVBatchLinear(nn.Layer):
self.weight_key = f"{prefix}.weight" # e.g., "kv_b_proj.weight"
def process_weights_after_loading(self):
if self.fd_config.load_config.dynamic_load_weight:
return
w = self.kv_b_proj.weight.reshape(
[
self.kv_lora_rank,