mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support redundant expert for eplb (#5918)
* [BugFix] support redundant expert for eplb * support redundant expert for eplb * support redundant expert for eplb * update * fix ci eplb
This commit is contained in:
@@ -927,6 +927,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
parallel_config = ParallelConfig(vars(args))
|
||||
cache_config = CacheConfig(vars(args))
|
||||
scheduler_config = SchedulerConfig(vars(args))
|
||||
eplb_config = EPLBConfig(args.eplb_config)
|
||||
|
||||
parallel_config.tensor_parallel_rank = local_rank % parallel_config.tensor_parallel_size
|
||||
parallel_config.data_parallel_rank = local_rank // parallel_config.tensor_parallel_size
|
||||
@@ -940,9 +941,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
if parallel_config.expert_parallel_size > 1:
|
||||
expert_parallel_rank = int(local_rank % parallel_config.expert_parallel_size)
|
||||
if isinstance(model_config.moe_num_experts, list):
|
||||
num_experts = model_config.moe_num_experts[0]
|
||||
num_experts = model_config.moe_num_experts[0] + eplb_config.redundant_experts_num
|
||||
else:
|
||||
num_experts = model_config.moe_num_experts
|
||||
num_experts = model_config.moe_num_experts + eplb_config.redundant_experts_num
|
||||
num_experts_per_rank = num_experts // parallel_config.expert_parallel_size
|
||||
num_experts_start_offset = expert_parallel_rank * num_experts_per_rank
|
||||
parallel_config.expert_parallel_rank = expert_parallel_rank
|
||||
@@ -958,7 +959,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
plas_attention_config = PlasAttentionConfig(args.plas_attention_config)
|
||||
|
||||
early_stop_config = EarlyStopConfig(args.early_stop_config)
|
||||
eplb_config = EPLBConfig(args.eplb_config)
|
||||
|
||||
structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=vars(args))
|
||||
routing_replay_config = RoutingReplayConfig(args.routing_replay_config)
|
||||
|
||||
Reference in New Issue
Block a user