[New][RL] Support Rollout Routing Replay (#5405)

* [RL] Support Rollout Routing Replay

* add routing indices cache

* fix config bug and moe forward bug

* R3 Support GLM

* support eb4.5

* fix merge bug

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* add routing replay ci

* support glm topk

* support orther top_k

* fix ci bug

* pre-commit

* only support chatcmpl

* Revert "Revert "[RL] Support Rollout Routing Replay (#5321)" (#5402)"

This reverts commit c45e064f3d.

* Fix XPU and NPU bug

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
RAM
2025-12-05 22:06:26 +08:00
committed by GitHub
parent c45e064f3d
commit b2908b8e82
26 changed files with 608 additions and 24 deletions
+10
View File
@@ -38,6 +38,7 @@ from fastdeploy.config import (
ModelConfig,
ParallelConfig,
PlasAttentionConfig,
RoutingReplayConfig,
SpeculativeConfig,
StructuredOutputsConfig,
)
@@ -885,6 +886,13 @@ def parse_args():
help="EPLB Configuration.",
)
parser.add_argument(
"--routing_replay_config",
type=json.loads,
default=None,
help="Configation of Rollout Routing Replay.",
)
args = parser.parse_args()
return args
@@ -944,6 +952,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
eplb_config = EPLBConfig(args.eplb_config)
structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=vars(args))
routing_replay_config = RoutingReplayConfig(args.routing_replay_config)
# Note(tangbinhan): used for load_checkpoint
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
@@ -1003,6 +1012,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
plas_attention_config=plas_attention_config,
structured_outputs_config=structured_outputs_config,
eplb_config=eplb_config,
routing_replay_config=routing_replay_config,
)
update_fd_config_for_mm(fd_config)
if fd_config.load_config.load_choices == "default_v1" and not v1_loader_support(fd_config):