mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[OP] support deepgeem for sm103 (#7073)
* support deepgeem for sm103 * add assert * modify code style * add assert * modify sm version condition * remove assert
This commit is contained in:
@@ -67,7 +67,7 @@ class BlockWiseFP8Config(QuantConfigBase):
|
||||
self.quant_round_type = 1
|
||||
self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM)
|
||||
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||
self.deepgemm_scale_ue8m0 = True if get_sm_version() == 100 else False
|
||||
self.deepgemm_scale_ue8m0 = True if get_sm_version() >= 100 else False
|
||||
|
||||
def name(self) -> str:
|
||||
return "block_wise_fp8"
|
||||
@@ -125,7 +125,8 @@ def deep_gemm_fp8_gemm_nt(
|
||||
layer_output_size: int,
|
||||
bias: paddle.Tensor = None,
|
||||
):
|
||||
if get_sm_version() == 100 and current_platform.is_cuda():
|
||||
sm_version = get_sm_version()
|
||||
if sm_version >= 100 and current_platform.is_cuda():
|
||||
# disable_ue8m0_cast is default False for SM100
|
||||
fp8_gemm_nt(
|
||||
(x, x_scale_tensor),
|
||||
|
||||
@@ -65,7 +65,7 @@ def load_deep_gemm():
|
||||
"""
|
||||
|
||||
if current_platform.is_cuda():
|
||||
if get_sm_version() == 100:
|
||||
if get_sm_version() >= 100:
|
||||
# SM100 should use PFCC DeepGemm
|
||||
paddle.compat.enable_torch_proxy(scope={"deep_gemm"})
|
||||
try:
|
||||
@@ -245,7 +245,7 @@ def fused_stack_transpose_quant(expert_weight_list, use_ue8m0=False):
|
||||
# Blackwell (SM100) GPUs require pow2_scale quantization.
|
||||
# Guard with is_cuda() so non-CUDA environments do not call into
|
||||
# paddle.device.cuda.* and cause a crash.
|
||||
use_pow2_scale = current_platform.is_cuda() and get_sm_version() == 100
|
||||
use_pow2_scale = current_platform.is_cuda() and get_sm_version() >= 100
|
||||
|
||||
w, scale = paddlefleet_ops.fuse_stack_transpose_fp8_quant(
|
||||
expert_weight_list,
|
||||
|
||||
Reference in New Issue
Block a user