mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] Support GLM MTP RL Model (#6267)
This commit is contained in:
@@ -1382,12 +1382,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
model_loader = get_model_loader(load_config=self.fd_config.load_config)
|
||||
self.model = model_loader.load_model(fd_config=self.fd_config)
|
||||
|
||||
# 1.1 Load RL dynamic model
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
|
||||
|
||||
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model, self.local_rank)
|
||||
|
||||
# 2. Load lora model
|
||||
|
||||
# 3. Load drafter model(for speculative decoding)
|
||||
@@ -1395,6 +1389,17 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# 4. Init proposer for speculative method
|
||||
self._init_speculative_proposer()
|
||||
|
||||
# Load RL dynamic model
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
|
||||
|
||||
if self.fd_config.speculative_config.method == "mtp":
|
||||
self.dynamic_weight_manager = DynamicWeightManager(
|
||||
self.fd_config, [self.model, self.proposer.model], self.local_rank
|
||||
)
|
||||
else:
|
||||
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model, self.local_rank)
|
||||
|
||||
def get_model(self) -> nn.Layer:
|
||||
"""Get current model"""
|
||||
return self.model
|
||||
|
||||
Reference in New Issue
Block a user