mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
fix rl moe gate type (#7393)
This commit is contained in:
@@ -68,6 +68,7 @@ class RolloutModelConfig:
|
|||||||
routing_replay_config: str = None,
|
routing_replay_config: str = None,
|
||||||
load_choices: str = "default_v1",
|
load_choices: str = "default_v1",
|
||||||
lm_head_fp32: bool = False,
|
lm_head_fp32: bool = False,
|
||||||
|
moe_gate_fp32: bool = True,
|
||||||
):
|
):
|
||||||
# Required parameters
|
# Required parameters
|
||||||
self.model = model_name_or_path
|
self.model = model_name_or_path
|
||||||
@@ -121,6 +122,7 @@ class RolloutModelConfig:
|
|||||||
self.routing_replay_config = routing_replay_config
|
self.routing_replay_config = routing_replay_config
|
||||||
self.load_choices = load_choices
|
self.load_choices = load_choices
|
||||||
self.lm_head_fp32 = lm_head_fp32
|
self.lm_head_fp32 = lm_head_fp32
|
||||||
|
self.moe_gate_fp32 = moe_gate_fp32
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||||
|
|||||||
Reference in New Issue
Block a user