mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[Others] Rename tensor_parallel_degree to tensor_model_parallel_size for paddleformers 0.4.1 (#5727)
This commit is contained in:
@@ -78,16 +78,16 @@ class VisionFlashAttention2(nn.Layer):
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 16,
|
||||
tensor_parallel_degree: int = 1,
|
||||
tensor_model_parallel_size: int = 1,
|
||||
tensor_parallel_rank: int = 0,
|
||||
model_format: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.tensor_parallel_degree = tensor_parallel_degree
|
||||
self.tensor_model_parallel_size = tensor_model_parallel_size
|
||||
self.tensor_parallel_rank = tensor_parallel_rank
|
||||
|
||||
if tensor_parallel_degree > 1:
|
||||
if tensor_model_parallel_size > 1:
|
||||
self.qkv = ColumnParallelLinear(
|
||||
dim,
|
||||
dim * 3,
|
||||
@@ -122,7 +122,7 @@ class VisionFlashAttention2(nn.Layer):
|
||||
self.head_dim = dim // num_heads # must added
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = dim
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.tensor_model_parallel_size)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
@@ -132,7 +132,9 @@ class VisionFlashAttention2(nn.Layer):
|
||||
if load_bias:
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
shard_weight = loaded_weight[...].reshape([3, self.num_heads, head_dim])
|
||||
shard_weight = paddle.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
|
||||
shard_weight = paddle.split(shard_weight, self.tensor_model_parallel_size, axis=-2)[
|
||||
self.tensor_parallel_rank
|
||||
]
|
||||
shard_weight = shard_weight.reshape([-1])
|
||||
else:
|
||||
shard_weight = loaded_weight[...].reshape(
|
||||
@@ -143,7 +145,9 @@ class VisionFlashAttention2(nn.Layer):
|
||||
self.head_dim,
|
||||
]
|
||||
)
|
||||
shard_weight = paddle.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
|
||||
shard_weight = paddle.split(shard_weight, self.tensor_model_parallel_size, axis=-2)[
|
||||
self.tensor_parallel_rank
|
||||
]
|
||||
shard_weight = shard_weight.reshape([self.hidden_size, -1])
|
||||
shard_weight = fd_cast(shard_weight, param)
|
||||
assert param.shape == shard_weight.shape, (
|
||||
@@ -176,7 +180,7 @@ class VisionFlashAttention2(nn.Layer):
|
||||
[
|
||||
seq_length,
|
||||
3,
|
||||
self.num_heads // self.tensor_parallel_degree,
|
||||
self.num_heads // self.tensor_model_parallel_size,
|
||||
-1,
|
||||
]
|
||||
)
|
||||
@@ -265,13 +269,13 @@ class VisionMlp(nn.Layer):
|
||||
hidden_dim: int,
|
||||
bias: bool = False,
|
||||
hidden_act: str = "gelu",
|
||||
tensor_parallel_degree: int = 1,
|
||||
tensor_model_parallel_size: int = 1,
|
||||
model_format: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.tensor_parallel_degree = tensor_parallel_degree
|
||||
self.tensor_model_parallel_size = tensor_model_parallel_size
|
||||
|
||||
if self.tensor_parallel_degree > 1:
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
self.gate_proj = ColumnParallelLinear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
@@ -414,7 +418,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
hidden_act: str = "gelu",
|
||||
tensor_parallel_degree: int = 1,
|
||||
tensor_model_parallel_size: int = 1,
|
||||
tensor_parallel_rank: int = 0,
|
||||
attn_implementation: str = "sdpa",
|
||||
model_format: str = "",
|
||||
@@ -432,7 +436,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
self.attn = VisionFlashAttention2(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
tensor_parallel_degree=tensor_parallel_degree,
|
||||
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
model_format=model_format,
|
||||
)
|
||||
@@ -442,7 +446,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
hidden_dim=mlp_hidden_dim,
|
||||
bias=True,
|
||||
hidden_act=hidden_act,
|
||||
tensor_parallel_degree=tensor_parallel_degree,
|
||||
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||
model_format=model_format,
|
||||
)
|
||||
|
||||
@@ -558,7 +562,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
num_heads=config.vision_config.num_heads,
|
||||
mlp_hidden_dim=config.vision_config.intermediate_size,
|
||||
hidden_act=config.vision_config.hidden_act,
|
||||
tensor_parallel_degree=config.pretrained_config.tensor_model_parallel_size,
|
||||
tensor_model_parallel_size=config.pretrained_config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.pretrained_config.tensor_parallel_rank,
|
||||
model_format=model_format,
|
||||
)
|
||||
|
||||
@@ -388,7 +388,7 @@ class Qwen2_5_VLPretrainedModel(PretrainedModel):
|
||||
|
||||
fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
||||
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
@@ -397,7 +397,7 @@ class Qwen2_5_VLPretrainedModel(PretrainedModel):
|
||||
|
||||
vision_fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
||||
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.vision_config.get("num_heads"),
|
||||
num_key_value_heads=config.vision_config.get("num_heads"),
|
||||
|
||||
Reference in New Issue
Block a user