""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import concurrent.futures import contextlib import copy import hashlib import inspect import json import os import pickle import re import time from contextlib import ExitStack from functools import wraps from pathlib import Path from typing import Optional import paddle import paddle.distributed as dist import safetensors from paddleformers.transformers import PretrainedModel from paddleformers.transformers.model_utils import load_tp_checkpoint from paddleformers.utils.log import logger from paddleformers.utils.safetensors import fast_safe_open from safetensors import safe_open from tqdm import tqdm from fastdeploy import envs from fastdeploy.config import FDConfig, LoadConfig from fastdeploy.model_executor.layers.linear import KVBatchLinear from fastdeploy.model_executor.utils import multi_switch_config_context DEFAULT_NUM_THREADS = 8 def natural_key(s: str): return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)] def layers_are_grouped(keys): seen = set() current_layer = None for k in keys: m = re.search(r"layers\.(\d+)", k) if not m: continue layer = int(m.group(1)) if layer != current_layer: if layer in seen: return False seen.add(layer) current_layer = layer return True def pdparams_weight_iterator(paddle_file_list: list[str]): for pdparams_file in tqdm( paddle_file_list, desc="Loading pdparams checkpoint shards", ): state_dict = paddle.load(pdparams_file) yield from state_dict.items() del state_dict def load_weights_from_cache(model, weights_iterator): params_dict = dict(model.named_parameters()) for loaded_weight_name, loaded_weight in weights_iterator: if loaded_weight_name not in params_dict: logger.info(f"{loaded_weight_name} is not in model parameters.") continue param = params_dict[loaded_weight_name] if param.shape != loaded_weight.shape: raise ValueError( f"Shape mismatch between loaded weight {loaded_weight_name}: {loaded_weight.shape}, expected shape: {param.shape}" ) param.copy_(loaded_weight, False) if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False): model.lm_head.linear.weight.set_value( loaded_weight.transpose([1, 0]).astype(model.lm_head.linear.weight.dtype) ) for _, model_sublayer in model.named_sublayers(): if isinstance(model_sublayer, KVBatchLinear): model_sublayer.process_weights_after_loading() def get_model_path(fd_config: FDConfig): model_path = fd_config.model_config.model rank_dirs = [ f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f)) ] if len(rank_dirs) > 1: local_rank = fd_config.parallel_config.tensor_parallel_rank if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs): raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}") model_path = os.path.join(model_path, f"rank{local_rank}") fd_config.load_config.is_pre_sharded = True return model_path def get_weight_iterator(model_path: str, load_config: Optional[LoadConfig] = None): files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped = get_all_weights_file(model_path) if use_safetensors: extra_config = load_config.model_loader_extra_config if load_config else None if extra_config is not None and extra_config.get("enable_multithread_load", False): weights_iterator = multi_thread_safetensors_weights_iterator( files_list, max_workers=extra_config.get("num_threads", DEFAULT_NUM_THREADS), disable_mmap=extra_config.get("disable_mmap", False), ) else: if is_layers_are_grouped: weights_iterator = safetensors_weights_iterator(files_list) else: weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map) else: weights_iterator = pdparams_weight_iterator(files_list) yield from weights_iterator kv_cache_scale_json_path = Path(model_path) / "kv_cache_scale.json" if kv_cache_scale_json_path.exists(): yield from kv_cache_scale_iterator(str(kv_cache_scale_json_path)) def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"): weight_cache_context = contextlib.nullcontext() weight_cache_dir = None enable_cache = False if envs.FD_ENABLE_MODEL_LOAD_CACHE and fd_config.quant_config is not None: model_weight_cache_path = os.path.join(fd_config.model_config.model, weight_cache_path) # model_type + quantization + tp_size + ep_size weight_cache_key = "_".join( [ fd_config.model_config.model_type, fd_config.quant_config.name(), str(fd_config.parallel_config.tensor_parallel_size), str(fd_config.parallel_config.expert_parallel_size), ] ) # only support tp now hash_key = hashlib.md5(pickle.dumps(weight_cache_key)).hexdigest() weight_cache_dir = os.path.join(model_weight_cache_path, hash_key) if os.path.exists(weight_cache_dir): logger.info( f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it." ) enable_cache = True weight_cache_context = multi_switch_config_context( (fd_config.quant_config, "is_checkpoint_bf16", False), ) return enable_cache, weight_cache_dir, weight_cache_context def save_model(model_arg_name="model", config_arg_name="fd_config"): @measure_time("Model saving") def _save_model(model_dict, weight_cache_dir): # Note: ProcessGroupNCCL do not support deepcopy protocol, we made modifications here. paddle.distributed.communication.group.Group.__deepcopy__ = lambda self, _: self paddle.distributed.communication.group.Group.to_json = lambda self: repr(self) paddle.save(model_dict, weight_cache_dir) def decorator(func): @wraps(func) def wrapper(*args, **kwargs): sig = inspect.signature(func) bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() fd_config = bound_args.arguments.get(config_arg_name, None) model = bound_args.arguments.get(model_arg_name, None) enable_cache, weight_cache_dir, _ = is_weight_cache_enabled(fd_config) assert fd_config is not None, "fd_config cannot be None" assert model is not None, "model cannot be None" if enable_cache: tp_weight_cache_dir = os.path.join( weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}" ) context = multi_switch_config_context((fd_config.model_config, "model", tp_weight_cache_dir)) else: context = contextlib.nullcontext() with context: result = func(*args, **kwargs) if envs.FD_ENABLE_MODEL_LOAD_CACHE: if not ( fd_config.quant_config is not None and getattr(fd_config.quant_config, "is_checkpoint_bf16", False) ): # Save cache only for dynamic quantization return result if weight_cache_dir is None: return result tp_weight_cache_dir = os.path.join( weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}" ) if not os.path.exists(tp_weight_cache_dir): logger.info(f"Saving model to {tp_weight_cache_dir}") os.makedirs( tp_weight_cache_dir, exist_ok=True, ) _save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams")) else: reason = "weights already cached" if envs.FD_ENABLE_MODEL_LOAD_CACHE else "cache disabled" logger.info(f"Skip saving ,{reason}") return result return wrapper return decorator def measure_time(prefix: str = "Model loading"): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): time_before = time.time() result = func(*args, **kwargs) time_after = time.time() logger.info(f"{prefix} took {time_after - time_before:.3f} seconds") return result return wrapper return decorator def load_reordered_experts(model_path: str, key_name: str): from safetensors import safe_open with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f: weight_list = json.load(f)["weight_map"] safetensor_path = os.path.join(model_path, weight_list[key_name]) with safe_open(safetensor_path, framework="np", device="cpu") as f: if key_name in f.keys(): weight = f.get_tensor(key_name) weight = paddle.Tensor(weight, zero_copy=True) weight = weight._copy_to(paddle.framework._current_expected_place(), False) return weight def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfig, return_numpy: bool = False): """ load ep checkpoint """ with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f: weight_list = json.load(f)["weight_map"] filtered_map = {k: v for k, v in weight_list.items() if ".experts." not in k} num_local_ffn_keys = [] from itertools import chain def get_expert_ranges(fd_config): """ Generate expert index ranges based on configuration parameters This function is primarily used in Mixture-of-Experts (MoE) models to generate expert index ranges according to configuration parameters. When moe_num_experts is a list in the fd_config, it returns a chained combination of two ranges, otherwise returns a single range. Args: fd_config: FastDeploy Configuration object Returns: If moe_num_experts is a list: Returns a chained combination (chain object) of two ranges: 1. Base range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank) 2. Offset range: [base_range.start + moe_num_experts[0], base_range.stop + moe_num_experts[0]) Else: Returns single range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank) """ base_range = range( fd_config.parallel_config.num_experts_start_offset, fd_config.parallel_config.num_experts_start_offset + fd_config.parallel_config.num_experts_per_rank, ) if isinstance(fd_config.model_config.moe_num_experts, list): return chain( base_range, range( base_range.start + fd_config.model_config.moe_num_experts[0], base_range.stop + fd_config.model_config.moe_num_experts[0], ), ) return base_range prefix_layer_name = ( "mtp_block" if getattr(fd_config.speculative_config, "model_type", "main") == "mtp" else "layers" ) moe_num_experts = fd_config.model_config.moe_num_experts if isinstance(moe_num_experts, list): moe_num_experts = moe_num_experts[0] for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers): for j in get_expert_ranges(fd_config): # Map redundant expert IDs back to actual expert IDs for weight loading j = j % moe_num_experts up_gate_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight" down_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight" up_gate_proj_quant_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.quant_weight" down_proj_quant_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.quant_weight" up_gate_proj_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight_scale" down_proj_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight_scale" down_proj_in_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.activation_scale" # single up_gate_proj.activation_scale for all mlp.experts up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.up_gate_proj.activation_scale" num_local_ffn_keys.append(up_gate_proj_key) num_local_ffn_keys.append(down_proj_key) num_local_ffn_keys.append(up_gate_proj_quant_key) num_local_ffn_keys.append(down_proj_quant_key) num_local_ffn_keys.append(up_gate_proj_scale_key) num_local_ffn_keys.append(down_proj_scale_key) num_local_ffn_keys.append(down_proj_in_scale_key) num_local_ffn_keys.append(up_gate_proj_in_scale_key) # for EP w4a8, we need all expert's activation_scale for up_gate_proj num_experts = fd_config.model_config.moe_num_experts if isinstance(num_experts, list): num_experts = num_experts[0] for j in range(num_experts): up_gate_proj_in_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.activation_scale" num_local_ffn_keys.append(up_gate_proj_in_scale_key) for k in num_local_ffn_keys: if k in weight_list: filtered_map[k] = weight_list[k] if fd_config.parallel_config.tensor_parallel_size > 1: no_tp_action_keys = copy.deepcopy(num_local_ffn_keys) if fd_config.parallel_config.use_sequence_parallel_moe: for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers): no_tp_keys = [ f"ernie.{prefix_layer_name}.{i}.self_attn.o_proj.weight", f"ernie.{prefix_layer_name}.{i}.self_attn.o_proj.bias", ] for k in no_tp_keys: if k in weight_list: no_tp_action_keys.append(k) tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config) new_actions = {k: v for k, v in tp_actions.items() if k not in no_tp_action_keys} state_dict = {} # Get all safetensor file paths that need to be opened safetensor_paths = set(filtered_map.values()) # Open each safetensor file sequentially with progress bar for safetensor_path in tqdm(safetensor_paths, desc="Loading safetensor files", unit="file"): with safe_open( os.path.join(model_path, safetensor_path), framework="np", device="cpu", ) as f: # Check if this file contains keys from filtered_map for k in filtered_map: if filtered_map[k] == safetensor_path and k in f.keys(): weight = f.get_tensor(k) if fd_config.parallel_config.tensor_parallel_size > 1: if k in new_actions: weight = new_actions[k](weight) if not return_numpy: weight = paddle.Tensor(weight, zero_copy=True) weight = weight._copy_to(paddle.framework._current_expected_place(), False) state_dict[k] = weight return state_dict def kv_cache_scale_iterator(kv_cache_scale_json_path): """ kv_cache_scale_iterator """ with open(kv_cache_scale_json_path, "r") as f: data = json.load(f) for key, value in data.items(): scale_tensor = paddle.to_tensor(value, dtype=paddle.get_default_dtype()) * 448.0 yield key, scale_tensor def safetensors_weights_iterator(safe_tensor_list: list[str]): """ safetensors_weights_iterator """ for st_file in tqdm( safe_tensor_list, desc="Loading safetensors checkpoint shards", ): with safe_open(st_file, framework="paddle", device="cpu") as f: for name in f.keys(): param = f.get_tensor(name) yield name, param def multi_thread_safetensors_weights_iterator(safe_tensor_list, max_workers: int = 4, disable_mmap: bool = False): """ Iterate over safetensors weights using multi-threaded loading. Args: safe_tensor_list: List of safetensors file paths to load. max_workers: Maximum number of threads for concurrent loading. Defaults to 4. disable_mmap: If True, load files into memory directly instead of using memory-mapped files. Useful when mmap is not supported or causes issues. Yields: Tuple[str, paddle.Tensor]: Weight name and corresponding tensor. """ try: enable_tqdm = dist.get_rank() == 0 except Exception: enable_tqdm = True def _load_file(st_file: str): if disable_mmap: with open(st_file, "rb") as f: result = safetensors.paddle.load(f.read()) else: result = safetensors.paddle.load_file(st_file, device="cpu") return result with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [executor.submit(_load_file, st_file) for st_file in safe_tensor_list] if enable_tqdm: futures_iter = tqdm( concurrent.futures.as_completed(futures), total=len(safe_tensor_list), desc="Multi-thread loading shards", disable=not enable_tqdm, ) else: futures_iter = concurrent.futures.as_completed(futures) for future in futures_iter: state_dict = future.result() for name, param in state_dict.items(): yield name, param def safetensors_weights_iterator_ordered(ordered_weight_map: dict[str, str]): """ safetensors_weights_iterator_ordered """ with ExitStack() as stack: current_file = None current_handle = None for key, st_file in tqdm( ordered_weight_map.items(), desc="Loading safetensors weights", ): if st_file != current_file: stack.close() current_handle = stack.enter_context(safe_open(st_file, framework="paddle", device="cpu")) current_file = st_file yield key, current_handle.get_tensor(key) def fast_weights_iterator(safe_tensor_list: list[str]): """ paddleformers' iterator for safetensors """ for st_file in tqdm( safe_tensor_list, desc="Loading safetensors checkpoint shards", ): with fast_safe_open(st_file, framework="np") as f: for name in f.keys(): param_slice = f.get_slice(name) yield name, param_slice def load_pre_sharded_checkpoint(model_path: str, local_rank: int): """ load_pre_sharded_checkpoint """ state_dict = {} weights_iterator = get_weight_iterator(os.path.join(model_path, f"rank{local_rank}")) for name, weight in weights_iterator: state_dict[name] = weight.clone() return state_dict def get_all_weights_file(model_path: str): """ get_all_safetensors """ model_path = Path(model_path) use_safetensors = True files_list = [str(file) for file in model_path.glob("*.pdparams") if file.name != "scheduler.pdparams"] if len(files_list) > 0: ordered_weight_map = {} use_safetensors = False # dont care about the order of the files return files_list, {}, use_safetensors, False else: safe_model_path = model_path / "model.safetensors" if safe_model_path.exists(): with safe_open(safe_model_path, framework="np", device="cpu") as f: key_name_list = sorted(f.keys(), key=natural_key) ordered_weight_map = {key: "model.safetensors" for key in key_name_list} is_layers_are_grouped = True files_list = [str(safe_model_path)] return files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped else: index_file = model_path / "model.safetensors.index.json" with index_file.open("r") as f: weight_map = json.load(f)["weight_map"] keys = list(weight_map.keys()) is_layers_are_grouped = layers_are_grouped(keys) ordered_weight_map = { key: str(model_path / weight_map[key]) for key in sorted(weight_map.keys(), key=natural_key) } weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map} files_list = sorted(weight_files_in_index) return files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped def deal_state_dict(state_dict): """deal_state_dict""" device = paddle.CUDAPinnedPlace() for name, src in state_dict.items(): if src._is_initialized() and not isinstance(src.place, paddle.CUDAPinnedPlace): dst = src._copy_to(device, True) dst_tensor = dst.value().get_tensor() src_tensor = src.value().get_tensor() src_tensor._clear() src_tensor._share_data_with(dst_tensor) def load_kv_cache_scale(fd_config, state_dict): file_path = fd_config.model_config.kv_cache_quant_scale_path prefix_layer_name = fd_config.model_config.prefix_layer_name if os.path.exists(file_path): with open(file_path, "r") as f: data = json.load(f) for i in range(fd_config.model_config.num_hidden_layers): k_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachek_matmul.activation_scale" v_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachev_matmul.activation_scale" k_scale = data[k_scale_name] k_scale_tensor = paddle.to_tensor(k_scale, dtype=paddle.get_default_dtype()) state_dict[k_scale_name] = k_scale_tensor * 448.0 v_scale = data[v_scale_name] v_scale_tensor = paddle.to_tensor(v_scale, dtype=paddle.get_default_dtype()) state_dict[v_scale_name] = v_scale_tensor * 448.0 logger.info(f"Loaded kv cache scales for layer {i}.") else: logger.warning(f"No kv_cache_scale.json found at {file_path}, skipping...") def load_composite_checkpoint( model_path: str, cls: PretrainedModel, fd_config: FDConfig, return_numpy=True, ): """ # This method supports loading model weights under three parallelism strategies: # 1. Expert Parallel (EP) # 2. Tensor Parallel (TP) # 3. Pre-sharded (pre-split) """ if fd_config.parallel_config.use_ep: state_dict = load_ep_checkpoint(cls, model_path, fd_config, return_numpy=True) else: rank_dirs = [ f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f)) ] if len(rank_dirs) > 1: if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs): raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}") state_dict = load_pre_sharded_checkpoint( model_path, fd_config.parallel_config.tensor_parallel_rank, ) else: fd_config.model_config.pretrained_config.use_sequence_parallel_moe = ( fd_config.parallel_config.use_sequence_parallel_moe ) # NOTE: for very big model, cpu will be out of memory state_dict = load_tp_checkpoint( model_path, cls, fd_config.model_config.pretrained_config, return_numpy=return_numpy, ) if not state_dict: raise ValueError("weight not found in state_dict !") if hasattr(fd_config.quant_config, "kv_cache_quant_type"): kv_cache_quant_type = fd_config.quant_config.kv_cache_quant_type if kv_cache_quant_type == "float8_e4m3fn": load_kv_cache_scale(fd_config, state_dict) return state_dict