[Loader] add multi-thread model loading (#6877)

* multi-thread-loader

* fix ut
This commit is contained in:
bukejiyu
2026-04-10 14:40:15 +08:00
committed by GitHub
parent c1fb3112f8
commit 14d46181b8
12 changed files with 105 additions and 7 deletions
+1
View File
@@ -58,6 +58,7 @@ When using FastDeploy to deploy models (including offline inference and service
| ```tool_call_parser``` | `str` | Specify the function call parser to be used for extracting function call content from the model's output. |
| ```tool_parser_plugin``` | `str` | Specify the file path of the tool parser to be registered, so as to register parsers that are not in the code repository. The code format within these parsers must adhere to the format used in the code repository. |
| ```load_choices``` | `str` | Weight loader selection, default: "default_v1". Supports "default", "default_v1", and "dummy". "default_v1" is used for loading torch weights and weight acceleration. "dummy" is used for quickly and randomly initializes weights for testing|
| ```model_loader_extra_config``` | `dict[str]` | Additional configuration options for the model loader. Supports: <br> - `enable_multithread_load` (bool): Enable multi-threaded weight loading. <br> - `num_threads` (int): Number of threads for loading. Defaults to 8. <br> - `disable_mmap` (bool): Disable memory-mapped file access. Useful when mmap is not supported. <br> Example: `'{"enable_multithread_load": true, "num_threads": 8}'` |
| ```max_encoder_cache``` | `int` | Maximum number of tokens in the encoder cache (use 0 to disable), default: -1 (auto-calculated) |
| ```max_processor_cache``` | `float` | Maximum number of bytes(in GiB) in the processor cache (use 0 to disable), default: -1 (auto-calculated) |
| ```api_key``` |`list[str]`| Validate API keys in the service request headers, supporting multiple key inputs. Same effect as environment variable `FD_API_KEY`, with higher priority|
+1
View File
@@ -56,6 +56,7 @@
| ```tool_call_parser``` | `str` | 指定要使用的function call解析器,以便从模型输出中抽取 function call内容|
| ```tool_parser_plugin``` | `str` | 指定要注册的tool parser文件路径,以便注册不在代码库中的parser,parser中代码格式需遵循代码库中格式|
| ```load_choices``` | `str` | 权重加载器选择,默认使用"default_v1"。支持"default"和"default_v1",后者用于加载torch权重和权重加速|
| ```model_loader_extra_config``` | `dict[str]` | 模型加载器额外配置选项。支持:<br> - `enable_multithread_load` (bool): 启用多线程权重加载。<br> - `num_threads` (int): 加载线程数,默认为8。<br> - `disable_mmap` (bool): 禁用内存映射文件访问,当mmap不支持时使用。<br> 示例:`'{"enable_multithread_load": true, "num_threads": 8}'` |
| ```max_encoder_cache``` | `int` | 编码器缓存的最大token数(使用0表示禁用),默认-1(自动计算)|
| ```max_processor_cache``` | `float` | 处理器缓存的最大字节数(以GiB为单位,使用0表示禁用),默认-1(自动计算)|
| ```api_key``` |`list[str]`| 校验服务请求头中的API密钥,支持传入多个密钥;与环境变量`FD_API_KEY`中的值效果相同,且优先级高于环境变量配置|
+1
View File
@@ -1447,6 +1447,7 @@ class LoadConfig:
self.dynamic_load_weight: bool = False
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal", "rsync"]] = "normal"
self.rsync_config: Optional[Dict[str, Any]] = None
self.model_loader_extra_config: Optional[Dict[str, Any]] = None
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
+16
View File
@@ -496,6 +496,14 @@ class EngineArgs:
- "default": default loader.
- "default_v1": default_v1 loader.
"""
model_loader_extra_config: Optional[Dict[str, Any]] = None
"""
Additional configuration options for the model loader.
Supports:
- enable_multithread_load (bool): Enable multi-threaded weight loading.
- num_threads (int): Number of threads for loading. Defaults to 8.
- disable_mmap (bool): Disable memory-mapped file access.
"""
lm_head_fp32: bool = False
"""
@@ -1091,6 +1099,14 @@ class EngineArgs:
default/default_v1/dummy.",
)
load_group.add_argument(
"--model-loader-extra-config",
type=json.loads,
default=EngineArgs.model_loader_extra_config,
help="Additional configuration for model loader (JSON format). "
'e.g., \'{"enable_multithread_load": true, "num_threads": 8}\'',
)
# CacheConfig parameters group
cache_group = parser.add_argument_group("Cache Configuration")
+1
View File
@@ -2483,6 +2483,7 @@ class EngineService:
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
f" --load_choices {self.cfg.load_config.load_choices}"
f" --model_loader_extra_config '{json.dumps(self.cfg.load_config.model_loader_extra_config)}'"
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
f" --ips {ips}"
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
+1
View File
@@ -613,6 +613,7 @@ class LLMEngine:
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
f" --load_choices {self.cfg.load_config.load_choices}"
f" --model_loader_extra_config '{json.dumps(self.cfg.load_config.model_loader_extra_config)}'"
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
f" --ips {ips}"
f" --max_encoder_cache {self.cfg.cache_config.max_encoder_cache}"
+62 -2
View File
@@ -14,6 +14,7 @@
# limitations under the License.
"""
import concurrent.futures
import contextlib
import copy
import hashlib
@@ -26,8 +27,11 @@ import time
from contextlib import ExitStack
from functools import wraps
from pathlib import Path
from typing import Optional
import paddle
import paddle.distributed as dist
import safetensors
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.model_utils import load_tp_checkpoint
from paddleformers.utils.log import logger
@@ -36,10 +40,12 @@ from safetensors import safe_open
from tqdm import tqdm
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.config import FDConfig, LoadConfig
from fastdeploy.model_executor.layers.linear import KVBatchLinear
from fastdeploy.model_executor.utils import multi_switch_config_context
DEFAULT_NUM_THREADS = 8
def natural_key(s: str):
return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)]
@@ -111,9 +117,17 @@ def get_model_path(fd_config: FDConfig):
return model_path
def get_weight_iterator(model_path: str):
def get_weight_iterator(model_path: str, load_config: Optional[LoadConfig] = None):
files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped = get_all_weights_file(model_path)
if use_safetensors:
extra_config = load_config.model_loader_extra_config if load_config else None
if extra_config is not None and extra_config.get("enable_multithread_load", False):
weights_iterator = multi_thread_safetensors_weights_iterator(
files_list,
max_workers=extra_config.get("num_threads", DEFAULT_NUM_THREADS),
disable_mmap=extra_config.get("disable_mmap", False),
)
else:
if is_layers_are_grouped:
weights_iterator = safetensors_weights_iterator(files_list)
else:
@@ -401,6 +415,52 @@ def safetensors_weights_iterator(safe_tensor_list: list[str]):
yield name, param
def multi_thread_safetensors_weights_iterator(safe_tensor_list, max_workers: int = 4, disable_mmap: bool = False):
"""
Iterate over safetensors weights using multi-threaded loading.
Args:
safe_tensor_list: List of safetensors file paths to load.
max_workers: Maximum number of threads for concurrent loading. Defaults to 4.
disable_mmap: If True, load files into memory directly instead of using memory-mapped
files. Useful when mmap is not supported or causes issues.
Yields:
Tuple[str, paddle.Tensor]: Weight name and corresponding tensor.
"""
try:
enable_tqdm = dist.get_rank() == 0
except Exception:
enable_tqdm = True
def _load_file(st_file: str):
if disable_mmap:
with open(st_file, "rb") as f:
result = safetensors.paddle.load(f.read())
else:
result = safetensors.paddle.load_file(st_file, device="cpu")
return result
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_load_file, st_file) for st_file in safe_tensor_list]
if enable_tqdm:
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(safe_tensor_list),
desc="Multi-thread loading shards",
disable=not enable_tqdm,
)
else:
futures_iter = concurrent.futures.as_completed(futures)
for future in futures_iter:
state_dict = future.result()
for name, param in state_dict.items():
yield name, param
def safetensors_weights_iterator_ordered(ordered_weight_map: dict[str, str]):
"""
safetensors_weights_iterator_ordered
@@ -57,7 +57,7 @@ class DefaultModelLoaderV1(BaseModelLoader):
@measure_time()
def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None:
model_path = get_model_path(fd_config)
weights_iterator = get_weight_iterator(model_path)
weights_iterator = get_weight_iterator(model_path, fd_config.load_config)
if enable_cache:
load_weights_from_cache(model, weights_iterator)
else:
+8
View File
@@ -1028,6 +1028,14 @@ def parse_args():
help="The format of the model weights to load. default/default_v1/dummy.",
)
parser.add_argument(
"--model_loader_extra_config",
type=json.loads,
default=None,
help="Additional configuration for model loader (JSON format). "
'e.g., \'{"enable_multithread_load": true, "num_threads": 8}\'',
)
parser.add_argument(
"--ips",
type=str,
+7 -1
View File
@@ -45,7 +45,13 @@ def _make_cfg(**ov):
cc.enable_prefix_caching = cc.enable_chunked_prefill = False
cc.kv_cache_ratio, cc.kvcache_storage_backend, cc.num_cpu_blocks, cc.max_encoder_cache = 1.0, None, 0, 0
cc.cache_transfer_protocol, cc.total_block_num = "tcp", 100
lc = ns(load_strategy="auto", rsync_config={}, dynamic_load_weight=False, load_choices="auto")
lc = ns(
load_strategy="auto",
rsync_config={},
dynamic_load_weight=False,
load_choices="auto",
model_loader_extra_config={},
)
soc = ns(guided_decoding_backend=None, logits_processors=None, reasoning_parser="none")
soc.disable_any_whitespace = False
cfg = ns(model_config=mc, parallel_config=pc, scheduler_config=sc, cache_config=cc, load_config=lc)
+1
View File
@@ -99,5 +99,6 @@ def test_offline_model(
quantization,
"default_v1",
prompts,
{"enable_multithread_load": True, "num_threads": 2},
),
)
+2
View File
@@ -89,6 +89,7 @@ def form_model_get_output_topp0(
load_choices,
prompts,
speculative_config={},
model_loader_extra_config=None,
result_queue=None,
):
try:
@@ -100,6 +101,7 @@ def form_model_get_output_topp0(
load_choices=load_choices,
quantization=quantization,
speculative_config=speculative_config,
model_loader_extra_config=model_loader_extra_config,
) as fd_model:
fd_outputs = fd_model.generate_topp0(prompts, max_tokens=max_tokens)
result_queue.put(fd_outputs)