mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[Others] Rename tensor_parallel_degree to tensor_model_parallel_size for paddleformers 0.4.1 (#5727)
This commit is contained in:
@@ -202,7 +202,7 @@ def build_expanded_keys(
|
||||
|
||||
|
||||
def gqa_qkv_split_func(
|
||||
tensor_parallel_degree,
|
||||
tensor_model_parallel_size,
|
||||
tensor_parallel_rank,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
@@ -258,15 +258,17 @@ def gqa_qkv_split_func(
|
||||
else:
|
||||
return np.split(tensor, degree, axis=0)
|
||||
|
||||
q_list = split_tensor(q, tensor_parallel_degree)
|
||||
repeat_kv = num_key_value_heads < tensor_parallel_degree and tensor_parallel_degree % num_key_value_heads == 0
|
||||
repeat_num = tensor_parallel_degree // num_key_value_heads if repeat_kv else 1
|
||||
q_list = split_tensor(q, tensor_model_parallel_size)
|
||||
repeat_kv = (
|
||||
num_key_value_heads < tensor_model_parallel_size and tensor_model_parallel_size % num_key_value_heads == 0
|
||||
)
|
||||
repeat_num = tensor_model_parallel_size // num_key_value_heads if repeat_kv else 1
|
||||
if repeat_kv:
|
||||
k_list = split_tensor(k, num_key_value_heads)
|
||||
v_list = split_tensor(v, num_key_value_heads)
|
||||
else:
|
||||
k_list = split_tensor(k, tensor_parallel_degree)
|
||||
v_list = split_tensor(v, tensor_parallel_degree)
|
||||
k_list = split_tensor(k, tensor_model_parallel_size)
|
||||
v_list = split_tensor(v, tensor_model_parallel_size)
|
||||
|
||||
if tensor_parallel_rank is None:
|
||||
res = []
|
||||
@@ -332,9 +334,9 @@ def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
|
||||
|
||||
def fn(weight_list, is_column=True):
|
||||
"""fn"""
|
||||
tensor_parallel_degree = len(weight_list)
|
||||
local_num_attention_heads = num_attention_heads // tensor_parallel_degree
|
||||
local_num_key_value_heads = num_key_value_heads // tensor_parallel_degree
|
||||
tensor_model_parallel_size = len(weight_list)
|
||||
local_num_attention_heads = num_attention_heads // tensor_model_parallel_size
|
||||
local_num_key_value_heads = num_key_value_heads // tensor_model_parallel_size
|
||||
|
||||
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
|
||||
|
||||
@@ -391,7 +393,7 @@ def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
|
||||
|
||||
def split_or_merge_qkv_func(
|
||||
is_split,
|
||||
tensor_parallel_degree,
|
||||
tensor_model_parallel_size,
|
||||
tensor_parallel_rank,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
@@ -402,7 +404,7 @@ def split_or_merge_qkv_func(
|
||||
"""
|
||||
if is_split:
|
||||
return gqa_qkv_split_func(
|
||||
tensor_parallel_degree=tensor_parallel_degree,
|
||||
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
@@ -418,7 +420,7 @@ def split_or_merge_qkv_func(
|
||||
|
||||
def split_or_merge_func_v1(
|
||||
is_split,
|
||||
tensor_parallel_degree,
|
||||
tensor_model_parallel_size,
|
||||
tensor_parallel_rank,
|
||||
num_attention_heads=None,
|
||||
num_key_value_heads=None,
|
||||
@@ -435,14 +437,14 @@ def split_or_merge_func_v1(
|
||||
if is_tp_row_bias:
|
||||
tensor = x[:, ...]
|
||||
if isinstance(tensor, paddle.Tensor):
|
||||
res = tensor / tensor_parallel_degree
|
||||
res = tensor / tensor_model_parallel_size
|
||||
else:
|
||||
res = paddle.to_tensor(tensor, paddle.get_default_dtype()) / tensor_parallel_degree
|
||||
res = paddle.to_tensor(tensor, paddle.get_default_dtype()) / tensor_model_parallel_size
|
||||
return res
|
||||
elif is_gqa:
|
||||
func = split_or_merge_qkv_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=tensor_parallel_degree,
|
||||
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
@@ -453,7 +455,7 @@ def split_or_merge_func_v1(
|
||||
else:
|
||||
func = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_model_parallel_size=tensor_parallel_degree,
|
||||
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
num_attention_heads=num_attention_heads,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user