mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] support qkrmsnorm use proxy-norm (#6862)
* support qkrmsnorm use paddle.nn.functional.rms_norm * remove flags in fd
This commit is contained in:
@@ -339,8 +339,9 @@ class QKRMSNorm(nn.Layer):
|
||||
self,
|
||||
qkv_out,
|
||||
forward_meta,
|
||||
proxy_rmsnorm=None,
|
||||
) -> paddle.Tensor:
|
||||
if self.qk_norm_fused and forward_meta.step_use_cudagraph:
|
||||
if proxy_rmsnorm is None and self.qk_norm_fused and forward_meta.step_use_cudagraph:
|
||||
qkv_out = qk_rmsnorm_fused(
|
||||
qkv_out,
|
||||
self.q_norm.weight,
|
||||
@@ -354,11 +355,11 @@ class QKRMSNorm(nn.Layer):
|
||||
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)
|
||||
|
||||
q_by_head = q.reshape([*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim])
|
||||
q_by_head = self.q_norm(q_by_head)[0]
|
||||
q_by_head = self.q_norm(q_by_head, proxy_rmsnorm=proxy_rmsnorm)[0]
|
||||
q = q_by_head.reshape(q.shape)
|
||||
|
||||
k_by_head = k.reshape([*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim])
|
||||
k_by_head = self.k_norm(k_by_head)[0]
|
||||
k_by_head = self.k_norm(k_by_head, proxy_rmsnorm=proxy_rmsnorm)[0]
|
||||
k = k_by_head.reshape(k.shape)
|
||||
|
||||
qkv_out = paddle.concat([q, k, v], axis=-1)
|
||||
|
||||
Reference in New Issue
Block a user