[Optimization] Support FA2/FA3/FA4 with attn_mask_q (#6354)

* support FA4 sm100

* flash attn backend support mask

* flash attn backend run flashmask correct

* add test for flash_attn_backend and flash_attn_func

* check

* add test for fa4

* requirements.txt add fa4 whl

* check test on sm100

* fix CI conflict

* add enable_torch_proxy for flash_mask

* lazy import fa4

* check

* fix tests import

* check test_load_mpt import
This commit is contained in:
chen
2026-02-05 14:39:00 +08:00
committed by GitHub
parent 72edd394d9
commit 29a313a402
22 changed files with 999 additions and 101 deletions
+2 -2
View File
@@ -26,9 +26,9 @@ import paddle.distributed.fleet as fleet
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.models.ernie4_5_mtp import Ernie4_5_MTPForCausalLM
ROOT = Path(__file__).resolve().parents[2]
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from tests.utils import get_default_test_fd_config
from utils import get_default_test_fd_config
strategy = fleet.DistributedStrategy()
fleet.init(strategy=strategy)