[loader]supoort wint2 backend (#6139)

* support wint2

* update
This commit is contained in:
bukejiyu
2026-02-09 14:42:36 +08:00
committed by GitHub
parent f18f3b99ed
commit dc5917289d
20 changed files with 86 additions and 11 deletions
+15 -3
View File
@@ -76,6 +76,20 @@ def load_weights_from_cache(model, weights_iterator):
model_sublayer.process_weights_after_loading()
def get_model_path(fd_config: FDConfig):
model_path = fd_config.model_config.model
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
local_rank = fd_config.parallel_config.tensor_parallel_rank
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
model_path = os.path.join(model_path, f"rank{local_rank}")
fd_config.load_config.is_pre_sharded = True
return model_path
def get_weight_iterator(model_path: str):
files_list, ordered_weight_map, use_safetensors, is_key_ordered = get_all_weights_file(model_path)
if use_safetensors:
@@ -404,10 +418,8 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int):
"""
load_pre_sharded_checkpoint
"""
state_dict = {}
safetensor_files, _, _, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
weights_iterator = safetensors_weights_iterator(safetensor_files)
weights_iterator = get_weight_iterator(os.path.join(model_path, f"rank{local_rank}"))
for name, weight in weights_iterator:
state_dict[name] = weight.clone()
return state_dict