[V1 Loader] Load safetensors weights in natural keyorder (#6006)

* sorted safetensor

* update

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
bukejiyu
2026-01-13 13:27:20 +08:00
committed by GitHub
parent ad8d05a8de
commit 8061f74773
+46 -9
View File
@@ -21,7 +21,9 @@ import inspect
import json
import os
import pickle
import re
import time
from contextlib import ExitStack
from functools import wraps
from pathlib import Path
@@ -39,6 +41,10 @@ from fastdeploy.model_executor.layers.linear import KVBatchLinear
from fastdeploy.model_executor.utils import multi_switch_config_context
def natural_key(s: str):
return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)]
def pdparams_weight_iterator(paddle_file_list: list[str]):
for pdparams_file in tqdm(
paddle_file_list,
@@ -71,9 +77,12 @@ def load_weights_from_cache(model, weights_iterator):
def get_weight_iterator(model_path: str):
_, files_list, use_safetensors = get_all_weights_file(model_path)
files_list, ordered_weight_map, use_safetensors, is_key_ordered = get_all_weights_file(model_path)
if use_safetensors:
weights_iterator = safetensors_weights_iterator(files_list)
if is_key_ordered:
weights_iterator = safetensors_weights_iterator(files_list)
else:
weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map)
else:
weights_iterator = pdparams_weight_iterator(files_list)
@@ -354,6 +363,26 @@ def safetensors_weights_iterator(safe_tensor_list: list[str]):
yield name, param
def safetensors_weights_iterator_ordered(ordered_weight_map: dict[str, str]):
"""
safetensors_weights_iterator_ordered
"""
with ExitStack() as stack:
current_file = None
current_handle = None
for key, st_file in tqdm(
ordered_weight_map.items(),
desc="Loading safetensors weights",
):
if st_file != current_file:
stack.close()
current_handle = stack.enter_context(safe_open(st_file, framework="paddle", device="cpu"))
current_file = st_file
yield key, current_handle.get_tensor(key)
def fast_weights_iterator(safe_tensor_list: list[str]):
"""
paddleformers' iterator for safetensors
@@ -374,7 +403,7 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int):
"""
state_dict = {}
_, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
safetensor_files, _, _, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
weights_iterator = safetensors_weights_iterator(safetensor_files)
for name, weight in weights_iterator:
state_dict[name] = weight.clone()
@@ -389,23 +418,31 @@ def get_all_weights_file(model_path: str):
use_safetensors = True
files_list = [str(file) for file in model_path.glob("*.pdparams") if file.name != "scheduler.pdparams"]
if len(files_list) > 0:
key_name_list = []
ordered_weight_map = {}
use_safetensors = False
# dont care about the order of the files
return files_list, {}, use_safetensors, False
else:
safe_model_path = model_path / "model.safetensors"
if safe_model_path.exists():
files_list = [str(safe_model_path)]
with safe_open(safe_model_path, framework="np", device="cpu") as f:
key_name_list = f.keys()
return key_name_list, files_list, use_safetensors
key_name_list = sorted(f.keys(), key=natural_key)
ordered_weight_map = {key: "model.safetensors" for key in key_name_list}
is_key_ordered = True
files_list = [str(safe_model_path)]
return files_list, ordered_weight_map, use_safetensors, is_key_ordered
else:
index_file = model_path / "model.safetensors.index.json"
with index_file.open("r") as f:
weight_map = json.load(f)["weight_map"]
keys = list(weight_map.keys())
is_key_ordered = keys == sorted(keys, key=natural_key)
ordered_weight_map = {
key: str(model_path / weight_map[key]) for key in sorted(weight_map.keys(), key=natural_key)
}
weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map}
key_name_list = list(weight_map.keys())
files_list = sorted(weight_files_in_index)
return key_name_list, files_list, use_safetensors
return files_list, ordered_weight_map, use_safetensors, is_key_ordered
def deal_state_dict(state_dict):