[Feature] Unify fp8 block_wise quant ops (#5991)

* quant stash

* blockwise_quant

* precommit

* rm tensor.cut

* tp ok

* add swiglu

* rm outdate code

* fix activate ut

* change baseline

* fix baseline error
This commit is contained in:
fxyfxy777
2026-01-15 21:50:37 +08:00
committed by GitHub
parent d38cd8b40b
commit 4c92035f2d
17 changed files with 55 additions and 571 deletions
+4 -11
View File
@@ -236,19 +236,12 @@ def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Ten
dtype=x.dtype,
)
x_padded[:m, :n] = x
x_view = paddle.view(
x_padded,
(-1, block_size[0], x_padded.shape[1] // block_size[1], block_size[1]),
)
from paddle.incubate.nn.functional.fp8 import fp8_quant_blockwise
x_abs = paddle.abs(x_view).astype(paddle.float32)
x_amax = paddle.amax(x_abs, axis=(1, 3), keepdim=True)
x_amax = paddle.clip(x_amax, min=1e-4)
x_scaled = (x_view * (448.0 / x_amax)).astype(paddle.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
paddle.view(x_amax / 448.0, (x_view.shape[0], x_view.shape[2]))
x_q, scale = fp8_quant_blockwise(
x_padded, quant_method="128x128", input_transpose=False, output_scale_transpose=False, using_pow2_scale=False
)
return x_q[:m, :n].contiguous(), scale
def per_token_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]: