mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] change rms norm for glm (#7269)
* change rms norm for glm * refine code * refine code * refine code
This commit is contained in:
@@ -25,6 +25,7 @@ from paddle import nn
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
@@ -264,6 +265,14 @@ class Glm4MoeAttention(nn.Layer):
|
||||
return output
|
||||
|
||||
|
||||
def rms_norm_func(x, weight, eps):
|
||||
rms_norm_out = paddle.nn.functional.rms_norm(x, x.shape[-1:], weight, eps)
|
||||
if isinstance(rms_norm_out, (tuple, list)):
|
||||
return rms_norm_out[0].astype(weight.dtype)
|
||||
else:
|
||||
return rms_norm_out.astype(weight.dtype)
|
||||
|
||||
|
||||
class Glm4MoeDecoderLayer(nn.Layer):
|
||||
""" """
|
||||
|
||||
@@ -317,8 +326,10 @@ class Glm4MoeDecoderLayer(nn.Layer):
|
||||
residual: paddle.Tensor = None,
|
||||
):
|
||||
""" """
|
||||
proxy_rmsnorm = rms_norm_func if fastdeploy.envs.FD_USE_PHI_RMSNORM else None
|
||||
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual_input=residual, forward_meta=forward_meta
|
||||
hidden_states, residual_input=residual, forward_meta=forward_meta, proxy_rmsnorm=proxy_rmsnorm
|
||||
)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
@@ -327,7 +338,7 @@ class Glm4MoeDecoderLayer(nn.Layer):
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual, proxy_rmsnorm=proxy_rmsnorm)
|
||||
|
||||
hidden_states = self.mlp(hidden_states, forward_meta)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user