mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[Optimize] Qwen2.5-VL vision model with merged linear layers and unif… (#6037)
* [Optimize] Qwen2.5-VL vision model with merged linear layers and unified normalization * [Optimize] Qwen2.5-VL vision model with merged linear layers and unified normalization
This commit is contained in:
@@ -29,10 +29,16 @@ from paddle.nn.functional.flash_attention import (
|
||||
)
|
||||
from paddleformers.transformers.model_utils import PretrainedModel
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.linear import MergedColumnParallelLinear
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
RowParallelLinear as FDRowParallelLinear,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.layers.utils import divide, get_tensor
|
||||
from fastdeploy.model_executor.utils import fd_cast, set_weight_attrs
|
||||
|
||||
from .activation import ACT2FN
|
||||
from .configuration import DFNRopeVisionTransformerConfig
|
||||
|
||||
|
||||
@@ -265,58 +271,43 @@ class VisionMlp(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
bias: bool = False,
|
||||
hidden_act: str = "gelu",
|
||||
tensor_model_parallel_size: int = 1,
|
||||
model_format: str = "",
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.tensor_model_parallel_size = tensor_model_parallel_size
|
||||
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
self.gate_proj = ColumnParallelLinear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
gather_output=False,
|
||||
has_bias=bias,
|
||||
)
|
||||
self.up_gate_proj = MergedColumnParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.up_gate_proj",
|
||||
input_size=dim,
|
||||
output_size=hidden_dim * 2,
|
||||
with_bias=bias,
|
||||
activation=hidden_act,
|
||||
)
|
||||
self.down_proj = FDRowParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
input_size=hidden_dim,
|
||||
output_size=dim,
|
||||
with_bias=bias,
|
||||
reduce_results=True,
|
||||
)
|
||||
|
||||
self.up_proj = ColumnParallelLinear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
gather_output=False,
|
||||
has_bias=bias,
|
||||
)
|
||||
if bias:
|
||||
set_weight_attrs(self.up_gate_proj.bias, {"output_dim": True})
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
input_is_parallel=True,
|
||||
has_bias=bias,
|
||||
)
|
||||
set_weight_attrs(self.gate_proj.weight, {"output_dim": True})
|
||||
set_weight_attrs(self.up_proj.weight, {"output_dim": True})
|
||||
set_weight_attrs(self.down_proj.weight, {"output_dim": False})
|
||||
if bias:
|
||||
set_weight_attrs(self.gate_proj.bias, {"output_dim": True})
|
||||
set_weight_attrs(self.up_proj.bias, {"output_dim": True})
|
||||
# set_weight_attrs(self.down_proj.bias, {"output_dim": False})
|
||||
|
||||
else:
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias_attr=bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias_attr=bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias_attr=bias)
|
||||
|
||||
set_weight_attrs(self.gate_proj.weight, {"weight_need_transpose": model_format == "torch"})
|
||||
set_weight_attrs(self.up_proj.weight, {"weight_need_transpose": model_format == "torch"})
|
||||
set_weight_attrs(self.down_proj.weight, {"weight_need_transpose": model_format == "torch"})
|
||||
|
||||
self.act = ACT2FN[hidden_act]
|
||||
self.act = SiluAndMul(
|
||||
fd_config=fd_config,
|
||||
bias=None,
|
||||
act_method=hidden_act,
|
||||
)
|
||||
|
||||
def forward(self, x) -> paddle.Tensor:
|
||||
"""_summary_
|
||||
@@ -327,10 +318,9 @@ class VisionMlp(nn.Layer):
|
||||
Returns:
|
||||
paddle.Tensor: _description_
|
||||
"""
|
||||
x_gate = self.gate_proj(x)
|
||||
x_gate = self.act(x_gate)
|
||||
x_up = self.up_proj(x)
|
||||
x_down = self.down_proj(x_gate * x_up)
|
||||
gate_up = self.up_gate_proj(x)
|
||||
x = self.act(gate_up)
|
||||
x_down = self.down_proj(x)
|
||||
return x_down
|
||||
|
||||
|
||||
@@ -397,6 +387,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
@@ -405,6 +396,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
tensor_parallel_rank: int = 0,
|
||||
attn_implementation: str = "sdpa",
|
||||
model_format: str = "",
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""_summary_
|
||||
|
||||
@@ -413,8 +405,21 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
attn_implementation (str, optional): _description_. Defaults to "sdpa".
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||
layer_id = int(prefix.split(sep=".")[-1])
|
||||
self.norm1 = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=dim,
|
||||
eps=1e-6,
|
||||
prefix=f"{prefix}.norm1",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
self.norm2 = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=dim,
|
||||
eps=1e-6,
|
||||
prefix=f"{prefix}.norm2",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.attn = VisionFlashAttention2(
|
||||
dim=dim,
|
||||
@@ -425,12 +430,14 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
)
|
||||
|
||||
self.mlp = VisionMlp(
|
||||
fd_config=fd_config,
|
||||
dim=dim,
|
||||
hidden_dim=mlp_hidden_dim,
|
||||
bias=True,
|
||||
hidden_act=hidden_act,
|
||||
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||
model_format=model_format,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_pos_emb) -> paddle.Tensor:
|
||||
@@ -446,12 +453,12 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
"""
|
||||
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
self.norm1(hidden_states)[0],
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)[0])
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -464,10 +471,12 @@ class PatchMerger(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
dim: int,
|
||||
context_dim: int,
|
||||
spatial_merge_size: int = 2,
|
||||
model_format: str = "",
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""_summary_
|
||||
|
||||
@@ -478,7 +487,12 @@ class PatchMerger(nn.Layer):
|
||||
"""
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
|
||||
self.ln_q = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=context_dim,
|
||||
eps=1e-6,
|
||||
prefix=f"{prefix}.ln_q",
|
||||
)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(self.hidden_size, self.hidden_size, bias_attr=True),
|
||||
nn.GELU(),
|
||||
@@ -497,7 +511,7 @@ class PatchMerger(nn.Layer):
|
||||
Returns:
|
||||
paddle.Tensor: _description_
|
||||
"""
|
||||
x = self.mlp(self.ln_q(x).reshape([-1, self.hidden_size]))
|
||||
x = self.mlp(self.ln_q(x)[0].reshape([-1, self.hidden_size]))
|
||||
|
||||
return x
|
||||
|
||||
@@ -514,7 +528,8 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
|
||||
config_class = DFNRopeVisionTransformerConfig
|
||||
|
||||
def __init__(self, config, prefix_name: str = "") -> None:
|
||||
def __init__(self, fd_config, prefix_name: str = "") -> None:
|
||||
config = fd_config.model_config
|
||||
super().__init__(config.vision_config)
|
||||
self.spatial_merge_size = config.vision_config.spatial_merge_size
|
||||
self.prefix_name = prefix_name
|
||||
@@ -541,6 +556,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
self.blocks = nn.LayerList(
|
||||
[
|
||||
DFNRopeVisionBlock(
|
||||
fd_config=fd_config,
|
||||
dim=config.vision_config.hidden_size,
|
||||
num_heads=config.vision_config.num_heads,
|
||||
mlp_hidden_dim=config.vision_config.intermediate_size,
|
||||
@@ -548,15 +564,18 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
tensor_model_parallel_size=config.pretrained_config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.pretrained_config.tensor_parallel_rank,
|
||||
model_format=model_format,
|
||||
prefix=f"{self.prefix_name}.block.{layer_idx}",
|
||||
)
|
||||
for _ in range(config.vision_config.depth)
|
||||
for layer_idx in range(config.vision_config.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.merger = PatchMerger(
|
||||
fd_config,
|
||||
dim=config.vision_config.out_hidden_size,
|
||||
context_dim=config.vision_config.hidden_size,
|
||||
model_format=model_format,
|
||||
prefix=f"{self.prefix_name}.merger",
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user