mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
edd31e8849
* add * [tests] Add Paddle attention determinism tests and refactor resource manager Add comprehensive determinism tests for Paddle attention layer and refactor resource manager for deterministic mode support. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * add * add * add * add * add more * add more * fixsome * fixsome * fix bugs * fix bugs * only in gpu * add docs * fix comments * fix some * fix some * fix comments * add more * fix potential problem * remove not need * remove not need * remove no need * fix bug * fix bugs * fix comments * fix comments * Update tests/ce/deterministic/test_determinism_verification.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/inter_communicator/test_ipc_signal.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/layers/test_paddle_attention_determinism.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/engine/test_sampling_params_determinism.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/layers/test_paddle_attention_determinism.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/layers/test_paddle_attention_determinism_standalone.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix comments * fix import error * fix a bug * fix bugs * fix bugs * fix coverage * refine codes * refine code * fix comments * fix comments * fix comments * rm not need * fix allreduce large tensor bug * mv log files * mv log files * add files --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
90 lines
3.6 KiB
Python
90 lines
3.6 KiB
Python
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py
|
|
|
|
import unittest
|
|
|
|
import paddle
|
|
|
|
from fastdeploy.model_executor.layers.batch_invariant_ops import (
|
|
set_batch_invariant_mode,
|
|
)
|
|
from fastdeploy.model_executor.layers.batch_invariant_ops.batch_invariant_ops import (
|
|
addmm_batch_invariant,
|
|
)
|
|
|
|
|
|
class TestBatchInvariantForAddmm(unittest.TestCase):
|
|
def setUp(self):
|
|
"""
|
|
Initialize the test environment
|
|
"""
|
|
device = "gpu" if paddle.is_compiled_with_cuda() else "cpu"
|
|
paddle.set_device(device)
|
|
|
|
def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32):
|
|
a = paddle.linspace(-100, 100, B * D, dtype=dtype).reshape(B, D)
|
|
b = paddle.linspace(-100, 100, D * D, dtype=dtype).reshape(D, D)
|
|
|
|
# Method 1: Matrix-vector multiplication and add (batch size 1)
|
|
out1 = paddle.addmm(a[:1].squeeze(0), a[:1], b)
|
|
|
|
# Method 2: Matrix-matrix multiplication and add, then slice (full batch)
|
|
out2 = paddle.addmm(a[:1].squeeze(0), a, b)[:1]
|
|
|
|
# Check if results are identical
|
|
diff = (out1 - out2).abs().max()
|
|
return diff.item() == 0, diff
|
|
|
|
def run_iters(self, iters=10, ass=False):
|
|
for dtype in [paddle.float32, paddle.bfloat16]:
|
|
is_deterministic = True
|
|
difflist = []
|
|
for i in range(iters):
|
|
isd, df = self.test_batch_invariance(dtype=dtype)
|
|
is_deterministic = is_deterministic and isd
|
|
difflist.append(df)
|
|
print(
|
|
f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations"
|
|
)
|
|
if ass:
|
|
assert max(difflist) == 0
|
|
|
|
def test_alpha_zero(self):
|
|
"""alpha == 0: result should be beta * input broadcast to [M, N]"""
|
|
M, N, K = 32, 64, 128
|
|
for dtype in [paddle.float32, paddle.bfloat16]:
|
|
x = paddle.randn([M, K], dtype=dtype)
|
|
y = paddle.randn([K, N], dtype=dtype)
|
|
bias = paddle.randn([N], dtype=dtype)
|
|
|
|
for beta in [0.0, 1.0, 2.5]:
|
|
out = addmm_batch_invariant(bias, x, y, beta=beta, alpha=0.0)
|
|
expected = (beta * bias).expand([M, N])
|
|
# shape must be [M, N]
|
|
assert out.shape == [M, N], f"Expected shape [{M}, {N}], got {out.shape}"
|
|
# cast to float32 for comparison (bfloat16 not supported by isclose)
|
|
diff = (out.cast(paddle.float32) - expected.cast(paddle.float32)).abs().max()
|
|
assert diff.item() == 0, f"dtype={dtype}, beta={beta}, max diff={diff.item()}"
|
|
|
|
def test_case(self):
|
|
# Test with standard Paddle (likely to show differences)
|
|
print("Standard Paddle:")
|
|
with set_batch_invariant_mode(False):
|
|
self.run_iters(ass=False)
|
|
# Test with batch-invariant operations
|
|
print("\nBatch-Invariant Mode:")
|
|
with set_batch_invariant_mode(True):
|
|
self.run_iters(ass=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|
|
"""
|
|
Standard Paddle:
|
|
Batch Deterministic: False run-to-run max/min/diff 10.7294921875/10.7294921875/0.0 for paddle.float32 in 10 iterations
|
|
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
|
|
|
|
Batch-Invariant Mode:
|
|
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations
|
|
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
|
|
"""
|