mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 09:44:10 +08:00
[Feature] Support redundant expert for eplb (#5918)
* [BugFix] support redundant expert for eplb * support redundant expert for eplb * support redundant expert for eplb * update * fix ci eplb
This commit is contained in:
@@ -467,13 +467,18 @@ class FusedMoE(nn.Layer):
|
||||
"""
|
||||
logical_expert_ids = [
|
||||
i
|
||||
% (
|
||||
self.fd_config.model_config.moe_num_experts[0]
|
||||
if isinstance(self.fd_config.model_config.moe_num_experts, list)
|
||||
else self.fd_config.model_config.moe_num_experts
|
||||
)
|
||||
for i in range(
|
||||
self.expert_id_offset,
|
||||
self.expert_id_offset + self.num_local_experts,
|
||||
)
|
||||
]
|
||||
ep_rank_to_expert_id_list = [i for i in range(self.num_experts)]
|
||||
if self.redundant_table_manger is not None and is_rearrange is True:
|
||||
if self.redundant_table_manger is not None:
|
||||
(
|
||||
ep_rank_to_expert_id_list,
|
||||
expert_id_to_ep_rank_array,
|
||||
@@ -487,10 +492,7 @@ class FusedMoE(nn.Layer):
|
||||
down_proj_weights = []
|
||||
if isinstance(state_dict, list):
|
||||
state_dict = dict(state_dict)
|
||||
is_ffn_merged = (
|
||||
up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset)
|
||||
in state_dict
|
||||
)
|
||||
is_ffn_merged = up_gate_proj_expert_weight_key.format(logical_expert_ids[0]) in state_dict
|
||||
if is_ffn_merged:
|
||||
for expert_idx in logical_expert_ids:
|
||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||
@@ -498,7 +500,7 @@ class FusedMoE(nn.Layer):
|
||||
up_gate_proj_weights.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(up_gate_proj_expert_weight_key_name)
|
||||
state_dict[up_gate_proj_expert_weight_key_name]
|
||||
if up_gate_proj_expert_weight_key_name in state_dict
|
||||
else up_gate_proj_expert_weight_key_name
|
||||
),
|
||||
@@ -508,7 +510,7 @@ class FusedMoE(nn.Layer):
|
||||
down_proj_weights.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(down_proj_expert_weight_key_name)
|
||||
state_dict[down_proj_expert_weight_key_name]
|
||||
if down_proj_expert_weight_key_name in state_dict
|
||||
else down_proj_expert_weight_key_name
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user