mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Unify fp8 block_wise quant ops (#5991)
* quant stash * blockwise_quant * precommit * rm tensor.cut * tp ok * add swiglu * rm outdate code * fix activate ut * change baseline * fix baseline error
This commit is contained in:
@@ -236,19 +236,12 @@ def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Ten
|
||||
dtype=x.dtype,
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = paddle.view(
|
||||
x_padded,
|
||||
(-1, block_size[0], x_padded.shape[1] // block_size[1], block_size[1]),
|
||||
)
|
||||
from paddle.incubate.nn.functional.fp8 import fp8_quant_blockwise
|
||||
|
||||
x_abs = paddle.abs(x_view).astype(paddle.float32)
|
||||
x_amax = paddle.amax(x_abs, axis=(1, 3), keepdim=True)
|
||||
x_amax = paddle.clip(x_amax, min=1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).astype(paddle.float8_e4m3fn)
|
||||
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
|
||||
paddle.view(x_amax / 448.0, (x_view.shape[0], x_view.shape[2]))
|
||||
x_q, scale = fp8_quant_blockwise(
|
||||
x_padded, quant_method="128x128", input_transpose=False, output_scale_transpose=False, using_pow2_scale=False
|
||||
)
|
||||
return x_q[:m, :n].contiguous(), scale
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user