mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[RL]Resolve shape mismatch problems in RL-related modules (#5032)
* RL fix * update
This commit is contained in:
@@ -959,7 +959,7 @@ class KVBatchLinear(nn.Layer):
|
||||
# Split num_attention_heads when using TP inference.
|
||||
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
self.fd_config = fd_config
|
||||
self.kv_b_proj = kv_b_proj
|
||||
|
||||
self.weight_dtype = self._helper.get_default_dtype()
|
||||
@@ -968,7 +968,8 @@ class KVBatchLinear(nn.Layer):
|
||||
self.weight_key = f"{prefix}.weight" # e.g., "kv_b_proj.weight"
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
return
|
||||
w = self.kv_b_proj.weight.reshape(
|
||||
[
|
||||
self.kv_lora_rank,
|
||||
|
||||
Reference in New Issue
Block a user