Files
FastDeploy/tests/deterministic/test_c16_warp1_4_determinism.py
T
gongweibao 30f9f33f34 [Feature][BugFix][OP] Enhance Deterministic Inference Mode with Kernel-level Fixes and Batch-invariant BMM (#6610)
* add fa deter

* add ut

* add long sentence

* fix basic

* fix bugs

* fix adn

* fix first

* fix single

* fix single

* fix single test

* refine

* add more test

* refine comments

* add comments of bmm

* fix ci

* remove probe

* add

* remove not need

* refine tests

* fix comments and refine code

* refine code

* refine test

* refine test

* mv 4cards tests

* fix tests

* add

* fix comments

* fix cover

* fix cover

---------

Co-authored-by: gongweibao <gognweibao@baidu.com>
2026-03-09 10:27:53 +08:00

800 lines
26 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test suite for the c16 warp1_4 decoder attention kernel determinism.
Background:
The c16 warp1_4 decoder kernel (multiquery_attention_c16_impl.cuh) had two
bugs that caused nondeterministic outputs under FD_DETERMINISTIC_MODE:
1. The warp1_4 dispatcher (lines 1164-1175) lacked the force_no_partition
check, so it still launched the multi-chunk (split) kernel even when
deterministic mode requested the single-chunk (nosplit) path.
2. The nosplit kernel template read runtime num_chunks_this_seq instead of
compile-time partition_kv, causing out-of-bounds nullptr writes to the
partial output buffer (lines 545, 748, 772, 812).
How the c16 warp1_4 path is triggered:
- dim_head=128, blocksize=64, cache_quant_type="none" -> selects c16 kernel
- decoder_block_shape_q=16 -> NUM_WARP_Q=1, i.e. warp1_4 configuration
- seq_lens_encoder=0, seq_lens_decoder>0 -> decoder mode
- FD_DETERMINISTIC_MODE=1 -> forces nosplit path
- Small decoder_max_partition_size (e.g. 64) with long prefill (e.g. 256)
ensures num_chunks > 1, which is the scenario that exposed the bugs.
Test items:
1. test_short_kv_nosplit
- Short KV (num_chunks=1): basic nosplit path with partition_kv=false.
- Verifies both correctness (vs naive reference) and determinism (10 runs).
2. test_long_kv_multi_chunk
- Long KV (num_chunks=4, prefill=256, partition=64): the exact scenario
the fix addresses. partition_kv=true template but grid_chunks=1.
- Verifies correctness and determinism.
3. test_multi_batch
- Multiple batches (batch_size=4) with multi-chunk decoder.
- Ensures the fix works across batch elements, not just single-batch.
4. test_float16
- Float16 dtype with multi-chunk decoder.
- Ensures the fix is dtype-agnostic (not only bfloat16).
5. test_unaligned_seq_len
- prefill_seq_len not divisible by blocksize (100 % 64 != 0).
- Catches off-by-one bugs in block/chunk boundary calculations.
6. test_mha_no_gqa
- MHA config: q_num_head == kv_num_head (no GQA grouping).
- Ensures the fix is not GQA-specific.
7. test_nosplit_vs_split_consistency
- Cross-path check: deterministic nosplit vs non-deterministic split.
- Both paths should produce numerically close results (rtol/atol=1e-2)
and both should match the naive attention reference.
8. test_partition_boundary
- Edge case: prefill_seq_len equals partition_size (boundary condition).
- Tests chunk calculation when num_chunks is exactly an integer.
9. test_empty_kv
- Edge case: decoder-only mode (no prefill, empty KV cache).
- Tests scenario with no encoder prefill phase.
Run:
python -m pytest tests/deterministic/test_c16_warp1_4_determinism.py -v
"""
import copy
import os
import unittest
import numpy as np
import paddle
os.environ["FD_DETERMINISTIC_MODE"] = "1"
from fastdeploy.model_executor.layers.attention.ops import ( # noqa: E402
append_attention,
get_block_shape_and_split_kv_block,
)
SEED = 42
ENCODER_BLOCK_SHAPE_Q = 64
DECODER_BLOCK_SHAPE_Q = 16
def _assert_deterministic_and_correct(results, ref):
"""
Helper to verify determinism and correctness.
Args:
results: List of output arrays from repeated runs
ref: Reference output array from naive attention
"""
# Verify determinism: all runs should produce identical output
for i in range(1, len(results)):
np.testing.assert_array_equal(
results[0],
results[i],
err_msg=f"Determinism failure: run 0 vs run {i}",
)
# Verify correctness: output should match naive reference
np.testing.assert_allclose(results[0], ref, rtol=1e-2, atol=1e-2)
def make_rope_emb(max_seq_len, dim_head, base=10000):
pos = paddle.arange(max_seq_len).reshape((1, -1))
inv_freq = base ** (-paddle.arange(0, dim_head, 2, dtype="float32") / dim_head)
freqs = paddle.einsum("ij,k->ijk", pos.cast("float32"), inv_freq)
emb = freqs.reshape((1, max_seq_len, dim_head // 2)).unsqueeze(2)
rope_emb = paddle.zeros((2, 1, max_seq_len, 1, dim_head // 2), dtype="float32")
rope_emb[0] = paddle.cos(emb)
rope_emb[1] = paddle.sin(emb)
return rope_emb
def apply_rope(x, rope_emb, positions):
"""
Apply rotary position embedding (non-neox interleaved style).
x: (batch, heads, seq_len, dim_head)
rope_emb: (2, 1, max_seq_len, 1, dim_head//2)
positions: list of int, one position per seq index
"""
dim_head = x.shape[-1]
half = dim_head // 2
x_f32 = x.cast("float32")
out = x_f32.clone()
for seq_idx, pos in enumerate(positions):
cos_p = rope_emb[0, 0, pos, 0, :] # (dim_head//2,)
sin_p = rope_emb[1, 0, pos, 0, :]
x_slice = x_f32[:, :, seq_idx, :] # (batch, heads, dim_head)
x_pairs = x_slice.reshape(list(x_slice.shape[:-1]) + [half, 2])
x0 = x_pairs[..., 0] # (batch, heads, half)
x1 = x_pairs[..., 1]
out0 = x0 * cos_p - x1 * sin_p
out1 = x0 * sin_p + x1 * cos_p
out[:, :, seq_idx, :] = paddle.stack([out0, out1], axis=-1).reshape(x_slice.shape)
return out.cast(x.dtype)
def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(max_seq_len - seq_lens_this_time, dtype="int32")
cum_offsets = paddle.zeros(shape=(bsz + 1,), dtype="int32")
cum_offsets[1:] = cum_offsets_now
token_num = int(paddle.sum(seq_lens_this_time))
batch_id_per_token = paddle.zeros(shape=(token_num,), dtype="int32")
cu_seqlens_q = paddle.zeros(shape=(bsz + 1,), dtype="int32")
for i in range(bsz):
sn = int(seq_lens_this_time[i])
co = int(cum_offsets[i])
for j in range(sn):
batch_id_per_token[i * max_seq_len - co + j] = i
cu_seqlens_q[i + 1] = (i + 1) * max_seq_len - int(cum_offsets[i + 1])
return batch_id_per_token, cu_seqlens_q
def naive_attention_impl(query, key, value, cache_k, cache_v, scale):
"""Reference: Q @ K^T * scale -> softmax -> @ V, with GQA expansion."""
batch, heads, seq_len, head_dim = query.shape
kv_head = key.shape[1]
g = heads // kv_head
key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
key = paddle.tile(key, [1, 1, g, 1, 1]).reshape([batch, heads, seq_len, head_dim])
value = value.reshape([batch, kv_head, 1, seq_len, head_dim])
value = paddle.tile(value, [1, 1, g, 1, 1]).reshape([batch, heads, seq_len, head_dim])
if cache_k is not None:
ck = cache_k.reshape([batch, kv_head, 1, -1, head_dim])
ck = paddle.tile(ck, [1, 1, g, 1, 1]).reshape([batch, heads, -1, head_dim])
key = paddle.concat([ck, key], axis=2)
if cache_v is not None:
cv = cache_v.reshape([batch, kv_head, 1, -1, head_dim])
cv = paddle.tile(cv, [1, 1, g, 1, 1]).reshape([batch, heads, -1, head_dim])
value = paddle.concat([cv, value], axis=2)
qk = paddle.matmul(query, key, transpose_y=True) * scale
attn = paddle.nn.functional.softmax(qk, -1)
return paddle.matmul(attn.cast(value.dtype), value)
def block_cache_to_naive(cache_k, cache_v, bsz, block_tables, seq_len):
_, num_head, blocksize, dim_head = cache_k.shape
ok = paddle.zeros([bsz, num_head, seq_len, dim_head], dtype=cache_k.dtype)
ov = paddle.zeros([bsz, num_head, seq_len, dim_head], dtype=cache_v.dtype)
for i in range(bsz):
for j in range(seq_len):
ok[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
ov[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
return ok, ov
def run_c16_warp14_decoder_test(
batch_size,
q_num_head,
kv_num_head,
dim_head,
blocksize,
prefill_seq_len,
max_dec_len,
dtype,
decoder_max_partition_size,
num_decode_runs,
):
"""
Run encoder prefill + N decoder runs.
Returns (list of decoder output numpy arrays, naive reference numpy array).
"""
np.random.seed(SEED)
paddle.seed(SEED)
max_seq_len = prefill_seq_len + max_dec_len
block_per_seq = (max_seq_len + blocksize - 1) // blocksize
max_block_num = block_per_seq * batch_size
scale = 1.0 / np.sqrt(dim_head)
group_size = q_num_head // kv_num_head
compute_type = "bf16" if dtype == "bfloat16" else "fp16"
rope_emb = make_rope_emb(max_seq_len, dim_head)
# Block tables
free_list = list(range(max_block_num - 1, -1, -1))
block_tables = paddle.zeros((batch_size, block_per_seq), dtype="int32")
for i in range(batch_size):
for j in range(block_per_seq):
block_tables[i, j] = free_list.pop()
cache_k = paddle.zeros((max_block_num, kv_num_head, blocksize, dim_head), dtype=dtype)
cache_v = paddle.zeros((max_block_num, kv_num_head, blocksize, dim_head), dtype=dtype)
# Tile metadata buffers, sized per allocate_launch_related_buffer() formula
gqa_ratio = q_num_head // kv_num_head
decode_tile_size = int(1024 * batch_size * np.ceil((2 * gqa_ratio) / DECODER_BLOCK_SHAPE_Q))
encode_tile_size = max(batch_size, batch_size * (max_seq_len * gqa_ratio // ENCODER_BLOCK_SHAPE_Q))
kv_tile_size = max(batch_size, batch_size * (max_seq_len // blocksize))
dec_batch_ids = paddle.full([decode_tile_size], 0, dtype="int32")
dec_tile_ids = paddle.full([decode_tile_size], 0, dtype="int32")
dec_nblocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
dec_nblocks_dev = paddle.full([1], 0, dtype="int32")
dec_chunk_dev = paddle.full([1], decoder_max_partition_size, dtype="int32")
max_len_cpu = paddle.full([8], 0, dtype="int32").cpu()
enc_batch_ids = paddle.full([encode_tile_size], 0, dtype="int32")
enc_tile_ids = paddle.full([encode_tile_size], 0, dtype="int32")
enc_nblocks_cpu = paddle.full([1], 0, dtype="int32").cpu()
kv_batch_ids = paddle.full([kv_tile_size], 0, dtype="int32")
kv_tile_ids = paddle.full([kv_tile_size], 0, dtype="int32")
kv_nblocks_cpu = paddle.full([1], 0, dtype="int32").cpu()
# ===== Encoder phase =====
seq_enc = paddle.full([batch_size], prefill_seq_len, dtype="int32")
seq_dec = paddle.full([batch_size], 0, dtype="int32")
seq_this = copy.deepcopy(seq_enc)
bid_enc, cu_enc = get_padding_offset(batch_size, prefill_seq_len, seq_this)
token_num = batch_size * prefill_seq_len
q_np = np.random.random([batch_size, q_num_head, prefill_seq_len, dim_head]).astype("float32") / 10
k_np = np.random.random([batch_size, kv_num_head, prefill_seq_len, dim_head]).astype("float32") / 10
v_np = np.random.random([batch_size, kv_num_head, prefill_seq_len, dim_head]).astype("float32") / 10
q = paddle.to_tensor(q_np, dtype=dtype)
k = paddle.to_tensor(k_np, dtype=dtype)
v = paddle.to_tensor(v_np, dtype=dtype)
qkv = paddle.concat(
[
q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * dim_head]),
k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
],
axis=1,
)
# Use large partition size for encoder to avoid issues in prefill
encoder_partition_size = 32768
get_block_shape_and_split_kv_block(
seq_enc,
seq_dec,
seq_this,
dec_batch_ids,
dec_tile_ids,
dec_nblocks_cpu,
dec_nblocks_dev,
dec_chunk_dev,
max_len_cpu,
enc_batch_ids,
enc_tile_ids,
enc_nblocks_cpu,
kv_batch_ids,
kv_tile_ids,
kv_nblocks_cpu,
ENCODER_BLOCK_SHAPE_Q,
DECODER_BLOCK_SHAPE_Q,
group_size,
blocksize,
)
append_attention(
qkv,
cache_k,
cache_v,
seq_enc,
seq_dec,
seq_this,
bid_enc,
cu_enc,
block_tables,
enc_batch_ids,
enc_tile_ids,
enc_nblocks_cpu,
kv_batch_ids,
kv_tile_ids,
kv_nblocks_cpu,
dec_batch_ids,
dec_tile_ids,
dec_nblocks_cpu,
max_len_cpu,
rope_emb,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
1e-6,
compute_type,
"none",
False,
False,
max_seq_len,
0.0,
0.0,
-1,
ENCODER_BLOCK_SHAPE_Q,
DECODER_BLOCK_SHAPE_Q,
encoder_partition_size,
encoder_partition_size,
2,
True,
False,
0,
)
paddle.device.synchronize()
# Extract naive KV cache for reference (already has RoPE applied by kernel)
naive_ck, naive_cv = block_cache_to_naive(
cache_k,
cache_v,
batch_size,
block_tables,
prefill_seq_len,
)
# ===== Decoder phase =====
seq_enc_d = paddle.full([batch_size], 0, dtype="int32")
seq_dec_d = paddle.full([batch_size], prefill_seq_len, dtype="int32")
seq_this_d = paddle.full([batch_size], 1, dtype="int32")
bid_dec, cu_dec = get_padding_offset(batch_size, 1, seq_this_d)
dq_np = np.random.random([batch_size, q_num_head, 1, dim_head]).astype("float32") / 10
dk_np = np.random.random([batch_size, kv_num_head, 1, dim_head]).astype("float32") / 10
dv_np = np.random.random([batch_size, kv_num_head, 1, dim_head]).astype("float32") / 10
dq = paddle.to_tensor(dq_np, dtype=dtype)
dk = paddle.to_tensor(dk_np, dtype=dtype)
dv = paddle.to_tensor(dv_np, dtype=dtype)
dec_qkv = paddle.concat(
[
dq.transpose([0, 2, 1, 3]).reshape([batch_size, q_num_head * dim_head]),
dk.transpose([0, 2, 1, 3]).reshape([batch_size, kv_num_head * dim_head]),
dv.transpose([0, 2, 1, 3]).reshape([batch_size, kv_num_head * dim_head]),
],
axis=1,
)
# Warmup: first decoder call on multi-chunk path may return zeros
# due to kernel JIT compilation. Run once and discard.
get_block_shape_and_split_kv_block(
seq_enc_d,
seq_dec_d,
seq_this_d,
dec_batch_ids,
dec_tile_ids,
dec_nblocks_cpu,
dec_nblocks_dev,
dec_chunk_dev,
max_len_cpu,
enc_batch_ids,
enc_tile_ids,
enc_nblocks_cpu,
kv_batch_ids,
kv_tile_ids,
kv_nblocks_cpu,
ENCODER_BLOCK_SHAPE_Q,
DECODER_BLOCK_SHAPE_Q,
group_size,
blocksize,
)
append_attention(
dec_qkv.clone(),
cache_k.clone(),
cache_v.clone(),
seq_enc_d,
seq_dec_d,
seq_this_d,
bid_dec,
cu_dec,
block_tables,
enc_batch_ids,
enc_tile_ids,
enc_nblocks_cpu,
kv_batch_ids,
kv_tile_ids,
kv_nblocks_cpu,
dec_batch_ids,
dec_tile_ids,
dec_nblocks_cpu,
max_len_cpu,
rope_emb,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
1e-6,
compute_type,
"none",
False,
False,
max_seq_len,
0.0,
0.0,
-1,
ENCODER_BLOCK_SHAPE_Q,
DECODER_BLOCK_SHAPE_Q,
decoder_max_partition_size,
encoder_partition_size,
2,
True,
False,
0,
)
paddle.device.synchronize()
results = []
for _ in range(num_decode_runs):
cache_k_c = cache_k.clone()
cache_v_c = cache_v.clone()
qkv_c = dec_qkv.clone()
get_block_shape_and_split_kv_block(
seq_enc_d,
seq_dec_d,
seq_this_d,
dec_batch_ids,
dec_tile_ids,
dec_nblocks_cpu,
dec_nblocks_dev,
dec_chunk_dev,
max_len_cpu,
enc_batch_ids,
enc_tile_ids,
enc_nblocks_cpu,
kv_batch_ids,
kv_tile_ids,
kv_nblocks_cpu,
ENCODER_BLOCK_SHAPE_Q,
DECODER_BLOCK_SHAPE_Q,
group_size,
blocksize,
)
out = append_attention(
qkv_c,
cache_k_c,
cache_v_c,
seq_enc_d,
seq_dec_d,
seq_this_d,
bid_dec,
cu_dec,
block_tables,
enc_batch_ids,
enc_tile_ids,
enc_nblocks_cpu,
kv_batch_ids,
kv_tile_ids,
kv_nblocks_cpu,
dec_batch_ids,
dec_tile_ids,
dec_nblocks_cpu,
max_len_cpu,
rope_emb,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
1e-6,
compute_type,
"none",
False,
False,
max_seq_len,
0.0,
0.0,
-1,
ENCODER_BLOCK_SHAPE_Q,
DECODER_BLOCK_SHAPE_Q,
decoder_max_partition_size,
encoder_partition_size,
2,
True,
False,
0,
)
paddle.device.synchronize()
results.append(out.numpy().copy())
# Naive reference: apply RoPE to decoder Q/K at position prefill_seq_len
# (cached K/V already have RoPE applied by the kernel during encoder phase)
dq_rope = apply_rope(dq, rope_emb, [prefill_seq_len])
dk_rope = apply_rope(dk, rope_emb, [prefill_seq_len])
ref = naive_attention_impl(dq_rope, dk_rope, dv, naive_ck, naive_cv, scale)
ref_np = ref.transpose([0, 2, 1, 3]).reshape([batch_size, q_num_head * dim_head]).numpy()
return results, ref_np
class TestC16Warp14Determinism(unittest.TestCase):
"""
Test the c16 warp1_4 decoder kernel under FD_DETERMINISTIC_MODE=1.
Verifies:
1. Correctness: output matches naive attention reference (rtol/atol=1e-2)
2. Determinism: repeated runs with identical input -> bitwise-identical output
"""
def test_short_kv_nosplit(self):
"""num_chunks=1 (short KV): basic nosplit path, partition_kv=false template."""
results, ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=64,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=32768,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_long_kv_multi_chunk(self):
"""
num_chunks=4 (prefill=256, partition=64): the exact scenario the fix addresses.
partition_kv=true template but grid_chunks=1 (deterministic).
"""
results, ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=256,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_multi_batch(self):
"""Multiple batches with multi-chunk decoder."""
results, ref = run_c16_warp14_decoder_test(
batch_size=4,
q_num_head=8,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=256,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_float16(self):
"""Float16 dtype with multi-chunk decoder."""
results, ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=256,
max_dec_len=32,
dtype="float16",
decoder_max_partition_size=64,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_unaligned_seq_len(self):
"""prefill_seq_len not divisible by blocksize (100 % 64 != 0)."""
results, ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=100,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_mha_no_gqa(self):
"""MHA: q_num_head == kv_num_head (no GQA grouping)."""
results, ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=8,
kv_num_head=8,
dim_head=128,
blocksize=64,
prefill_seq_len=128,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_nosplit_vs_split_consistency(self):
"""
Cross-path check: force_no_partition (deterministic) vs split (partitioned)
should produce numerically close results.
The two paths differ only in floating-point accumulation order:
- nosplit: single chunk, sequential accumulation
- split: multi-chunk parallel, then merge
Difference should be within low-precision rounding (rtol/atol=1e-2 for bf16).
Note: envs.FD_DETERMINISTIC_MODE is lazily evaluated via __getattr__,
so runtime changes to os.environ["FD_DETERMINISTIC_MODE"] take effect.
"""
# Run once in deterministic mode (already set at module level)
det_results, det_ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=256,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=1,
)
# Run once in non-deterministic mode (split/partitioned path)
os.environ["FD_DETERMINISTIC_MODE"] = "0"
try:
split_results, split_ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=256,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=1,
)
finally:
os.environ["FD_DETERMINISTIC_MODE"] = "1"
# Both paths should match the naive reference
np.testing.assert_allclose(
det_results[0], det_ref, rtol=1e-2, atol=1e-2, err_msg="Deterministic path vs naive reference"
)
np.testing.assert_allclose(
split_results[0], split_ref, rtol=1e-2, atol=1e-2, err_msg="Split path vs naive reference"
)
# Cross-path: nosplit vs split should be close
np.testing.assert_allclose(
det_results[0],
split_results[0],
rtol=1e-2,
atol=1e-2,
err_msg="Deterministic (nosplit) vs split path divergence",
)
def test_partition_boundary(self):
"""
Edge case: prefill_seq_len equals partition_size (boundary condition).
This tests the scenario where num_chunks = prefill_seq_len / partition_size
is exactly an integer, which is a boundary condition for chunk calculation.
"""
results, ref = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=128,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=128,
num_decode_runs=10,
)
_assert_deterministic_and_correct(results, ref)
def test_empty_kv(self):
"""
Edge case: decoder-only mode (no prefill, empty KV cache).
This tests the scenario where there's no encoder prefill phase,
only decoder with empty KV cache.
"""
results, _ = run_c16_warp14_decoder_test(
batch_size=1,
q_num_head=16,
kv_num_head=2,
dim_head=128,
blocksize=64,
prefill_seq_len=0,
max_dec_len=32,
dtype="bfloat16",
decoder_max_partition_size=64,
num_decode_runs=10,
)
# For empty KV, we only check determinism as naive reference may not be valid
for i in range(1, len(results)):
np.testing.assert_array_equal(
results[0],
results[i],
err_msg=f"Determinism failure: run 0 vs run {i}",
)
if __name__ == "__main__":
unittest.main()