[RL] support qkrmsnorm use proxy-norm (#6862)

* support qkrmsnorm use paddle.nn.functional.rms_norm

* remove flags in fd
This commit is contained in:
JYChen
2026-03-19 14:27:26 +08:00
committed by GitHub
parent 1a05744c4e
commit f95d8ca7df
@@ -339,8 +339,9 @@ class QKRMSNorm(nn.Layer):
self,
qkv_out,
forward_meta,
proxy_rmsnorm=None,
) -> paddle.Tensor:
if self.qk_norm_fused and forward_meta.step_use_cudagraph:
if proxy_rmsnorm is None and self.qk_norm_fused and forward_meta.step_use_cudagraph:
qkv_out = qk_rmsnorm_fused(
qkv_out,
self.q_norm.weight,
@@ -354,11 +355,11 @@ class QKRMSNorm(nn.Layer):
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)
q_by_head = q.reshape([*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim])
q_by_head = self.q_norm(q_by_head)[0]
q_by_head = self.q_norm(q_by_head, proxy_rmsnorm=proxy_rmsnorm)[0]
q = q_by_head.reshape(q.shape)
k_by_head = k.reshape([*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim])
k_by_head = self.k_norm(k_by_head)[0]
k_by_head = self.k_norm(k_by_head, proxy_rmsnorm=proxy_rmsnorm)[0]
k = k_by_head.reshape(k.shape)
qkv_out = paddle.concat([q, k, v], axis=-1)