[RL] change rms norm for glm (#7269)

* change rms norm for glm

* refine code

* refine code

* refine code
This commit is contained in:
zhangbo9674
2026-04-10 16:02:37 +08:00
committed by GitHub
parent 870dbac370
commit 627f0d9cc8
2 changed files with 15 additions and 2 deletions
+2
View File
@@ -219,6 +219,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))),
# Whether to use phi MOE permute,if 1,use paddle op.
"FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))),
# Whether to use phi rms_norm,if 1,use paddle op.
"FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))),
# Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function
"FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))),
# Reserve output blocks for decoding requests when schedule new prefill requests
+13 -2
View File
@@ -25,6 +25,7 @@ from paddle import nn
from paddleformers.transformers import PretrainedModel
from paddleformers.utils.log import logger
import fastdeploy
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.forward_meta import ForwardMeta
@@ -264,6 +265,14 @@ class Glm4MoeAttention(nn.Layer):
return output
def rms_norm_func(x, weight, eps):
rms_norm_out = paddle.nn.functional.rms_norm(x, x.shape[-1:], weight, eps)
if isinstance(rms_norm_out, (tuple, list)):
return rms_norm_out[0].astype(weight.dtype)
else:
return rms_norm_out.astype(weight.dtype)
class Glm4MoeDecoderLayer(nn.Layer):
""" """
@@ -317,8 +326,10 @@ class Glm4MoeDecoderLayer(nn.Layer):
residual: paddle.Tensor = None,
):
""" """
proxy_rmsnorm = rms_norm_func if fastdeploy.envs.FD_USE_PHI_RMSNORM else None
hidden_states, residual = self.input_layernorm(
hidden_states, residual_input=residual, forward_meta=forward_meta
hidden_states, residual_input=residual, forward_meta=forward_meta, proxy_rmsnorm=proxy_rmsnorm
)
hidden_states = self.self_attn(
@@ -327,7 +338,7 @@ class Glm4MoeDecoderLayer(nn.Layer):
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual, proxy_rmsnorm=proxy_rmsnorm)
hidden_states = self.mlp(hidden_states, forward_meta)