mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[Optimization][OP]support per_token_group_fp8_quant cuda kernel (#6865)
* support per_token_group_fp8_quant cuda kernel * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * update code --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,9 @@ from paddleformers.utils.log import logger
|
||||
from fastdeploy.model_executor.ops.triton_ops import _per_token_group_quant_fp8
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import per_token_group_fp8_quant
|
||||
|
||||
from ..utils import get_sm_version
|
||||
|
||||
|
||||
@@ -162,6 +165,7 @@ def per_token_group_quant_fp8(
|
||||
"""
|
||||
|
||||
dtype = paddle.float8_e4m3fn # current_platform.fp8_dtype() if dtype is None else dtype
|
||||
assert x.ndim == 2, f"per_token_group_fp8_quant only supports ndim == 2, but got shape {tuple(x.shape)}"
|
||||
assert x.shape[-1] % group_size == 0, (
|
||||
f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}"
|
||||
)
|
||||
@@ -177,30 +181,30 @@ def per_token_group_quant_fp8(
|
||||
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
|
||||
x_s = paddle.empty(shape, dtype=paddle.float32)
|
||||
|
||||
# torch.ops._C.per_token_group_fp8_quant(
|
||||
# x.contiguous(), x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0
|
||||
# )
|
||||
# return x_q, x_s
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
_per_token_group_quant_fp8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=use_ue8m0,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
if current_platform.is_cuda():
|
||||
per_token_group_fp8_quant(x.contiguous(), x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0)
|
||||
|
||||
else:
|
||||
M = x.numel() // group_size
|
||||
# N: int = group_size
|
||||
BLOCK = triton.next_power_of_2
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
_per_token_group_quant_fp8[(M,)](
|
||||
x.contiguous(),
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=use_ue8m0,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
Reference in New Issue
Block a user