Revert "[Feature] Support Ernie FP8 on sm100 (#5593)" (#6275)

This reverts commit eb80724b71.
This commit is contained in:
JYChen
2026-01-30 11:22:01 +08:00
committed by GitHub
parent 292bab7e6d
commit 6c685c9474
11 changed files with 197 additions and 725 deletions
@@ -25,33 +25,16 @@ from fastdeploy.model_executor.layers.linear import (
QKVParallelLinear,
)
from fastdeploy.model_executor.layers.moe import FusedMoE
from fastdeploy.model_executor.layers.quantization.fp8_utils import (
quant_weight_ue8m0,
transform_scale_ue8m0,
)
from fastdeploy.model_executor.utils import (
TensorTracker,
process_weight_transpose,
set_weight_attrs,
)
from fastdeploy.platforms import current_platform
from fastdeploy.utils import register_custom_python_op
from ..utils import get_sm_version, get_tensor, per_block_cast_to_fp8
from ..utils import get_tensor, per_block_cast_to_fp8
from .quant_base import QuantConfigBase, QuantMethodBase
if current_platform.is_cuda():
if get_sm_version() == 100:
# SM100 should use PFCC DeepGemm
paddle.compat.enable_torch_proxy(scope={"deep_gemm"})
from deep_gemm import fp8_gemm_nt
else:
from fastdeploy.model_executor.ops.gpu.deep_gemm import (
gemm_fp8_fp8_bf16_nt as fp8_gemm_nt,
)
else:
fp8_gemm_nt = None
class BlockWiseFP8Config(QuantConfigBase):
"""
@@ -68,7 +51,6 @@ class BlockWiseFP8Config(QuantConfigBase):
self.quant_round_type = 1
self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM)
self.is_checkpoint_bf16 = is_checkpoint_bf16
self.deepgemm_scale_ue8m0 = True if get_sm_version() == 100 else False
def name(self) -> str:
return "block_wise_fp8"
@@ -99,7 +81,7 @@ class BlockWiseFP8Config(QuantConfigBase):
return BlockWiseFP8LinearMethod(self)
def deep_gemm_fp8_gemm_nt_infer_meta(
def deep_gemm_fp8_fp8_bf16_nt_infer_meta(
x_meta: "paddle.static.MetaTensor",
x_scale_tensor_meta: "paddle.static.MetaTensor",
layer_weight_meta: "paddle.static.MetaTensor",
@@ -111,13 +93,13 @@ def deep_gemm_fp8_gemm_nt_infer_meta(
@register_custom_python_op(
name="deep_gemm_fp8_gemm_nt",
infer_meta=deep_gemm_fp8_gemm_nt_infer_meta,
name="deep_gemm_fp8_fp8_bf16_nt",
infer_meta=deep_gemm_fp8_fp8_bf16_nt_infer_meta,
input_names=["x", "x_scale_tensor", "layer_weight", "layer_weight_scale_inv", "linear_out_empty"],
output_names=["linear_out"],
inplace_map={},
)
def deep_gemm_fp8_gemm_nt(
def deep_gemm_fp8_fp8_bf16_nt(
x: paddle.Tensor,
x_scale_tensor: paddle.Tensor,
layer_weight: paddle.Tensor,
@@ -125,12 +107,14 @@ def deep_gemm_fp8_gemm_nt(
linear_out: paddle.Tensor,
layer_output_size: int,
):
# disable_ue8m0_cast is default False for SM100
fp8_gemm_nt(
from fastdeploy.model_executor.ops.gpu import deep_gemm
deep_gemm.gemm_fp8_fp8_bf16_nt(
(x, x_scale_tensor),
(layer_weight, layer_weight_scale_inv),
linear_out,
)
return linear_out
@@ -225,16 +209,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
def process_weights_after_loading(self, layer) -> None:
def _process_quantize():
weight_tensor = layer.weight.transpose([1, 0])
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
if not self.quant_config.deepgemm_scale_ue8m0:
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
else:
quanted_weight_tensor, weight_block_scale_tensor = quant_weight_ue8m0(weight_tensor, [128, 128])
weight_block_scale_tensor = transform_scale_ue8m0(
weight_block_scale_tensor,
mn=quanted_weight_tensor.shape[-2],
weight_block_size=[128, 128],
)
if hasattr(layer.weight, "tensor_track"):
layer.weight.tensor_track = None
layer.weight.value().get_tensor()._clear()
@@ -248,12 +224,13 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
)
layer.weight_scale_inv = layer.create_parameter(
shape=weight_block_scale_tensor.shape,
dtype=weight_block_scale_tensor.dtype,
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight.copy_(quanted_weight_tensor, False)
layer.weight_scale_inv.data = weight_block_scale_tensor
layer.weight_scale_inv.copy_(weight_block_scale_tensor, False)
if self.quant_config.is_checkpoint_bf16:
if self.model_format == "torch":
@@ -286,24 +263,13 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
layer.weight_scale_inv.set_value(weight_scale)
def apply(self, layer, x):
linear_out = paddle.empty((x.shape[0], layer.output_size), dtype=paddle.bfloat16)
if x.shape[0] == 0:
return linear_out
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
x,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
output_scale_transpose=True,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
x, using_pow2_scale=False, output_scale_transpose=True
)
x_scale_tensor = x_scale_tensor.T[: x.shape[0], ...]
deep_gemm_fp8_gemm_nt(
x,
x_scale_tensor,
layer.weight,
layer.weight_scale_inv,
linear_out,
layer_output_size=layer.output_size,
x_scale_tensor = x_scale_tensor.T
linear_out = paddle.empty((x.shape[0], layer.output_size), dtype=paddle.bfloat16)
linear_out = deep_gemm_fp8_fp8_bf16_nt(
x, x_scale_tensor, layer.weight, layer.weight_scale_inv, linear_out, layer.output_size
)
if layer.with_bias:
linear_out = paddle.add(linear_out, layer.bias)