mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
@@ -556,7 +556,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.tp_size > 1 and output_dim is not None:
|
||||
if self.tp_size > 1 and output_dim is not None and not self.fd_config.load_config.is_pre_sharded:
|
||||
dim = -1 if output_dim else 0
|
||||
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||
size = loaded_weight.shape[dim]
|
||||
@@ -713,7 +713,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.tp_size > 1 and output_dim is not None:
|
||||
if self.tp_size > 1 and output_dim is not None and not self.fd_config.load_config.is_pre_sharded:
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
|
||||
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
||||
shard_offset = shard_id * block_size
|
||||
|
||||
Reference in New Issue
Block a user