mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] ep+tp all2all (#4836)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
@@ -267,8 +268,14 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi
|
||||
filtered_map[k] = weight_list[k]
|
||||
|
||||
if fd_config.parallel_config.tensor_parallel_size > 1:
|
||||
no_tp_action_keys = copy.deepcopy(num_local_ffn_keys)
|
||||
if fd_config.parallel_config.ep_tp_strategy == "all_to_all":
|
||||
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
|
||||
k = f"ernie.layers.{i}.self_attn.o_proj.weight"
|
||||
if k in weight_list:
|
||||
no_tp_action_keys.append(k)
|
||||
tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config)
|
||||
new_actions = {k: v for k, v in tp_actions.items() if k not in num_local_ffn_keys}
|
||||
new_actions = {k: v for k, v in tp_actions.items() if k not in no_tp_action_keys}
|
||||
|
||||
state_dict = {}
|
||||
# Get all safetensor file paths that need to be opened
|
||||
|
||||
Reference in New Issue
Block a user