[loader]supoort wint2 backend (#6139)

* support wint2

* update
This commit is contained in:
bukejiyu
2026-02-09 14:42:36 +08:00
committed by GitHub
parent f18f3b99ed
commit dc5917289d
20 changed files with 86 additions and 11 deletions
+2 -2
View File
@@ -556,7 +556,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.tp_size > 1 and output_dim is not None:
if self.tp_size > 1 and output_dim is not None and not self.fd_config.load_config.is_pre_sharded:
dim = -1 if output_dim else 0
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim]
@@ -713,7 +713,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.tp_size > 1 and output_dim is not None:
if self.tp_size > 1 and output_dim is not None and not self.fd_config.load_config.is_pre_sharded:
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size