[Optimization] Refine row parallel bias and nranks and moe all_reduce (#5247)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

* rename nranks to tp_size and fix bias in v1 loader

* fix

* update
This commit is contained in:
Yuanle Liu
2025-11-26 21:09:09 +08:00
committed by GitHub
parent bf30f45738
commit cb56d46694
20 changed files with 52 additions and 112 deletions
+32 -39
View File
@@ -79,7 +79,6 @@ class UnquantizedLinearMethod(QuantMethodBase):
layer.weight.set_value(weights)
def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor:
linear_out = paddle.matmul(x, layer.weight)
if layer.with_bias:
linear_out = paddle.add(linear_out, layer.bias)
@@ -423,9 +422,9 @@ class ColumnParallelLinear(LinearBase):
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
self.fd_config = fd_config
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.input_size = input_size
self.output_size = divide(output_size, self.nranks) # Split the output_size using TP inference.
self.output_size = divide(output_size, self.tp_size) # Split the output_size using TP inference.
self.hidden_size = fd_config.model_config.hidden_size
super().__init__(
@@ -449,7 +448,7 @@ class ColumnParallelLinear(LinearBase):
model_format=fd_config.model_config.model_format,
)
if self.nranks > 0:
if self.tp_size > 0:
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=1)
@@ -492,7 +491,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"""
self.activation = activation
self.hidden_size = fd_config.model_config.hidden_size
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.output_size = output_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
@@ -522,8 +521,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Loaded weight is already fused on disk.
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("gate", 0, output_size * self.nranks // 2),
("up", output_size * self.nranks // 2, output_size * self.nranks // 2),
("gate", 0, output_size * self.tp_size // 2),
("up", output_size * self.tp_size // 2, output_size * self.tp_size // 2),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
@@ -537,13 +536,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.nranks > 1 and output_dim is not None:
if self.tp_size > 1 and output_dim is not None:
dim = -1 if output_dim else 0
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
block_size = size // self.tp_size
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
@@ -635,15 +634,15 @@ class QKVParallelLinear(ColumnParallelLinear):
self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads
self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size
self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0:
self.num_heads_per_rank = divide(self.num_heads, self.tp_size)
if self.kv_num_heads < self.tp_size and self.tp_size % self.kv_num_heads == 0:
self.kv_num_heads_per_rank = 1
self.num_kv_head_replicas = divide(self.nranks, self.kv_num_heads)
output_size = (self.num_heads + 2 * self.nranks) * self.head_dim
self.num_kv_head_replicas = divide(self.tp_size, self.kv_num_heads)
output_size = (self.num_heads + 2 * self.tp_size) * self.head_dim
else:
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.tp_size)
self.num_kv_head_replicas = 1
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
input_size = self.hidden_size
@@ -697,7 +696,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.nranks > 1 and output_dim is not None:
if self.tp_size > 1 and output_dim is not None:
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size
@@ -750,10 +749,10 @@ class QKVParallelLinear(ColumnParallelLinear):
k_tensor = get_tensor(state_dict.pop(k_weight_key))
v_tensor = get_tensor(state_dict.pop(v_weight_key))
if self.kv_num_heads < self.nranks:
if self.kv_num_heads < self.tp_size:
sharedkv_index = (
self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads
) // self.nranks
) // self.tp_size
sharedkv_start = sharedkv_index * self.head_dim
sharedkv_end = sharedkv_start + self.head_dim
k_tensor = k_tensor[:, sharedkv_start:sharedkv_end]
@@ -767,10 +766,7 @@ class QKVParallelLinear(ColumnParallelLinear):
)
weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
if self.fd_config.quant_config:
self.quant_method.process_loaded_weights(self, weight_tensor)
else:
self.weight.set_value(weight_tensor)
self.quant_method.process_loaded_weights(self, weight_tensor)
def load_state_dict(self, state_dict: dict):
"""
@@ -846,10 +842,8 @@ class RowParallelLinear(LinearBase):
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
self.fd_config = fd_config
self.skip_quant = False
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.tp_group = fd_config.parallel_config.tp_group
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
@@ -863,7 +857,7 @@ class RowParallelLinear(LinearBase):
if self.split_token:
self.input_size = input_size
else:
self.input_size = divide(input_size, self.nranks)
self.input_size = divide(input_size, self.tp_size)
self.output_size = output_size
super().__init__(
@@ -876,8 +870,7 @@ class RowParallelLinear(LinearBase):
skip_quant=skip_quant,
weight_dtype=weight_dtype,
)
if add_bias:
assert with_bias, "with_bias must be True when add_bias is True."
assert self.quant_method is not None
create_weight_kwargs = dict(
layer=self,
@@ -887,12 +880,17 @@ class RowParallelLinear(LinearBase):
),
model_format=fd_config.model_config.model_format,
)
if self.nranks > 0:
if self.tp_size > 1:
create_weight_kwargs["split_axis"] = 0
create_weight_kwargs["is_distributed"] = True
self.quant_method.create_weights(**create_weight_kwargs)
self.reduce_results = reduce_results
self.reduce_results = reduce_results and not self.split_token
if add_bias:
assert with_bias, "with_bias must be True when add_bias is True."
if self.tp_size > 1 and self.reduce_results:
set_weight_attrs(self.bias, {"tp_row_bias": True})
def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor:
token_num = x.shape[0]
@@ -912,15 +910,11 @@ class RowParallelLinear(LinearBase):
if self.split_token:
x = self.all2all_transpose(x)
if self.fd_config.quant_config:
out = self.quant_method.apply(self, x)
else:
out = paddle.matmul(x, self.weight)
out = self.quant_method.apply(self, x)
if self.reduce_results and self.nranks > 1 and not self.split_token:
if self.reduce_results and self.tp_size > 1:
out = tensor_model_parallel_all_reduce(out, self.tp_group)
if not self.fd_config.quant_config and self.add_bias:
out = paddle.add(out, self.bias)
return out
@@ -950,16 +944,15 @@ class KVBatchLinear(nn.Layer):
qk_nope_head_dim (int): Dimension for Q/K projection (nope part). Defaults to None.
v_head_dim (int): Dimension for V projection. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
super().__init__()
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.kv_lora_rank = kv_lora_rank
self.num_attention_heads = num_attention_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.v_head_dim = v_head_dim
# Split num_attention_heads when using TP inference.
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
self.num_heads_per_partition = divide(num_attention_heads, self.tp_size)
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.fd_config = fd_config
self.kv_b_proj = kv_b_proj