mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
perf: avoid unnecessary dtype casting in RMSNorm
Added checks before calling `.astype` in `fastdeploy/model_executor/layers/normalization.py`. In PaddlePaddle, calling `.astype` allocates a new tensor even if it's already the target dtype, avoiding these casts skips memory allocations and kernel launches on the hot path.
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
## 2026-04-19 - Unnecessary dtype conversions in hot paths
|
||||
**Learning:** In PaddlePaddle, calling `.astype(dtype)` creates a new tensor and dispatches a kernel even when the tensor is already of the target dtype, which can slow down hot paths like RMSNorm.
|
||||
**Action:** Add explicit conditional checks (`if tensor.dtype != target_dtype`) before calling `.astype` in frequently executed methods to save memory allocations and kernel dispatch overheads.
|
||||
@@ -232,10 +232,12 @@ class RMSNorm(nn.Layer):
|
||||
operations (like linear transformation) on the `residual_input`.
|
||||
"""
|
||||
x_dtype = x.dtype
|
||||
x = x.astype(self.weight.dtype)
|
||||
if x.dtype != self.weight.dtype:
|
||||
x = x.astype(self.weight.dtype)
|
||||
if residual_input is not None:
|
||||
residual_input_dtype = residual_input.dtype
|
||||
residual_input = residual_input.astype(self.weight.dtype)
|
||||
if residual_input.dtype != self.weight.dtype:
|
||||
residual_input = residual_input.astype(self.weight.dtype)
|
||||
|
||||
if residual_input is None:
|
||||
residual_out = x
|
||||
@@ -276,9 +278,13 @@ class RMSNorm(nn.Layer):
|
||||
x = x + residual_input
|
||||
norm_out = proxy_rmsnorm(x, self.weight, self.eps), x
|
||||
|
||||
out = norm_out[0].astype(x_dtype)
|
||||
out = norm_out[0]
|
||||
if out.dtype != x_dtype:
|
||||
out = out.astype(x_dtype)
|
||||
if residual_input is not None:
|
||||
residual_out = norm_out[1].astype(residual_input_dtype)
|
||||
residual_out = norm_out[1]
|
||||
if residual_out.dtype != residual_input_dtype:
|
||||
residual_out = residual_out.astype(residual_input_dtype)
|
||||
|
||||
if self.split_x:
|
||||
assert residual_out is not None
|
||||
|
||||
Reference in New Issue
Block a user