mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[Optimization] Refine row parallel bias and nranks and moe all_reduce (#5247)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
* rename nranks to tp_size and fix bias in v1 loader * fix * update
This commit is contained in:
@@ -79,7 +79,6 @@ class UnquantizedLinearMethod(QuantMethodBase):
|
||||
layer.weight.set_value(weights)
|
||||
|
||||
def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor:
|
||||
|
||||
linear_out = paddle.matmul(x, layer.weight)
|
||||
if layer.with_bias:
|
||||
linear_out = paddle.add(linear_out, layer.bias)
|
||||
@@ -423,9 +422,9 @@ class ColumnParallelLinear(LinearBase):
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
self.fd_config = fd_config
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.input_size = input_size
|
||||
self.output_size = divide(output_size, self.nranks) # Split the output_size using TP inference.
|
||||
self.output_size = divide(output_size, self.tp_size) # Split the output_size using TP inference.
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
|
||||
super().__init__(
|
||||
@@ -449,7 +448,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
model_format=fd_config.model_config.model_format,
|
||||
)
|
||||
|
||||
if self.nranks > 0:
|
||||
if self.tp_size > 0:
|
||||
if self.with_bias:
|
||||
# col parallel
|
||||
_set_var_distributed(self.bias, split_axis=1)
|
||||
@@ -492,7 +491,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"""
|
||||
self.activation = activation
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.output_size = output_size
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
@@ -522,8 +521,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
# Loaded weight is already fused on disk.
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("gate", 0, output_size * self.nranks // 2),
|
||||
("up", output_size * self.nranks // 2, output_size * self.nranks // 2),
|
||||
("gate", 0, output_size * self.tp_size // 2),
|
||||
("up", output_size * self.tp_size // 2, output_size * self.tp_size // 2),
|
||||
]
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
loaded_weight_shard = slice_fn(
|
||||
@@ -537,13 +536,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks > 1 and output_dim is not None:
|
||||
if self.tp_size > 1 and output_dim is not None:
|
||||
dim = -1 if output_dim else 0
|
||||
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||
size = loaded_weight.shape[dim]
|
||||
else:
|
||||
size = loaded_weight.get_shape()[dim]
|
||||
block_size = size // self.nranks
|
||||
block_size = size // self.tp_size
|
||||
shard_offset = self.local_rank * block_size
|
||||
shard_size = (self.local_rank + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
|
||||
@@ -635,15 +634,15 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads
|
||||
self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
|
||||
if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0:
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.tp_size)
|
||||
if self.kv_num_heads < self.tp_size and self.tp_size % self.kv_num_heads == 0:
|
||||
self.kv_num_heads_per_rank = 1
|
||||
self.num_kv_head_replicas = divide(self.nranks, self.kv_num_heads)
|
||||
output_size = (self.num_heads + 2 * self.nranks) * self.head_dim
|
||||
self.num_kv_head_replicas = divide(self.tp_size, self.kv_num_heads)
|
||||
output_size = (self.num_heads + 2 * self.tp_size) * self.head_dim
|
||||
else:
|
||||
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
|
||||
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.tp_size)
|
||||
self.num_kv_head_replicas = 1
|
||||
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
|
||||
input_size = self.hidden_size
|
||||
@@ -697,7 +696,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks > 1 and output_dim is not None:
|
||||
if self.tp_size > 1 and output_dim is not None:
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
|
||||
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
||||
shard_offset = shard_id * block_size
|
||||
@@ -750,10 +749,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
k_tensor = get_tensor(state_dict.pop(k_weight_key))
|
||||
v_tensor = get_tensor(state_dict.pop(v_weight_key))
|
||||
|
||||
if self.kv_num_heads < self.nranks:
|
||||
if self.kv_num_heads < self.tp_size:
|
||||
sharedkv_index = (
|
||||
self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads
|
||||
) // self.nranks
|
||||
) // self.tp_size
|
||||
sharedkv_start = sharedkv_index * self.head_dim
|
||||
sharedkv_end = sharedkv_start + self.head_dim
|
||||
k_tensor = k_tensor[:, sharedkv_start:sharedkv_end]
|
||||
@@ -767,10 +766,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
)
|
||||
weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
|
||||
|
||||
if self.fd_config.quant_config:
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
else:
|
||||
self.weight.set_value(weight_tensor)
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
@@ -846,10 +842,8 @@ class RowParallelLinear(LinearBase):
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
self.fd_config = fd_config
|
||||
self.skip_quant = False
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
@@ -863,7 +857,7 @@ class RowParallelLinear(LinearBase):
|
||||
if self.split_token:
|
||||
self.input_size = input_size
|
||||
else:
|
||||
self.input_size = divide(input_size, self.nranks)
|
||||
self.input_size = divide(input_size, self.tp_size)
|
||||
self.output_size = output_size
|
||||
|
||||
super().__init__(
|
||||
@@ -876,8 +870,7 @@ class RowParallelLinear(LinearBase):
|
||||
skip_quant=skip_quant,
|
||||
weight_dtype=weight_dtype,
|
||||
)
|
||||
if add_bias:
|
||||
assert with_bias, "with_bias must be True when add_bias is True."
|
||||
|
||||
assert self.quant_method is not None
|
||||
create_weight_kwargs = dict(
|
||||
layer=self,
|
||||
@@ -887,12 +880,17 @@ class RowParallelLinear(LinearBase):
|
||||
),
|
||||
model_format=fd_config.model_config.model_format,
|
||||
)
|
||||
if self.nranks > 0:
|
||||
if self.tp_size > 1:
|
||||
create_weight_kwargs["split_axis"] = 0
|
||||
create_weight_kwargs["is_distributed"] = True
|
||||
self.quant_method.create_weights(**create_weight_kwargs)
|
||||
|
||||
self.reduce_results = reduce_results
|
||||
self.reduce_results = reduce_results and not self.split_token
|
||||
|
||||
if add_bias:
|
||||
assert with_bias, "with_bias must be True when add_bias is True."
|
||||
if self.tp_size > 1 and self.reduce_results:
|
||||
set_weight_attrs(self.bias, {"tp_row_bias": True})
|
||||
|
||||
def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
token_num = x.shape[0]
|
||||
@@ -912,15 +910,11 @@ class RowParallelLinear(LinearBase):
|
||||
if self.split_token:
|
||||
x = self.all2all_transpose(x)
|
||||
|
||||
if self.fd_config.quant_config:
|
||||
out = self.quant_method.apply(self, x)
|
||||
else:
|
||||
out = paddle.matmul(x, self.weight)
|
||||
out = self.quant_method.apply(self, x)
|
||||
|
||||
if self.reduce_results and self.nranks > 1 and not self.split_token:
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
if not self.fd_config.quant_config and self.add_bias:
|
||||
out = paddle.add(out, self.bias)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -950,16 +944,15 @@ class KVBatchLinear(nn.Layer):
|
||||
qk_nope_head_dim (int): Dimension for Q/K projection (nope part). Defaults to None.
|
||||
v_head_dim (int): Dimension for V projection. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
# Split num_attention_heads when using TP inference.
|
||||
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
|
||||
self.num_heads_per_partition = divide(num_attention_heads, self.tp_size)
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.fd_config = fd_config
|
||||
self.kv_b_proj = kv_b_proj
|
||||
|
||||
Reference in New Issue
Block a user