mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Others] Rename tensor_parallel_degree to tensor_model_parallel_size for paddleformers 0.4.1 (#5727)
This commit is contained in:
@@ -796,7 +796,7 @@ class Ernie4_5_MoePretrainedModel(PretrainedModel):
|
|||||||
|
|
||||||
fn = split_or_merge_func_v1(
|
fn = split_or_merge_func_v1(
|
||||||
is_split=is_split,
|
is_split=is_split,
|
||||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||||
num_attention_heads=config.num_attention_heads,
|
num_attention_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel):
|
|||||||
|
|
||||||
def gqa_qkv_split_func(
|
def gqa_qkv_split_func(
|
||||||
weight,
|
weight,
|
||||||
tensor_parallel_degree,
|
tensor_model_parallel_size,
|
||||||
tensor_parallel_rank,
|
tensor_parallel_rank,
|
||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
num_key_value_heads,
|
num_key_value_heads,
|
||||||
@@ -109,9 +109,9 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel):
|
|||||||
else:
|
else:
|
||||||
return np.split(tensor, degree, axis=-1)
|
return np.split(tensor, degree, axis=-1)
|
||||||
|
|
||||||
q_list = split_tensor(q, tensor_parallel_degree)
|
q_list = split_tensor(q, tensor_model_parallel_size)
|
||||||
k_list = split_tensor(k, tensor_parallel_degree)
|
k_list = split_tensor(k, tensor_model_parallel_size)
|
||||||
v_list = split_tensor(v, tensor_parallel_degree)
|
v_list = split_tensor(v, tensor_model_parallel_size)
|
||||||
|
|
||||||
if tensor_parallel_rank is None:
|
if tensor_parallel_rank is None:
|
||||||
return [np.concatenate([q_i, k_i, v_i], axis=-1) for q_i, k_i, v_i in zip(q_list, k_list, v_list)]
|
return [np.concatenate([q_i, k_i, v_i], axis=-1) for q_i, k_i, v_i in zip(q_list, k_list, v_list)]
|
||||||
@@ -126,9 +126,9 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def gqa_qkv_merge_func(weight_list, num_attention_heads, num_key_value_heads, head_dim):
|
def gqa_qkv_merge_func(weight_list, num_attention_heads, num_key_value_heads, head_dim):
|
||||||
tensor_parallel_degree = len(weight_list)
|
tensor_model_parallel_size = len(weight_list)
|
||||||
num_attention_heads = num_attention_heads // tensor_parallel_degree
|
num_attention_heads = num_attention_heads // tensor_model_parallel_size
|
||||||
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
|
num_key_value_heads = num_key_value_heads // tensor_model_parallel_size
|
||||||
|
|
||||||
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
|
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
|
||||||
|
|
||||||
@@ -170,7 +170,7 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel):
|
|||||||
if is_split:
|
if is_split:
|
||||||
qkv_fn = partial(
|
qkv_fn = partial(
|
||||||
gqa_qkv_split_func,
|
gqa_qkv_split_func,
|
||||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||||
num_attention_heads=config.num_attention_heads,
|
num_attention_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
|
|||||||
@@ -159,15 +159,15 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
num_heads: int = 16,
|
num_heads: int = 16,
|
||||||
tensor_parallel_degree: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
tensor_parallel_rank: int = 0,
|
tensor_parallel_rank: int = 0,
|
||||||
model_format: str = "",
|
model_format: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.tensor_parallel_degree = tensor_parallel_degree
|
self.tensor_model_parallel_size = tensor_model_parallel_size
|
||||||
self.tensor_parallel_rank = tensor_parallel_rank
|
self.tensor_parallel_rank = tensor_parallel_rank
|
||||||
if tensor_parallel_degree > 1:
|
if tensor_model_parallel_size > 1:
|
||||||
use_fuse_matmul_bias = False if current_platform.is_maca() or current_platform.is_iluvatar() else True
|
use_fuse_matmul_bias = False if current_platform.is_maca() or current_platform.is_iluvatar() else True
|
||||||
self.qkv = ColumnParallelLinear(
|
self.qkv = ColumnParallelLinear(
|
||||||
dim,
|
dim,
|
||||||
@@ -199,7 +199,7 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
self.head_dim = dim // num_heads # must added
|
self.head_dim = dim // num_heads # must added
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.hidden_size = dim
|
self.hidden_size = dim
|
||||||
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
|
self.num_heads_per_rank = divide(self.num_heads, self.tensor_model_parallel_size)
|
||||||
|
|
||||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||||
@@ -209,7 +209,9 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
if load_bias:
|
if load_bias:
|
||||||
head_dim = self.hidden_size // self.num_heads
|
head_dim = self.hidden_size // self.num_heads
|
||||||
shard_weight = loaded_weight[...].reshape([3, self.num_heads, head_dim])
|
shard_weight = loaded_weight[...].reshape([3, self.num_heads, head_dim])
|
||||||
shard_weight = paddle.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
|
shard_weight = paddle.split(shard_weight, self.tensor_model_parallel_size, axis=-2)[
|
||||||
|
self.tensor_parallel_rank
|
||||||
|
]
|
||||||
shard_weight = shard_weight.reshape([-1])
|
shard_weight = shard_weight.reshape([-1])
|
||||||
else:
|
else:
|
||||||
shard_weight = loaded_weight[...].reshape(
|
shard_weight = loaded_weight[...].reshape(
|
||||||
@@ -220,7 +222,9 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
shard_weight = paddle.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
|
shard_weight = paddle.split(shard_weight, self.tensor_model_parallel_size, axis=-2)[
|
||||||
|
self.tensor_parallel_rank
|
||||||
|
]
|
||||||
shard_weight = shard_weight.reshape([self.hidden_size, -1])
|
shard_weight = shard_weight.reshape([self.hidden_size, -1])
|
||||||
shard_weight = get_tensor(shard_weight)
|
shard_weight = get_tensor(shard_weight)
|
||||||
shard_weight = fd_cast(shard_weight, param)
|
shard_weight = fd_cast(shard_weight, param)
|
||||||
@@ -252,7 +256,7 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
[
|
[
|
||||||
seq_length,
|
seq_length,
|
||||||
3,
|
3,
|
||||||
self.num_heads // self.tensor_parallel_degree,
|
self.num_heads // self.tensor_model_parallel_size,
|
||||||
-1,
|
-1,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -332,13 +336,13 @@ class VisionMlp(nn.Layer):
|
|||||||
dim: int,
|
dim: int,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
tensor_parallel_degree: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
model_format: str = "",
|
model_format: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tensor_parallel_degree = tensor_parallel_degree
|
self.tensor_model_parallel_size = tensor_model_parallel_size
|
||||||
|
|
||||||
if self.tensor_parallel_degree > 1:
|
if self.tensor_model_parallel_size > 1:
|
||||||
self.fc1 = ColumnParallelLinear(
|
self.fc1 = ColumnParallelLinear(
|
||||||
dim,
|
dim,
|
||||||
hidden_dim,
|
hidden_dim,
|
||||||
@@ -418,7 +422,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
tensor_parallel_degree: int,
|
tensor_model_parallel_size: int,
|
||||||
tensor_parallel_rank: int,
|
tensor_parallel_rank: int,
|
||||||
attn_implementation: str = "sdpa",
|
attn_implementation: str = "sdpa",
|
||||||
model_format: str = "",
|
model_format: str = "",
|
||||||
@@ -437,7 +441,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
|||||||
self.attn = VisionFlashAttention2(
|
self.attn = VisionFlashAttention2(
|
||||||
config.embed_dim,
|
config.embed_dim,
|
||||||
num_heads=config.num_heads,
|
num_heads=config.num_heads,
|
||||||
tensor_parallel_degree=tensor_parallel_degree,
|
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=tensor_parallel_rank,
|
tensor_parallel_rank=tensor_parallel_rank,
|
||||||
model_format=model_format,
|
model_format=model_format,
|
||||||
)
|
)
|
||||||
@@ -445,7 +449,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
|||||||
dim=config.embed_dim,
|
dim=config.embed_dim,
|
||||||
hidden_dim=mlp_hidden_dim,
|
hidden_dim=mlp_hidden_dim,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
tensor_parallel_degree=tensor_parallel_degree,
|
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||||
model_format=model_format,
|
model_format=model_format,
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|||||||
@@ -978,7 +978,7 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
|||||||
|
|
||||||
fn = split_or_merge_func_v1(
|
fn = split_or_merge_func_v1(
|
||||||
is_split=is_split,
|
is_split=is_split,
|
||||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||||
num_attention_heads=config.num_attention_heads,
|
num_attention_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
@@ -986,7 +986,7 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
|||||||
)
|
)
|
||||||
vision_fn = split_or_merge_func_v1(
|
vision_fn = split_or_merge_func_v1(
|
||||||
is_split=is_split,
|
is_split=is_split,
|
||||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||||
num_attention_heads=config.vision_config.get("num_heads"),
|
num_attention_heads=config.vision_config.get("num_heads"),
|
||||||
num_key_value_heads=config.vision_config.get("num_heads"),
|
num_key_value_heads=config.vision_config.get("num_heads"),
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ class VariableResolutionResamplerModel(nn.Layer):
|
|||||||
self.temporal_conv_size = temporal_conv_size
|
self.temporal_conv_size = temporal_conv_size
|
||||||
self.use_recompute_resampler = False
|
self.use_recompute_resampler = False
|
||||||
self.use_temporal_conv = True
|
self.use_temporal_conv = True
|
||||||
self.tensor_parallel_degree = config.pretrained_config.tensor_model_parallel_size
|
self.tensor_model_parallel_size = config.pretrained_config.tensor_model_parallel_size
|
||||||
self.prefix_name = prefix_name
|
self.prefix_name = prefix_name
|
||||||
|
|
||||||
# for 空间四合一
|
# for 空间四合一
|
||||||
@@ -174,7 +174,7 @@ class VariableResolutionResamplerModel(nn.Layer):
|
|||||||
has_bias=True,
|
has_bias=True,
|
||||||
fuse_matmul_bias=use_fuse_matmul_bias,
|
fuse_matmul_bias=use_fuse_matmul_bias,
|
||||||
)
|
)
|
||||||
if self.tensor_parallel_degree > 1
|
if self.tensor_model_parallel_size > 1
|
||||||
else nn.Linear(self.spatial_dim, self.spatial_dim)
|
else nn.Linear(self.spatial_dim, self.spatial_dim)
|
||||||
),
|
),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
@@ -206,7 +206,7 @@ class VariableResolutionResamplerModel(nn.Layer):
|
|||||||
out_config.hidden_size = out_dim
|
out_config.hidden_size = out_dim
|
||||||
self.after_norm = RMSNorm(out_config)
|
self.after_norm = RMSNorm(out_config)
|
||||||
|
|
||||||
if self.tensor_parallel_degree > 1:
|
if self.tensor_model_parallel_size > 1:
|
||||||
set_weight_attrs(self.spatial_linear[0].weight, {"output_dim": False})
|
set_weight_attrs(self.spatial_linear[0].weight, {"output_dim": False})
|
||||||
|
|
||||||
def spatial_conv_reshape(self, x, spatial_conv_size):
|
def spatial_conv_reshape(self, x, spatial_conv_size):
|
||||||
@@ -232,17 +232,17 @@ class VariableResolutionResamplerModel(nn.Layer):
|
|||||||
x = self.spatial_conv_reshape(x, self.spatial_conv_size)
|
x = self.spatial_conv_reshape(x, self.spatial_conv_size)
|
||||||
|
|
||||||
num_pad = 0
|
num_pad = 0
|
||||||
if self.tensor_parallel_degree > 1:
|
if self.tensor_model_parallel_size > 1:
|
||||||
num_pad = (
|
num_pad = (
|
||||||
x.shape[0] + self.tensor_parallel_degree - 1
|
x.shape[0] + self.tensor_model_parallel_size - 1
|
||||||
) // self.tensor_parallel_degree * self.tensor_parallel_degree - x.shape[0]
|
) // self.tensor_model_parallel_size * self.tensor_model_parallel_size - x.shape[0]
|
||||||
|
|
||||||
if num_pad > 0:
|
if num_pad > 0:
|
||||||
x = paddle.nn.functional.pad(x, [0, num_pad, 0, 0])
|
x = paddle.nn.functional.pad(x, [0, num_pad, 0, 0])
|
||||||
|
|
||||||
x = self.spatial_linear(x)
|
x = self.spatial_linear(x)
|
||||||
|
|
||||||
if self.tensor_parallel_degree > 1:
|
if self.tensor_model_parallel_size > 1:
|
||||||
x = AllGatherOp.apply(x)
|
x = AllGatherOp.apply(x)
|
||||||
|
|
||||||
if num_pad > 0:
|
if num_pad > 0:
|
||||||
@@ -298,13 +298,13 @@ class VariableResolutionResamplerModel(nn.Layer):
|
|||||||
|
|
||||||
def fwd_temporal(x):
|
def fwd_temporal(x):
|
||||||
num_pad = 0
|
num_pad = 0
|
||||||
if self.tensor_parallel_degree > 1:
|
if self.tensor_model_parallel_size > 1:
|
||||||
num_pad = (
|
num_pad = (
|
||||||
x.shape[0] + self.tensor_parallel_degree - 1
|
x.shape[0] + self.tensor_model_parallel_size - 1
|
||||||
) // self.tensor_parallel_degree * self.tensor_parallel_degree - x.shape[0]
|
) // self.tensor_model_parallel_size * self.tensor_model_parallel_size - x.shape[0]
|
||||||
if num_pad > 0:
|
if num_pad > 0:
|
||||||
x = paddle.nn.functional.pad(x, [0, num_pad, 0, 0])
|
x = paddle.nn.functional.pad(x, [0, num_pad, 0, 0])
|
||||||
if self.tensor_parallel_degree > 1:
|
if self.tensor_model_parallel_size > 1:
|
||||||
x = ScatterOp.apply(x, axis=0)
|
x = ScatterOp.apply(x, axis=0)
|
||||||
x = self.temporal_linear(x)
|
x = self.temporal_linear(x)
|
||||||
|
|
||||||
@@ -316,7 +316,7 @@ class VariableResolutionResamplerModel(nn.Layer):
|
|||||||
def fwd_mlp(x):
|
def fwd_mlp(x):
|
||||||
x = self.mlp(x)
|
x = self.mlp(x)
|
||||||
x = self.after_norm(x)
|
x = self.after_norm(x)
|
||||||
if self.tensor_parallel_degree > 1:
|
if self.tensor_model_parallel_size > 1:
|
||||||
x = AllGatherOp.apply(x)
|
x = AllGatherOp.apply(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
@@ -549,7 +549,7 @@ class Glm4MoePretrainedModel(PretrainedModel):
|
|||||||
|
|
||||||
fn = split_or_merge_func_v1(
|
fn = split_or_merge_func_v1(
|
||||||
is_split=is_split,
|
is_split=is_split,
|
||||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||||
num_attention_heads=config.num_attention_heads,
|
num_attention_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
|
|||||||
@@ -78,16 +78,16 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
num_heads: int = 16,
|
num_heads: int = 16,
|
||||||
tensor_parallel_degree: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
tensor_parallel_rank: int = 0,
|
tensor_parallel_rank: int = 0,
|
||||||
model_format: str = "",
|
model_format: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.tensor_parallel_degree = tensor_parallel_degree
|
self.tensor_model_parallel_size = tensor_model_parallel_size
|
||||||
self.tensor_parallel_rank = tensor_parallel_rank
|
self.tensor_parallel_rank = tensor_parallel_rank
|
||||||
|
|
||||||
if tensor_parallel_degree > 1:
|
if tensor_model_parallel_size > 1:
|
||||||
self.qkv = ColumnParallelLinear(
|
self.qkv = ColumnParallelLinear(
|
||||||
dim,
|
dim,
|
||||||
dim * 3,
|
dim * 3,
|
||||||
@@ -122,7 +122,7 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
self.head_dim = dim // num_heads # must added
|
self.head_dim = dim // num_heads # must added
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.hidden_size = dim
|
self.hidden_size = dim
|
||||||
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
|
self.num_heads_per_rank = divide(self.num_heads, self.tensor_model_parallel_size)
|
||||||
|
|
||||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||||
@@ -132,7 +132,9 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
if load_bias:
|
if load_bias:
|
||||||
head_dim = self.hidden_size // self.num_heads
|
head_dim = self.hidden_size // self.num_heads
|
||||||
shard_weight = loaded_weight[...].reshape([3, self.num_heads, head_dim])
|
shard_weight = loaded_weight[...].reshape([3, self.num_heads, head_dim])
|
||||||
shard_weight = paddle.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
|
shard_weight = paddle.split(shard_weight, self.tensor_model_parallel_size, axis=-2)[
|
||||||
|
self.tensor_parallel_rank
|
||||||
|
]
|
||||||
shard_weight = shard_weight.reshape([-1])
|
shard_weight = shard_weight.reshape([-1])
|
||||||
else:
|
else:
|
||||||
shard_weight = loaded_weight[...].reshape(
|
shard_weight = loaded_weight[...].reshape(
|
||||||
@@ -143,7 +145,9 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
shard_weight = paddle.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
|
shard_weight = paddle.split(shard_weight, self.tensor_model_parallel_size, axis=-2)[
|
||||||
|
self.tensor_parallel_rank
|
||||||
|
]
|
||||||
shard_weight = shard_weight.reshape([self.hidden_size, -1])
|
shard_weight = shard_weight.reshape([self.hidden_size, -1])
|
||||||
shard_weight = fd_cast(shard_weight, param)
|
shard_weight = fd_cast(shard_weight, param)
|
||||||
assert param.shape == shard_weight.shape, (
|
assert param.shape == shard_weight.shape, (
|
||||||
@@ -176,7 +180,7 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
[
|
[
|
||||||
seq_length,
|
seq_length,
|
||||||
3,
|
3,
|
||||||
self.num_heads // self.tensor_parallel_degree,
|
self.num_heads // self.tensor_model_parallel_size,
|
||||||
-1,
|
-1,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -265,13 +269,13 @@ class VisionMlp(nn.Layer):
|
|||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
hidden_act: str = "gelu",
|
hidden_act: str = "gelu",
|
||||||
tensor_parallel_degree: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
model_format: str = "",
|
model_format: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tensor_parallel_degree = tensor_parallel_degree
|
self.tensor_model_parallel_size = tensor_model_parallel_size
|
||||||
|
|
||||||
if self.tensor_parallel_degree > 1:
|
if self.tensor_model_parallel_size > 1:
|
||||||
self.gate_proj = ColumnParallelLinear(
|
self.gate_proj = ColumnParallelLinear(
|
||||||
dim,
|
dim,
|
||||||
hidden_dim,
|
hidden_dim,
|
||||||
@@ -414,7 +418,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_hidden_dim: int,
|
mlp_hidden_dim: int,
|
||||||
hidden_act: str = "gelu",
|
hidden_act: str = "gelu",
|
||||||
tensor_parallel_degree: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
tensor_parallel_rank: int = 0,
|
tensor_parallel_rank: int = 0,
|
||||||
attn_implementation: str = "sdpa",
|
attn_implementation: str = "sdpa",
|
||||||
model_format: str = "",
|
model_format: str = "",
|
||||||
@@ -432,7 +436,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
|||||||
self.attn = VisionFlashAttention2(
|
self.attn = VisionFlashAttention2(
|
||||||
dim=dim,
|
dim=dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
tensor_parallel_degree=tensor_parallel_degree,
|
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=tensor_parallel_rank,
|
tensor_parallel_rank=tensor_parallel_rank,
|
||||||
model_format=model_format,
|
model_format=model_format,
|
||||||
)
|
)
|
||||||
@@ -442,7 +446,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
|||||||
hidden_dim=mlp_hidden_dim,
|
hidden_dim=mlp_hidden_dim,
|
||||||
bias=True,
|
bias=True,
|
||||||
hidden_act=hidden_act,
|
hidden_act=hidden_act,
|
||||||
tensor_parallel_degree=tensor_parallel_degree,
|
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||||
model_format=model_format,
|
model_format=model_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -558,7 +562,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
|||||||
num_heads=config.vision_config.num_heads,
|
num_heads=config.vision_config.num_heads,
|
||||||
mlp_hidden_dim=config.vision_config.intermediate_size,
|
mlp_hidden_dim=config.vision_config.intermediate_size,
|
||||||
hidden_act=config.vision_config.hidden_act,
|
hidden_act=config.vision_config.hidden_act,
|
||||||
tensor_parallel_degree=config.pretrained_config.tensor_model_parallel_size,
|
tensor_model_parallel_size=config.pretrained_config.tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=config.pretrained_config.tensor_parallel_rank,
|
tensor_parallel_rank=config.pretrained_config.tensor_parallel_rank,
|
||||||
model_format=model_format,
|
model_format=model_format,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -388,7 +388,7 @@ class Qwen2_5_VLPretrainedModel(PretrainedModel):
|
|||||||
|
|
||||||
fn = split_or_merge_func_v1(
|
fn = split_or_merge_func_v1(
|
||||||
is_split=is_split,
|
is_split=is_split,
|
||||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||||
num_attention_heads=config.num_attention_heads,
|
num_attention_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
@@ -397,7 +397,7 @@ class Qwen2_5_VLPretrainedModel(PretrainedModel):
|
|||||||
|
|
||||||
vision_fn = split_or_merge_func_v1(
|
vision_fn = split_or_merge_func_v1(
|
||||||
is_split=is_split,
|
is_split=is_split,
|
||||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||||
num_attention_heads=config.vision_config.get("num_heads"),
|
num_attention_heads=config.vision_config.get("num_heads"),
|
||||||
num_key_value_heads=config.vision_config.get("num_heads"),
|
num_key_value_heads=config.vision_config.get("num_heads"),
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ def build_expanded_keys(
|
|||||||
|
|
||||||
|
|
||||||
def gqa_qkv_split_func(
|
def gqa_qkv_split_func(
|
||||||
tensor_parallel_degree,
|
tensor_model_parallel_size,
|
||||||
tensor_parallel_rank,
|
tensor_parallel_rank,
|
||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
num_key_value_heads,
|
num_key_value_heads,
|
||||||
@@ -258,15 +258,17 @@ def gqa_qkv_split_func(
|
|||||||
else:
|
else:
|
||||||
return np.split(tensor, degree, axis=0)
|
return np.split(tensor, degree, axis=0)
|
||||||
|
|
||||||
q_list = split_tensor(q, tensor_parallel_degree)
|
q_list = split_tensor(q, tensor_model_parallel_size)
|
||||||
repeat_kv = num_key_value_heads < tensor_parallel_degree and tensor_parallel_degree % num_key_value_heads == 0
|
repeat_kv = (
|
||||||
repeat_num = tensor_parallel_degree // num_key_value_heads if repeat_kv else 1
|
num_key_value_heads < tensor_model_parallel_size and tensor_model_parallel_size % num_key_value_heads == 0
|
||||||
|
)
|
||||||
|
repeat_num = tensor_model_parallel_size // num_key_value_heads if repeat_kv else 1
|
||||||
if repeat_kv:
|
if repeat_kv:
|
||||||
k_list = split_tensor(k, num_key_value_heads)
|
k_list = split_tensor(k, num_key_value_heads)
|
||||||
v_list = split_tensor(v, num_key_value_heads)
|
v_list = split_tensor(v, num_key_value_heads)
|
||||||
else:
|
else:
|
||||||
k_list = split_tensor(k, tensor_parallel_degree)
|
k_list = split_tensor(k, tensor_model_parallel_size)
|
||||||
v_list = split_tensor(v, tensor_parallel_degree)
|
v_list = split_tensor(v, tensor_model_parallel_size)
|
||||||
|
|
||||||
if tensor_parallel_rank is None:
|
if tensor_parallel_rank is None:
|
||||||
res = []
|
res = []
|
||||||
@@ -332,9 +334,9 @@ def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
|
|||||||
|
|
||||||
def fn(weight_list, is_column=True):
|
def fn(weight_list, is_column=True):
|
||||||
"""fn"""
|
"""fn"""
|
||||||
tensor_parallel_degree = len(weight_list)
|
tensor_model_parallel_size = len(weight_list)
|
||||||
local_num_attention_heads = num_attention_heads // tensor_parallel_degree
|
local_num_attention_heads = num_attention_heads // tensor_model_parallel_size
|
||||||
local_num_key_value_heads = num_key_value_heads // tensor_parallel_degree
|
local_num_key_value_heads = num_key_value_heads // tensor_model_parallel_size
|
||||||
|
|
||||||
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
|
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
|
||||||
|
|
||||||
@@ -391,7 +393,7 @@ def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
|
|||||||
|
|
||||||
def split_or_merge_qkv_func(
|
def split_or_merge_qkv_func(
|
||||||
is_split,
|
is_split,
|
||||||
tensor_parallel_degree,
|
tensor_model_parallel_size,
|
||||||
tensor_parallel_rank,
|
tensor_parallel_rank,
|
||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
num_key_value_heads,
|
num_key_value_heads,
|
||||||
@@ -402,7 +404,7 @@ def split_or_merge_qkv_func(
|
|||||||
"""
|
"""
|
||||||
if is_split:
|
if is_split:
|
||||||
return gqa_qkv_split_func(
|
return gqa_qkv_split_func(
|
||||||
tensor_parallel_degree=tensor_parallel_degree,
|
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=tensor_parallel_rank,
|
tensor_parallel_rank=tensor_parallel_rank,
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
num_key_value_heads=num_key_value_heads,
|
num_key_value_heads=num_key_value_heads,
|
||||||
@@ -418,7 +420,7 @@ def split_or_merge_qkv_func(
|
|||||||
|
|
||||||
def split_or_merge_func_v1(
|
def split_or_merge_func_v1(
|
||||||
is_split,
|
is_split,
|
||||||
tensor_parallel_degree,
|
tensor_model_parallel_size,
|
||||||
tensor_parallel_rank,
|
tensor_parallel_rank,
|
||||||
num_attention_heads=None,
|
num_attention_heads=None,
|
||||||
num_key_value_heads=None,
|
num_key_value_heads=None,
|
||||||
@@ -435,14 +437,14 @@ def split_or_merge_func_v1(
|
|||||||
if is_tp_row_bias:
|
if is_tp_row_bias:
|
||||||
tensor = x[:, ...]
|
tensor = x[:, ...]
|
||||||
if isinstance(tensor, paddle.Tensor):
|
if isinstance(tensor, paddle.Tensor):
|
||||||
res = tensor / tensor_parallel_degree
|
res = tensor / tensor_model_parallel_size
|
||||||
else:
|
else:
|
||||||
res = paddle.to_tensor(tensor, paddle.get_default_dtype()) / tensor_parallel_degree
|
res = paddle.to_tensor(tensor, paddle.get_default_dtype()) / tensor_model_parallel_size
|
||||||
return res
|
return res
|
||||||
elif is_gqa:
|
elif is_gqa:
|
||||||
func = split_or_merge_qkv_func(
|
func = split_or_merge_qkv_func(
|
||||||
is_split=is_split,
|
is_split=is_split,
|
||||||
tensor_parallel_degree=tensor_parallel_degree,
|
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=tensor_parallel_rank,
|
tensor_parallel_rank=tensor_parallel_rank,
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
num_key_value_heads=num_key_value_heads,
|
num_key_value_heads=num_key_value_heads,
|
||||||
@@ -453,7 +455,7 @@ def split_or_merge_func_v1(
|
|||||||
else:
|
else:
|
||||||
func = split_or_merge_func(
|
func = split_or_merge_func(
|
||||||
is_split=is_split,
|
is_split=is_split,
|
||||||
tensor_model_parallel_size=tensor_parallel_degree,
|
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||||
tensor_parallel_rank=tensor_parallel_rank,
|
tensor_parallel_rank=tensor_parallel_rank,
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
|
|||||||
def update_fd_config_for_mm(fd_config: FDConfig) -> None:
|
def update_fd_config_for_mm(fd_config: FDConfig) -> None:
|
||||||
architectures = fd_config.model_config.architectures
|
architectures = fd_config.model_config.architectures
|
||||||
if fd_config.model_config.enable_mm and ErnieArchitectures.contains_ernie_arch(architectures):
|
if fd_config.model_config.enable_mm and ErnieArchitectures.contains_ernie_arch(architectures):
|
||||||
fd_config.model_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size
|
fd_config.model_config.tensor_model_parallel_size = fd_config.parallel_config.tensor_parallel_size
|
||||||
fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||||
fd_config.model_config.vision_config.dtype = fd_config.model_config.dtype
|
fd_config.model_config.vision_config.dtype = fd_config.model_config.dtype
|
||||||
|
|
||||||
|
|||||||
@@ -396,7 +396,7 @@ class BuildExpandedKeysTest(unittest.TestCase):
|
|||||||
class GQATensorOpsTest(unittest.TestCase):
|
class GQATensorOpsTest(unittest.TestCase):
|
||||||
def test_gqa_split_returns_all_partitions(self):
|
def test_gqa_split_returns_all_partitions(self):
|
||||||
func = _tp_utils.gqa_qkv_split_func(
|
func = _tp_utils.gqa_qkv_split_func(
|
||||||
tensor_parallel_degree=2,
|
tensor_model_parallel_size=2,
|
||||||
tensor_parallel_rank=None,
|
tensor_parallel_rank=None,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
num_key_value_heads=2,
|
num_key_value_heads=2,
|
||||||
@@ -411,7 +411,7 @@ class GQATensorOpsTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_gqa_split_with_rank_and_repeat_kv(self):
|
def test_gqa_split_with_rank_and_repeat_kv(self):
|
||||||
func = _tp_utils.gqa_qkv_split_func(
|
func = _tp_utils.gqa_qkv_split_func(
|
||||||
tensor_parallel_degree=2,
|
tensor_model_parallel_size=2,
|
||||||
tensor_parallel_rank=1,
|
tensor_parallel_rank=1,
|
||||||
num_attention_heads=2,
|
num_attention_heads=2,
|
||||||
num_key_value_heads=1,
|
num_key_value_heads=1,
|
||||||
@@ -423,7 +423,7 @@ class GQATensorOpsTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_gqa_split_on_matrix_rows(self):
|
def test_gqa_split_on_matrix_rows(self):
|
||||||
func = _tp_utils.gqa_qkv_split_func(
|
func = _tp_utils.gqa_qkv_split_func(
|
||||||
tensor_parallel_degree=2,
|
tensor_model_parallel_size=2,
|
||||||
tensor_parallel_rank=None,
|
tensor_parallel_rank=None,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
num_key_value_heads=2,
|
num_key_value_heads=2,
|
||||||
@@ -454,7 +454,7 @@ class GQATensorOpsTest(unittest.TestCase):
|
|||||||
def test_split_or_merge_func_v1_row_bias(self):
|
def test_split_or_merge_func_v1_row_bias(self):
|
||||||
fn = _tp_utils.split_or_merge_func_v1(
|
fn = _tp_utils.split_or_merge_func_v1(
|
||||||
is_split=True,
|
is_split=True,
|
||||||
tensor_parallel_degree=4,
|
tensor_model_parallel_size=4,
|
||||||
tensor_parallel_rank=0,
|
tensor_parallel_rank=0,
|
||||||
)
|
)
|
||||||
bias = np.ones(4, dtype=np.float32)
|
bias = np.ones(4, dtype=np.float32)
|
||||||
@@ -464,7 +464,7 @@ class GQATensorOpsTest(unittest.TestCase):
|
|||||||
def test_split_or_merge_func_v1_gqa_path(self):
|
def test_split_or_merge_func_v1_gqa_path(self):
|
||||||
fn = _tp_utils.split_or_merge_func_v1(
|
fn = _tp_utils.split_or_merge_func_v1(
|
||||||
is_split=True,
|
is_split=True,
|
||||||
tensor_parallel_degree=2,
|
tensor_model_parallel_size=2,
|
||||||
tensor_parallel_rank=None,
|
tensor_parallel_rank=None,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
num_key_value_heads=2,
|
num_key_value_heads=2,
|
||||||
@@ -477,7 +477,7 @@ class GQATensorOpsTest(unittest.TestCase):
|
|||||||
def test_split_or_merge_func_v1_default_path(self):
|
def test_split_or_merge_func_v1_default_path(self):
|
||||||
fn = _tp_utils.split_or_merge_func_v1(
|
fn = _tp_utils.split_or_merge_func_v1(
|
||||||
is_split=False,
|
is_split=False,
|
||||||
tensor_parallel_degree=2,
|
tensor_model_parallel_size=2,
|
||||||
tensor_parallel_rank=None,
|
tensor_parallel_rank=None,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user