[V1 loader] Qwen25 VL support v1 loader and torch style safetensors load (#4388)

* [BugFix] qwen2.5vl enable_thinking=true and image_patch_id bug fix

* [Docs]offine infer add apply_chat_template add_generation_prompt parameter

* [Model]qwen2.5VL support --use-cudagraph

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v2

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v3

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v4

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v5

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v6

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v7

* qwen25vl v1 loader

* qwen25vl v1 loader v2

* qwen25vl v1 loader v3

* qwen25vl v1 loader fix tp2 weight PySafeSlice

* qwen25vl v1 loader no test

* qwen25vl v1 loader add unit test

* qwen25vl v1 loader add unit test v2

* qwen25vl v1 loader add torch unit test v3

* qwen25vl v1 loader add torch unit test v4

* qwen25vl v1 loader add torch unit test v5

* qwen25vl v1 loader add torch unit test v6
This commit is contained in:
CSWYF3634076
2025-10-27 10:54:15 +08:00
committed by GitHub
parent 5c6105f4a2
commit acd331780c
8 changed files with 697 additions and 20 deletions
@@ -15,6 +15,7 @@
"""
from functools import partial
from typing import Optional
import numpy as np
import paddle
@@ -30,7 +31,8 @@ from paddle.nn.functional.flash_attention import (
)
from paddleformers.transformers.model_utils import PretrainedModel
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.layers.utils import divide, get_tensor
from fastdeploy.model_executor.utils import set_weight_attrs
from .activation import ACT2FN
from .configuration import DFNRopeVisionTransformerConfig
@@ -74,10 +76,18 @@ class VisionFlashAttention2(nn.Layer):
nn (_type_): _description_
"""
def __init__(self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1) -> None:
def __init__(
self,
dim: int,
num_heads: int = 16,
tensor_parallel_degree: int = 1,
tensor_parallel_rank: int = 0,
model_format: str = "",
) -> None:
super().__init__()
self.num_heads = num_heads
self.tensor_parallel_degree = tensor_parallel_degree
self.tensor_parallel_rank = tensor_parallel_rank
if tensor_parallel_degree > 1:
self.qkv = ColumnParallelLinear(
@@ -96,11 +106,52 @@ class VisionFlashAttention2(nn.Layer):
input_is_parallel=True,
has_bias=True,
)
# TODO(wangyafeng) Referring to the current situation of combining ernie vl
# with the framework, it should be possible to optimize it in the future
set_weight_attrs(self.qkv.weight, {"weight_loader": self.weight_loader})
set_weight_attrs(
self.qkv.bias, {"weight_loader": self.weight_loader, "load_bias": True, "output_dim": True}
)
set_weight_attrs(self.proj.weight, {"output_dim": False})
else:
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
self.proj = nn.Linear(dim, dim, bias_attr=True)
set_weight_attrs(self.qkv.weight, {"weight_need_transpose": model_format == "torch"})
set_weight_attrs(self.proj.weight, {"weight_need_transpose": model_format == "torch"})
self.head_dim = dim // num_heads # must added
self.num_heads = num_heads
self.hidden_size = dim
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight).transpose([1, 0])
load_bias = getattr(param, "load_bias", None)
if load_bias:
head_dim = self.hidden_size // self.num_heads
shard_weight = loaded_weight[...].reshape([3, self.num_heads, head_dim])
shard_weight = np.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
shard_weight = shard_weight.reshape([-1])
else:
shard_weight = loaded_weight[...].reshape(
[
self.hidden_size,
3,
self.num_heads,
self.head_dim,
]
)
shard_weight = np.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
shard_weight = shard_weight.reshape([self.hidden_size, -1])
shard_weight = get_tensor(shard_weight)
assert param.shape == shard_weight.shape, (
f" Attempted to load weight ({shard_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(shard_weight, False)
def forward(
self,
@@ -216,6 +267,7 @@ class VisionMlp(nn.Layer):
bias: bool = False,
hidden_act: str = "gelu",
tensor_parallel_degree: int = 1,
model_format: str = "",
) -> None:
super().__init__()
self.tensor_parallel_degree = tensor_parallel_degree
@@ -244,11 +296,23 @@ class VisionMlp(nn.Layer):
input_is_parallel=True,
has_bias=bias,
)
set_weight_attrs(self.gate_proj.weight, {"output_dim": True})
set_weight_attrs(self.up_proj.weight, {"output_dim": True})
set_weight_attrs(self.down_proj.weight, {"output_dim": False})
if bias:
set_weight_attrs(self.gate_proj.bias, {"output_dim": True})
set_weight_attrs(self.up_proj.bias, {"output_dim": True})
# set_weight_attrs(self.down_proj.bias, {"output_dim": False})
else:
self.gate_proj = nn.Linear(dim, hidden_dim, bias_attr=bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias_attr=bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias_attr=bias)
set_weight_attrs(self.gate_proj.weight, {"weight_need_transpose": model_format == "torch"})
set_weight_attrs(self.up_proj.weight, {"weight_need_transpose": model_format == "torch"})
set_weight_attrs(self.down_proj.weight, {"weight_need_transpose": model_format == "torch"})
self.act = ACT2FN[hidden_act]
def forward(self, x) -> paddle.Tensor:
@@ -352,7 +416,9 @@ class DFNRopeVisionBlock(nn.Layer):
mlp_hidden_dim: int,
hidden_act: str = "gelu",
tensor_parallel_degree: int = 1,
tensor_parallel_rank: int = 0,
attn_implementation: str = "sdpa",
model_format: str = "",
) -> None:
"""_summary_
@@ -361,7 +427,6 @@ class DFNRopeVisionBlock(nn.Layer):
attn_implementation (str, optional): _description_. Defaults to "sdpa".
"""
super().__init__()
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
@@ -369,6 +434,8 @@ class DFNRopeVisionBlock(nn.Layer):
dim=dim,
num_heads=num_heads,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
model_format=model_format,
)
self.mlp = VisionMlp(
@@ -377,6 +444,7 @@ class DFNRopeVisionBlock(nn.Layer):
bias=True,
hidden_act=hidden_act,
tensor_parallel_degree=tensor_parallel_degree,
model_format=model_format,
)
def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_pos_emb) -> paddle.Tensor:
@@ -408,7 +476,13 @@ class PatchMerger(nn.Layer):
nn (_type_): _description_
"""
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
def __init__(
self,
dim: int,
context_dim: int,
spatial_merge_size: int = 2,
model_format: str = "",
) -> None:
"""_summary_
Args:
@@ -425,6 +499,9 @@ class PatchMerger(nn.Layer):
nn.Linear(self.hidden_size, dim, bias_attr=True),
)
set_weight_attrs(self.mlp[0].weight, {"weight_need_transpose": model_format == "torch"})
set_weight_attrs(self.mlp[2].weight, {"weight_need_transpose": model_format == "torch"})
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
"""_summary_
@@ -470,6 +547,8 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
hidden_size=config.vision_config.hidden_size,
)
model_format = getattr(config, "model_format", "")
head_dim = config.vision_config.hidden_size // config.vision_config.num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
@@ -481,13 +560,17 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
mlp_hidden_dim=config.vision_config.intermediate_size,
hidden_act=config.vision_config.hidden_act,
tensor_parallel_degree=config.pretrained_config.tensor_parallel_degree,
tensor_parallel_rank=config.pretrained_config.tensor_parallel_rank,
model_format=model_format,
)
for _ in range(config.vision_config.depth)
]
)
self.merger = PatchMerger(
dim=config.vision_config.out_hidden_size, context_dim=config.vision_config.hidden_size
dim=config.vision_config.out_hidden_size,
context_dim=config.vision_config.hidden_size,
model_format=model_format,
)
@property
@@ -16,6 +16,7 @@
from __future__ import annotations
import re
from functools import partial
from typing import Dict, Optional, Union
@@ -182,6 +183,61 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
def name(self):
return "Qwen2_5_VLForConditionalGeneration"
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
# 参数变量名与权重key不同的要做映射
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("up_gate_proj", "gate_proj", "gate"),
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
]
params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight, shard_id)
break
else:
model_param_name = loaded_weight_name
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
# because we use lazy guard and is not initialized by default
if not self.lm_head.linear.weight._is_initialized():
self.lm_head.linear.weight.initialize()
self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight})
@paddle.no_grad()
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
"""
@@ -235,11 +291,12 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
video_mask = ids_remove_padding == self.model.video_token_id
video_token_num = video_mask.sum()
# 由于框架只有 image_features,所以目前不支持图片和视频混合
# TODO(wangyafeng) 后续考虑支持传入 video_features
if image_token_num > 0:
# Due to the fact that the framework only has image_features,
# it currently does not support mixing images and videos
# TODO(wangyafeng) Consider supporting the input of video_features in the future
if image_token_num.item() > 0:
input_embeddings[image_mask] = image_features.cast(self.model._dtype)
if video_token_num > 0:
if video_token_num.item() > 0:
input_embeddings[video_mask] = image_features.cast(self.model._dtype)
return input_embeddings