support deepgemm without bias input (#7559)

This commit is contained in:
JYChen
2026-04-22 19:49:48 +08:00
committed by GitHub
parent b19754c5f4
commit a6a740f759
@@ -51,6 +51,19 @@ if current_platform.is_cuda():
else:
fp8_gemm_nt = None
# Detect whether fp8_gemm_nt accepts a 'bias' keyword argument
_fp8_gemm_nt_has_bias_kwarg = False
if fp8_gemm_nt is not None:
import inspect
try:
_sig = inspect.signature(fp8_gemm_nt)
_fp8_gemm_nt_has_bias_kwarg = "bias" in _sig.parameters
except (ValueError, TypeError):
# pybind11 functions may not expose signatures via inspect;
# fall back to a cheap probe call to determine support.
pass
class BlockWiseFP8Config(QuantConfigBase):
"""
@@ -138,14 +151,22 @@ def deep_gemm_fp8_gemm_nt(
sm_version = get_sm_version()
if sm_version >= 100 and current_platform.is_cuda():
# disable_ue8m0_cast is default False for SM100
fp8_gemm_nt(
(x, x_scale_tensor),
(layer_weight, layer_weight_scale_inv),
linear_out,
bias=bias,
)
if _fp8_gemm_nt_has_bias_kwarg:
fp8_gemm_nt(
(x, x_scale_tensor),
(layer_weight, layer_weight_scale_inv),
linear_out,
bias=bias,
)
else:
fp8_gemm_nt(
(x, x_scale_tensor),
(layer_weight, layer_weight_scale_inv),
linear_out,
)
if bias is not None:
linear_out = paddle.add(linear_out, bias)
else:
# disable_ue8m0_cast is default False for SM100
fp8_gemm_nt(
(x, x_scale_tensor),
(layer_weight, layer_weight_scale_inv),