fix deepgemm import (#6451)

Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com>
This commit is contained in:
JYChen
2026-02-11 20:10:01 +08:00
committed by GitHub
parent e40fb16912
commit 40c952e7b5
2 changed files with 16 additions and 21 deletions
@@ -62,7 +62,8 @@ def _get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(
):
"""Convert FP32 tensor to TMA-aligned packed UE8M0 format tensor"""
from deep_gemm.utils import align, get_tma_aligned_size
align = deep_gemm.utils.align
get_tma_aligned_size = deep_gemm.utils.get_tma_aligned_size
# Input validation: must be FP32 type 2D or 3D tensor
assert x.dtype == paddle.float and x.dim() in (2, 3)