Files
FastDeploy/fastdeploy/model_executor/layers/quantization/fp8_utils.py
T
2026-02-27 14:34:29 +08:00

133 lines
4.6 KiB
Python

"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddleformers.utils.log import logger
from fastdeploy.platforms import current_platform
from ..utils import get_sm_version
def load_deep_gemm():
"""
Load DeepGemm module according to FastDeploy env switch.
Returns:
Imported deep_gemm module object.
"""
if current_platform.is_cuda():
if get_sm_version() == 100:
# SM100 should use PFCC DeepGemm
paddle.compat.enable_torch_proxy(scope={"deep_gemm"})
try:
import logging
import paddlefleet.ops.deep_gemm as deep_gemm
logging.getLogger().handlers.clear()
logger.info("Detected sm100, use PaddleFleet DeepGEMM")
except:
import deep_gemm as deep_gemm
logger.info("Detected sm100, use PFCC DeepGEMM")
else:
logger.info("use FastDeploy DeepGEMM")
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
else:
deep_gemm = None
return deep_gemm
deep_gemm = load_deep_gemm()
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def _get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(
x: paddle.Tensor,
):
"""Convert FP32 tensor to TMA-aligned packed UE8M0 format tensor"""
align = deep_gemm.utils.align
get_tma_aligned_size = deep_gemm.utils.get_tma_aligned_size
# Input validation: must be FP32 type 2D or 3D tensor
assert x.dtype == paddle.float and x.dim() in (2, 3)
# Step 1: Convert FP32 to UE8M0 format uint8 tensor
# Extract FP32 exponent part through bit shift operation, convert to unsigned 8-bit integer
ue8m0_tensor = (x.view(paddle.int) >> 23).to(paddle.uint8)
# Step 2: Create padding and pack tensor
# Get the last two dimensions of the input tensor
mn, k = x.shape[-2], x.shape[-1]
remove_dim = False
# If it's a 2D tensor, add batch dimension for unified processing
if x.dim() == 2:
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
# Calculate TMA-aligned dimensions (aligned to 4-byte boundary)
aligned_mn = get_tma_aligned_size(mn, 4)
aligned_k = align(k, 4)
# Create padded tensor with alignment and fill with valid data
padded = paddle.zeros((b, aligned_mn, aligned_k), device=x.device, dtype=paddle.uint8)
padded[:, :mn, :k] = ue8m0_tensor
# Pack uint8 data into int32 (pack 4 uint8 into 1 int32)
padded = padded.view(-1).view(dtype=paddle.int).view(b, aligned_mn, aligned_k // 4)
# Step 3: Transpose tensor to meet TMA memory access pattern requirements
# Transpose tensor dimensions for TMA to efficiently access in MN-major order
transposed = paddle.zeros((b, aligned_k // 4, aligned_mn), device=x.device, dtype=paddle.int).mT
transposed[:, :, :] = padded
# Extract original non-padded part
aligned_x = transposed[:, :mn, :]
# If input was 2D tensor, remove batch dimension
return aligned_x.squeeze(0) if remove_dim else aligned_x
def transform_scale_ue8m0(sf, mn, weight_block_size=None):
get_mn_major_tma_aligned_packed_ue8m0_tensor = _get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl
if weight_block_size:
assert weight_block_size == [128, 128]
sf = sf.index_select(-2, paddle.arange(mn, device=sf.device) // 128)
sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
return sf
def quant_weight_ue8m0(weight_dequant, weight_block_size):
assert weight_block_size == [128, 128]
assert weight_dequant.dtype == paddle.bfloat16, f"{weight_dequant.dtype=} {weight_dequant.shape=}"
*batch_dims, n, k = weight_dequant.shape
weight_dequant_flat = weight_dequant.view((-1, k))
out_w_flat, out_s_flat = deep_gemm.utils.math.per_block_cast_to_fp8(weight_dequant_flat, use_ue8m0=True)
out_w = out_w_flat.view((*batch_dims, n, k))
out_s = out_s_flat.view(
(
*batch_dims,
ceil_div(n, weight_block_size[0]),
ceil_div(k, weight_block_size[1]),
)
)
return out_w, out_s