[Feature][BugFix][OP] Enhance Deterministic Inference Mode with Kernel-level Fixes and Batch-invariant BMM (#6610)

* add fa deter

* add ut

* add long sentence

* fix basic

* fix bugs

* fix adn

* fix first

* fix single

* fix single

* fix single test

* refine

* add more test

* refine comments

* add comments of bmm

* fix ci

* remove probe

* add

* remove not need

* refine tests

* fix comments and refine code

* refine code

* refine test

* refine test

* mv 4cards tests

* fix tests

* add

* fix comments

* fix cover

* fix cover

---------

Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-09 10:27:53 +08:00
committed by GitHub
parent 3a85ecf3bc
commit 30f9f33f34
23 changed files with 3563 additions and 153 deletions
@@ -0,0 +1,799 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test suite for the c16 warp1_4 decoder attention kernel determinism.
Background:
The c16 warp1_4 decoder kernel (multiquery_attention_c16_impl.cuh) had two
bugs that caused nondeterministic outputs under FD_DETERMINISTIC_MODE:
1. The warp1_4 dispatcher (lines 1164-1175) lacked the force_no_partition
check, so it still launched the multi-chunk (split) kernel even when
deterministic mode requested the single-chunk (nosplit) path.
2. The nosplit kernel template read runtime num_chunks_this_seq instead of
compile-time partition_kv, causing out-of-bounds nullptr writes to the
partial output buffer (lines 545, 748, 772, 812).
How the c16 warp1_4 path is triggered:
- dim_head=128, blocksize=64, cache_quant_type="none" -> selects c16 kernel
- decoder_block_shape_q=16 -> NUM_WARP_Q=1, i.e. warp1_4 configuration
- seq_lens_encoder=0, seq_lens_decoder>0 -> decoder mode
- FD_DETERMINISTIC_MODE=1 -> forces nosplit path
- Small decoder_max_partition_size (e.g. 64) with long prefill (e.g. 256)
ensures num_chunks > 1, which is the scenario that exposed the bugs.
Test items:
1. test_short_kv_nosplit
- Short KV (num_chunks=1): basic nosplit path with partition_kv=false.
- Verifies both correctness (vs naive reference) and determinism (10 runs).
2. test_long_kv_multi_chunk
- Long KV (num_chunks=4, prefill=256, partition=64): the exact scenario
the fix addresses. partition_kv=true template but grid_chunks=1.
- Verifies correctness and determinism.
3. test_multi_batch
- Multiple batches (batch_size=4) with multi-chunk decoder.
- Ensures the fix works across batch elements, not just single-batch.
4. test_float16
- Float16 dtype with multi-chunk decoder.
- Ensures the fix is dtype-agnostic (not only bfloat16).
5. test_unaligned_seq_len
- prefill_seq_len not divisible by blocksize (100 % 64 != 0).
- Catches off-by-one bugs in block/chunk boundary calculations.
6. test_mha_no_gqa
- MHA config: q_num_head == kv_num_head (no GQA grouping).
- Ensures the fix is not GQA-specific.
7. test_nosplit_vs_split_consistency
- Cross-path check: deterministic nosplit vs non-deterministic split.
- Both paths should produce numerically close results (rtol/atol=1e-2)
and both should match the naive attention reference.
8. test_partition_boundary
- Edge case: prefill_seq_len equals partition_size (boundary condition).
- Tests chunk calculation when num_chunks is exactly an integer.
9. test_empty_kv
- Edge case: decoder-only mode (no prefill, empty KV cache).
- Tests scenario with no encoder prefill phase.
Run:
python -m pytest tests/deterministic/test_c16_warp1_4_determinism.py -v
"""
import copy
import os
import unittest
import numpy as np
import paddle
os.environ["FD_DETERMINISTIC_MODE"] = "1"
from fastdeploy.model_executor.layers.attention.ops import ( # noqa: E402
append_attention,
get_block_shape_and_split_kv_block,
)
SEED = 42
ENCODER_BLOCK_SHAPE_Q = 64
DECODER_BLOCK_SHAPE_Q = 16
def _assert_deterministic_and_correct(results, ref):
"""
Helper to verify determinism and correctness.
Args:
results: List of output arrays from repeated runs
ref: Reference output array from naive attention
"""
# Verify determinism: all runs should produce identical output
for i in range(1, len(results)):
np.testing.assert_array_equal(
results[0],
results[i],
err_msg=f"Determinism failure: run 0 vs run {i}",
)
# Verify correctness: output should match naive reference
np.testing.assert_allclose(results[0], ref, rtol=1e-2, atol=1e-2)
def make_rope_emb(max_seq_len, dim_head, base=10000):
pos = paddle.arange(max_seq_len).reshape((1, -1))
inv_freq = base ** (-paddle.arange(0, dim_head, 2, dtype="float32") / dim_head)
freqs = paddle.einsum("ij,k->ijk", pos.cast("float32"), inv_freq)
emb = freqs.reshape((1, max_seq_len, dim_head // 2)).unsqueeze(2)
rope_emb = paddle.zeros((2, 1, max_seq_len, 1, dim_head // 2), dtype="float32")
rope_emb[0] = paddle.cos(emb)
rope_emb[1] = paddle.sin(emb)
return rope_emb
def apply_rope(x, rope_emb, positions):
"""
Apply rotary position embedding (non-neox interleaved style).
x: (batch, heads, seq_len, dim_head)
rope_emb: (2, 1, max_seq_len, 1, dim_head//2)
positions: list of int, one position per seq index
"""
dim_head = x.shape[-1]
half = dim_head // 2
x_f32 = x.cast("float32")
out = x_f32.clone()
for seq_idx, pos in enumerate(positions):
cos_p = rope_emb[0, 0, pos, 0, :] # (dim_head//2,)
sin_p = rope_emb[1, 0, pos, 0, :]
x_slice = x_f32[:, :, seq_idx, :] # (batch, heads, dim_head)
x_pairs = x_slice.reshape(list(x_slice.shape[:-1]) + [half, 2])
x0 = x_pairs[..., 0] # (batch, heads, half)
x1 = x_pairs[..., 1]
out0 = x0 * cos_p - x1 * sin_p
out1 = x0 * sin_p + x1 * cos_p
out[:, :, seq_idx, :] = paddle.stack([out0, out1], axis=-1).reshape(x_slice.shape)
return out.cast(x.dtype)
def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(max_seq_len - seq_lens_this_time, dtype="int32")
cum_offsets = paddle.zeros(shape=(bsz + 1,), dtype="int32")
cum_offsets[1:] = cum_offsets_now
token_num = int(paddle.sum(seq_lens_this_time))
batch_id_per_token = paddle.zeros(shape=(token_num,), dtype="int32")
cu_seqlens_q = paddle.zeros(shape=(bsz + 1,), dtype="int32")
for i in range(bsz):
sn = int(seq_lens_this_time[i])
co = int(cum_offsets[i])
for j in range(sn):
batch_id_per_token[i * max_seq_len - co + j] = i
cu_seqlens_q[i + 1] = (i + 1) * max_seq_len - int(cum_offsets[i + 1])
return batch_id_per_token, cu_seqlens_q
def naive_attention_impl(query, key, value, cache_k, cache_v, scale):
"""Reference: Q @ K^T * scale -> softmax -> @ V, with GQA expansion."""
batch, heads, seq_len, head_dim = query.shape
kv_head = key.shape[1]
g = heads // kv_head
key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
key = paddle.tile(key, [1, 1, g, 1, 1]).reshape([batch, heads, seq_len, head_dim])
value = value.reshape([batch, kv_head, 1, seq_len, head_dim])
value = paddle.tile(value, [1, 1, g, 1, 1]).reshape([batch, heads, seq_len, head_dim])
if cache_k is not None:
ck = cache_k.reshape([batch, kv_head, 1, -1, head_dim])
ck = paddle.tile(ck, [1, 1, g, 1, 1]).reshape([batch, heads, -1, head_dim])
key = paddle.concat([ck, key], axis=2)
if cache_v is not None:
cv = cache_v.reshape([batch, kv_head, 1, -1, head_dim])
cv = paddle.tile(cv, [1, 1, g, 1, 1]).reshape([batch, heads, -1, head_dim])
value = paddle.concat([cv, value], axis=2)
qk = paddle.matmul(query, key, transpose_y=True) * scale
attn = paddle.nn.functional.softmax(qk, -1)
return paddle.matmul(attn.cast(value.dtype), value)
def block_cache_to_naive(cache_k, cache_v, bsz, block_tables, seq_len):
_, num_head, blocksize, dim_head = cache_k.shape
ok = paddle.zeros([bsz, num_head, seq_len, dim_head], dtype=cache_k.dtype)
ov = paddle.zeros([bsz, num_head, seq_len, dim_head], dtype=cache_v.dtype)
for i in range(bsz):
for j in range(seq_len):
ok[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
ov[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
return ok, ov
def run_c16_warp14_decoder_test(
batch_size,
q_num_head,
kv_num_head,
dim_head,
blocksize,
prefill_seq_len,
max_dec_len,
dtype,
decoder_max_partition_size,
num_decode_runs,
):
"""
Run encoder prefill + N decoder runs.
Returns (list of decoder output numpy arrays, naive reference numpy array).
"""
np.random.seed(SEED)
paddle.seed(SEED)
max_seq_len = prefill_seq_len + max_dec_len
block_per_seq = (max_seq_len + blocksize - 1) // blocksize
max_block_num = block_per_seq * batch_size
scale = 1.0 / np.sqrt(dim_head)
group_size = q_num_head // kv_num_head
compute_type = "bf16" if dtype == "bfloat16" else "fp16"
rope_emb = make_rope_emb(max_seq_len, dim_head)
# Block tables
free_list = list(range(max_block_num - 1, -1, -1))
block_tables = paddle.zeros((batch_size, block_per_seq), dtype="int32")
for i in range(batch_size):
for j in range(block_per_seq):
block_tables[i, j] = free_list.pop()
cache_k = paddle.zeros((max_block_num, kv_num_head, blocksize, dim_head), dtype=dtype)
cache_v = paddle.zeros((max_block_num, kv_num_head, blocksize, dim_head), dtype=dtype)
# Tile metadata buffers, sized per allocate_launch_related_buffer() formula
gqa_ratio = q_num_head // kv_num_head
decode_tile_size = int(1024 * batch_size * np.ceil((2 * gqa_ratio) / DECODER_BLOCK_SHAPE_Q))
encode_tile_size = max(batch_size, batch_size * (max_seq_len * gqa_ratio // ENCODER_BLOCK_SHAPE_Q))
kv_tile_size = max(batch_size, batch_size * (max_seq_len // blocksize))
dec_batch_ids = paddle.full([decode_tile_size], 0, dtype="int32")
dec_tile_ids = paddle.full([decode_tile_size], 0, dtype="int32")
dec_nblocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
dec_nblocks_dev = paddle.full([1], 0, dtype="int32")
dec_chunk_dev = paddle.full([1], decoder_max_partition_size, dtype="int32")
max_len_cpu = paddle.full([8], 0, dtype="int32").cpu()
enc_batch_ids = paddle.full([encode_tile_size], 0, dtype="int32")
enc_tile_ids = paddle.full([encode_tile_size], 0, dtype="int32")
enc_nblocks_cpu = paddle.full([1], 0, dtype="int32").cpu()
kv_batch_ids = paddle.full([kv_tile_size], 0, dtype="int32")
kv_tile_ids = paddle.full([kv_tile_size], 0, dtype="int32")
kv_nblocks_cpu = paddle.full([1], 0, dtype="int32").cpu()
# ===== Encoder phase =====
seq_enc = paddle.full([batch_size], prefill_seq_len, dtype="int32")
seq_dec = paddle.full([batch_size], 0, dtype="int32")
seq_this = copy.deepcopy(seq_enc)
bid_enc, cu_enc = get_padding_offset(batch_size, prefill_seq_len, seq_this)
token_num = batch_size * prefill_seq_len
q_np = np.random.random([batch_size, q_num_head, prefill_seq_len, dim_head]).astype("float32") / 10
k_np = np.random.random([batch_size, kv_num_head, prefill_seq_len, dim_head]).astype("float32") / 10
v_np = np.random.random([batch_size, kv_num_head, prefill_seq_len, dim_head]).astype("float32") / 10
q = paddle.to_tensor(q_np, dtype=dtype)
k = paddle.to_tensor(k_np, dtype=dtype)
v = paddle.to_tensor(v_np, dtype=dtype)
qkv = paddle.concat(
[
q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * dim_head]),
k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
],
axis=1,
)
# Use large partition size for encoder to avoid issues in prefill
encoder_partition_size = 32768
get_block_shape_and_split_kv_block(
seq_enc,
seq_dec,
seq_this,
dec_batch_ids,
dec_tile_ids,
dec_nblocks_cpu,
dec_nblocks_dev,
dec_chunk_dev,
max_len_cpu,
enc_batch_ids,
enc_tile_ids,
enc_nblocks_cpu,
kv_batch_ids,
kv_tile_ids,
kv_nblocks_cpu,
ENCODER_BLOCK_SHAPE_Q,
DECODER_BLOCK_SHAPE_Q,
group_size,
blocksize,
)
append_attention(
qkv,
cache_k,
cache_v,
seq_enc,
seq_dec,
seq_this,
bid_enc,
cu_enc,
block_tables,
enc_batch_ids,
enc_tile_ids,
enc_nblocks_cpu,
kv_batch_ids,
kv_tile_ids,
kv_nblocks_cpu,
dec_batch_ids,
dec_tile_ids,
dec_nblocks_cpu,
max_len_cpu,
rope_emb,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
1e-6,
compute_type,
"none",
False,
False,
max_seq_len,
0.0,
0.0,
-1,
ENCODER_BLOCK_SHAPE_Q,
DECODER_BLOCK_SHAPE_Q,
encoder_partition_size,
encoder_partition_size,
2,
True,
False,
0,
)
paddle.device.synchronize()
# Extract naive KV cache for reference (already has RoPE applied by kernel)
naive_ck, naive_cv = block_cache_to_naive(
cache_k,
cache_v,
batch_size,
block_tables,
prefill_seq_len,
)
# ===== Decoder phase =====
seq_enc_d = paddle.full([batch_size], 0, dtype="int32")
seq_dec_d = paddle.full([batch_size], prefill_seq_len, dtype="int32")
seq_this_d = paddle.full([batch_size], 1, dtype="int32")
bid_dec, cu_dec = get_padding_offset(batch_size, 1, seq_this_d)
dq_np = np.random.random([batch_size, q_num_head, 1, dim_head]).astype("float32") / 10
dk_np = np.random.random([batch_size, kv_num_head, 1, dim_head]).astype("float32") / 10
dv_np = np.random.random([batch_size, kv_num_head, 1, dim_head]).astype("float32") / 10
dq = paddle.to_tensor(dq_np, dtype=dtype)
dk = paddle.to_tensor(dk_np, dtype=dtype)
dv = paddle.to_tensor(dv_np, dtype=dtype)
dec_qkv = paddle.concat(
[
dq.transpose([0, 2, 1, 3]).reshape([batch_size, q_num_head * dim_head]),
dk.transpose([0, 2, 1, 3]).reshape([batch_size, kv_num_head * dim_head]),
dv.transpose([0, 2, 1, 3]).reshape([batch_size, kv_num_head * dim_head]),
],
axis=1,
)
# Warmup: first decoder call on multi-chunk path may return zeros
# due to kernel JIT compilation. Run once and discard.
get_block_shape_and_split_kv_block(
seq_enc_d,
seq_dec_d,
seq_this_d,
dec_batch_ids,
dec_tile_ids,
dec_nblocks_cpu,
dec_nblocks_dev,
dec_chunk_dev,
max_len_cpu,
enc_batch_ids,
enc_tile_ids,
enc_nblocks_cpu,
kv_batch_ids,
kv_tile_ids,
kv_nblocks_cpu,
ENCODER_BLOCK_SHAPE_Q,
DECODER_BLOCK_SHAPE_Q,
group_size,
blocksize,
)
append_attention(
dec_qkv.clone(),
cache_k.clone(),
cache_v.clone(),
seq_enc_d,
seq_dec_d,
seq_this_d,
bid_dec,
cu_dec,
block_tables,
enc_batch_ids,
enc_tile_ids,
enc_nblocks_cpu,
kv_batch_ids,
kv_tile_ids,
kv_nblocks_cpu,
dec_batch_ids,
dec_tile_ids,
dec_nblocks_cpu,
max_len_cpu,
rope_emb,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
1e-6,
compute_type,
"none",
False,
False,
max_seq_len,
0.0,
0.0,
-1,
ENCODER_BLOCK_SHAPE_Q,
DECODER_BLOCK_SHAPE_Q,
decoder_max_partition_size,
encoder_partition_size,
2,
True,
False,
0,
)
paddle.device.synchronize()
results = []
for _ in range(num_decode_runs):
cache_k_c = cache_k.clone()
cache_v_c = cache_v.clone()
qkv_c = dec_qkv.clone()
get_block_shape_and_split_kv_block(
seq_enc_d,
seq_dec_d,
seq_this_d,
dec_batch_ids,
dec_tile_ids,
dec_nblocks_cpu,
dec_nblocks_dev,
dec_chunk_dev,
max_len_cpu,
enc_batch_ids,
enc_tile_ids,
enc_nblocks_cpu,
kv_batch_ids,
kv_tile_ids,
kv_nblocks_cpu,
ENCODER_BLOCK_SHAPE_Q,
DECODER_BLOCK_SHAPE_Q,
group_size,
blocksize,
)
out = append_attention(
qkv_c,
cache_k_c,
cache_v_c,
seq_enc_d,
seq_dec_d,
seq_this_d,
bid_dec,
cu_dec,
block_tables,
enc_batch_ids,
enc_tile_ids,
enc_nblocks_cpu,
kv_batch_ids,
kv_tile_ids,
kv_nblocks_cpu,
dec_batch_ids,
dec_tile_ids,
dec_nblocks_cpu,
max_len_cpu,
rope_emb,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
1e-6,
compute_type,
"none",
False,
False,
max_seq_len,
0.0,
0.0,
-1,
ENCODER_BLOCK_SHAPE_Q,
DECODER_BLOCK_SHAPE_Q,
decoder_max_partition_size,
encoder_partition_size,
2,
True,
False,
0,
)
paddle.device.synchronize()
results.append(out.numpy().copy())
# Naive reference: apply RoPE to decoder Q/K at position prefill_seq_len
# (cached K/V already have RoPE applied by the kernel during encoder phase)
dq_rope = apply_rope(dq, rope_emb, [prefill_seq_len])
dk_rope = apply_rope(dk, rope_emb, [prefill_seq_len])
ref = naive_attention_impl(dq_rope, dk_rope, dv, naive_ck, naive_cv, scale)
ref_np = ref.transpose([0, 2, 1, 3]).reshape([batch_size, q_num_head * dim_head]).numpy()
return results, ref_np
class TestC16Warp14Determinism(unittest.TestCase):
"""
Test the c16 warp1_4 decoder kernel under FD_DETERMINISTIC_MODE=1.
Verifies:
1. Correctness: output matches naive attention reference (rtol/atol=1e-2)
2. Determinism: repeated runs with identical input -> bitwise-identical output
"""
def test_short_kv_nosplit(self):
"""num_chunks=1 (short KV): basic nosplit path, partition_kv=false template."""
results, ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=64,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=32768,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_long_kv_multi_chunk(self):
"""
num_chunks=4 (prefill=256, partition=64): the exact scenario the fix addresses.
partition_kv=true template but grid_chunks=1 (deterministic).
"""
results, ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=256,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_multi_batch(self):
"""Multiple batches with multi-chunk decoder."""
results, ref = run_c16_warp14_decoder_test(
batch_size=4,
q_num_head=8,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=256,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_float16(self):
"""Float16 dtype with multi-chunk decoder."""
results, ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=256,
max_dec_len=32,
dtype="float16",
decoder_max_partition_size=64,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_unaligned_seq_len(self):
"""prefill_seq_len not divisible by blocksize (100 % 64 != 0)."""
results, ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=100,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_mha_no_gqa(self):
"""MHA: q_num_head == kv_num_head (no GQA grouping)."""
results, ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=8,
kv_num_head=8,
dim_head=128,
blocksize=64,
prefill_seq_len=128,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_nosplit_vs_split_consistency(self):
"""
Cross-path check: force_no_partition (deterministic) vs split (partitioned)
should produce numerically close results.
The two paths differ only in floating-point accumulation order:
- nosplit: single chunk, sequential accumulation
- split: multi-chunk parallel, then merge
Difference should be within low-precision rounding (rtol/atol=1e-2 for bf16).
Note: envs.FD_DETERMINISTIC_MODE is lazily evaluated via __getattr__,
so runtime changes to os.environ["FD_DETERMINISTIC_MODE"] take effect.
"""
# Run once in deterministic mode (already set at module level)
det_results, det_ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=256,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=1,
)
# Run once in non-deterministic mode (split/partitioned path)
os.environ["FD_DETERMINISTIC_MODE"] = "0"
try:
split_results, split_ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=256,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=1,
)
finally:
os.environ["FD_DETERMINISTIC_MODE"] = "1"
# Both paths should match the naive reference
np.testing.assert_allclose(
det_results[0], det_ref, rtol=1e-2, atol=1e-2, err_msg="Deterministic path vs naive reference"
)
np.testing.assert_allclose(
split_results[0], split_ref, rtol=1e-2, atol=1e-2, err_msg="Split path vs naive reference"
)
# Cross-path: nosplit vs split should be close
np.testing.assert_allclose(
det_results[0],
split_results[0],
rtol=1e-2,
atol=1e-2,
err_msg="Deterministic (nosplit) vs split path divergence",
)
def test_partition_boundary(self):
"""
Edge case: prefill_seq_len equals partition_size (boundary condition).
This tests the scenario where num_chunks = prefill_seq_len / partition_size
is exactly an integer, which is a boundary condition for chunk calculation.
"""
results, ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=128,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=128,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_empty_kv(self):
"""
Edge case: decoder-only mode (no prefill, empty KV cache).
This tests the scenario where there's no encoder prefill phase,
only decoder with empty KV cache.
"""
results, _ = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=0,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=10,
)
# For empty KV, we only check determinism as naive reference may not be valid
for i in range(1, len(results)):
np.testing.assert_array_equal(
results[0],
results[i],
err_msg=f"Determinism failure: run 0 vs run {i}",
)
if __name__ == "__main__":
unittest.main()
@@ -1,360 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Determinism offline inference tests using LLM.generate
Test scenarios:
1. Same-prompt repeatability (FD_DETERMINISTIC_MODE=1)
2. Batch invariance (single vs. batch, different positions)
3. Different batch sizes consistency
4. Sampling-parameter combinations (temperature x top_p, parametrized)
5. Long sequence generation (512-1024 tokens)
6. Long input prompt handling
7. Minimal output (max_tokens=1, early stop)
8. Special characters & multi-language prompts
9. Multi-turn conversation
10. State isolation (interleaved / interference prompts)
11. Non-deterministic validation (proves tests are effective)
Usage:
CUDA_VISIBLE_DEVICES=0 pytest tests/deterministic/test_determinism_offline.py -v
"""
import os
import pytest
pytestmark = pytest.mark.gpu
DEFAULT_MODEL_DIR = "./models"
MODEL_NAME = "Qwen2-7B-Instruct"
_ENV_CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
_ENV_FD_DETERMINISTIC_MODE = "FD_DETERMINISTIC_MODE"
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module", autouse=True)
def _module_env():
"""Set env vars before importing fastdeploy (must happen first)."""
old_cuda = os.environ.get(_ENV_CUDA_VISIBLE_DEVICES)
old_det = os.environ.get(_ENV_FD_DETERMINISTIC_MODE)
os.environ[_ENV_CUDA_VISIBLE_DEVICES] = os.environ.get(_ENV_CUDA_VISIBLE_DEVICES, "0")
os.environ[_ENV_FD_DETERMINISTIC_MODE] = "1"
global LLM, SamplingParams # noqa: PLW0603
from fastdeploy import LLM, SamplingParams
yield
if old_cuda is None:
os.environ.pop(_ENV_CUDA_VISIBLE_DEVICES, None)
else:
os.environ[_ENV_CUDA_VISIBLE_DEVICES] = old_cuda
if old_det is None:
os.environ.pop(_ENV_FD_DETERMINISTIC_MODE, None)
else:
os.environ[_ENV_FD_DETERMINISTIC_MODE] = old_det
@pytest.fixture(autouse=True)
def _reset_deterministic_mode():
"""Ensure every test starts with deterministic mode ON."""
os.environ[_ENV_FD_DETERMINISTIC_MODE] = "1"
yield
os.environ[_ENV_FD_DETERMINISTIC_MODE] = "1"
@pytest.fixture(scope="module")
def model_path():
model_dir = os.getenv("MODEL_PATH", DEFAULT_MODEL_DIR)
return os.path.join(model_dir, MODEL_NAME)
@pytest.fixture(scope="module")
def llm(model_path, _module_env):
return LLM(
model=model_path,
tensor_parallel_size=1,
max_model_len=8192,
enable_prefix_caching=False,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _generate_text(llm, prompt, sp):
"""Generate once, return (text, token_ids)."""
out = llm.generate([prompt], sp)[0]
return out.outputs.text, out.outputs.token_ids
def _assert_deterministic(llm, prompt, sp, runs=2):
"""Run *runs* times and assert all outputs are identical."""
results = [_generate_text(llm, prompt, sp) for _ in range(runs)]
texts = [r[0] for r in results]
token_ids = [r[1] for r in results]
assert all(t == texts[0] for t in texts), "Text outputs differ across runs"
assert all(t == token_ids[0] for t in token_ids), "Token IDs differ across runs"
return texts[0], token_ids[0]
# ===================== Core determinism tests =====================
def test_deterministic_same_prompt(llm):
"""Same prompt + same seed produces identical output across 5 runs."""
sp = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=50, seed=123)
_assert_deterministic(llm, "Please introduce artificial intelligence in one sentence.", sp, runs=5)
def test_deterministic_batch_invariance(llm):
"""Target prompt produces identical output regardless of batch position."""
prompt = "What kind of programming language is Python?"
sp = SamplingParams(temperature=0.5, max_tokens=40, seed=456)
baseline, _ = _generate_text(llm, prompt, sp)
batch_configs = [
[prompt, "Filler question 1"],
["Filler question 2", prompt, "Filler question 3"],
["Filler question 4", "Filler question 5", prompt],
["Filler 6", "Filler 7", "Filler 8", prompt],
]
for i, batch in enumerate(batch_configs):
outputs = llm.generate(batch, sp)
idx = batch.index(prompt)
assert (
outputs[idx].outputs.text == baseline
), f"Batch config {i} (pos {idx}): result differs from single-request baseline"
def test_deterministic_different_batch_sizes(llm):
"""Same prompt is consistent across batch sizes 1 / 2 / 4 / 8."""
prompt = "What is machine learning?"
sp = SamplingParams(temperature=0.5, max_tokens=30, seed=789)
baseline, _ = _generate_text(llm, prompt, sp)
for bs in [2, 4, 8]:
outputs = llm.generate([prompt] * bs, sp)
assert outputs[0].outputs.text == baseline, f"Batch size {bs} differs from bs=1"
# ===================== Sampling-parameter combinations =====================
@pytest.mark.parametrize(
"temp,top_p,seed",
[
(0.0, 1.0, 300), # greedy, no top_p filter
(0.0, 0.0, 301), # double-greedy
(0.3, 0.9, 302), # low temp, moderate top_p
(0.8, 0.0, 303), # medium temp, greedy top_p
(0.8, 1.0, 304), # medium temp, no top_p filter
(0.8, 0.5, 305), # medium temp, strict top_p
(1.0, 0.95, 306), # high temp
(1.5, 0.9, 307), # very high temp
],
)
def test_deterministic_param_combos(llm, temp, top_p, seed):
"""Determinism holds across various (temperature, top_p) combinations."""
sp = SamplingParams(temperature=temp, top_p=top_p, max_tokens=30, seed=seed)
_assert_deterministic(llm, "What is a neural network?", sp)
# ===================== Long sequence tests =====================
@pytest.mark.parametrize(
"temp,seed",
[
(0.0, 100),
(0.3, 130),
(0.5, 150),
(0.7, 170),
],
)
@pytest.mark.skip(reason="Potential non-determinism in long sequences, will be fixed by gongweibao in next PR")
def test_deterministic_long_sequence(llm, temp, seed):
"""Long generation (512+ tokens) stays deterministic at various temperatures."""
prompt = "Please describe the history of AI in detail, including major milestones and key technical breakthroughs."
sp = SamplingParams(temperature=temp, top_p=0.95, max_tokens=512, seed=seed)
text, token_ids = _assert_deterministic(llm, prompt, sp)
assert len(token_ids) >= 100, f"Expected >= 100 tokens, got {len(token_ids)}"
def test_deterministic_long_prompt(llm):
"""Long input prompt (prefill-heavy) stays deterministic."""
base = "This is a description about natural language processing. "
long_prompt = (base * 50) + "Please summarize the above."
sp = SamplingParams(temperature=0.5, max_tokens=100, seed=2024)
_assert_deterministic(llm, long_prompt, sp)
# ===================== Minimal / boundary output tests =====================
def test_deterministic_max_tokens_one(llm):
"""Single-token output is deterministic."""
sp = SamplingParams(temperature=0.1, max_tokens=1, seed=700)
text, token_ids = _assert_deterministic(llm, "What color is the sky?", sp)
assert len(token_ids) == 1, f"Expected 1 token, got {len(token_ids)}"
def test_deterministic_early_stop(llm):
"""Early stopping via stop sequences is deterministic."""
sp = SamplingParams(temperature=0.7, max_tokens=100, stop=["\u3002", "."], seed=800)
text, token_ids = _assert_deterministic(llm, "Please list three colors:", sp)
assert len(token_ids) < 100, f"Expected early stop, got {len(token_ids)} tokens"
# ===================== Special input tests =====================
@pytest.mark.parametrize(
"prompt,seed",
[
("What is AI? \U0001f52c\U0001f9e0", 900), # emoji
("Math: E = mc\u00b2", 901), # superscript
("Code: def hello(): return 'world'", 902), # code
("Symbols: @#$%^&*()", 903), # special symbols
],
)
def test_deterministic_special_chars(llm, prompt, seed):
sp = SamplingParams(temperature=0.5, max_tokens=30, seed=seed)
_assert_deterministic(llm, prompt, sp)
@pytest.mark.parametrize(
"lang,prompt,seed",
[
("Chinese", "Please introduce artificial intelligence in one sentence.", 1000),
("English", "What is artificial intelligence in one sentence?", 1001),
(
"Japanese",
"\u4eba\u5de5\u77e5\u80fd\u306b\u3064\u3044\u3066\u4e00\u8a00\u3067\u8aac\u660e\u3057\u3066\u304f\u3060\u3055\u3044\u3002",
1002,
),
("Spanish", "\u00bfQu\u00e9 es la inteligencia artificial en una frase?", 1003),
],
)
def test_deterministic_multi_language(llm, lang, prompt, seed):
sp = SamplingParams(temperature=0.5, max_tokens=30, seed=seed)
_assert_deterministic(llm, prompt, sp)
# ===================== Multi-turn conversation test =====================
def test_deterministic_multi_turn(llm):
"""Multi-turn chat maintains determinism."""
sp = SamplingParams(temperature=0.5, max_tokens=50, seed=1100)
messages1 = [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi! How can I help you?"},
{"role": "user", "content": "Please introduce yourself."},
]
# First full conversation
r1_turn1 = llm.chat(messages1, sp)[0].outputs.text
msgs2 = messages1 + [
{"role": "assistant", "content": r1_turn1},
{"role": "user", "content": "What can you do?"},
]
r1_turn2 = llm.chat(msgs2, sp)[0].outputs.text
# Second full conversation (same seed)
r2_turn1 = llm.chat(messages1, sp)[0].outputs.text
msgs2_repeat = messages1 + [
{"role": "assistant", "content": r2_turn1},
{"role": "user", "content": "What can you do?"},
]
r2_turn2 = llm.chat(msgs2_repeat, sp)[0].outputs.text
assert r1_turn1 == r2_turn1, "Multi-turn: turn-1 outputs differ"
assert r1_turn2 == r2_turn2, "Multi-turn: turn-2 outputs differ"
# ===================== State isolation test =====================
def test_deterministic_state_isolation(llm):
"""Interference prompts and interleaving do not break determinism."""
prompt_a = "What is Python?"
prompt_b = "What is JavaScript?"
sp_a = SamplingParams(temperature=0.5, max_tokens=30, seed=1200)
sp_b = SamplingParams(temperature=0.5, max_tokens=30, seed=1201)
# Round 1
a1, _ = _generate_text(llm, prompt_a, sp_a)
b1, _ = _generate_text(llm, prompt_b, sp_b)
# Run unrelated interference
for p in ["Explain reinforcement learning.", "What is NLP?", "List 3 fruits."]:
llm.generate([p], SamplingParams(temperature=0.7, max_tokens=20, seed=999))
# Round 2
a2, _ = _generate_text(llm, prompt_a, sp_a)
b2, _ = _generate_text(llm, prompt_b, sp_b)
assert a1 == a2, "Prompt A: output changed after interference"
assert b1 == b2, "Prompt B: output changed after interference"
# ===================== Non-deterministic validation =====================
def test_non_deterministic_validation(llm):
"""
Prove that tests are effective:
- Without seed + without mode: outputs vary
- With explicit seed: outputs are consistent
"""
prompt = "Please explain deep learning in one sentence."
# Part 1: no mode, no seed -> outputs should differ
os.environ.pop("FD_DETERMINISTIC_MODE", None)
results_no_seed = []
for _ in range(5):
sp = SamplingParams(temperature=0.7, max_tokens=30)
results_no_seed.append(llm.generate([prompt], sp)[0].outputs.text)
# Probabilistic, skip if all outputs are the same
if len(set(results_no_seed)) == 1:
pytest.skip("Sampling produced identical outputs (probabilistic case)")
# Part 2: explicit seed -> outputs must be consistent
sp_seeded = SamplingParams(temperature=0.7, max_tokens=30, seed=999)
results_seeded = [llm.generate([prompt], sp_seeded)[0].outputs.text for _ in range(5)]
assert len(set(results_seeded)) == 1, "With explicit seed: expected consistent outputs"
if __name__ == "__main__":
pytest.main(["-sv", __file__])
@@ -0,0 +1,129 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Single-GPU determinism offline inference tests for coverage.
Simplified from tests/e2e/4cards_cases/test_determinism_offline.py
for single-GPU coverage testing.
Usage:
CUDA_VISIBLE_DEVICES=0 pytest tests/deterministic/test_determinism_offline_single_gpu.py -v
"""
import os
from contextlib import contextmanager
import pytest
pytestmark = pytest.mark.gpu
DEFAULT_MODEL_DIR = "./models"
MODEL_NAME = "Qwen2-7B-Instruct"
@contextmanager
def env_override(mapping):
"""Temporarily set env vars, restoring original values on exit."""
old = {k: os.environ.get(k) for k in mapping}
os.environ.update(mapping)
try:
yield
finally:
for k, v in old.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
@pytest.fixture(scope="module")
def model_path():
model_dir = os.getenv("MODEL_PATH", DEFAULT_MODEL_DIR)
return os.path.join(model_dir, MODEL_NAME)
@pytest.fixture(autouse=True)
def _reset_deterministic_mode():
"""Ensure every test starts with deterministic mode ON."""
os.environ["FD_DETERMINISTIC_MODE"] = "1"
yield
os.environ["FD_DETERMINISTIC_MODE"] = "1"
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module", autouse=True)
def _module_env():
"""Set env vars before importing fastdeploy (must happen first)."""
with env_override(
{
"CUDA_VISIBLE_DEVICES": os.environ.get("CUDA_VISIBLE_DEVICES", "0"),
"FD_DETERMINISTIC_MODE": "1",
"FD_CUSTOM_AR_MAX_SIZE_MB": "64",
}
):
# Lazy import: env vars must be set before importing fastdeploy
global LLM, SamplingParams # noqa: PLW0603
from fastdeploy import LLM, SamplingParams
yield
@pytest.fixture(scope="module")
def llm(model_path, _module_env):
return LLM(
model=model_path,
tensor_parallel_size=1, # Single GPU
max_model_len=4096,
enable_prefix_caching=False,
graph_optimization_config={"use_cudagraph": False},
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _generate_text(llm, prompt, sp):
"""Generate once, return (text, token_ids)."""
out = llm.generate([prompt], sp)[0]
return out.outputs.text, list(out.outputs.token_ids)
def _assert_deterministic(llm, prompt, sp, runs=2):
"""Run *runs* times and assert all outputs are identical."""
results = [_generate_text(llm, prompt, sp) for _ in range(runs)]
texts = [r[0] for r in results]
token_ids = [r[1] for r in results]
assert all(t == texts[0] for t in texts), "Text outputs differ across runs"
assert all(t == token_ids[0] for t in token_ids), "Token IDs differ across runs"
return texts[0], token_ids[0]
# ===================== Core determinism tests =====================
def test_deterministic_same_prompt(llm):
"""Same prompt + same seed produces identical output across 3 runs."""
sp = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=30, seed=123)
_assert_deterministic(llm, "What is AI?", sp, runs=3)
if __name__ == "__main__":
pytest.main(["-sv", __file__])
@@ -0,0 +1,216 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for flash_attn_func deterministic mode.
Verifies that flash_attn_func passes correct deterministic parameters
(e.g. num_splits=1 for FA3) when FD_DETERMINISTIC_MODE=1.
Usage:
CUDA_VISIBLE_DEVICES=0 pytest tests/deterministic/test_flash_attn_determinism.py -v
"""
import importlib
import os
import numpy as np
import paddle
import pytest
pytestmark = pytest.mark.gpu
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
NUM_HEADS = 8
KV_NUM_HEADS = 8
HEAD_DIM = 128
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_sm_version():
prop = paddle.device.cuda.get_device_properties()
return prop.major * 10 + prop.minor
def _reload_flash_attn_backend():
"""Reload flash_attn_backend so env-var changes take effect."""
import fastdeploy.model_executor.layers.attention.flash_attn_backend as mod
importlib.reload(mod)
return mod
def _make_tensors(seq_lens, num_heads=NUM_HEADS, head_dim=HEAD_DIM):
"""Create Q/K/V tensors and cu_seqlens for a batch of sequences."""
total_tokens = sum(seq_lens)
q = paddle.randn([total_tokens, num_heads, head_dim], dtype="bfloat16")
k = paddle.randn([total_tokens, num_heads, head_dim], dtype="bfloat16")
v = paddle.randn([total_tokens, num_heads, head_dim], dtype="bfloat16")
cu_seqlens = paddle.to_tensor(np.array([0] + list(np.cumsum(seq_lens))), dtype="int32")
max_seqlen = max(seq_lens)
return q, k, v, cu_seqlens, max_seqlen
def _call_flash_attn_func(mod, q, k, v, cu_seqlens, max_seqlen, version=None):
"""Call flash_attn_func and return the output tensor."""
result = mod.flash_attn_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=True,
num_heads=NUM_HEADS,
kv_num_heads=KV_NUM_HEADS,
head_dim=HEAD_DIM,
version=version,
)
if isinstance(result, tuple):
return result[0]
return result
def _run_determinism_check(mod, seq_lens, runs, version, test_name):
"""Run flash_attn_func multiple times and verify deterministic output."""
q, k, v, cu_seqlens, max_seqlen = _make_tensors(seq_lens)
outputs = []
for _ in range(runs):
out = _call_flash_attn_func(mod, q, k, v, cu_seqlens, max_seqlen, version=version)
outputs.append(out.numpy())
for i in range(1, runs):
assert np.array_equal(outputs[0], outputs[i]), f"{test_name}: run {i} differs from run 0"
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _clean_env():
"""Save and restore determinism-related env vars around every test."""
keys = ["FD_DETERMINISTIC_MODE", "FD_DETERMINISTIC_DEBUG"]
saved = {k: os.environ.get(k) for k in keys}
yield
for k, v in saved.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
@pytest.fixture
def _deterministic_mode_enabled():
"""Enable deterministic mode and return reloaded module."""
os.environ["FD_DETERMINISTIC_MODE"] = "1"
return _reload_flash_attn_backend()
@pytest.fixture
def _nondeterministic_mode_enabled():
"""Disable deterministic mode and return reloaded module."""
os.environ["FD_DETERMINISTIC_MODE"] = "0"
return _reload_flash_attn_backend()
# ---------------------------------------------------------------------------
# Tests: _is_deterministic_mode
# ---------------------------------------------------------------------------
class TestIsDeterministicMode:
"""Test the _is_deterministic_mode helper."""
def test_enabled(self):
os.environ["FD_DETERMINISTIC_MODE"] = "1"
mod = _reload_flash_attn_backend()
assert mod._is_deterministic_mode() is True
def test_disabled(self):
os.environ["FD_DETERMINISTIC_MODE"] = "0"
mod = _reload_flash_attn_backend()
assert mod._is_deterministic_mode() is False
def test_unset_defaults_false(self):
os.environ.pop("FD_DETERMINISTIC_MODE", None)
mod = _reload_flash_attn_backend()
assert mod._is_deterministic_mode() is False
# ---------------------------------------------------------------------------
# Tests: FA3 determinism (requires SM89+, <SM100)
# ---------------------------------------------------------------------------
class TestFA3Determinism:
"""Test FA3 deterministic behavior with num_splits control."""
@pytest.fixture(autouse=True)
def _require_fa3(self):
sm = _get_sm_version()
if sm < 89 or sm >= 100:
pytest.skip(f"FA3 requires SM89-99, current SM={sm}")
paddle.set_flags({"FLAGS_flash_attn_version": 3})
def test_deterministic_produces_identical_output(self, _deterministic_mode_enabled):
"""num_splits=1 (deterministic) gives bitwise identical results."""
_run_determinism_check(_deterministic_mode_enabled, [64, 128, 256], 5, 3, "FA3 deterministic")
def test_long_sequence_determinism(self, _deterministic_mode_enabled):
"""Long sequences (>1024 tokens) remain deterministic with FA3."""
_run_determinism_check(_deterministic_mode_enabled, [2048], 3, 3, "FA3 long seq")
def test_mixed_batch_determinism(self, _deterministic_mode_enabled):
"""Mixed batch with varying sequence lengths stays deterministic."""
_run_determinism_check(_deterministic_mode_enabled, [16, 512, 1024, 64], 3, 3, "FA3 mixed batch")
def test_nondeterministic_mode_also_works(self, _nondeterministic_mode_enabled):
"""FD_DETERMINISTIC_MODE=0 still works (num_splits=1 is always used)."""
q, k, v, cu_seqlens, max_seqlen = _make_tensors([256])
out = _call_flash_attn_func(_nondeterministic_mode_enabled, q, k, v, cu_seqlens, max_seqlen, version=3)
assert out.shape[0] == 256
assert out.shape[1] == NUM_HEADS
assert out.shape[2] == HEAD_DIM
# ---------------------------------------------------------------------------
# Tests: FA2 determinism
# ---------------------------------------------------------------------------
class TestFA2Determinism:
"""Test FA2 deterministic behavior (inherently deterministic forward)."""
@pytest.fixture(autouse=True)
def _set_fa2(self):
paddle.set_flags({"FLAGS_flash_attn_version": 2})
def test_fa2_deterministic(self, _deterministic_mode_enabled):
"""FA2 forward is inherently deterministic (no split-KV)."""
_run_determinism_check(_deterministic_mode_enabled, [128, 256], 5, 2, "FA2 deterministic")
def test_fa2_long_sequence(self, _deterministic_mode_enabled):
"""FA2 with long sequence remains deterministic."""
_run_determinism_check(_deterministic_mode_enabled, [2048], 3, 2, "FA2 long seq")
@@ -0,0 +1,823 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for the `update_repeat_times` kernel in token_penalty_multi_scores.cu.
Tests verify both **correctness** and **determinism** of the penalty kernel,
specifically targeting the race condition that was fixed by splitting into
two passes with a __syncthreads() barrier.
Race condition background:
The `update_repeat_times` kernel processes token_ids_all to build a
repeat_times array with these semantics:
0 -> token absent from both prompt and generated tokens
-1 -> token only in the prompt
>=1 -> count of occurrences in generated tokens
The original (buggy) code ran both passes in a single loop without
synchronization. When a token appeared in BOTH the prompt and generated
portions, three atomic operations could interleave across warps:
Pass 1: atomicCAS(&slot, 0, -1) -- mark as prompt-only
Pass 2: atomicMax(&slot, 0) -- lift from -1 to 0
Pass 2: atomicAdd(&slot, 1) -- count generated occurrence
Without __syncthreads() between passes, a thread in Pass 2 could execute
atomicMax/atomicAdd BEFORE another thread in Pass 1 executed atomicCAS,
resulting in repeat_times=0 (should be 1). This caused non-deterministic
penalty application.
The fix: split into two explicit passes with __syncthreads() between them.
Usage:
source /root/paddlejob/workspace/env_run/gongweibao/archfd/fdarchenv/bin/activate
FD_DETERMINISTIC_MODE=1 CUDA_VISIBLE_DEVICES=0 pytest tests/deterministic/test_penalty_kernel_determinism.py -v -s
"""
import os
# Set deterministic mode before any imports that might read it.
os.environ["FD_DETERMINISTIC_MODE"] = "1"
import numpy as np
import paddle
import pytest
# Import the custom op. This goes through the fastdeploy ops import machinery
# which loads the compiled CUDA custom op and exposes it as a Python callable.
from fastdeploy.model_executor.ops.gpu import get_token_penalty_multi_scores
pytestmark = pytest.mark.gpu
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# Default penalty parameters (matching production defaults)
DEFAULT_ALPHA = 1.2 # repetition_penalty
DEFAULT_BETA = 0.5 # frequency_penalty
DEFAULT_GAMMA = 0.3 # presence_penalty
DEFAULT_TEMP = 1.0 # temperature
# Token sentinels
EOS_TOKEN_ID_QWEN2 = 151643
NO_TOKEN_SENTINEL = -1
# ---------------------------------------------------------------------------
# Helper: build all tensors needed by the penalty custom op
# ---------------------------------------------------------------------------
def _make_penalty_inputs(
token_ids_all_np, # int64 [bs, max_model_len]
logits_np, # float32 [bs, vocab_size]
prompt_lens_np, # int64 [bs]
cur_dec_lens_np, # int64 [bs]
penalty_scores_np=None, # float32 [bs, 1] (repetition penalty, default 1.2)
frequency_scores_np=None, # float32 [bs, 1] (default 0.5)
presence_scores_np=None, # float32 [bs, 1] (default 0.3)
temperatures_np=None, # float32 [bs, 1] (default 1.0)
eos_token_id_np=None, # int64 [eos_len]
min_dec_lens_np=None, # int64 [bs]
bad_tokens_np=None, # int64 [bs, bad_words_len]
bad_tokens_lens_np=None, # int64 [bs]
):
"""
Build GPU tensors for the penalty op from numpy arrays.
All penalty/frequency/presence/temperature tensors are float32, matching
the production dtype used in input_batch.py. Logits are also float32
so that update_value_by_repeat_times (which is templated on logit dtype)
reads penalty scalars correctly. The update_repeat_times kernel -- the
one that had the race condition -- is NOT templated and exercises the
same code path regardless of logit dtype.
"""
bs = token_ids_all_np.shape[0]
place = paddle.CUDAPlace(0)
# Required inputs
token_ids_all = paddle.to_tensor(token_ids_all_np, dtype="int64", place=place)
logits = paddle.to_tensor(logits_np, dtype="float32", place=place)
prompt_lens = paddle.to_tensor(prompt_lens_np, dtype="int64", place=place)
cur_dec_lens = paddle.to_tensor(cur_dec_lens_np, dtype="int64", place=place)
# Optional inputs with sensible defaults
if penalty_scores_np is None:
penalty_scores_np = np.full([bs, 1], DEFAULT_ALPHA, dtype=np.float32)
if frequency_scores_np is None:
frequency_scores_np = np.full([bs, 1], DEFAULT_BETA, dtype=np.float32)
if presence_scores_np is None:
presence_scores_np = np.full([bs, 1], DEFAULT_GAMMA, dtype=np.float32)
if temperatures_np is None:
temperatures_np = np.full([bs, 1], DEFAULT_TEMP, dtype=np.float32)
if eos_token_id_np is None:
eos_token_id_np = np.array([EOS_TOKEN_ID_QWEN2], dtype=np.int64)
if min_dec_lens_np is None:
min_dec_lens_np = np.zeros([bs], dtype=np.int64)
if bad_tokens_np is None:
bad_tokens_np = np.full([bs, 1], NO_TOKEN_SENTINEL, dtype=np.int64)
if bad_tokens_lens_np is None:
bad_tokens_lens_np = np.zeros([bs], dtype=np.int64)
penalty_scores = paddle.to_tensor(penalty_scores_np, dtype="float32", place=place)
frequency_scores = paddle.to_tensor(frequency_scores_np, dtype="float32", place=place)
presence_scores = paddle.to_tensor(presence_scores_np, dtype="float32", place=place)
temperatures = paddle.to_tensor(temperatures_np, dtype="float32", place=place)
eos_token_id = paddle.to_tensor(eos_token_id_np, dtype="int64", place=place)
min_dec_lens = paddle.to_tensor(min_dec_lens_np, dtype="int64", place=place)
bad_tokens = paddle.to_tensor(bad_tokens_np, dtype="int64", place=place)
bad_tokens_lens = paddle.to_tensor(bad_tokens_lens_np, dtype="int64", place=place)
return (
token_ids_all,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
bad_tokens_lens,
prompt_lens,
cur_dec_lens,
min_dec_lens,
eos_token_id,
)
def _run_penalty(
token_ids_all,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
bad_tokens_lens,
prompt_lens,
cur_dec_lens,
min_dec_lens,
eos_token_id,
):
"""
Run the penalty op on a CLONE of logits (so the original is not modified)
and return the resulting logits as a numpy array.
"""
logits_clone = logits.clone()
result = get_token_penalty_multi_scores(
token_ids_all,
logits_clone,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
bad_tokens_lens,
prompt_lens,
cur_dec_lens,
min_dec_lens,
eos_token_id,
)
paddle.device.cuda.synchronize()
return result.numpy()
# ---------------------------------------------------------------------------
# Helper: determinism assertion
# ---------------------------------------------------------------------------
def _assert_determinism(
inputs,
num_runs: int,
test_name: str,
verbose: bool = True,
include_diff_details: bool = True,
):
"""
Assert that running the penalty op `num_runs` times produces
bit-identical results.
Args:
inputs: Tuple of tensors for _run_penalty
num_runs: Number of runs to perform
test_name: Name for error messages
verbose: Print success message if True
include_diff_details: Include diff details in error message
Returns:
True if deterministic
"""
reference = _run_penalty(*inputs)
mismatches = []
for i in range(1, num_runs):
result = _run_penalty(*inputs)
if not np.array_equal(reference, result):
diff_mask = reference != result
diff_count = np.sum(diff_mask)
if include_diff_details:
diff_indices = np.argwhere(diff_mask)[:5].tolist()
max_abs_diff = np.max(np.abs(reference[diff_mask] - result[diff_mask]))
mismatches.append((i, diff_count, diff_indices, max_abs_diff))
else:
mismatches.append((i, diff_count))
if mismatches:
error_msg = (
f"{test_name} is NON-DETERMINISTIC: " f"{len(mismatches)}/{num_runs-1} runs differ from reference.\n"
)
if include_diff_details:
error_msg += (
f"First 3 mismatches (run_idx, num_diffs, sample_indices, max_abs_diff): " f"{mismatches[:3]}\n"
)
else:
error_msg += f"First 3 mismatches (run_idx, num_diffs): " f"{mismatches[:3]}\n"
error_msg += (
"This indicates the atomicCAS/atomicMax/atomicAdd race condition " "in update_repeat_times is NOT fixed."
)
raise AssertionError(error_msg)
if verbose:
print(f"\n {test_name}: all {num_runs} runs produced bit-identical results.")
return True
def _print_penalty_summary(actual: float, expected: float, label: str, raw_value: float):
"""Print a formatted summary line for a penalty test."""
print(f" {label}: raw={raw_value:.1f} -> {actual:.6f} (expected {expected:.6f})")
# ---------------------------------------------------------------------------
# Test 1: Correctness -- verify repeat_times semantics and final logits
# ---------------------------------------------------------------------------
class TestPenaltyCorrectness:
"""
Test that the penalty kernel applies the correct transformation for
each repeat_times category:
Token A (id=10): only in prompt -> repeat_times = -1
Token B (id=20): only in generated (x2) -> repeat_times = 2
Token C (id=30): in BOTH prompt + gen -> repeat_times = 1
Token D (id=40): nowhere -> repeat_times = 0
"""
VOCAB_SIZE = 100
MAX_MODEL_LEN = 20
PROMPT_LEN = 5
# Generated tokens occupy positions prompt_len .. max_model_len-1.
# Unused slots are filled with -1 (sentinel for "no token").
def _build_scenario(self):
"""
Build a single-batch scenario:
Prompt tokens (positions 0..4): [10, 30, 50, 51, 52]
Generated tokens (positions 5..): [20, 20, 30, -1, -1, ...]
So:
token 10: prompt only -> repeat_times = -1
token 20: generated x2 -> repeat_times = 2
token 30: prompt AND gen x1 -> repeat_times = 1 (the fixed case!)
token 40: absent -> repeat_times = 0
token 50,51,52: prompt only -> repeat_times = -1
"""
bs = 1
token_ids = np.full([bs, self.MAX_MODEL_LEN], -1, dtype=np.int64)
# Prompt region
token_ids[0, 0] = 10
token_ids[0, 1] = 30
token_ids[0, 2] = 50
token_ids[0, 3] = 51
token_ids[0, 4] = 52
# Generated region
token_ids[0, 5] = 20
token_ids[0, 6] = 20
token_ids[0, 7] = 30
prompt_lens = np.array([self.PROMPT_LEN], dtype=np.int64)
cur_dec_lens = np.array([3], dtype=np.int64) # 3 generated tokens
# Logits: put known positive and negative values at token positions
# to verify the penalty formula direction.
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
logits[0, 10] = 2.0 # Token A (prompt only) -- positive logit
logits[0, 20] = -1.0 # Token B (gen only x2) -- negative logit
logits[0, 30] = 3.0 # Token C (both) -- positive logit
logits[0, 40] = 0.5 # Token D (absent) -- positive logit
penalty_scores = np.array([[DEFAULT_ALPHA]], dtype=np.float32)
frequency_scores = np.array([[DEFAULT_BETA]], dtype=np.float32)
presence_scores = np.array([[DEFAULT_GAMMA]], dtype=np.float32)
temperatures = np.array([[DEFAULT_TEMP]], dtype=np.float32)
return _make_penalty_inputs(
token_ids,
logits,
prompt_lens,
cur_dec_lens,
penalty_scores_np=penalty_scores,
frequency_scores_np=frequency_scores,
presence_scores_np=presence_scores,
temperatures_np=temperatures,
)
def _expected_logit(self, raw_logit, repeat_times):
"""
Compute expected logit after penalty application:
if times != 0:
logit = logit * alpha if logit < 0 else logit / alpha
if times > 0:
logit = logit - times * beta - gamma
logit = logit / temperature
"""
logit = raw_logit
if repeat_times != 0:
if logit < 0:
logit = logit * DEFAULT_ALPHA
else:
logit = logit / DEFAULT_ALPHA
if repeat_times > 0:
logit = logit - repeat_times * DEFAULT_BETA - DEFAULT_GAMMA
logit = logit / DEFAULT_TEMP
return logit
def test_penalty_correctness(self):
"""
Verify that each token category gets the correct penalty applied.
This specifically tests the case where Token C (id=30) appears in
BOTH the prompt and generated regions. Before the __syncthreads()
fix, this could non-deterministically produce repeat_times=0
instead of the correct repeat_times=1.
"""
inputs = self._build_scenario()
result = _run_penalty(*inputs)
# Token A (id=10): repeat_times = -1 (prompt only)
# Penalty: only repetition (times != 0), no frequency/presence (times <= 0)
expected_10 = self._expected_logit(2.0, repeat_times=-1)
actual_10 = result[0, 10]
assert np.isclose(actual_10, expected_10, atol=1e-5), (
f"Token A (prompt only): expected {expected_10:.6f}, got {actual_10:.6f}. " f"repeat_times should be -1."
)
# Token B (id=20): repeat_times = 2 (gen only, 2 occurrences)
# Penalty: repetition + 2*frequency + presence
expected_20 = self._expected_logit(-1.0, repeat_times=2)
actual_20 = result[0, 20]
assert np.isclose(actual_20, expected_20, atol=1e-5), (
f"Token B (gen only x2): expected {expected_20:.6f}, got {actual_20:.6f}. " f"repeat_times should be 2."
)
# Token C (id=30): repeat_times = 1 (BOTH prompt and gen, 1 gen occurrence)
# THIS IS THE KEY TEST CASE for the race condition fix.
# Before the fix, repeat_times could be 0 instead of 1.
expected_30 = self._expected_logit(3.0, repeat_times=1)
actual_30 = result[0, 30]
assert np.isclose(actual_30, expected_30, atol=1e-5), (
f"Token C (both prompt+gen): expected {expected_30:.6f}, got {actual_30:.6f}. "
f"repeat_times should be 1. If this fails intermittently, the "
f"atomicCAS/atomicMax/atomicAdd race in update_repeat_times is "
f"not properly fixed."
)
# Token D (id=40): repeat_times = 0 (absent)
# Only temperature scaling, no penalty.
expected_40 = self._expected_logit(0.5, repeat_times=0)
actual_40 = result[0, 40]
assert np.isclose(actual_40, expected_40, atol=1e-5), (
f"Token D (absent): expected {expected_40:.6f}, got {actual_40:.6f}. " f"repeat_times should be 0."
)
# Print summary for debugging
_print_penalty_summary(actual_10, expected_10, "Token A (id=10, prompt only)", 2.0)
_print_penalty_summary(actual_20, expected_20, "Token B (id=20, gen only x2)", -1.0)
_print_penalty_summary(actual_30, expected_30, "Token C (id=30, both prompt+gen)", 3.0)
_print_penalty_summary(actual_40, expected_40, "Token D (id=40, absent)", 0.5)
# ---------------------------------------------------------------------------
# Test 2: Determinism -- same inputs must produce bit-identical outputs
# ---------------------------------------------------------------------------
class TestPenaltyDeterminism:
"""
Run the penalty op multiple times with the same inputs (including tokens
that appear in both prompt and generated regions) and verify that all
outputs are bit-identical.
"""
VOCAB_SIZE = 1000
MAX_MODEL_LEN = 50
PROMPT_LEN = 15
def _build_overlapping_scenario(self, seed=42):
"""
Build a scenario with multiple tokens appearing in both prompt
and generated regions, which is the pattern that triggers the
race condition.
"""
rng = np.random.RandomState(seed)
bs = 1
token_ids = np.full([bs, self.MAX_MODEL_LEN], -1, dtype=np.int64)
# Prompt: random tokens from [0, VOCAB_SIZE)
prompt_tokens = rng.randint(0, self.VOCAB_SIZE, size=self.PROMPT_LEN)
token_ids[0, : self.PROMPT_LEN] = prompt_tokens
# Generated: include some tokens from prompt (overlap!) plus new ones
gen_len = 20
gen_tokens = np.concatenate(
[
# First few generated tokens overlap with prompt tokens
prompt_tokens[:5],
prompt_tokens[:3],
# Remaining are new tokens
rng.randint(0, self.VOCAB_SIZE, size=gen_len - 8),
]
)
token_ids[0, self.PROMPT_LEN : self.PROMPT_LEN + gen_len] = gen_tokens
prompt_lens = np.array([self.PROMPT_LEN], dtype=np.int64)
cur_dec_lens = np.array([gen_len], dtype=np.int64)
logits = rng.randn(bs, self.VOCAB_SIZE).astype(np.float32)
return _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
def test_penalty_determinism(self):
"""
Run the penalty op 20 times with identical inputs containing
overlapping prompt/generated tokens. All results must be
bit-identical (np.array_equal, not np.allclose).
Before the __syncthreads() fix, this would fail sporadically
because the atomicCAS in Pass 1 could race with atomicMax/atomicAdd
in Pass 2 for overlapping token IDs.
"""
inputs = self._build_overlapping_scenario()
_assert_determinism(inputs, num_runs=20, test_name="Penalty determinism")
# ---------------------------------------------------------------------------
# Test 3: Determinism stress -- large token_ids_all with many overlaps
# ---------------------------------------------------------------------------
class TestPenaltyDeterminismStress:
"""
Stress test with large sequences (prompt_len=500, generated=500)
and many overlapping tokens. Runs 50 times to detect rare races.
"""
VOCAB_SIZE = 32000 # Realistic vocab size (LLaMA/Qwen range)
MAX_MODEL_LEN = 1024
PROMPT_LEN = 500
GEN_LEN = 500
NUM_RUNS = 50
def _build_stress_scenario(self, seed=123):
"""
Large scenario designed to maximize race window:
- 500 prompt tokens, 500 generated tokens
- ~40% of generated tokens overlap with prompt tokens
- Multiple batch elements to increase GPU occupancy
"""
rng = np.random.RandomState(seed)
bs = 4 # Multiple batch elements
token_ids = np.full([bs, self.MAX_MODEL_LEN], -1, dtype=np.int64)
for b in range(bs):
# Prompt: random tokens
prompt_tokens = rng.randint(0, self.VOCAB_SIZE, size=self.PROMPT_LEN)
token_ids[b, : self.PROMPT_LEN] = prompt_tokens
# Generated: ~40% overlap with prompt
num_overlap = int(self.GEN_LEN * 0.4)
num_new = self.GEN_LEN - num_overlap
# Pick overlap tokens by sampling from prompt tokens with replacement
overlap_tokens = rng.choice(prompt_tokens, size=num_overlap, replace=True)
new_tokens = rng.randint(0, self.VOCAB_SIZE, size=num_new)
# Interleave overlap and new tokens (shuffle to spread races)
gen_tokens = np.concatenate([overlap_tokens, new_tokens])
rng.shuffle(gen_tokens)
token_ids[b, self.PROMPT_LEN : self.PROMPT_LEN + self.GEN_LEN] = gen_tokens
prompt_lens = np.full([bs], self.PROMPT_LEN, dtype=np.int64)
cur_dec_lens = np.full([bs], self.GEN_LEN, dtype=np.int64)
logits = rng.randn(bs, self.VOCAB_SIZE).astype(np.float32)
return _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
def test_penalty_determinism_stress(self):
"""
Run the penalty op 50 times with large, heavily-overlapping inputs.
All results must be bit-identical.
The large prompt/generated sizes and high overlap ratio (~40%)
maximize the chance of exposing races between the two passes.
With 512 threads per block and 1024 tokens to process, each
thread handles ~2 tokens, creating ample opportunity for
interleaving between Pass 1 (atomicCAS) and Pass 2
(atomicMax + atomicAdd) if the __syncthreads() is missing.
"""
inputs = self._build_stress_scenario()
_assert_determinism(
inputs,
num_runs=self.NUM_RUNS,
test_name=f"Penalty stress (bs=4, prompt_len={self.PROMPT_LEN}, gen_len={self.GEN_LEN})",
)
# ---------------------------------------------------------------------------
# Test 4: Correctness of "both" case repeated many times
# ---------------------------------------------------------------------------
class TestPenaltyBothCaseRepeated:
"""
Targeted test: verify the specific race condition scenario where a token
appears in both prompt and generated regions. Uses distinctive penalty
values so repeat_times=0 and repeat_times=1 produce clearly different outputs.
"""
VOCAB_SIZE = 100
MAX_MODEL_LEN = 20
PROMPT_LEN = 5
def _build_minimal_both_case(self):
"""
Minimal scenario: token 7 appears once in prompt, once in generated.
Uses non-default penalty values to make the outcome more distinct.
"""
bs = 1
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
# Prompt: token 7 at position 0
token_ids[0, 0] = 7
token_ids[0, 1] = 8
token_ids[0, 2] = 9
token_ids[0, 3] = 11
token_ids[0, 4] = 12
# Generated: token 7 at position PROMPT_LEN
token_ids[0, self.PROMPT_LEN] = 7
token_ids[0, self.PROMPT_LEN + 1] = 13
prompt_lens = np.array([self.PROMPT_LEN], dtype=np.int64)
cur_dec_lens = np.array([2], dtype=np.int64)
# Use a distinctive logit value for token 7
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
logits[0, 7] = 5.0
# Use non-trivial penalty values so the output differs significantly
# between repeat_times=0 and repeat_times=1
alpha = 1.5
beta = 0.8
gamma = 0.4
penalty_scores = np.array([[alpha]], dtype=np.float32)
frequency_scores = np.array([[beta]], dtype=np.float32)
presence_scores = np.array([[gamma]], dtype=np.float32)
temperatures = np.array([[1.0]], dtype=np.float32)
inputs = _make_penalty_inputs(
token_ids,
logits,
prompt_lens,
cur_dec_lens,
penalty_scores_np=penalty_scores,
frequency_scores_np=frequency_scores,
presence_scores_np=presence_scores,
temperatures_np=temperatures,
)
# Expected: repeat_times=1 for token 7 (in both prompt and gen)
# logit=5.0 -> positive, so divided by alpha: 5.0/1.5 = 3.333...
# then frequency+presence: 3.333... - 1*0.8 - 0.4 = 2.133...
# then /temperature (1.0): 2.133...
expected = (5.0 / alpha) - 1 * beta - gamma
return inputs, expected
def test_both_case_repeated(self):
"""
Run 100 times, verifying that token 7 always produces consistent output.
Uses a loose tolerance check since we're checking for consistency,
not exact value correctness (that's covered by TestPenaltyCorrectness).
Before the fix, about 1-5% of runs would produce a different output
for the overlapping token due to the race condition.
"""
inputs, expected_value = self._build_minimal_both_case()
num_runs = 100
# First run establishes the reference
reference = _run_penalty(*inputs)[0, 7]
# Verify all subsequent runs match the reference
for i in range(1, num_runs):
value = _run_penalty(*inputs)[0, 7]
if not np.isclose(value, reference, atol=1e-5):
raise AssertionError(
f"Token 7 (in both prompt+gen) produced INCONSISTENT values: "
f"first run={reference:.6f}, run {i}={value:.6f}.\n"
f"This indicates the atomicCAS/atomicMax/atomicAdd race condition."
)
print(
f"\n All {num_runs} runs produced consistent value {reference:.6f} "
f"for the overlapping token (expected ~{expected_value:.6f})."
)
# ---------------------------------------------------------------------------
# Test 5: Edge cases -- boundary conditions
# ---------------------------------------------------------------------------
class TestPenaltyEdgeCases:
"""Test edge cases and boundary conditions."""
VOCAB_SIZE = 100
MAX_MODEL_LEN = 20
def test_empty_generated_tokens(self):
"""Test with no generated tokens (cur_dec_len=0)."""
bs = 1
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
# Only prompt tokens, no generated tokens
token_ids[0, 0] = 10
token_ids[0, 1] = 20
token_ids[0, 2] = 30
prompt_lens = np.array([3], dtype=np.int64)
cur_dec_lens = np.array([0], dtype=np.int64) # Empty generated
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
logits[0, 10] = 1.0
inputs = _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
# Should not crash and should be deterministic
_assert_determinism(inputs, num_runs=10, test_name="Empty generated tokens")
def test_empty_prompt(self):
"""Test with no prompt tokens (prompt_len=0)."""
bs = 1
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
# Only generated tokens
token_ids[0, 0] = 10
token_ids[0, 1] = 20
token_ids[0, 2] = 10
prompt_lens = np.array([0], dtype=np.int64) # Empty prompt
cur_dec_lens = np.array([3], dtype=np.int64)
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
logits[0, 10] = 1.0
inputs = _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
# Should not crash and should be deterministic
_assert_determinism(inputs, num_runs=10, test_name="Empty prompt")
def test_single_token(self):
"""Test with a single generated token."""
bs = 1
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
# Single prompt token
token_ids[0, 0] = 5
# Single generated token (different from prompt)
token_ids[0, 1] = 10
prompt_lens = np.array([1], dtype=np.int64)
cur_dec_lens = np.array([1], dtype=np.int64)
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
logits[0, 10] = 1.0
inputs = _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
# Should not crash and should be deterministic
_assert_determinism(inputs, num_runs=10, test_name="Single token")
def test_no_overlapping_tokens(self):
"""Test with no overlapping tokens between prompt and generated."""
bs = 1
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
# Prompt tokens
token_ids[0, 0] = 10
token_ids[0, 1] = 11
# Generated tokens (no overlap with prompt)
token_ids[0, 2] = 20
token_ids[0, 3] = 21
prompt_lens = np.array([2], dtype=np.int64)
cur_dec_lens = np.array([2], dtype=np.int64)
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
logits[0, 10] = 1.0
logits[0, 20] = -1.0
inputs = _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
# Should not crash and should be deterministic
_assert_determinism(inputs, num_runs=10, test_name="No overlapping tokens")
def test_repeated_same_token_only_generated(self):
"""Test token repeated many times in generated only (no overlap with prompt)."""
bs = 1
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
# Prompt tokens
token_ids[0, 0] = 1
token_ids[0, 1] = 2
token_ids[0, 2] = 3
# Generated: same token repeated 10 times
for i in range(10):
token_ids[0, 3 + i] = 50
prompt_lens = np.array([3], dtype=np.int64)
cur_dec_lens = np.array([10], dtype=np.int64)
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
logits[0, 50] = 2.0
inputs = _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
# Should not crash and should be deterministic
_assert_determinism(inputs, num_runs=10, test_name="Repeated token in generated only")
# ---------------------------------------------------------------------------
# Test 6: Multi-batch determinism
# ---------------------------------------------------------------------------
class TestPenaltyMultiBatch:
"""Test determinism with multiple batch elements."""
VOCAB_SIZE = 500
MAX_MODEL_LEN = 50
BATCH_SIZE = 8
def test_multi_batch_determinism(self):
"""Test with multiple batch elements, each with different patterns."""
rng = np.random.RandomState(42)
bs = self.BATCH_SIZE
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
prompt_lens = []
cur_dec_lens = []
for b in range(bs):
# Each batch element has a different pattern
prompt_len = rng.randint(1, 20)
gen_len = rng.randint(1, 20)
prompt_tokens = rng.randint(0, self.VOCAB_SIZE, size=prompt_len)
token_ids[b, :prompt_len] = prompt_tokens
# Mix of overlapping and new tokens
overlap_count = rng.randint(0, gen_len + 1)
gen_tokens = []
if overlap_count > 0:
gen_tokens.extend(rng.choice(prompt_tokens, size=overlap_count, replace=True))
gen_tokens.extend(rng.randint(0, self.VOCAB_SIZE, size=gen_len - overlap_count))
gen_tokens = gen_tokens[:gen_len] # Ensure correct length
token_ids[b, prompt_len : prompt_len + gen_len] = gen_tokens
prompt_lens.append(prompt_len)
cur_dec_lens.append(gen_len)
prompt_lens = np.array(prompt_lens, dtype=np.int64)
cur_dec_lens = np.array(cur_dec_lens, dtype=np.int64)
logits = rng.randn(bs, self.VOCAB_SIZE).astype(np.float32)
inputs = _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
# Should be deterministic across multiple batches
_assert_determinism(inputs, num_runs=20, test_name="Multi-batch determinism")
if __name__ == "__main__":
pytest.main(["-sv", __file__])
@@ -0,0 +1,159 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit test: isolate sampling determinism from model computation.
This test fixes the logits (model output) and runs only the sampling
pipeline multiple times. If the results differ, the bug is in sampling;
if they are always identical, the non-determinism comes from model
computation (logits differ between runs).
Usage:
CUDA_VISIBLE_DEVICES=0 pytest tests/deterministic/test_sampling_determinism.py -v -s
"""
import paddle
import paddle.nn.functional as F
import pytest
pytestmark = pytest.mark.gpu
VOCAB_SIZE = 151936 # Qwen2 vocab size
BATCH_SIZE = 1
def _make_logits(seed: int = 42):
"""Create reproducible random logits that look like real model output."""
paddle.seed(seed)
# Simulate logits with realistic distribution (not uniform)
logits = paddle.randn([BATCH_SIZE, VOCAB_SIZE], dtype="float32")
# Make it slightly peaked (a few tokens have higher logits)
logits[0, 100] += 5.0
logits[0, 200] += 4.5
logits[0, 300] += 4.0
return logits
def _sample_with_top_p(logits, top_p_val, seed_val):
"""Run the same sampling pipeline as sampler.forward_cuda (non-greedy path)."""
probs = F.softmax(logits, axis=-1)
top_p = paddle.to_tensor([top_p_val], dtype="float32")
topp_seed = paddle.to_tensor([[seed_val]], dtype="int64")
_, ids = paddle.tensor.top_p_sampling(probs, top_p, topp_seed=topp_seed, seed=-1, mode="truncated")
return ids.item()
# ---- Test 1: basic repeated sampling on identical logits ----
def test_sampling_determinism_basic():
"""Same logits + same seed -> must produce same token every time."""
logits = _make_logits(seed=42)
results = [_sample_with_top_p(logits, top_p_val=0.95, seed_val=200) for _ in range(20)]
assert len(set(results)) == 1, f"Sampling non-deterministic! Got {len(set(results))} distinct values: {results}"
# ---- Test 2: simulate multi-step decode (seed increments like real runner) ----
def test_sampling_determinism_multistep():
"""Simulate 100 decode steps with seed incrementing by 4 each step."""
logits = _make_logits(seed=42)
def run_steps():
tokens = []
for step in range(100):
seed_val = 200 + step * 4 # real runner increments seed by 4
tok = _sample_with_top_p(logits, top_p_val=0.95, seed_val=seed_val)
tokens.append(tok)
return tokens
run1 = run_steps()
run2 = run_steps()
assert run1 == run2, _diff_msg(run1, run2)
# ---- Test 3: interleave GPU work between sampling calls ----
def test_sampling_determinism_with_gpu_noise():
"""
Insert GPU matmul work between sampling calls to check if
GPU state residuals affect sampling determinism.
"""
logits = _make_logits(seed=42)
def run_steps_with_noise():
tokens = []
for step in range(50):
# Simulate GPU model forward between steps
_ = paddle.matmul(paddle.randn([256, 256]), paddle.randn([256, 256]))
seed_val = 200 + step * 4
tok = _sample_with_top_p(logits, top_p_val=0.95, seed_val=seed_val)
tokens.append(tok)
return tokens
run1 = run_steps_with_noise()
run2 = run_steps_with_noise()
assert run1 == run2, _diff_msg(run1, run2)
# ---- Test 4: flat distribution (temp=1.0 scenario, hardest case) ----
def test_sampling_determinism_flat_distribution():
"""
Flat probability distribution (simulating temp=1.0 with no dominant token).
This is the hardest case for determinism.
"""
paddle.seed(99)
# Logits close to zero -> softmax gives nearly uniform distribution
logits = paddle.randn([BATCH_SIZE, VOCAB_SIZE], dtype="float32") * 0.1
results_per_seed = {}
for seed_val in [100, 200, 300, 400, 500]:
results = [_sample_with_top_p(logits, top_p_val=0.95, seed_val=seed_val) for _ in range(10)]
results_per_seed[seed_val] = results
assert len(set(results)) == 1, (
f"seed={seed_val}: sampling non-deterministic on flat dist! "
f"Got {len(set(results))} distinct values: {results}"
)
# ---- Test 5: different top_p values ----
@pytest.mark.parametrize("top_p_val", [0.5, 0.8, 0.95, 1.0])
def test_sampling_determinism_various_top_p(top_p_val):
"""Determinism across different top_p values."""
logits = _make_logits(seed=42)
results = [_sample_with_top_p(logits, top_p_val=top_p_val, seed_val=200) for _ in range(10)]
assert len(set(results)) == 1, (
f"top_p={top_p_val}: non-deterministic! " f"Got {len(set(results))} distinct values: {results}"
)
# ---- Helpers ----
def _diff_msg(run1, run2):
for i, (a, b) in enumerate(zip(run1, run2)):
if a != b:
return f"First diff at step {i}: run1={a}, run2={b}. Total diffs: {sum(1 for x, y in zip(run1, run2) if x != y)}/{len(run1)}"
return "Lengths differ"
if __name__ == "__main__":
pytest.main(["-sv", __file__])