mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[OP] support deepgeem for sm103 (#7073)
* support deepgeem for sm103 * add assert * modify code style * add assert * modify sm version condition * remove assert
This commit is contained in:
@@ -65,7 +65,7 @@ def load_deep_gemm():
|
||||
"""
|
||||
|
||||
if current_platform.is_cuda():
|
||||
if get_sm_version() == 100:
|
||||
if get_sm_version() >= 100:
|
||||
# SM100 should use PFCC DeepGemm
|
||||
paddle.compat.enable_torch_proxy(scope={"deep_gemm"})
|
||||
try:
|
||||
@@ -245,7 +245,7 @@ def fused_stack_transpose_quant(expert_weight_list, use_ue8m0=False):
|
||||
# Blackwell (SM100) GPUs require pow2_scale quantization.
|
||||
# Guard with is_cuda() so non-CUDA environments do not call into
|
||||
# paddle.device.cuda.* and cause a crash.
|
||||
use_pow2_scale = current_platform.is_cuda() and get_sm_version() == 100
|
||||
use_pow2_scale = current_platform.is_cuda() and get_sm_version() >= 100
|
||||
|
||||
w, scale = paddlefleet_ops.fuse_stack_transpose_fp8_quant(
|
||||
expert_weight_list,
|
||||
|
||||
Reference in New Issue
Block a user