mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[V1 Loader] Load safetensors weights in natural keyorder (#6006)
* sorted safetensor * update --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -21,7 +21,9 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import time
|
||||
from contextlib import ExitStack
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
||||
@@ -39,6 +41,10 @@ from fastdeploy.model_executor.layers.linear import KVBatchLinear
|
||||
from fastdeploy.model_executor.utils import multi_switch_config_context
|
||||
|
||||
|
||||
def natural_key(s: str):
|
||||
return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)]
|
||||
|
||||
|
||||
def pdparams_weight_iterator(paddle_file_list: list[str]):
|
||||
for pdparams_file in tqdm(
|
||||
paddle_file_list,
|
||||
@@ -71,9 +77,12 @@ def load_weights_from_cache(model, weights_iterator):
|
||||
|
||||
|
||||
def get_weight_iterator(model_path: str):
|
||||
_, files_list, use_safetensors = get_all_weights_file(model_path)
|
||||
files_list, ordered_weight_map, use_safetensors, is_key_ordered = get_all_weights_file(model_path)
|
||||
if use_safetensors:
|
||||
weights_iterator = safetensors_weights_iterator(files_list)
|
||||
if is_key_ordered:
|
||||
weights_iterator = safetensors_weights_iterator(files_list)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map)
|
||||
else:
|
||||
weights_iterator = pdparams_weight_iterator(files_list)
|
||||
|
||||
@@ -354,6 +363,26 @@ def safetensors_weights_iterator(safe_tensor_list: list[str]):
|
||||
yield name, param
|
||||
|
||||
|
||||
def safetensors_weights_iterator_ordered(ordered_weight_map: dict[str, str]):
|
||||
"""
|
||||
safetensors_weights_iterator_ordered
|
||||
"""
|
||||
with ExitStack() as stack:
|
||||
current_file = None
|
||||
current_handle = None
|
||||
|
||||
for key, st_file in tqdm(
|
||||
ordered_weight_map.items(),
|
||||
desc="Loading safetensors weights",
|
||||
):
|
||||
if st_file != current_file:
|
||||
stack.close()
|
||||
current_handle = stack.enter_context(safe_open(st_file, framework="paddle", device="cpu"))
|
||||
current_file = st_file
|
||||
|
||||
yield key, current_handle.get_tensor(key)
|
||||
|
||||
|
||||
def fast_weights_iterator(safe_tensor_list: list[str]):
|
||||
"""
|
||||
paddleformers' iterator for safetensors
|
||||
@@ -374,7 +403,7 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int):
|
||||
"""
|
||||
|
||||
state_dict = {}
|
||||
_, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
|
||||
safetensor_files, _, _, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
|
||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||
for name, weight in weights_iterator:
|
||||
state_dict[name] = weight.clone()
|
||||
@@ -389,23 +418,31 @@ def get_all_weights_file(model_path: str):
|
||||
use_safetensors = True
|
||||
files_list = [str(file) for file in model_path.glob("*.pdparams") if file.name != "scheduler.pdparams"]
|
||||
if len(files_list) > 0:
|
||||
key_name_list = []
|
||||
ordered_weight_map = {}
|
||||
use_safetensors = False
|
||||
# dont care about the order of the files
|
||||
return files_list, {}, use_safetensors, False
|
||||
else:
|
||||
safe_model_path = model_path / "model.safetensors"
|
||||
if safe_model_path.exists():
|
||||
files_list = [str(safe_model_path)]
|
||||
with safe_open(safe_model_path, framework="np", device="cpu") as f:
|
||||
key_name_list = f.keys()
|
||||
return key_name_list, files_list, use_safetensors
|
||||
key_name_list = sorted(f.keys(), key=natural_key)
|
||||
ordered_weight_map = {key: "model.safetensors" for key in key_name_list}
|
||||
is_key_ordered = True
|
||||
files_list = [str(safe_model_path)]
|
||||
return files_list, ordered_weight_map, use_safetensors, is_key_ordered
|
||||
else:
|
||||
index_file = model_path / "model.safetensors.index.json"
|
||||
with index_file.open("r") as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
keys = list(weight_map.keys())
|
||||
is_key_ordered = keys == sorted(keys, key=natural_key)
|
||||
ordered_weight_map = {
|
||||
key: str(model_path / weight_map[key]) for key in sorted(weight_map.keys(), key=natural_key)
|
||||
}
|
||||
weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map}
|
||||
key_name_list = list(weight_map.keys())
|
||||
files_list = sorted(weight_files_in_index)
|
||||
return key_name_list, files_list, use_safetensors
|
||||
return files_list, ordered_weight_map, use_safetensors, is_key_ordered
|
||||
|
||||
|
||||
def deal_state_dict(state_dict):
|
||||
|
||||
Reference in New Issue
Block a user