[BugFix] Enable moe_gate_fp32 using FD_ENABLE_RL (#7130)

* rl gate fp32

* clean
This commit is contained in:
sunxin
2026-04-07 12:07:38 +08:00
committed by GitHub
parent 18f012457d
commit ae2f9f4d22
4 changed files with 6 additions and 6 deletions
+2
View File
@@ -624,6 +624,8 @@ class EngineArgs:
raise NotImplementedError(
f"not support model_impl: '{self.model_impl}'. " f"Must be one of: {', '.join(valid_model_impls)}"
)
if envs.FD_ENABLE_RL == 1:
self.moe_gate_fp32 = True
self.post_init_all_ports()
+2
View File
@@ -266,6 +266,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
),
# Whether to align RoPE and moe gate precision with training
"FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")),
}
+1 -3
View File
@@ -151,9 +151,7 @@ class Glm4Moe(nn.Layer):
output_size=fd_config.model_config.n_routed_experts,
with_bias=False,
skip_quant=True,
weight_dtype=(
"float32" if fd_config.load_config.dynamic_load_weight or fd_config.model_config.moe_gate_fp32 else ""
),
weight_dtype=("float32" if fd_config.model_config.moe_gate_fp32 else ""),
)
self.gate.e_score_correction_bias = self.create_parameter(
shape=[1, fd_config.model_config.n_routed_experts],
+1 -3
View File
@@ -77,9 +77,7 @@ class Qwen3MoeBlock(nn.Layer):
output_size=fd_config.model_config.num_experts,
with_bias=False,
skip_quant=True,
weight_dtype=(
"float32" if fd_config.load_config.dynamic_load_weight or fd_config.model_config.moe_gate_fp32 else ""
),
weight_dtype=("float32" if fd_config.model_config.moe_gate_fp32 else ""),
)
def forward(self, x, forward_meta):