mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Models][OP][Optimization] Support DeepSeek-v3.2 model, integrate DSA & Indexer architecture with FlashMLA/DeepGEMM (#6689)
* Support DeepSeek-v3.2 model, integrate DSA & Indexer architecture with FlashMLA/DeepGEMM
This commit is contained in:
@@ -15,8 +15,10 @@
|
||||
"""
|
||||
|
||||
import paddle
|
||||
import triton
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.model_executor.ops.triton_ops import _per_token_group_quant_fp8
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from ..utils import get_sm_version
|
||||
@@ -130,3 +132,75 @@ def quant_weight_ue8m0(weight_dequant, weight_block_size):
|
||||
)
|
||||
|
||||
return out_w, out_s
|
||||
|
||||
|
||||
def per_token_group_quant_fp8(
|
||||
x: paddle.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: paddle.dtype | None = None,
|
||||
column_major_scales: bool = False,
|
||||
tma_aligned_scales: bool = False,
|
||||
out_q: paddle.Tensor | None = None,
|
||||
use_ue8m0: bool | None = None,
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
It converts the tensor values into signed float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
Args:
|
||||
x: The input tensor with ndim >= 2.
|
||||
group_size: The group size used for quantization.
|
||||
eps: The minimum to avoid dividing zero.
|
||||
dtype: The dtype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
is supported for now.
|
||||
column_major_scales: Outputs scales in column major.
|
||||
tma_aligned_scales: Outputs scales in TMA-aligned layout.
|
||||
out_q: Optional output tensor. If not provided, function will create.
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor.
|
||||
"""
|
||||
|
||||
dtype = paddle.float8_e4m3fn # current_platform.fp8_dtype() if dtype is None else dtype
|
||||
assert x.shape[-1] % group_size == 0, (
|
||||
f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}"
|
||||
)
|
||||
assert x.stride(-1) == 1, "`x` groups must be contiguous"
|
||||
|
||||
fp8_min, fp8_max = -224.0, 224.0 # get_fp8_min_max()
|
||||
|
||||
assert out_q is None or out_q.shape == x.shape
|
||||
x_q = out_q
|
||||
if x_q is None:
|
||||
x_q = paddle.empty(x.shape, dtype=dtype)
|
||||
|
||||
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
|
||||
x_s = paddle.empty(shape, dtype=paddle.float32)
|
||||
|
||||
# torch.ops._C.per_token_group_fp8_quant(
|
||||
# x.contiguous(), x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0
|
||||
# )
|
||||
# return x_q, x_s
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
_per_token_group_quant_fp8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=use_ue8m0,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
@@ -20,7 +20,6 @@ from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle.nn.quant import weight_quantize
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
@@ -181,7 +180,7 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
and check_machete_supports_shape(layer.weight_shape[0], layer.weight_shape[1])
|
||||
):
|
||||
self.group_size = query_machete_supported_group_size(layer.weight_shape[0])
|
||||
logger.info(f"Using Machete kernel for WeightOnlyLinearMethod, group size: {self.group_size}")
|
||||
# logger.info(f"Using Machete kernel for WeightOnlyLinearMethod, group size: {self.group_size}")
|
||||
return MacheteWeightOnlyLinearMethod(self)
|
||||
return GPUWeightOnlyLinearMethod(self)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user