[Others] support import deepgemm/deepep from fleet ops (#6351)

* update paddleformers to v1.0

* only change import fleetpath
This commit is contained in:
JYChen
2026-02-09 11:53:13 +08:00
committed by GitHub
parent 74762b0fb2
commit 9bcd863902
4 changed files with 46 additions and 33 deletions
@@ -42,14 +42,10 @@ from ..utils import get_sm_version, get_tensor, per_block_cast_to_fp8
from .quant_base import QuantConfigBase, QuantMethodBase
if current_platform.is_cuda():
if get_sm_version() == 100:
# SM100 should use PFCC DeepGemm
paddle.compat.enable_torch_proxy(scope={"deep_gemm"})
from deep_gemm import fp8_gemm_nt
else:
from fastdeploy.model_executor.ops.gpu.deep_gemm import (
gemm_fp8_fp8_bf16_nt as fp8_gemm_nt,
)
try:
fp8_gemm_nt = fastdeploy.model_executor.layers.quantization.fp8_utils.deep_gemm.fp8_gemm_nt
except:
fp8_gemm_nt = fastdeploy.model_executor.layers.quantization.fp8_utils.deep_gemm.gemm_fp8_fp8_bf16_nt
else:
fp8_gemm_nt = None
@@ -21,16 +21,36 @@ from fastdeploy.platforms import current_platform
from ..utils import get_sm_version
if current_platform.is_cuda():
if get_sm_version() == 100:
# SM100 should use PFCC DeepGemm
logger.info("Detected sm100, use PFCC DeepGEMM")
paddle.compat.enable_torch_proxy(scope={"deep_gemm"})
import deep_gemm
def load_deep_gemm():
"""
Load DeepGemm module according to FastDeploy env switch.
Returns:
Imported deep_gemm module object.
"""
if current_platform.is_cuda():
if get_sm_version() == 100:
# SM100 should use PFCC DeepGemm
paddle.compat.enable_torch_proxy(scope={"deep_gemm"})
try:
from paddlefleet.ops import deep_gemm
logger.info("Detected sm100, use PaddleFleet DeepGEMM")
except:
import deep_gemm
logger.info("Detected sm100, use PFCC DeepGEMM")
else:
logger.info("use FastDeploy DeepGEMM")
from fastdeploy.model_executor.ops.gpu import deep_gemm
else:
from fastdeploy.model_executor.ops.gpu import deep_gemm
else:
deep_gemm = None
deep_gemm = None
return deep_gemm
deep_gemm = load_deep_gemm()
def ceil_div(x: int, y: int) -> int: