fix fa4 test (#6408)

This commit is contained in:
chen
2026-02-10 10:57:21 +08:00
committed by GitHub
parent 3ce842b55b
commit a8ffcaa068
3 changed files with 7 additions and 2 deletions
@@ -118,9 +118,9 @@ def flash_attn_func(
head_dim: int = 128,
version: Optional[int] = None,
):
if FLASH_ATTN_VERSION is None:
init_flash_attn_version()
if version is None:
if FLASH_ATTN_VERSION is None:
init_flash_attn_version()
version = FLASH_ATTN_VERSION
if version == 4:
assert (