Files
FastDeploy/tests/layers/test_flash_attention_versions_determinism.py
T
gongweibao edd31e8849 [Feature] Add Deterministic Inference Support (#6476)
* add

* [tests] Add Paddle attention determinism tests and refactor resource manager

Add comprehensive determinism tests for Paddle attention layer and refactor
resource manager for deterministic mode support.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* add

* add

* add

* add

* add more

* add more

* fixsome

* fixsome

* fix bugs

* fix bugs

* only in gpu

* add docs

* fix comments

* fix some

* fix some

* fix comments

* add more

* fix potential problem

* remove not need

* remove not need

* remove no need

* fix bug

* fix bugs

* fix comments

* fix comments

* Update tests/ce/deterministic/test_determinism_verification.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/inter_communicator/test_ipc_signal.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/layers/test_paddle_attention_determinism.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/engine/test_sampling_params_determinism.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/layers/test_paddle_attention_determinism.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/layers/test_paddle_attention_determinism_standalone.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix comments

* fix import error

* fix a bug

* fix bugs

* fix bugs

* fix coverage

* refine codes

* refine code

* fix comments

* fix comments

* fix comments

* rm not need

* fix allreduce large tensor bug

* mv log files

* mv log files

* add files

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-26 19:31:51 -08:00

178 lines
6.9 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Flash Attention V2 / V3 determinism tests.
Verify bitwise determinism of flash-backend SDPA when explicitly
selecting FA version via FLAGS_flash_attn_version (2 or 3).
"""
import unittest
import pytest
pytestmark = pytest.mark.gpu
import paddle
import paddle.nn.functional as F
# --------------- constants ---------------
BATCH_SIZE = 2
NUM_HEADS = 32
HEAD_DIM = 64
SEQ_LEN = 2048
NUM_RUNS = 5
# --------------- helpers ---------------
def _make_qkv(batch_size, num_heads, seq_len, head_dim, dtype="float16", seed=42):
"""Create deterministic q/k/v tensors."""
paddle.seed(seed)
shape = [batch_size, num_heads, seq_len, head_dim]
return (
paddle.randn(shape, dtype=dtype),
paddle.randn(shape, dtype=dtype),
paddle.randn(shape, dtype=dtype),
)
def _assert_deterministic(test_case, func, num_runs=NUM_RUNS):
"""Run *func* multiple times and assert all results are bitwise equal."""
results = [func().clone() for _ in range(num_runs)]
for i in range(1, num_runs):
test_case.assertTrue(
paddle.equal(results[0], results[i]).all().item(),
f"Run 0 vs Run {i} differ",
)
# --------------- test class ---------------
class TestFlashAttentionVersionsDeterminism(unittest.TestCase):
"""Test determinism when switching between FA2 and FA3."""
FA_VERSIONS = [2, 3]
def setUp(self):
if not paddle.is_compiled_with_cuda():
self.skipTest("Flash Attention requires CUDA")
paddle.set_device("gpu")
# Save/restore flag to avoid cross-test pollution
self._saved_version = paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"]
def tearDown(self):
paddle.set_flags({"FLAGS_flash_attn_version": self._saved_version})
def _skip_if_fa3_unsupported(self):
prop = paddle.device.cuda.get_device_properties()
sm = prop.major * 10 + prop.minor
if sm < 89 or sm >= 100:
self.skipTest(f"FA3 requires SM89-SM99, current SM{sm}")
def _set_fa_version(self, version):
if version == 3:
self._skip_if_fa3_unsupported()
paddle.set_flags({"FLAGS_flash_attn_version": version})
def _flash_sdpa(self, q, k, v, **kwargs):
"""Thin wrapper: synchronize then call flash-backend SDPA."""
paddle.device.synchronize()
return F.scaled_dot_product_attention(q, k, v, backend="flash", **kwargs)
# ==================== tests ====================
def test_determinism(self):
"""Multi-run determinism for FA2/FA3, causal and non-causal."""
for version in self.FA_VERSIONS:
for is_causal in [False, True]:
with self.subTest(version=version, is_causal=is_causal):
self._set_fa_version(version)
q, k, v = _make_qkv(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM)
_assert_deterministic(
self,
lambda: self._flash_sdpa(q, k, v, is_causal=is_causal, enable_gqa=False),
)
def test_batch_invariance(self):
"""First-sample result should be identical across batch sizes."""
for version in self.FA_VERSIONS:
with self.subTest(version=version):
self._set_fa_version(version)
max_bs = 8
q, k, v = _make_qkv(max_bs, NUM_HEADS, SEQ_LEN, HEAD_DIM)
ref = self._flash_sdpa(q[:1], k[:1], v[:1], is_causal=False, enable_gqa=False)
for bs in [2, 4, 8]:
result = self._flash_sdpa(q[:bs], k[:bs], v[:bs], is_causal=False, enable_gqa=False)
self.assertTrue(
paddle.equal(ref, result[0:1]).all().item(),
f"FA{version} batch invariance failed at bs={bs}",
)
def test_seq_length_determinism(self):
"""Determinism across various sequence lengths (including boundaries)."""
seq_lengths = [1, 2, 4, 8, 16, 64, 128, 256, 512, 1024, 2048, 4096]
for version in self.FA_VERSIONS:
for seq_len in seq_lengths:
with self.subTest(version=version, seq_len=seq_len):
self._set_fa_version(version)
q, k, v = _make_qkv(BATCH_SIZE, NUM_HEADS, seq_len, HEAD_DIM)
_assert_deterministic(
self,
lambda: self._flash_sdpa(q, k, v, is_causal=False, enable_gqa=False),
num_runs=2,
)
def test_dtype_determinism(self):
"""Determinism across float16 and float32."""
for version in self.FA_VERSIONS:
for dtype in ["float16", "float32"]:
with self.subTest(version=version, dtype=dtype):
self._set_fa_version(version)
q, k, v = _make_qkv(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, dtype=dtype)
_assert_deterministic(
self,
lambda: self._flash_sdpa(q, k, v, is_causal=False, enable_gqa=False),
num_runs=3,
)
def test_head_config_determinism(self):
"""Determinism across different head configurations."""
for version in self.FA_VERSIONS:
for num_heads, head_dim in [(1, 64), (7, 64), (32, 64)]:
with self.subTest(version=version, num_heads=num_heads, head_dim=head_dim):
self._set_fa_version(version)
q, k, v = _make_qkv(BATCH_SIZE, num_heads, SEQ_LEN, head_dim)
_assert_deterministic(
self,
lambda: self._flash_sdpa(q, k, v, is_causal=False, enable_gqa=False),
num_runs=2,
)
def test_gqa_determinism(self):
"""Determinism with GQA enabled."""
for version in self.FA_VERSIONS:
with self.subTest(version=version):
self._set_fa_version(version)
q, k, v = _make_qkv(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM)
_assert_deterministic(
self,
lambda: self._flash_sdpa(q, k, v, is_causal=False, enable_gqa=True),
num_runs=3,
)
if __name__ == "__main__":
unittest.main(verbosity=2)