mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
add scale_wrapper for per_block_cast_to_fp8 (#6183)
This commit is contained in:
@@ -220,6 +220,35 @@ def group_wise_int4_weight_quantize(weight: paddle.Tensor, group_size: int = 128
|
||||
return quant_weight.astype(paddle.int8), weight_scale
|
||||
|
||||
|
||||
def scale_wrapper(x_amax: paddle.Tensor, eps: float = 0.0) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle implementation of CUDA ScaleWrapper logic.
|
||||
Args:
|
||||
x_amax (paddle.Tensor): amax tensor (float32 recommended)
|
||||
eps (float): epsilon to avoid division by zero
|
||||
Returns:
|
||||
paddle.Tensor: scale tensor, same shape as x_amax
|
||||
"""
|
||||
fp8_max = 448.0
|
||||
float_max = paddle.finfo(paddle.float32).max
|
||||
amax_mod = paddle.maximum(
|
||||
x_amax,
|
||||
paddle.full_like(x_amax, eps),
|
||||
)
|
||||
scale = fp8_max / amax_mod
|
||||
scale = paddle.where(
|
||||
amax_mod == 0,
|
||||
paddle.ones_like(scale),
|
||||
scale,
|
||||
)
|
||||
scale = paddle.where(
|
||||
paddle.isinf(scale),
|
||||
paddle.full_like(scale, float_max),
|
||||
scale,
|
||||
)
|
||||
return scale
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Only used in deep_gemm block wise quant weight.
|
||||
@@ -244,10 +273,10 @@ def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Ten
|
||||
|
||||
x_abs = paddle.abs(x_view).astype(paddle.float32)
|
||||
x_amax = paddle.amax(x_abs, axis=(1, 3), keepdim=True)
|
||||
x_amax = paddle.clip(x_amax, min=1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).astype(paddle.float8_e4m3fn)
|
||||
scale = scale_wrapper(x_amax)
|
||||
x_scaled = (x_view * scale).astype(paddle.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
|
||||
paddle.view(1.0 / (448.0 / x_amax), (x_view.shape[0], x_view.shape[2]))
|
||||
paddle.view(1.0 / scale, (x_view.shape[0], x_view.shape[2]))
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user