mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[OP] support deepgeem for sm103 (#7073)
* support deepgeem for sm103 * add assert * modify code style * add assert * modify sm version condition * remove assert
This commit is contained in:
@@ -67,7 +67,7 @@ class BlockWiseFP8Config(QuantConfigBase):
|
||||
self.quant_round_type = 1
|
||||
self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM)
|
||||
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||
self.deepgemm_scale_ue8m0 = True if get_sm_version() == 100 else False
|
||||
self.deepgemm_scale_ue8m0 = True if get_sm_version() >= 100 else False
|
||||
|
||||
def name(self) -> str:
|
||||
return "block_wise_fp8"
|
||||
@@ -125,7 +125,8 @@ def deep_gemm_fp8_gemm_nt(
|
||||
layer_output_size: int,
|
||||
bias: paddle.Tensor = None,
|
||||
):
|
||||
if get_sm_version() == 100 and current_platform.is_cuda():
|
||||
sm_version = get_sm_version()
|
||||
if sm_version >= 100 and current_platform.is_cuda():
|
||||
# disable_ue8m0_cast is default False for SM100
|
||||
fp8_gemm_nt(
|
||||
(x, x_scale_tensor),
|
||||
|
||||
Reference in New Issue
Block a user