add scale_wrapper for per_block_cast_to_fp8 (#6183)

This commit is contained in:
fxyfxy777
2026-01-23 16:37:20 +08:00
committed by GitHub
parent b3a48529ab
commit 79f42209bf
+32 -3
View File
@@ -220,6 +220,35 @@ def group_wise_int4_weight_quantize(weight: paddle.Tensor, group_size: int = 128
return quant_weight.astype(paddle.int8), weight_scale
def scale_wrapper(x_amax: paddle.Tensor, eps: float = 0.0) -> paddle.Tensor:
"""
Paddle implementation of CUDA ScaleWrapper logic.
Args:
x_amax (paddle.Tensor): amax tensor (float32 recommended)
eps (float): epsilon to avoid division by zero
Returns:
paddle.Tensor: scale tensor, same shape as x_amax
"""
fp8_max = 448.0
float_max = paddle.finfo(paddle.float32).max
amax_mod = paddle.maximum(
x_amax,
paddle.full_like(x_amax, eps),
)
scale = fp8_max / amax_mod
scale = paddle.where(
amax_mod == 0,
paddle.ones_like(scale),
scale,
)
scale = paddle.where(
paddle.isinf(scale),
paddle.full_like(scale, float_max),
scale,
)
return scale
def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]:
"""
Only used in deep_gemm block wise quant weight.
@@ -244,10 +273,10 @@ def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Ten
x_abs = paddle.abs(x_view).astype(paddle.float32)
x_amax = paddle.amax(x_abs, axis=(1, 3), keepdim=True)
x_amax = paddle.clip(x_amax, min=1e-4)
x_scaled = (x_view * (448.0 / x_amax)).astype(paddle.float8_e4m3fn)
scale = scale_wrapper(x_amax)
x_scaled = (x_view * scale).astype(paddle.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
paddle.view(1.0 / (448.0 / x_amax), (x_view.shape[0], x_view.shape[2]))
paddle.view(1.0 / scale, (x_view.shape[0], x_view.shape[2]))
)