mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[New][RL] Support Rollout Routing Replay (#5405)
* [RL] Support Rollout Routing Replay
* add routing indices cache
* fix config bug and moe forward bug
* R3 Support GLM
* support eb4.5
* fix merge bug
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* add routing replay ci
* support glm topk
* support orther top_k
* fix ci bug
* pre-commit
* only support chatcmpl
* Revert "Revert "[RL] Support Rollout Routing Replay (#5321)" (#5402)"
This reverts commit c45e064f3d.
* Fix XPU and NPU bug
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -38,6 +38,7 @@ from fastdeploy.config import (
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
PlasAttentionConfig,
|
||||
RoutingReplayConfig,
|
||||
SpeculativeConfig,
|
||||
StructuredOutputsConfig,
|
||||
)
|
||||
@@ -885,6 +886,13 @@ def parse_args():
|
||||
help="EPLB Configuration.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--routing_replay_config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Configation of Rollout Routing Replay.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -944,6 +952,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
eplb_config = EPLBConfig(args.eplb_config)
|
||||
|
||||
structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=vars(args))
|
||||
routing_replay_config = RoutingReplayConfig(args.routing_replay_config)
|
||||
|
||||
# Note(tangbinhan): used for load_checkpoint
|
||||
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
@@ -1003,6 +1012,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
plas_attention_config=plas_attention_config,
|
||||
structured_outputs_config=structured_outputs_config,
|
||||
eplb_config=eplb_config,
|
||||
routing_replay_config=routing_replay_config,
|
||||
)
|
||||
update_fd_config_for_mm(fd_config)
|
||||
if fd_config.load_config.load_choices == "default_v1" and not v1_loader_support(fd_config):
|
||||
|
||||
Reference in New Issue
Block a user