[RL] Support GLM MTP RL Model (#6267)

This commit is contained in:
GoldPancake
2026-02-04 20:14:35 +08:00
committed by GitHub
parent 765df94e6c
commit 183b8d325a
10 changed files with 308 additions and 33 deletions
+11 -6
View File
@@ -1382,12 +1382,6 @@ class GPUModelRunner(ModelRunnerBase):
model_loader = get_model_loader(load_config=self.fd_config.load_config)
self.model = model_loader.load_model(fd_config=self.fd_config)
# 1.1 Load RL dynamic model
if self.fd_config.load_config.dynamic_load_weight:
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model, self.local_rank)
# 2. Load lora model
# 3. Load drafter model(for speculative decoding)
@@ -1395,6 +1389,17 @@ class GPUModelRunner(ModelRunnerBase):
# 4. Init proposer for speculative method
self._init_speculative_proposer()
# Load RL dynamic model
if self.fd_config.load_config.dynamic_load_weight:
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
if self.fd_config.speculative_config.method == "mtp":
self.dynamic_weight_manager = DynamicWeightManager(
self.fd_config, [self.model, self.proposer.model], self.local_rank
)
else:
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model, self.local_rank)
def get_model(self) -> nn.Layer:
"""Get current model"""
return self.model