mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Loader] add multi-thread model loading (#6877)
* multi-thread-loader * fix ut
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
@@ -26,8 +27,11 @@ import time
|
||||
from contextlib import ExitStack
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
import safetensors
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
from paddleformers.transformers.model_utils import load_tp_checkpoint
|
||||
from paddleformers.utils.log import logger
|
||||
@@ -36,10 +40,12 @@ from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.config import FDConfig, LoadConfig
|
||||
from fastdeploy.model_executor.layers.linear import KVBatchLinear
|
||||
from fastdeploy.model_executor.utils import multi_switch_config_context
|
||||
|
||||
DEFAULT_NUM_THREADS = 8
|
||||
|
||||
|
||||
def natural_key(s: str):
|
||||
return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)]
|
||||
@@ -111,13 +117,21 @@ def get_model_path(fd_config: FDConfig):
|
||||
return model_path
|
||||
|
||||
|
||||
def get_weight_iterator(model_path: str):
|
||||
def get_weight_iterator(model_path: str, load_config: Optional[LoadConfig] = None):
|
||||
files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped = get_all_weights_file(model_path)
|
||||
if use_safetensors:
|
||||
if is_layers_are_grouped:
|
||||
weights_iterator = safetensors_weights_iterator(files_list)
|
||||
extra_config = load_config.model_loader_extra_config if load_config else None
|
||||
if extra_config is not None and extra_config.get("enable_multithread_load", False):
|
||||
weights_iterator = multi_thread_safetensors_weights_iterator(
|
||||
files_list,
|
||||
max_workers=extra_config.get("num_threads", DEFAULT_NUM_THREADS),
|
||||
disable_mmap=extra_config.get("disable_mmap", False),
|
||||
)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map)
|
||||
if is_layers_are_grouped:
|
||||
weights_iterator = safetensors_weights_iterator(files_list)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map)
|
||||
else:
|
||||
weights_iterator = pdparams_weight_iterator(files_list)
|
||||
|
||||
@@ -401,6 +415,52 @@ def safetensors_weights_iterator(safe_tensor_list: list[str]):
|
||||
yield name, param
|
||||
|
||||
|
||||
def multi_thread_safetensors_weights_iterator(safe_tensor_list, max_workers: int = 4, disable_mmap: bool = False):
|
||||
"""
|
||||
Iterate over safetensors weights using multi-threaded loading.
|
||||
|
||||
Args:
|
||||
safe_tensor_list: List of safetensors file paths to load.
|
||||
max_workers: Maximum number of threads for concurrent loading. Defaults to 4.
|
||||
disable_mmap: If True, load files into memory directly instead of using memory-mapped
|
||||
files. Useful when mmap is not supported or causes issues.
|
||||
|
||||
Yields:
|
||||
Tuple[str, paddle.Tensor]: Weight name and corresponding tensor.
|
||||
"""
|
||||
try:
|
||||
enable_tqdm = dist.get_rank() == 0
|
||||
except Exception:
|
||||
enable_tqdm = True
|
||||
|
||||
def _load_file(st_file: str):
|
||||
if disable_mmap:
|
||||
with open(st_file, "rb") as f:
|
||||
result = safetensors.paddle.load(f.read())
|
||||
else:
|
||||
result = safetensors.paddle.load_file(st_file, device="cpu")
|
||||
|
||||
return result
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(_load_file, st_file) for st_file in safe_tensor_list]
|
||||
|
||||
if enable_tqdm:
|
||||
futures_iter = tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
total=len(safe_tensor_list),
|
||||
desc="Multi-thread loading shards",
|
||||
disable=not enable_tqdm,
|
||||
)
|
||||
else:
|
||||
futures_iter = concurrent.futures.as_completed(futures)
|
||||
|
||||
for future in futures_iter:
|
||||
state_dict = future.result()
|
||||
for name, param in state_dict.items():
|
||||
yield name, param
|
||||
|
||||
|
||||
def safetensors_weights_iterator_ordered(ordered_weight_map: dict[str, str]):
|
||||
"""
|
||||
safetensors_weights_iterator_ordered
|
||||
|
||||
@@ -57,7 +57,7 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
@measure_time()
|
||||
def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None:
|
||||
model_path = get_model_path(fd_config)
|
||||
weights_iterator = get_weight_iterator(model_path)
|
||||
weights_iterator = get_weight_iterator(model_path, fd_config.load_config)
|
||||
if enable_cache:
|
||||
load_weights_from_cache(model, weights_iterator)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user