[v1 loader]qwen Offline fp8 (#4036)

* support offline fp8

* update ut

* update ut

* update ut

* fix

* update

* update
This commit is contained in:
bukejiyu
2025-09-15 13:44:11 +08:00
committed by GitHub
parent b1a5b756a3
commit 29ed617f0f
21 changed files with 440 additions and 138 deletions
@@ -527,6 +527,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
rename_offline_ckpt_suffix_to_fd_suffix,
)
general_params_mapping = [
@@ -564,15 +565,20 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
param_down_proj_name="experts.down_proj_",
num_experts_start_offset=num_experts_start_offset,
)
all_param_mapping = general_params_mapping + expert_params_mapping
all_param_mapping = [
(param, weight, exp, shard, False) for param, weight, exp, shard in general_params_mapping
] + [(param, weight, exp, shard, True) for param, weight, exp, shard in expert_params_mapping]
checkpoint_to_fd_key_fn = rename_offline_ckpt_suffix_to_fd_suffix(
fd_config=self.fd_config, ckpt_weight_suffix="quant_weight", ckpt_scale_suffix="weight_scale"
)
params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
loaded_weight_name = loaded_weight_name.replace("model", "ernie")
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
for param_name, weight_name, exp_id, shard_id, is_moe in all_param_mapping:
loaded_weight_name = checkpoint_to_fd_key_fn(loaded_weight_name, is_moe)
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
@@ -583,6 +589,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
else:
expert_id = None
shard_id = None
loaded_weight_name = checkpoint_to_fd_key_fn(loaded_weight_name, is_moe=False)
model_param_name = loaded_weight_name
if model_param_name not in params_dict.keys():
continue
@@ -193,16 +193,16 @@ class VisionFlashAttention2(nn.Layer):
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
self.proj = nn.Linear(dim, dim)
set_weight_attrs(self.qkv.weight, {"model_format": model_format})
set_weight_attrs(self.proj.weight, {"model_format": model_format})
set_weight_attrs(self.qkv.weight, {"weight_need_transpose": model_format == "torch"})
set_weight_attrs(self.proj.weight, {"weight_need_transpose": model_format == "torch"})
self.head_dim = dim // num_heads # must added
self.num_heads = num_heads
self.hidden_size = dim
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
if model_format == "torch":
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if weight_need_transpose:
loaded_weight = loaded_weight.transpose([1, 0])
load_bias = getattr(param, "load_bias", None)
if load_bias:
@@ -358,8 +358,8 @@ class VisionMlp(nn.Layer):
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
set_weight_attrs(self.fc1.weight, {"model_format": model_format})
set_weight_attrs(self.fc2.weight, {"model_format": model_format})
set_weight_attrs(self.fc1.weight, {"weight_need_transpose": model_format == "torch"})
set_weight_attrs(self.fc2.weight, {"weight_need_transpose": model_format == "torch"})
self.act = ACT2FN[hidden_act]
@@ -528,8 +528,10 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
in_channels=config.vision_config.in_channels,
embed_dim=config.vision_config.embed_dim,
)
model_format = getattr(config, "model_format", "")
set_weight_attrs(self.patch_embed.proj.weight, {"model_format": model_format})
set_weight_attrs(self.patch_embed.proj.weight, {"weight_need_transpose": model_format == "torch"})
head_dim = config.vision_config.embed_dim // config.vision_config.num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
@@ -181,8 +181,8 @@ class VariableResolutionResamplerModel(nn.Layer):
nn.Linear(self.spatial_dim, self.spatial_dim),
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
)
set_weight_attrs(self.spatial_linear[0].weight, {"model_format": config.model_format})
set_weight_attrs(self.spatial_linear[2].weight, {"model_format": config.model_format})
set_weight_attrs(self.spatial_linear[0].weight, {"weight_need_transpose": config.model_format == "torch"})
set_weight_attrs(self.spatial_linear[2].weight, {"weight_need_transpose": config.model_format == "torch"})
if self.use_temporal_conv:
self.temporal_linear = nn.Sequential(
@@ -191,12 +191,16 @@ class VariableResolutionResamplerModel(nn.Layer):
nn.Linear(self.spatial_dim, self.spatial_dim),
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
)
set_weight_attrs(self.temporal_linear[0].weight, {"model_format": config.model_format})
set_weight_attrs(self.temporal_linear[2].weight, {"model_format": config.model_format})
set_weight_attrs(
self.temporal_linear[0].weight, {"weight_need_transpose": config.model_format == "torch"}
)
set_weight_attrs(
self.temporal_linear[2].weight, {"weight_need_transpose": config.model_format == "torch"}
)
self.mlp = nn.Linear(self.spatial_dim, self.out_dim)
set_weight_attrs(self.mlp.weight, {"model_format": config.model_format})
set_weight_attrs(self.mlp.weight, {"weight_need_transpose": config.model_format == "torch"})
out_config = deepcopy(config)
out_config.hidden_size = out_dim