[Others] Rename tensor_parallel_degree to tensor_model_parallel_size for paddleformers 0.4.1 (#5727)

This commit is contained in:
bukejiyu
2025-12-24 15:19:11 +08:00
committed by GitHub
parent a0fed22ddb
commit ba4b7afb3a
11 changed files with 86 additions and 76 deletions
+18 -16
View File
@@ -202,7 +202,7 @@ def build_expanded_keys(
def gqa_qkv_split_func(
tensor_parallel_degree,
tensor_model_parallel_size,
tensor_parallel_rank,
num_attention_heads,
num_key_value_heads,
@@ -258,15 +258,17 @@ def gqa_qkv_split_func(
else:
return np.split(tensor, degree, axis=0)
q_list = split_tensor(q, tensor_parallel_degree)
repeat_kv = num_key_value_heads < tensor_parallel_degree and tensor_parallel_degree % num_key_value_heads == 0
repeat_num = tensor_parallel_degree // num_key_value_heads if repeat_kv else 1
q_list = split_tensor(q, tensor_model_parallel_size)
repeat_kv = (
num_key_value_heads < tensor_model_parallel_size and tensor_model_parallel_size % num_key_value_heads == 0
)
repeat_num = tensor_model_parallel_size // num_key_value_heads if repeat_kv else 1
if repeat_kv:
k_list = split_tensor(k, num_key_value_heads)
v_list = split_tensor(v, num_key_value_heads)
else:
k_list = split_tensor(k, tensor_parallel_degree)
v_list = split_tensor(v, tensor_parallel_degree)
k_list = split_tensor(k, tensor_model_parallel_size)
v_list = split_tensor(v, tensor_model_parallel_size)
if tensor_parallel_rank is None:
res = []
@@ -332,9 +334,9 @@ def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
def fn(weight_list, is_column=True):
"""fn"""
tensor_parallel_degree = len(weight_list)
local_num_attention_heads = num_attention_heads // tensor_parallel_degree
local_num_key_value_heads = num_key_value_heads // tensor_parallel_degree
tensor_model_parallel_size = len(weight_list)
local_num_attention_heads = num_attention_heads // tensor_model_parallel_size
local_num_key_value_heads = num_key_value_heads // tensor_model_parallel_size
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
@@ -391,7 +393,7 @@ def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
def split_or_merge_qkv_func(
is_split,
tensor_parallel_degree,
tensor_model_parallel_size,
tensor_parallel_rank,
num_attention_heads,
num_key_value_heads,
@@ -402,7 +404,7 @@ def split_or_merge_qkv_func(
"""
if is_split:
return gqa_qkv_split_func(
tensor_parallel_degree=tensor_parallel_degree,
tensor_model_parallel_size=tensor_model_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
@@ -418,7 +420,7 @@ def split_or_merge_qkv_func(
def split_or_merge_func_v1(
is_split,
tensor_parallel_degree,
tensor_model_parallel_size,
tensor_parallel_rank,
num_attention_heads=None,
num_key_value_heads=None,
@@ -435,14 +437,14 @@ def split_or_merge_func_v1(
if is_tp_row_bias:
tensor = x[:, ...]
if isinstance(tensor, paddle.Tensor):
res = tensor / tensor_parallel_degree
res = tensor / tensor_model_parallel_size
else:
res = paddle.to_tensor(tensor, paddle.get_default_dtype()) / tensor_parallel_degree
res = paddle.to_tensor(tensor, paddle.get_default_dtype()) / tensor_model_parallel_size
return res
elif is_gqa:
func = split_or_merge_qkv_func(
is_split=is_split,
tensor_parallel_degree=tensor_parallel_degree,
tensor_model_parallel_size=tensor_model_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
@@ -453,7 +455,7 @@ def split_or_merge_func_v1(
else:
func = split_or_merge_func(
is_split=is_split,
tensor_model_parallel_size=tensor_parallel_degree,
tensor_model_parallel_size=tensor_model_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
num_attention_heads=num_attention_heads,
)