mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
@@ -76,6 +76,20 @@ def load_weights_from_cache(model, weights_iterator):
|
||||
model_sublayer.process_weights_after_loading()
|
||||
|
||||
|
||||
def get_model_path(fd_config: FDConfig):
|
||||
model_path = fd_config.model_config.model
|
||||
rank_dirs = [
|
||||
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
|
||||
]
|
||||
if len(rank_dirs) > 1:
|
||||
local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
|
||||
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
|
||||
model_path = os.path.join(model_path, f"rank{local_rank}")
|
||||
fd_config.load_config.is_pre_sharded = True
|
||||
return model_path
|
||||
|
||||
|
||||
def get_weight_iterator(model_path: str):
|
||||
files_list, ordered_weight_map, use_safetensors, is_key_ordered = get_all_weights_file(model_path)
|
||||
if use_safetensors:
|
||||
@@ -404,10 +418,8 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int):
|
||||
"""
|
||||
load_pre_sharded_checkpoint
|
||||
"""
|
||||
|
||||
state_dict = {}
|
||||
safetensor_files, _, _, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
|
||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||
weights_iterator = get_weight_iterator(os.path.join(model_path, f"rank{local_rank}"))
|
||||
for name, weight in weights_iterator:
|
||||
state_dict[name] = weight.clone()
|
||||
return state_dict
|
||||
|
||||
Reference in New Issue
Block a user