[Loader] add multi-thread model loading (#6877)

* multi-thread-loader

* fix ut
This commit is contained in:
bukejiyu
2026-04-10 14:40:15 +08:00
committed by GitHub
parent c1fb3112f8
commit 14d46181b8
12 changed files with 105 additions and 7 deletions
+65 -5
View File
@@ -14,6 +14,7 @@
# limitations under the License.
"""
import concurrent.futures
import contextlib
import copy
import hashlib
@@ -26,8 +27,11 @@ import time
from contextlib import ExitStack
from functools import wraps
from pathlib import Path
from typing import Optional
import paddle
import paddle.distributed as dist
import safetensors
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.model_utils import load_tp_checkpoint
from paddleformers.utils.log import logger
@@ -36,10 +40,12 @@ from safetensors import safe_open
from tqdm import tqdm
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.config import FDConfig, LoadConfig
from fastdeploy.model_executor.layers.linear import KVBatchLinear
from fastdeploy.model_executor.utils import multi_switch_config_context
DEFAULT_NUM_THREADS = 8
def natural_key(s: str):
return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)]
@@ -111,13 +117,21 @@ def get_model_path(fd_config: FDConfig):
return model_path
def get_weight_iterator(model_path: str):
def get_weight_iterator(model_path: str, load_config: Optional[LoadConfig] = None):
files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped = get_all_weights_file(model_path)
if use_safetensors:
if is_layers_are_grouped:
weights_iterator = safetensors_weights_iterator(files_list)
extra_config = load_config.model_loader_extra_config if load_config else None
if extra_config is not None and extra_config.get("enable_multithread_load", False):
weights_iterator = multi_thread_safetensors_weights_iterator(
files_list,
max_workers=extra_config.get("num_threads", DEFAULT_NUM_THREADS),
disable_mmap=extra_config.get("disable_mmap", False),
)
else:
weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map)
if is_layers_are_grouped:
weights_iterator = safetensors_weights_iterator(files_list)
else:
weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map)
else:
weights_iterator = pdparams_weight_iterator(files_list)
@@ -401,6 +415,52 @@ def safetensors_weights_iterator(safe_tensor_list: list[str]):
yield name, param
def multi_thread_safetensors_weights_iterator(safe_tensor_list, max_workers: int = 4, disable_mmap: bool = False):
"""
Iterate over safetensors weights using multi-threaded loading.
Args:
safe_tensor_list: List of safetensors file paths to load.
max_workers: Maximum number of threads for concurrent loading. Defaults to 4.
disable_mmap: If True, load files into memory directly instead of using memory-mapped
files. Useful when mmap is not supported or causes issues.
Yields:
Tuple[str, paddle.Tensor]: Weight name and corresponding tensor.
"""
try:
enable_tqdm = dist.get_rank() == 0
except Exception:
enable_tqdm = True
def _load_file(st_file: str):
if disable_mmap:
with open(st_file, "rb") as f:
result = safetensors.paddle.load(f.read())
else:
result = safetensors.paddle.load_file(st_file, device="cpu")
return result
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_load_file, st_file) for st_file in safe_tensor_list]
if enable_tqdm:
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(safe_tensor_list),
desc="Multi-thread loading shards",
disable=not enable_tqdm,
)
else:
futures_iter = concurrent.futures.as_completed(futures)
for future in futures_iter:
state_dict = future.result()
for name, param in state_dict.items():
yield name, param
def safetensors_weights_iterator_ordered(ordered_weight_map: dict[str, str]):
"""
safetensors_weights_iterator_ordered
@@ -57,7 +57,7 @@ class DefaultModelLoaderV1(BaseModelLoader):
@measure_time()
def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None:
model_path = get_model_path(fd_config)
weights_iterator = get_weight_iterator(model_path)
weights_iterator = get_weight_iterator(model_path, fd_config.load_config)
if enable_cache:
load_weights_from_cache(model, weights_iterator)
else: