mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature][BugFix][OP] Enhance Deterministic Inference Mode with Kernel-level Fixes and Batch-invariant BMM (#6610)
* add fa deter * add ut * add long sentence * fix basic * fix bugs * fix adn * fix first * fix single * fix single * fix single test * refine * add more test * refine comments * add comments of bmm * fix ci * remove probe * add * remove not need * refine tests * fix comments and refine code * refine code * refine test * refine test * mv 4cards tests * fix tests * add * fix comments * fix cover * fix cover --------- Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
@@ -0,0 +1,799 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Test suite for the c16 warp1_4 decoder attention kernel determinism.
|
||||
|
||||
Background:
|
||||
The c16 warp1_4 decoder kernel (multiquery_attention_c16_impl.cuh) had two
|
||||
bugs that caused nondeterministic outputs under FD_DETERMINISTIC_MODE:
|
||||
1. The warp1_4 dispatcher (lines 1164-1175) lacked the force_no_partition
|
||||
check, so it still launched the multi-chunk (split) kernel even when
|
||||
deterministic mode requested the single-chunk (nosplit) path.
|
||||
2. The nosplit kernel template read runtime num_chunks_this_seq instead of
|
||||
compile-time partition_kv, causing out-of-bounds nullptr writes to the
|
||||
partial output buffer (lines 545, 748, 772, 812).
|
||||
|
||||
How the c16 warp1_4 path is triggered:
|
||||
- dim_head=128, blocksize=64, cache_quant_type="none" -> selects c16 kernel
|
||||
- decoder_block_shape_q=16 -> NUM_WARP_Q=1, i.e. warp1_4 configuration
|
||||
- seq_lens_encoder=0, seq_lens_decoder>0 -> decoder mode
|
||||
- FD_DETERMINISTIC_MODE=1 -> forces nosplit path
|
||||
- Small decoder_max_partition_size (e.g. 64) with long prefill (e.g. 256)
|
||||
ensures num_chunks > 1, which is the scenario that exposed the bugs.
|
||||
|
||||
Test items:
|
||||
1. test_short_kv_nosplit
|
||||
- Short KV (num_chunks=1): basic nosplit path with partition_kv=false.
|
||||
- Verifies both correctness (vs naive reference) and determinism (10 runs).
|
||||
|
||||
2. test_long_kv_multi_chunk
|
||||
- Long KV (num_chunks=4, prefill=256, partition=64): the exact scenario
|
||||
the fix addresses. partition_kv=true template but grid_chunks=1.
|
||||
- Verifies correctness and determinism.
|
||||
|
||||
3. test_multi_batch
|
||||
- Multiple batches (batch_size=4) with multi-chunk decoder.
|
||||
- Ensures the fix works across batch elements, not just single-batch.
|
||||
|
||||
4. test_float16
|
||||
- Float16 dtype with multi-chunk decoder.
|
||||
- Ensures the fix is dtype-agnostic (not only bfloat16).
|
||||
|
||||
5. test_unaligned_seq_len
|
||||
- prefill_seq_len not divisible by blocksize (100 % 64 != 0).
|
||||
- Catches off-by-one bugs in block/chunk boundary calculations.
|
||||
|
||||
6. test_mha_no_gqa
|
||||
- MHA config: q_num_head == kv_num_head (no GQA grouping).
|
||||
- Ensures the fix is not GQA-specific.
|
||||
|
||||
7. test_nosplit_vs_split_consistency
|
||||
- Cross-path check: deterministic nosplit vs non-deterministic split.
|
||||
- Both paths should produce numerically close results (rtol/atol=1e-2)
|
||||
and both should match the naive attention reference.
|
||||
|
||||
8. test_partition_boundary
|
||||
- Edge case: prefill_seq_len equals partition_size (boundary condition).
|
||||
- Tests chunk calculation when num_chunks is exactly an integer.
|
||||
|
||||
9. test_empty_kv
|
||||
- Edge case: decoder-only mode (no prefill, empty KV cache).
|
||||
- Tests scenario with no encoder prefill phase.
|
||||
|
||||
Run:
|
||||
python -m pytest tests/deterministic/test_c16_warp1_4_determinism.py -v
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
os.environ["FD_DETERMINISTIC_MODE"] = "1"
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import ( # noqa: E402
|
||||
append_attention,
|
||||
get_block_shape_and_split_kv_block,
|
||||
)
|
||||
|
||||
SEED = 42
|
||||
|
||||
ENCODER_BLOCK_SHAPE_Q = 64
|
||||
DECODER_BLOCK_SHAPE_Q = 16
|
||||
|
||||
|
||||
def _assert_deterministic_and_correct(results, ref):
|
||||
"""
|
||||
Helper to verify determinism and correctness.
|
||||
|
||||
Args:
|
||||
results: List of output arrays from repeated runs
|
||||
ref: Reference output array from naive attention
|
||||
"""
|
||||
# Verify determinism: all runs should produce identical output
|
||||
for i in range(1, len(results)):
|
||||
np.testing.assert_array_equal(
|
||||
results[0],
|
||||
results[i],
|
||||
err_msg=f"Determinism failure: run 0 vs run {i}",
|
||||
)
|
||||
# Verify correctness: output should match naive reference
|
||||
np.testing.assert_allclose(results[0], ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
def make_rope_emb(max_seq_len, dim_head, base=10000):
|
||||
pos = paddle.arange(max_seq_len).reshape((1, -1))
|
||||
inv_freq = base ** (-paddle.arange(0, dim_head, 2, dtype="float32") / dim_head)
|
||||
freqs = paddle.einsum("ij,k->ijk", pos.cast("float32"), inv_freq)
|
||||
emb = freqs.reshape((1, max_seq_len, dim_head // 2)).unsqueeze(2)
|
||||
rope_emb = paddle.zeros((2, 1, max_seq_len, 1, dim_head // 2), dtype="float32")
|
||||
rope_emb[0] = paddle.cos(emb)
|
||||
rope_emb[1] = paddle.sin(emb)
|
||||
return rope_emb
|
||||
|
||||
|
||||
def apply_rope(x, rope_emb, positions):
|
||||
"""
|
||||
Apply rotary position embedding (non-neox interleaved style).
|
||||
|
||||
x: (batch, heads, seq_len, dim_head)
|
||||
rope_emb: (2, 1, max_seq_len, 1, dim_head//2)
|
||||
positions: list of int, one position per seq index
|
||||
"""
|
||||
dim_head = x.shape[-1]
|
||||
half = dim_head // 2
|
||||
x_f32 = x.cast("float32")
|
||||
out = x_f32.clone()
|
||||
|
||||
for seq_idx, pos in enumerate(positions):
|
||||
cos_p = rope_emb[0, 0, pos, 0, :] # (dim_head//2,)
|
||||
sin_p = rope_emb[1, 0, pos, 0, :]
|
||||
|
||||
x_slice = x_f32[:, :, seq_idx, :] # (batch, heads, dim_head)
|
||||
x_pairs = x_slice.reshape(list(x_slice.shape[:-1]) + [half, 2])
|
||||
x0 = x_pairs[..., 0] # (batch, heads, half)
|
||||
x1 = x_pairs[..., 1]
|
||||
|
||||
out0 = x0 * cos_p - x1 * sin_p
|
||||
out1 = x0 * sin_p + x1 * cos_p
|
||||
|
||||
out[:, :, seq_idx, :] = paddle.stack([out0, out1], axis=-1).reshape(x_slice.shape)
|
||||
|
||||
return out.cast(x.dtype)
|
||||
|
||||
|
||||
def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
|
||||
cum_offsets_now = paddle.cumsum(max_seq_len - seq_lens_this_time, dtype="int32")
|
||||
cum_offsets = paddle.zeros(shape=(bsz + 1,), dtype="int32")
|
||||
cum_offsets[1:] = cum_offsets_now
|
||||
token_num = int(paddle.sum(seq_lens_this_time))
|
||||
batch_id_per_token = paddle.zeros(shape=(token_num,), dtype="int32")
|
||||
cu_seqlens_q = paddle.zeros(shape=(bsz + 1,), dtype="int32")
|
||||
for i in range(bsz):
|
||||
sn = int(seq_lens_this_time[i])
|
||||
co = int(cum_offsets[i])
|
||||
for j in range(sn):
|
||||
batch_id_per_token[i * max_seq_len - co + j] = i
|
||||
cu_seqlens_q[i + 1] = (i + 1) * max_seq_len - int(cum_offsets[i + 1])
|
||||
return batch_id_per_token, cu_seqlens_q
|
||||
|
||||
|
||||
def naive_attention_impl(query, key, value, cache_k, cache_v, scale):
|
||||
"""Reference: Q @ K^T * scale -> softmax -> @ V, with GQA expansion."""
|
||||
batch, heads, seq_len, head_dim = query.shape
|
||||
kv_head = key.shape[1]
|
||||
g = heads // kv_head
|
||||
|
||||
key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
|
||||
key = paddle.tile(key, [1, 1, g, 1, 1]).reshape([batch, heads, seq_len, head_dim])
|
||||
value = value.reshape([batch, kv_head, 1, seq_len, head_dim])
|
||||
value = paddle.tile(value, [1, 1, g, 1, 1]).reshape([batch, heads, seq_len, head_dim])
|
||||
|
||||
if cache_k is not None:
|
||||
ck = cache_k.reshape([batch, kv_head, 1, -1, head_dim])
|
||||
ck = paddle.tile(ck, [1, 1, g, 1, 1]).reshape([batch, heads, -1, head_dim])
|
||||
key = paddle.concat([ck, key], axis=2)
|
||||
if cache_v is not None:
|
||||
cv = cache_v.reshape([batch, kv_head, 1, -1, head_dim])
|
||||
cv = paddle.tile(cv, [1, 1, g, 1, 1]).reshape([batch, heads, -1, head_dim])
|
||||
value = paddle.concat([cv, value], axis=2)
|
||||
|
||||
qk = paddle.matmul(query, key, transpose_y=True) * scale
|
||||
attn = paddle.nn.functional.softmax(qk, -1)
|
||||
return paddle.matmul(attn.cast(value.dtype), value)
|
||||
|
||||
|
||||
def block_cache_to_naive(cache_k, cache_v, bsz, block_tables, seq_len):
|
||||
_, num_head, blocksize, dim_head = cache_k.shape
|
||||
ok = paddle.zeros([bsz, num_head, seq_len, dim_head], dtype=cache_k.dtype)
|
||||
ov = paddle.zeros([bsz, num_head, seq_len, dim_head], dtype=cache_v.dtype)
|
||||
for i in range(bsz):
|
||||
for j in range(seq_len):
|
||||
ok[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||||
ov[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||||
return ok, ov
|
||||
|
||||
|
||||
def run_c16_warp14_decoder_test(
|
||||
batch_size,
|
||||
q_num_head,
|
||||
kv_num_head,
|
||||
dim_head,
|
||||
blocksize,
|
||||
prefill_seq_len,
|
||||
max_dec_len,
|
||||
dtype,
|
||||
decoder_max_partition_size,
|
||||
num_decode_runs,
|
||||
):
|
||||
"""
|
||||
Run encoder prefill + N decoder runs.
|
||||
Returns (list of decoder output numpy arrays, naive reference numpy array).
|
||||
"""
|
||||
np.random.seed(SEED)
|
||||
paddle.seed(SEED)
|
||||
|
||||
max_seq_len = prefill_seq_len + max_dec_len
|
||||
block_per_seq = (max_seq_len + blocksize - 1) // blocksize
|
||||
max_block_num = block_per_seq * batch_size
|
||||
scale = 1.0 / np.sqrt(dim_head)
|
||||
group_size = q_num_head // kv_num_head
|
||||
compute_type = "bf16" if dtype == "bfloat16" else "fp16"
|
||||
|
||||
rope_emb = make_rope_emb(max_seq_len, dim_head)
|
||||
|
||||
# Block tables
|
||||
free_list = list(range(max_block_num - 1, -1, -1))
|
||||
block_tables = paddle.zeros((batch_size, block_per_seq), dtype="int32")
|
||||
for i in range(batch_size):
|
||||
for j in range(block_per_seq):
|
||||
block_tables[i, j] = free_list.pop()
|
||||
|
||||
cache_k = paddle.zeros((max_block_num, kv_num_head, blocksize, dim_head), dtype=dtype)
|
||||
cache_v = paddle.zeros((max_block_num, kv_num_head, blocksize, dim_head), dtype=dtype)
|
||||
|
||||
# Tile metadata buffers, sized per allocate_launch_related_buffer() formula
|
||||
gqa_ratio = q_num_head // kv_num_head
|
||||
decode_tile_size = int(1024 * batch_size * np.ceil((2 * gqa_ratio) / DECODER_BLOCK_SHAPE_Q))
|
||||
encode_tile_size = max(batch_size, batch_size * (max_seq_len * gqa_ratio // ENCODER_BLOCK_SHAPE_Q))
|
||||
kv_tile_size = max(batch_size, batch_size * (max_seq_len // blocksize))
|
||||
|
||||
dec_batch_ids = paddle.full([decode_tile_size], 0, dtype="int32")
|
||||
dec_tile_ids = paddle.full([decode_tile_size], 0, dtype="int32")
|
||||
dec_nblocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
dec_nblocks_dev = paddle.full([1], 0, dtype="int32")
|
||||
dec_chunk_dev = paddle.full([1], decoder_max_partition_size, dtype="int32")
|
||||
max_len_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||
enc_batch_ids = paddle.full([encode_tile_size], 0, dtype="int32")
|
||||
enc_tile_ids = paddle.full([encode_tile_size], 0, dtype="int32")
|
||||
enc_nblocks_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||
kv_batch_ids = paddle.full([kv_tile_size], 0, dtype="int32")
|
||||
kv_tile_ids = paddle.full([kv_tile_size], 0, dtype="int32")
|
||||
kv_nblocks_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
# ===== Encoder phase =====
|
||||
seq_enc = paddle.full([batch_size], prefill_seq_len, dtype="int32")
|
||||
seq_dec = paddle.full([batch_size], 0, dtype="int32")
|
||||
seq_this = copy.deepcopy(seq_enc)
|
||||
bid_enc, cu_enc = get_padding_offset(batch_size, prefill_seq_len, seq_this)
|
||||
token_num = batch_size * prefill_seq_len
|
||||
|
||||
q_np = np.random.random([batch_size, q_num_head, prefill_seq_len, dim_head]).astype("float32") / 10
|
||||
k_np = np.random.random([batch_size, kv_num_head, prefill_seq_len, dim_head]).astype("float32") / 10
|
||||
v_np = np.random.random([batch_size, kv_num_head, prefill_seq_len, dim_head]).astype("float32") / 10
|
||||
|
||||
q = paddle.to_tensor(q_np, dtype=dtype)
|
||||
k = paddle.to_tensor(k_np, dtype=dtype)
|
||||
v = paddle.to_tensor(v_np, dtype=dtype)
|
||||
qkv = paddle.concat(
|
||||
[
|
||||
q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * dim_head]),
|
||||
k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
|
||||
v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Use large partition size for encoder to avoid issues in prefill
|
||||
encoder_partition_size = 32768
|
||||
|
||||
get_block_shape_and_split_kv_block(
|
||||
seq_enc,
|
||||
seq_dec,
|
||||
seq_this,
|
||||
dec_batch_ids,
|
||||
dec_tile_ids,
|
||||
dec_nblocks_cpu,
|
||||
dec_nblocks_dev,
|
||||
dec_chunk_dev,
|
||||
max_len_cpu,
|
||||
enc_batch_ids,
|
||||
enc_tile_ids,
|
||||
enc_nblocks_cpu,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids,
|
||||
kv_nblocks_cpu,
|
||||
ENCODER_BLOCK_SHAPE_Q,
|
||||
DECODER_BLOCK_SHAPE_Q,
|
||||
group_size,
|
||||
blocksize,
|
||||
)
|
||||
|
||||
append_attention(
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
seq_enc,
|
||||
seq_dec,
|
||||
seq_this,
|
||||
bid_enc,
|
||||
cu_enc,
|
||||
block_tables,
|
||||
enc_batch_ids,
|
||||
enc_tile_ids,
|
||||
enc_nblocks_cpu,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids,
|
||||
kv_nblocks_cpu,
|
||||
dec_batch_ids,
|
||||
dec_tile_ids,
|
||||
dec_nblocks_cpu,
|
||||
max_len_cpu,
|
||||
rope_emb,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
1e-6,
|
||||
compute_type,
|
||||
"none",
|
||||
False,
|
||||
False,
|
||||
max_seq_len,
|
||||
0.0,
|
||||
0.0,
|
||||
-1,
|
||||
ENCODER_BLOCK_SHAPE_Q,
|
||||
DECODER_BLOCK_SHAPE_Q,
|
||||
encoder_partition_size,
|
||||
encoder_partition_size,
|
||||
2,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
)
|
||||
paddle.device.synchronize()
|
||||
|
||||
# Extract naive KV cache for reference (already has RoPE applied by kernel)
|
||||
naive_ck, naive_cv = block_cache_to_naive(
|
||||
cache_k,
|
||||
cache_v,
|
||||
batch_size,
|
||||
block_tables,
|
||||
prefill_seq_len,
|
||||
)
|
||||
|
||||
# ===== Decoder phase =====
|
||||
seq_enc_d = paddle.full([batch_size], 0, dtype="int32")
|
||||
seq_dec_d = paddle.full([batch_size], prefill_seq_len, dtype="int32")
|
||||
seq_this_d = paddle.full([batch_size], 1, dtype="int32")
|
||||
bid_dec, cu_dec = get_padding_offset(batch_size, 1, seq_this_d)
|
||||
|
||||
dq_np = np.random.random([batch_size, q_num_head, 1, dim_head]).astype("float32") / 10
|
||||
dk_np = np.random.random([batch_size, kv_num_head, 1, dim_head]).astype("float32") / 10
|
||||
dv_np = np.random.random([batch_size, kv_num_head, 1, dim_head]).astype("float32") / 10
|
||||
dq = paddle.to_tensor(dq_np, dtype=dtype)
|
||||
dk = paddle.to_tensor(dk_np, dtype=dtype)
|
||||
dv = paddle.to_tensor(dv_np, dtype=dtype)
|
||||
dec_qkv = paddle.concat(
|
||||
[
|
||||
dq.transpose([0, 2, 1, 3]).reshape([batch_size, q_num_head * dim_head]),
|
||||
dk.transpose([0, 2, 1, 3]).reshape([batch_size, kv_num_head * dim_head]),
|
||||
dv.transpose([0, 2, 1, 3]).reshape([batch_size, kv_num_head * dim_head]),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Warmup: first decoder call on multi-chunk path may return zeros
|
||||
# due to kernel JIT compilation. Run once and discard.
|
||||
get_block_shape_and_split_kv_block(
|
||||
seq_enc_d,
|
||||
seq_dec_d,
|
||||
seq_this_d,
|
||||
dec_batch_ids,
|
||||
dec_tile_ids,
|
||||
dec_nblocks_cpu,
|
||||
dec_nblocks_dev,
|
||||
dec_chunk_dev,
|
||||
max_len_cpu,
|
||||
enc_batch_ids,
|
||||
enc_tile_ids,
|
||||
enc_nblocks_cpu,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids,
|
||||
kv_nblocks_cpu,
|
||||
ENCODER_BLOCK_SHAPE_Q,
|
||||
DECODER_BLOCK_SHAPE_Q,
|
||||
group_size,
|
||||
blocksize,
|
||||
)
|
||||
append_attention(
|
||||
dec_qkv.clone(),
|
||||
cache_k.clone(),
|
||||
cache_v.clone(),
|
||||
seq_enc_d,
|
||||
seq_dec_d,
|
||||
seq_this_d,
|
||||
bid_dec,
|
||||
cu_dec,
|
||||
block_tables,
|
||||
enc_batch_ids,
|
||||
enc_tile_ids,
|
||||
enc_nblocks_cpu,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids,
|
||||
kv_nblocks_cpu,
|
||||
dec_batch_ids,
|
||||
dec_tile_ids,
|
||||
dec_nblocks_cpu,
|
||||
max_len_cpu,
|
||||
rope_emb,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
1e-6,
|
||||
compute_type,
|
||||
"none",
|
||||
False,
|
||||
False,
|
||||
max_seq_len,
|
||||
0.0,
|
||||
0.0,
|
||||
-1,
|
||||
ENCODER_BLOCK_SHAPE_Q,
|
||||
DECODER_BLOCK_SHAPE_Q,
|
||||
decoder_max_partition_size,
|
||||
encoder_partition_size,
|
||||
2,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
)
|
||||
paddle.device.synchronize()
|
||||
|
||||
results = []
|
||||
for _ in range(num_decode_runs):
|
||||
cache_k_c = cache_k.clone()
|
||||
cache_v_c = cache_v.clone()
|
||||
qkv_c = dec_qkv.clone()
|
||||
|
||||
get_block_shape_and_split_kv_block(
|
||||
seq_enc_d,
|
||||
seq_dec_d,
|
||||
seq_this_d,
|
||||
dec_batch_ids,
|
||||
dec_tile_ids,
|
||||
dec_nblocks_cpu,
|
||||
dec_nblocks_dev,
|
||||
dec_chunk_dev,
|
||||
max_len_cpu,
|
||||
enc_batch_ids,
|
||||
enc_tile_ids,
|
||||
enc_nblocks_cpu,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids,
|
||||
kv_nblocks_cpu,
|
||||
ENCODER_BLOCK_SHAPE_Q,
|
||||
DECODER_BLOCK_SHAPE_Q,
|
||||
group_size,
|
||||
blocksize,
|
||||
)
|
||||
|
||||
out = append_attention(
|
||||
qkv_c,
|
||||
cache_k_c,
|
||||
cache_v_c,
|
||||
seq_enc_d,
|
||||
seq_dec_d,
|
||||
seq_this_d,
|
||||
bid_dec,
|
||||
cu_dec,
|
||||
block_tables,
|
||||
enc_batch_ids,
|
||||
enc_tile_ids,
|
||||
enc_nblocks_cpu,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids,
|
||||
kv_nblocks_cpu,
|
||||
dec_batch_ids,
|
||||
dec_tile_ids,
|
||||
dec_nblocks_cpu,
|
||||
max_len_cpu,
|
||||
rope_emb,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
1e-6,
|
||||
compute_type,
|
||||
"none",
|
||||
False,
|
||||
False,
|
||||
max_seq_len,
|
||||
0.0,
|
||||
0.0,
|
||||
-1,
|
||||
ENCODER_BLOCK_SHAPE_Q,
|
||||
DECODER_BLOCK_SHAPE_Q,
|
||||
decoder_max_partition_size,
|
||||
encoder_partition_size,
|
||||
2,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
)
|
||||
paddle.device.synchronize()
|
||||
results.append(out.numpy().copy())
|
||||
|
||||
# Naive reference: apply RoPE to decoder Q/K at position prefill_seq_len
|
||||
# (cached K/V already have RoPE applied by the kernel during encoder phase)
|
||||
dq_rope = apply_rope(dq, rope_emb, [prefill_seq_len])
|
||||
dk_rope = apply_rope(dk, rope_emb, [prefill_seq_len])
|
||||
ref = naive_attention_impl(dq_rope, dk_rope, dv, naive_ck, naive_cv, scale)
|
||||
ref_np = ref.transpose([0, 2, 1, 3]).reshape([batch_size, q_num_head * dim_head]).numpy()
|
||||
|
||||
return results, ref_np
|
||||
|
||||
|
||||
class TestC16Warp14Determinism(unittest.TestCase):
|
||||
"""
|
||||
Test the c16 warp1_4 decoder kernel under FD_DETERMINISTIC_MODE=1.
|
||||
|
||||
Verifies:
|
||||
1. Correctness: output matches naive attention reference (rtol/atol=1e-2)
|
||||
2. Determinism: repeated runs with identical input -> bitwise-identical output
|
||||
"""
|
||||
|
||||
def test_short_kv_nosplit(self):
|
||||
"""num_chunks=1 (short KV): basic nosplit path, partition_kv=false template."""
|
||||
results, ref = run_c16_warp14_decoder_test(
|
||||
batch_size=1,
|
||||
q_num_head=16,
|
||||
kv_num_head=2,
|
||||
dim_head=128,
|
||||
blocksize=64,
|
||||
prefill_seq_len=64,
|
||||
max_dec_len=32,
|
||||
dtype="bfloat16",
|
||||
decoder_max_partition_size=32768,
|
||||
num_decode_runs=10,
|
||||
)
|
||||
_assert_deterministic_and_correct(results, ref)
|
||||
|
||||
def test_long_kv_multi_chunk(self):
|
||||
"""
|
||||
num_chunks=4 (prefill=256, partition=64): the exact scenario the fix addresses.
|
||||
partition_kv=true template but grid_chunks=1 (deterministic).
|
||||
"""
|
||||
results, ref = run_c16_warp14_decoder_test(
|
||||
batch_size=1,
|
||||
q_num_head=16,
|
||||
kv_num_head=2,
|
||||
dim_head=128,
|
||||
blocksize=64,
|
||||
prefill_seq_len=256,
|
||||
max_dec_len=32,
|
||||
dtype="bfloat16",
|
||||
decoder_max_partition_size=64,
|
||||
num_decode_runs=10,
|
||||
)
|
||||
_assert_deterministic_and_correct(results, ref)
|
||||
|
||||
def test_multi_batch(self):
|
||||
"""Multiple batches with multi-chunk decoder."""
|
||||
results, ref = run_c16_warp14_decoder_test(
|
||||
batch_size=4,
|
||||
q_num_head=8,
|
||||
kv_num_head=2,
|
||||
dim_head=128,
|
||||
blocksize=64,
|
||||
prefill_seq_len=256,
|
||||
max_dec_len=32,
|
||||
dtype="bfloat16",
|
||||
decoder_max_partition_size=64,
|
||||
num_decode_runs=10,
|
||||
)
|
||||
_assert_deterministic_and_correct(results, ref)
|
||||
|
||||
def test_float16(self):
|
||||
"""Float16 dtype with multi-chunk decoder."""
|
||||
results, ref = run_c16_warp14_decoder_test(
|
||||
batch_size=1,
|
||||
q_num_head=16,
|
||||
kv_num_head=2,
|
||||
dim_head=128,
|
||||
blocksize=64,
|
||||
prefill_seq_len=256,
|
||||
max_dec_len=32,
|
||||
dtype="float16",
|
||||
decoder_max_partition_size=64,
|
||||
num_decode_runs=10,
|
||||
)
|
||||
_assert_deterministic_and_correct(results, ref)
|
||||
|
||||
def test_unaligned_seq_len(self):
|
||||
"""prefill_seq_len not divisible by blocksize (100 % 64 != 0)."""
|
||||
results, ref = run_c16_warp14_decoder_test(
|
||||
batch_size=1,
|
||||
q_num_head=16,
|
||||
kv_num_head=2,
|
||||
dim_head=128,
|
||||
blocksize=64,
|
||||
prefill_seq_len=100,
|
||||
max_dec_len=32,
|
||||
dtype="bfloat16",
|
||||
decoder_max_partition_size=64,
|
||||
num_decode_runs=10,
|
||||
)
|
||||
_assert_deterministic_and_correct(results, ref)
|
||||
|
||||
def test_mha_no_gqa(self):
|
||||
"""MHA: q_num_head == kv_num_head (no GQA grouping)."""
|
||||
results, ref = run_c16_warp14_decoder_test(
|
||||
batch_size=1,
|
||||
q_num_head=8,
|
||||
kv_num_head=8,
|
||||
dim_head=128,
|
||||
blocksize=64,
|
||||
prefill_seq_len=128,
|
||||
max_dec_len=32,
|
||||
dtype="bfloat16",
|
||||
decoder_max_partition_size=64,
|
||||
num_decode_runs=10,
|
||||
)
|
||||
_assert_deterministic_and_correct(results, ref)
|
||||
|
||||
def test_nosplit_vs_split_consistency(self):
|
||||
"""
|
||||
Cross-path check: force_no_partition (deterministic) vs split (partitioned)
|
||||
should produce numerically close results.
|
||||
|
||||
The two paths differ only in floating-point accumulation order:
|
||||
- nosplit: single chunk, sequential accumulation
|
||||
- split: multi-chunk parallel, then merge
|
||||
Difference should be within low-precision rounding (rtol/atol=1e-2 for bf16).
|
||||
|
||||
Note: envs.FD_DETERMINISTIC_MODE is lazily evaluated via __getattr__,
|
||||
so runtime changes to os.environ["FD_DETERMINISTIC_MODE"] take effect.
|
||||
"""
|
||||
# Run once in deterministic mode (already set at module level)
|
||||
det_results, det_ref = run_c16_warp14_decoder_test(
|
||||
batch_size=1,
|
||||
q_num_head=16,
|
||||
kv_num_head=2,
|
||||
dim_head=128,
|
||||
blocksize=64,
|
||||
prefill_seq_len=256,
|
||||
max_dec_len=32,
|
||||
dtype="bfloat16",
|
||||
decoder_max_partition_size=64,
|
||||
num_decode_runs=1,
|
||||
)
|
||||
|
||||
# Run once in non-deterministic mode (split/partitioned path)
|
||||
os.environ["FD_DETERMINISTIC_MODE"] = "0"
|
||||
try:
|
||||
split_results, split_ref = run_c16_warp14_decoder_test(
|
||||
batch_size=1,
|
||||
q_num_head=16,
|
||||
kv_num_head=2,
|
||||
dim_head=128,
|
||||
blocksize=64,
|
||||
prefill_seq_len=256,
|
||||
max_dec_len=32,
|
||||
dtype="bfloat16",
|
||||
decoder_max_partition_size=64,
|
||||
num_decode_runs=1,
|
||||
)
|
||||
finally:
|
||||
os.environ["FD_DETERMINISTIC_MODE"] = "1"
|
||||
|
||||
# Both paths should match the naive reference
|
||||
np.testing.assert_allclose(
|
||||
det_results[0], det_ref, rtol=1e-2, atol=1e-2, err_msg="Deterministic path vs naive reference"
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
split_results[0], split_ref, rtol=1e-2, atol=1e-2, err_msg="Split path vs naive reference"
|
||||
)
|
||||
|
||||
# Cross-path: nosplit vs split should be close
|
||||
np.testing.assert_allclose(
|
||||
det_results[0],
|
||||
split_results[0],
|
||||
rtol=1e-2,
|
||||
atol=1e-2,
|
||||
err_msg="Deterministic (nosplit) vs split path divergence",
|
||||
)
|
||||
|
||||
def test_partition_boundary(self):
|
||||
"""
|
||||
Edge case: prefill_seq_len equals partition_size (boundary condition).
|
||||
|
||||
This tests the scenario where num_chunks = prefill_seq_len / partition_size
|
||||
is exactly an integer, which is a boundary condition for chunk calculation.
|
||||
"""
|
||||
results, ref = run_c16_warp14_decoder_test(
|
||||
batch_size=1,
|
||||
q_num_head=16,
|
||||
kv_num_head=2,
|
||||
dim_head=128,
|
||||
blocksize=64,
|
||||
prefill_seq_len=128,
|
||||
max_dec_len=32,
|
||||
dtype="bfloat16",
|
||||
decoder_max_partition_size=128,
|
||||
num_decode_runs=10,
|
||||
)
|
||||
_assert_deterministic_and_correct(results, ref)
|
||||
|
||||
def test_empty_kv(self):
|
||||
"""
|
||||
Edge case: decoder-only mode (no prefill, empty KV cache).
|
||||
|
||||
This tests the scenario where there's no encoder prefill phase,
|
||||
only decoder with empty KV cache.
|
||||
"""
|
||||
results, _ = run_c16_warp14_decoder_test(
|
||||
batch_size=1,
|
||||
q_num_head=16,
|
||||
kv_num_head=2,
|
||||
dim_head=128,
|
||||
blocksize=64,
|
||||
prefill_seq_len=0,
|
||||
max_dec_len=32,
|
||||
dtype="bfloat16",
|
||||
decoder_max_partition_size=64,
|
||||
num_decode_runs=10,
|
||||
)
|
||||
# For empty KV, we only check determinism as naive reference may not be valid
|
||||
for i in range(1, len(results)):
|
||||
np.testing.assert_array_equal(
|
||||
results[0],
|
||||
results[i],
|
||||
err_msg=f"Determinism failure: run 0 vs run {i}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,360 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Determinism offline inference tests using LLM.generate
|
||||
|
||||
Test scenarios:
|
||||
1. Same-prompt repeatability (FD_DETERMINISTIC_MODE=1)
|
||||
2. Batch invariance (single vs. batch, different positions)
|
||||
3. Different batch sizes consistency
|
||||
4. Sampling-parameter combinations (temperature x top_p, parametrized)
|
||||
5. Long sequence generation (512-1024 tokens)
|
||||
6. Long input prompt handling
|
||||
7. Minimal output (max_tokens=1, early stop)
|
||||
8. Special characters & multi-language prompts
|
||||
9. Multi-turn conversation
|
||||
10. State isolation (interleaved / interference prompts)
|
||||
11. Non-deterministic validation (proves tests are effective)
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 pytest tests/deterministic/test_determinism_offline.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.gpu
|
||||
|
||||
DEFAULT_MODEL_DIR = "./models"
|
||||
MODEL_NAME = "Qwen2-7B-Instruct"
|
||||
|
||||
_ENV_CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
|
||||
_ENV_FD_DETERMINISTIC_MODE = "FD_DETERMINISTIC_MODE"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def _module_env():
|
||||
"""Set env vars before importing fastdeploy (must happen first)."""
|
||||
old_cuda = os.environ.get(_ENV_CUDA_VISIBLE_DEVICES)
|
||||
old_det = os.environ.get(_ENV_FD_DETERMINISTIC_MODE)
|
||||
|
||||
os.environ[_ENV_CUDA_VISIBLE_DEVICES] = os.environ.get(_ENV_CUDA_VISIBLE_DEVICES, "0")
|
||||
os.environ[_ENV_FD_DETERMINISTIC_MODE] = "1"
|
||||
|
||||
global LLM, SamplingParams # noqa: PLW0603
|
||||
from fastdeploy import LLM, SamplingParams
|
||||
|
||||
yield
|
||||
|
||||
if old_cuda is None:
|
||||
os.environ.pop(_ENV_CUDA_VISIBLE_DEVICES, None)
|
||||
else:
|
||||
os.environ[_ENV_CUDA_VISIBLE_DEVICES] = old_cuda
|
||||
if old_det is None:
|
||||
os.environ.pop(_ENV_FD_DETERMINISTIC_MODE, None)
|
||||
else:
|
||||
os.environ[_ENV_FD_DETERMINISTIC_MODE] = old_det
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_deterministic_mode():
|
||||
"""Ensure every test starts with deterministic mode ON."""
|
||||
os.environ[_ENV_FD_DETERMINISTIC_MODE] = "1"
|
||||
yield
|
||||
os.environ[_ENV_FD_DETERMINISTIC_MODE] = "1"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_path():
|
||||
model_dir = os.getenv("MODEL_PATH", DEFAULT_MODEL_DIR)
|
||||
return os.path.join(model_dir, MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm(model_path, _module_env):
|
||||
return LLM(
|
||||
model=model_path,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=8192,
|
||||
enable_prefix_caching=False,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _generate_text(llm, prompt, sp):
|
||||
"""Generate once, return (text, token_ids)."""
|
||||
out = llm.generate([prompt], sp)[0]
|
||||
return out.outputs.text, out.outputs.token_ids
|
||||
|
||||
|
||||
def _assert_deterministic(llm, prompt, sp, runs=2):
|
||||
"""Run *runs* times and assert all outputs are identical."""
|
||||
results = [_generate_text(llm, prompt, sp) for _ in range(runs)]
|
||||
texts = [r[0] for r in results]
|
||||
token_ids = [r[1] for r in results]
|
||||
assert all(t == texts[0] for t in texts), "Text outputs differ across runs"
|
||||
assert all(t == token_ids[0] for t in token_ids), "Token IDs differ across runs"
|
||||
return texts[0], token_ids[0]
|
||||
|
||||
|
||||
# ===================== Core determinism tests =====================
|
||||
|
||||
|
||||
def test_deterministic_same_prompt(llm):
|
||||
"""Same prompt + same seed produces identical output across 5 runs."""
|
||||
sp = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=50, seed=123)
|
||||
_assert_deterministic(llm, "Please introduce artificial intelligence in one sentence.", sp, runs=5)
|
||||
|
||||
|
||||
def test_deterministic_batch_invariance(llm):
|
||||
"""Target prompt produces identical output regardless of batch position."""
|
||||
prompt = "What kind of programming language is Python?"
|
||||
sp = SamplingParams(temperature=0.5, max_tokens=40, seed=456)
|
||||
|
||||
baseline, _ = _generate_text(llm, prompt, sp)
|
||||
|
||||
batch_configs = [
|
||||
[prompt, "Filler question 1"],
|
||||
["Filler question 2", prompt, "Filler question 3"],
|
||||
["Filler question 4", "Filler question 5", prompt],
|
||||
["Filler 6", "Filler 7", "Filler 8", prompt],
|
||||
]
|
||||
|
||||
for i, batch in enumerate(batch_configs):
|
||||
outputs = llm.generate(batch, sp)
|
||||
idx = batch.index(prompt)
|
||||
assert (
|
||||
outputs[idx].outputs.text == baseline
|
||||
), f"Batch config {i} (pos {idx}): result differs from single-request baseline"
|
||||
|
||||
|
||||
def test_deterministic_different_batch_sizes(llm):
|
||||
"""Same prompt is consistent across batch sizes 1 / 2 / 4 / 8."""
|
||||
prompt = "What is machine learning?"
|
||||
sp = SamplingParams(temperature=0.5, max_tokens=30, seed=789)
|
||||
|
||||
baseline, _ = _generate_text(llm, prompt, sp)
|
||||
|
||||
for bs in [2, 4, 8]:
|
||||
outputs = llm.generate([prompt] * bs, sp)
|
||||
assert outputs[0].outputs.text == baseline, f"Batch size {bs} differs from bs=1"
|
||||
|
||||
|
||||
# ===================== Sampling-parameter combinations =====================
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp,top_p,seed",
|
||||
[
|
||||
(0.0, 1.0, 300), # greedy, no top_p filter
|
||||
(0.0, 0.0, 301), # double-greedy
|
||||
(0.3, 0.9, 302), # low temp, moderate top_p
|
||||
(0.8, 0.0, 303), # medium temp, greedy top_p
|
||||
(0.8, 1.0, 304), # medium temp, no top_p filter
|
||||
(0.8, 0.5, 305), # medium temp, strict top_p
|
||||
(1.0, 0.95, 306), # high temp
|
||||
(1.5, 0.9, 307), # very high temp
|
||||
],
|
||||
)
|
||||
def test_deterministic_param_combos(llm, temp, top_p, seed):
|
||||
"""Determinism holds across various (temperature, top_p) combinations."""
|
||||
sp = SamplingParams(temperature=temp, top_p=top_p, max_tokens=30, seed=seed)
|
||||
_assert_deterministic(llm, "What is a neural network?", sp)
|
||||
|
||||
|
||||
# ===================== Long sequence tests =====================
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp,seed",
|
||||
[
|
||||
(0.0, 100),
|
||||
(0.3, 130),
|
||||
(0.5, 150),
|
||||
(0.7, 170),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skip(reason="Potential non-determinism in long sequences, will be fixed by gongweibao in next PR")
|
||||
def test_deterministic_long_sequence(llm, temp, seed):
|
||||
"""Long generation (512+ tokens) stays deterministic at various temperatures."""
|
||||
prompt = "Please describe the history of AI in detail, including major milestones and key technical breakthroughs."
|
||||
sp = SamplingParams(temperature=temp, top_p=0.95, max_tokens=512, seed=seed)
|
||||
|
||||
text, token_ids = _assert_deterministic(llm, prompt, sp)
|
||||
assert len(token_ids) >= 100, f"Expected >= 100 tokens, got {len(token_ids)}"
|
||||
|
||||
|
||||
def test_deterministic_long_prompt(llm):
|
||||
"""Long input prompt (prefill-heavy) stays deterministic."""
|
||||
base = "This is a description about natural language processing. "
|
||||
long_prompt = (base * 50) + "Please summarize the above."
|
||||
sp = SamplingParams(temperature=0.5, max_tokens=100, seed=2024)
|
||||
|
||||
_assert_deterministic(llm, long_prompt, sp)
|
||||
|
||||
|
||||
# ===================== Minimal / boundary output tests =====================
|
||||
|
||||
|
||||
def test_deterministic_max_tokens_one(llm):
|
||||
"""Single-token output is deterministic."""
|
||||
sp = SamplingParams(temperature=0.1, max_tokens=1, seed=700)
|
||||
|
||||
text, token_ids = _assert_deterministic(llm, "What color is the sky?", sp)
|
||||
assert len(token_ids) == 1, f"Expected 1 token, got {len(token_ids)}"
|
||||
|
||||
|
||||
def test_deterministic_early_stop(llm):
|
||||
"""Early stopping via stop sequences is deterministic."""
|
||||
sp = SamplingParams(temperature=0.7, max_tokens=100, stop=["\u3002", "."], seed=800)
|
||||
|
||||
text, token_ids = _assert_deterministic(llm, "Please list three colors:", sp)
|
||||
assert len(token_ids) < 100, f"Expected early stop, got {len(token_ids)} tokens"
|
||||
|
||||
|
||||
# ===================== Special input tests =====================
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,seed",
|
||||
[
|
||||
("What is AI? \U0001f52c\U0001f9e0", 900), # emoji
|
||||
("Math: E = mc\u00b2", 901), # superscript
|
||||
("Code: def hello(): return 'world'", 902), # code
|
||||
("Symbols: @#$%^&*()", 903), # special symbols
|
||||
],
|
||||
)
|
||||
def test_deterministic_special_chars(llm, prompt, seed):
|
||||
sp = SamplingParams(temperature=0.5, max_tokens=30, seed=seed)
|
||||
_assert_deterministic(llm, prompt, sp)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lang,prompt,seed",
|
||||
[
|
||||
("Chinese", "Please introduce artificial intelligence in one sentence.", 1000),
|
||||
("English", "What is artificial intelligence in one sentence?", 1001),
|
||||
(
|
||||
"Japanese",
|
||||
"\u4eba\u5de5\u77e5\u80fd\u306b\u3064\u3044\u3066\u4e00\u8a00\u3067\u8aac\u660e\u3057\u3066\u304f\u3060\u3055\u3044\u3002",
|
||||
1002,
|
||||
),
|
||||
("Spanish", "\u00bfQu\u00e9 es la inteligencia artificial en una frase?", 1003),
|
||||
],
|
||||
)
|
||||
def test_deterministic_multi_language(llm, lang, prompt, seed):
|
||||
sp = SamplingParams(temperature=0.5, max_tokens=30, seed=seed)
|
||||
_assert_deterministic(llm, prompt, sp)
|
||||
|
||||
|
||||
# ===================== Multi-turn conversation test =====================
|
||||
|
||||
|
||||
def test_deterministic_multi_turn(llm):
|
||||
"""Multi-turn chat maintains determinism."""
|
||||
sp = SamplingParams(temperature=0.5, max_tokens=50, seed=1100)
|
||||
|
||||
messages1 = [
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi! How can I help you?"},
|
||||
{"role": "user", "content": "Please introduce yourself."},
|
||||
]
|
||||
|
||||
# First full conversation
|
||||
r1_turn1 = llm.chat(messages1, sp)[0].outputs.text
|
||||
msgs2 = messages1 + [
|
||||
{"role": "assistant", "content": r1_turn1},
|
||||
{"role": "user", "content": "What can you do?"},
|
||||
]
|
||||
r1_turn2 = llm.chat(msgs2, sp)[0].outputs.text
|
||||
|
||||
# Second full conversation (same seed)
|
||||
r2_turn1 = llm.chat(messages1, sp)[0].outputs.text
|
||||
msgs2_repeat = messages1 + [
|
||||
{"role": "assistant", "content": r2_turn1},
|
||||
{"role": "user", "content": "What can you do?"},
|
||||
]
|
||||
r2_turn2 = llm.chat(msgs2_repeat, sp)[0].outputs.text
|
||||
|
||||
assert r1_turn1 == r2_turn1, "Multi-turn: turn-1 outputs differ"
|
||||
assert r1_turn2 == r2_turn2, "Multi-turn: turn-2 outputs differ"
|
||||
|
||||
|
||||
# ===================== State isolation test =====================
|
||||
|
||||
|
||||
def test_deterministic_state_isolation(llm):
|
||||
"""Interference prompts and interleaving do not break determinism."""
|
||||
prompt_a = "What is Python?"
|
||||
prompt_b = "What is JavaScript?"
|
||||
sp_a = SamplingParams(temperature=0.5, max_tokens=30, seed=1200)
|
||||
sp_b = SamplingParams(temperature=0.5, max_tokens=30, seed=1201)
|
||||
|
||||
# Round 1
|
||||
a1, _ = _generate_text(llm, prompt_a, sp_a)
|
||||
b1, _ = _generate_text(llm, prompt_b, sp_b)
|
||||
|
||||
# Run unrelated interference
|
||||
for p in ["Explain reinforcement learning.", "What is NLP?", "List 3 fruits."]:
|
||||
llm.generate([p], SamplingParams(temperature=0.7, max_tokens=20, seed=999))
|
||||
|
||||
# Round 2
|
||||
a2, _ = _generate_text(llm, prompt_a, sp_a)
|
||||
b2, _ = _generate_text(llm, prompt_b, sp_b)
|
||||
|
||||
assert a1 == a2, "Prompt A: output changed after interference"
|
||||
assert b1 == b2, "Prompt B: output changed after interference"
|
||||
|
||||
|
||||
# ===================== Non-deterministic validation =====================
|
||||
|
||||
|
||||
def test_non_deterministic_validation(llm):
|
||||
"""
|
||||
Prove that tests are effective:
|
||||
- Without seed + without mode: outputs vary
|
||||
- With explicit seed: outputs are consistent
|
||||
"""
|
||||
prompt = "Please explain deep learning in one sentence."
|
||||
|
||||
# Part 1: no mode, no seed -> outputs should differ
|
||||
os.environ.pop("FD_DETERMINISTIC_MODE", None)
|
||||
results_no_seed = []
|
||||
for _ in range(5):
|
||||
sp = SamplingParams(temperature=0.7, max_tokens=30)
|
||||
results_no_seed.append(llm.generate([prompt], sp)[0].outputs.text)
|
||||
|
||||
# Probabilistic, skip if all outputs are the same
|
||||
if len(set(results_no_seed)) == 1:
|
||||
pytest.skip("Sampling produced identical outputs (probabilistic case)")
|
||||
|
||||
# Part 2: explicit seed -> outputs must be consistent
|
||||
sp_seeded = SamplingParams(temperature=0.7, max_tokens=30, seed=999)
|
||||
results_seeded = [llm.generate([prompt], sp_seeded)[0].outputs.text for _ in range(5)]
|
||||
assert len(set(results_seeded)) == 1, "With explicit seed: expected consistent outputs"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-sv", __file__])
|
||||
@@ -0,0 +1,129 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Single-GPU determinism offline inference tests for coverage.
|
||||
|
||||
Simplified from tests/e2e/4cards_cases/test_determinism_offline.py
|
||||
for single-GPU coverage testing.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 pytest tests/deterministic/test_determinism_offline_single_gpu.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.gpu
|
||||
|
||||
DEFAULT_MODEL_DIR = "./models"
|
||||
MODEL_NAME = "Qwen2-7B-Instruct"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def env_override(mapping):
|
||||
"""Temporarily set env vars, restoring original values on exit."""
|
||||
old = {k: os.environ.get(k) for k in mapping}
|
||||
os.environ.update(mapping)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for k, v in old.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_path():
|
||||
model_dir = os.getenv("MODEL_PATH", DEFAULT_MODEL_DIR)
|
||||
return os.path.join(model_dir, MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_deterministic_mode():
|
||||
"""Ensure every test starts with deterministic mode ON."""
|
||||
os.environ["FD_DETERMINISTIC_MODE"] = "1"
|
||||
yield
|
||||
os.environ["FD_DETERMINISTIC_MODE"] = "1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def _module_env():
|
||||
"""Set env vars before importing fastdeploy (must happen first)."""
|
||||
with env_override(
|
||||
{
|
||||
"CUDA_VISIBLE_DEVICES": os.environ.get("CUDA_VISIBLE_DEVICES", "0"),
|
||||
"FD_DETERMINISTIC_MODE": "1",
|
||||
"FD_CUSTOM_AR_MAX_SIZE_MB": "64",
|
||||
}
|
||||
):
|
||||
# Lazy import: env vars must be set before importing fastdeploy
|
||||
global LLM, SamplingParams # noqa: PLW0603
|
||||
from fastdeploy import LLM, SamplingParams
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm(model_path, _module_env):
|
||||
return LLM(
|
||||
model=model_path,
|
||||
tensor_parallel_size=1, # Single GPU
|
||||
max_model_len=4096,
|
||||
enable_prefix_caching=False,
|
||||
graph_optimization_config={"use_cudagraph": False},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _generate_text(llm, prompt, sp):
|
||||
"""Generate once, return (text, token_ids)."""
|
||||
out = llm.generate([prompt], sp)[0]
|
||||
return out.outputs.text, list(out.outputs.token_ids)
|
||||
|
||||
|
||||
def _assert_deterministic(llm, prompt, sp, runs=2):
|
||||
"""Run *runs* times and assert all outputs are identical."""
|
||||
results = [_generate_text(llm, prompt, sp) for _ in range(runs)]
|
||||
texts = [r[0] for r in results]
|
||||
token_ids = [r[1] for r in results]
|
||||
assert all(t == texts[0] for t in texts), "Text outputs differ across runs"
|
||||
assert all(t == token_ids[0] for t in token_ids), "Token IDs differ across runs"
|
||||
return texts[0], token_ids[0]
|
||||
|
||||
|
||||
# ===================== Core determinism tests =====================
|
||||
|
||||
|
||||
def test_deterministic_same_prompt(llm):
|
||||
"""Same prompt + same seed produces identical output across 3 runs."""
|
||||
sp = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=30, seed=123)
|
||||
_assert_deterministic(llm, "What is AI?", sp, runs=3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-sv", __file__])
|
||||
@@ -0,0 +1,216 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Unit tests for flash_attn_func deterministic mode.
|
||||
|
||||
Verifies that flash_attn_func passes correct deterministic parameters
|
||||
(e.g. num_splits=1 for FA3) when FD_DETERMINISTIC_MODE=1.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 pytest tests/deterministic/test_flash_attn_determinism.py -v
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.gpu
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
NUM_HEADS = 8
|
||||
KV_NUM_HEADS = 8
|
||||
HEAD_DIM = 128
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_sm_version():
|
||||
prop = paddle.device.cuda.get_device_properties()
|
||||
return prop.major * 10 + prop.minor
|
||||
|
||||
|
||||
def _reload_flash_attn_backend():
|
||||
"""Reload flash_attn_backend so env-var changes take effect."""
|
||||
import fastdeploy.model_executor.layers.attention.flash_attn_backend as mod
|
||||
|
||||
importlib.reload(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def _make_tensors(seq_lens, num_heads=NUM_HEADS, head_dim=HEAD_DIM):
|
||||
"""Create Q/K/V tensors and cu_seqlens for a batch of sequences."""
|
||||
total_tokens = sum(seq_lens)
|
||||
q = paddle.randn([total_tokens, num_heads, head_dim], dtype="bfloat16")
|
||||
k = paddle.randn([total_tokens, num_heads, head_dim], dtype="bfloat16")
|
||||
v = paddle.randn([total_tokens, num_heads, head_dim], dtype="bfloat16")
|
||||
cu_seqlens = paddle.to_tensor(np.array([0] + list(np.cumsum(seq_lens))), dtype="int32")
|
||||
max_seqlen = max(seq_lens)
|
||||
return q, k, v, cu_seqlens, max_seqlen
|
||||
|
||||
|
||||
def _call_flash_attn_func(mod, q, k, v, cu_seqlens, max_seqlen, version=None):
|
||||
"""Call flash_attn_func and return the output tensor."""
|
||||
result = mod.flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
causal=True,
|
||||
num_heads=NUM_HEADS,
|
||||
kv_num_heads=KV_NUM_HEADS,
|
||||
head_dim=HEAD_DIM,
|
||||
version=version,
|
||||
)
|
||||
if isinstance(result, tuple):
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
|
||||
def _run_determinism_check(mod, seq_lens, runs, version, test_name):
|
||||
"""Run flash_attn_func multiple times and verify deterministic output."""
|
||||
q, k, v, cu_seqlens, max_seqlen = _make_tensors(seq_lens)
|
||||
|
||||
outputs = []
|
||||
for _ in range(runs):
|
||||
out = _call_flash_attn_func(mod, q, k, v, cu_seqlens, max_seqlen, version=version)
|
||||
outputs.append(out.numpy())
|
||||
|
||||
for i in range(1, runs):
|
||||
assert np.array_equal(outputs[0], outputs[i]), f"{test_name}: run {i} differs from run 0"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_env():
|
||||
"""Save and restore determinism-related env vars around every test."""
|
||||
keys = ["FD_DETERMINISTIC_MODE", "FD_DETERMINISTIC_DEBUG"]
|
||||
saved = {k: os.environ.get(k) for k in keys}
|
||||
yield
|
||||
for k, v in saved.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _deterministic_mode_enabled():
|
||||
"""Enable deterministic mode and return reloaded module."""
|
||||
os.environ["FD_DETERMINISTIC_MODE"] = "1"
|
||||
return _reload_flash_attn_backend()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _nondeterministic_mode_enabled():
|
||||
"""Disable deterministic mode and return reloaded module."""
|
||||
os.environ["FD_DETERMINISTIC_MODE"] = "0"
|
||||
return _reload_flash_attn_backend()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _is_deterministic_mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsDeterministicMode:
|
||||
"""Test the _is_deterministic_mode helper."""
|
||||
|
||||
def test_enabled(self):
|
||||
os.environ["FD_DETERMINISTIC_MODE"] = "1"
|
||||
mod = _reload_flash_attn_backend()
|
||||
assert mod._is_deterministic_mode() is True
|
||||
|
||||
def test_disabled(self):
|
||||
os.environ["FD_DETERMINISTIC_MODE"] = "0"
|
||||
mod = _reload_flash_attn_backend()
|
||||
assert mod._is_deterministic_mode() is False
|
||||
|
||||
def test_unset_defaults_false(self):
|
||||
os.environ.pop("FD_DETERMINISTIC_MODE", None)
|
||||
mod = _reload_flash_attn_backend()
|
||||
assert mod._is_deterministic_mode() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: FA3 determinism (requires SM89+, <SM100)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFA3Determinism:
|
||||
"""Test FA3 deterministic behavior with num_splits control."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_fa3(self):
|
||||
sm = _get_sm_version()
|
||||
if sm < 89 or sm >= 100:
|
||||
pytest.skip(f"FA3 requires SM89-99, current SM={sm}")
|
||||
paddle.set_flags({"FLAGS_flash_attn_version": 3})
|
||||
|
||||
def test_deterministic_produces_identical_output(self, _deterministic_mode_enabled):
|
||||
"""num_splits=1 (deterministic) gives bitwise identical results."""
|
||||
_run_determinism_check(_deterministic_mode_enabled, [64, 128, 256], 5, 3, "FA3 deterministic")
|
||||
|
||||
def test_long_sequence_determinism(self, _deterministic_mode_enabled):
|
||||
"""Long sequences (>1024 tokens) remain deterministic with FA3."""
|
||||
_run_determinism_check(_deterministic_mode_enabled, [2048], 3, 3, "FA3 long seq")
|
||||
|
||||
def test_mixed_batch_determinism(self, _deterministic_mode_enabled):
|
||||
"""Mixed batch with varying sequence lengths stays deterministic."""
|
||||
_run_determinism_check(_deterministic_mode_enabled, [16, 512, 1024, 64], 3, 3, "FA3 mixed batch")
|
||||
|
||||
def test_nondeterministic_mode_also_works(self, _nondeterministic_mode_enabled):
|
||||
"""FD_DETERMINISTIC_MODE=0 still works (num_splits=1 is always used)."""
|
||||
q, k, v, cu_seqlens, max_seqlen = _make_tensors([256])
|
||||
out = _call_flash_attn_func(_nondeterministic_mode_enabled, q, k, v, cu_seqlens, max_seqlen, version=3)
|
||||
assert out.shape[0] == 256
|
||||
assert out.shape[1] == NUM_HEADS
|
||||
assert out.shape[2] == HEAD_DIM
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: FA2 determinism
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFA2Determinism:
|
||||
"""Test FA2 deterministic behavior (inherently deterministic forward)."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _set_fa2(self):
|
||||
paddle.set_flags({"FLAGS_flash_attn_version": 2})
|
||||
|
||||
def test_fa2_deterministic(self, _deterministic_mode_enabled):
|
||||
"""FA2 forward is inherently deterministic (no split-KV)."""
|
||||
_run_determinism_check(_deterministic_mode_enabled, [128, 256], 5, 2, "FA2 deterministic")
|
||||
|
||||
def test_fa2_long_sequence(self, _deterministic_mode_enabled):
|
||||
"""FA2 with long sequence remains deterministic."""
|
||||
_run_determinism_check(_deterministic_mode_enabled, [2048], 3, 2, "FA2 long seq")
|
||||
@@ -0,0 +1,823 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Unit tests for the `update_repeat_times` kernel in token_penalty_multi_scores.cu.
|
||||
|
||||
Tests verify both **correctness** and **determinism** of the penalty kernel,
|
||||
specifically targeting the race condition that was fixed by splitting into
|
||||
two passes with a __syncthreads() barrier.
|
||||
|
||||
Race condition background:
|
||||
The `update_repeat_times` kernel processes token_ids_all to build a
|
||||
repeat_times array with these semantics:
|
||||
0 -> token absent from both prompt and generated tokens
|
||||
-1 -> token only in the prompt
|
||||
>=1 -> count of occurrences in generated tokens
|
||||
|
||||
The original (buggy) code ran both passes in a single loop without
|
||||
synchronization. When a token appeared in BOTH the prompt and generated
|
||||
portions, three atomic operations could interleave across warps:
|
||||
Pass 1: atomicCAS(&slot, 0, -1) -- mark as prompt-only
|
||||
Pass 2: atomicMax(&slot, 0) -- lift from -1 to 0
|
||||
Pass 2: atomicAdd(&slot, 1) -- count generated occurrence
|
||||
|
||||
Without __syncthreads() between passes, a thread in Pass 2 could execute
|
||||
atomicMax/atomicAdd BEFORE another thread in Pass 1 executed atomicCAS,
|
||||
resulting in repeat_times=0 (should be 1). This caused non-deterministic
|
||||
penalty application.
|
||||
|
||||
The fix: split into two explicit passes with __syncthreads() between them.
|
||||
|
||||
Usage:
|
||||
source /root/paddlejob/workspace/env_run/gongweibao/archfd/fdarchenv/bin/activate
|
||||
FD_DETERMINISTIC_MODE=1 CUDA_VISIBLE_DEVICES=0 pytest tests/deterministic/test_penalty_kernel_determinism.py -v -s
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Set deterministic mode before any imports that might read it.
|
||||
os.environ["FD_DETERMINISTIC_MODE"] = "1"
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import pytest
|
||||
|
||||
# Import the custom op. This goes through the fastdeploy ops import machinery
|
||||
# which loads the compiled CUDA custom op and exposes it as a Python callable.
|
||||
from fastdeploy.model_executor.ops.gpu import get_token_penalty_multi_scores
|
||||
|
||||
pytestmark = pytest.mark.gpu
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Default penalty parameters (matching production defaults)
|
||||
DEFAULT_ALPHA = 1.2 # repetition_penalty
|
||||
DEFAULT_BETA = 0.5 # frequency_penalty
|
||||
DEFAULT_GAMMA = 0.3 # presence_penalty
|
||||
DEFAULT_TEMP = 1.0 # temperature
|
||||
|
||||
# Token sentinels
|
||||
EOS_TOKEN_ID_QWEN2 = 151643
|
||||
NO_TOKEN_SENTINEL = -1
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: build all tensors needed by the penalty custom op
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_penalty_inputs(
|
||||
token_ids_all_np, # int64 [bs, max_model_len]
|
||||
logits_np, # float32 [bs, vocab_size]
|
||||
prompt_lens_np, # int64 [bs]
|
||||
cur_dec_lens_np, # int64 [bs]
|
||||
penalty_scores_np=None, # float32 [bs, 1] (repetition penalty, default 1.2)
|
||||
frequency_scores_np=None, # float32 [bs, 1] (default 0.5)
|
||||
presence_scores_np=None, # float32 [bs, 1] (default 0.3)
|
||||
temperatures_np=None, # float32 [bs, 1] (default 1.0)
|
||||
eos_token_id_np=None, # int64 [eos_len]
|
||||
min_dec_lens_np=None, # int64 [bs]
|
||||
bad_tokens_np=None, # int64 [bs, bad_words_len]
|
||||
bad_tokens_lens_np=None, # int64 [bs]
|
||||
):
|
||||
"""
|
||||
Build GPU tensors for the penalty op from numpy arrays.
|
||||
|
||||
All penalty/frequency/presence/temperature tensors are float32, matching
|
||||
the production dtype used in input_batch.py. Logits are also float32
|
||||
so that update_value_by_repeat_times (which is templated on logit dtype)
|
||||
reads penalty scalars correctly. The update_repeat_times kernel -- the
|
||||
one that had the race condition -- is NOT templated and exercises the
|
||||
same code path regardless of logit dtype.
|
||||
"""
|
||||
bs = token_ids_all_np.shape[0]
|
||||
|
||||
place = paddle.CUDAPlace(0)
|
||||
|
||||
# Required inputs
|
||||
token_ids_all = paddle.to_tensor(token_ids_all_np, dtype="int64", place=place)
|
||||
logits = paddle.to_tensor(logits_np, dtype="float32", place=place)
|
||||
prompt_lens = paddle.to_tensor(prompt_lens_np, dtype="int64", place=place)
|
||||
cur_dec_lens = paddle.to_tensor(cur_dec_lens_np, dtype="int64", place=place)
|
||||
|
||||
# Optional inputs with sensible defaults
|
||||
if penalty_scores_np is None:
|
||||
penalty_scores_np = np.full([bs, 1], DEFAULT_ALPHA, dtype=np.float32)
|
||||
if frequency_scores_np is None:
|
||||
frequency_scores_np = np.full([bs, 1], DEFAULT_BETA, dtype=np.float32)
|
||||
if presence_scores_np is None:
|
||||
presence_scores_np = np.full([bs, 1], DEFAULT_GAMMA, dtype=np.float32)
|
||||
if temperatures_np is None:
|
||||
temperatures_np = np.full([bs, 1], DEFAULT_TEMP, dtype=np.float32)
|
||||
if eos_token_id_np is None:
|
||||
eos_token_id_np = np.array([EOS_TOKEN_ID_QWEN2], dtype=np.int64)
|
||||
if min_dec_lens_np is None:
|
||||
min_dec_lens_np = np.zeros([bs], dtype=np.int64)
|
||||
if bad_tokens_np is None:
|
||||
bad_tokens_np = np.full([bs, 1], NO_TOKEN_SENTINEL, dtype=np.int64)
|
||||
if bad_tokens_lens_np is None:
|
||||
bad_tokens_lens_np = np.zeros([bs], dtype=np.int64)
|
||||
|
||||
penalty_scores = paddle.to_tensor(penalty_scores_np, dtype="float32", place=place)
|
||||
frequency_scores = paddle.to_tensor(frequency_scores_np, dtype="float32", place=place)
|
||||
presence_scores = paddle.to_tensor(presence_scores_np, dtype="float32", place=place)
|
||||
temperatures = paddle.to_tensor(temperatures_np, dtype="float32", place=place)
|
||||
eos_token_id = paddle.to_tensor(eos_token_id_np, dtype="int64", place=place)
|
||||
min_dec_lens = paddle.to_tensor(min_dec_lens_np, dtype="int64", place=place)
|
||||
bad_tokens = paddle.to_tensor(bad_tokens_np, dtype="int64", place=place)
|
||||
bad_tokens_lens = paddle.to_tensor(bad_tokens_lens_np, dtype="int64", place=place)
|
||||
|
||||
return (
|
||||
token_ids_all,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
bad_tokens_lens,
|
||||
prompt_lens,
|
||||
cur_dec_lens,
|
||||
min_dec_lens,
|
||||
eos_token_id,
|
||||
)
|
||||
|
||||
|
||||
def _run_penalty(
|
||||
token_ids_all,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
bad_tokens_lens,
|
||||
prompt_lens,
|
||||
cur_dec_lens,
|
||||
min_dec_lens,
|
||||
eos_token_id,
|
||||
):
|
||||
"""
|
||||
Run the penalty op on a CLONE of logits (so the original is not modified)
|
||||
and return the resulting logits as a numpy array.
|
||||
"""
|
||||
logits_clone = logits.clone()
|
||||
result = get_token_penalty_multi_scores(
|
||||
token_ids_all,
|
||||
logits_clone,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
bad_tokens_lens,
|
||||
prompt_lens,
|
||||
cur_dec_lens,
|
||||
min_dec_lens,
|
||||
eos_token_id,
|
||||
)
|
||||
paddle.device.cuda.synchronize()
|
||||
return result.numpy()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: determinism assertion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _assert_determinism(
|
||||
inputs,
|
||||
num_runs: int,
|
||||
test_name: str,
|
||||
verbose: bool = True,
|
||||
include_diff_details: bool = True,
|
||||
):
|
||||
"""
|
||||
Assert that running the penalty op `num_runs` times produces
|
||||
bit-identical results.
|
||||
|
||||
Args:
|
||||
inputs: Tuple of tensors for _run_penalty
|
||||
num_runs: Number of runs to perform
|
||||
test_name: Name for error messages
|
||||
verbose: Print success message if True
|
||||
include_diff_details: Include diff details in error message
|
||||
|
||||
Returns:
|
||||
True if deterministic
|
||||
"""
|
||||
reference = _run_penalty(*inputs)
|
||||
mismatches = []
|
||||
|
||||
for i in range(1, num_runs):
|
||||
result = _run_penalty(*inputs)
|
||||
if not np.array_equal(reference, result):
|
||||
diff_mask = reference != result
|
||||
diff_count = np.sum(diff_mask)
|
||||
if include_diff_details:
|
||||
diff_indices = np.argwhere(diff_mask)[:5].tolist()
|
||||
max_abs_diff = np.max(np.abs(reference[diff_mask] - result[diff_mask]))
|
||||
mismatches.append((i, diff_count, diff_indices, max_abs_diff))
|
||||
else:
|
||||
mismatches.append((i, diff_count))
|
||||
|
||||
if mismatches:
|
||||
error_msg = (
|
||||
f"{test_name} is NON-DETERMINISTIC: " f"{len(mismatches)}/{num_runs-1} runs differ from reference.\n"
|
||||
)
|
||||
if include_diff_details:
|
||||
error_msg += (
|
||||
f"First 3 mismatches (run_idx, num_diffs, sample_indices, max_abs_diff): " f"{mismatches[:3]}\n"
|
||||
)
|
||||
else:
|
||||
error_msg += f"First 3 mismatches (run_idx, num_diffs): " f"{mismatches[:3]}\n"
|
||||
error_msg += (
|
||||
"This indicates the atomicCAS/atomicMax/atomicAdd race condition " "in update_repeat_times is NOT fixed."
|
||||
)
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
if verbose:
|
||||
print(f"\n {test_name}: all {num_runs} runs produced bit-identical results.")
|
||||
return True
|
||||
|
||||
|
||||
def _print_penalty_summary(actual: float, expected: float, label: str, raw_value: float):
|
||||
"""Print a formatted summary line for a penalty test."""
|
||||
print(f" {label}: raw={raw_value:.1f} -> {actual:.6f} (expected {expected:.6f})")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: Correctness -- verify repeat_times semantics and final logits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPenaltyCorrectness:
|
||||
"""
|
||||
Test that the penalty kernel applies the correct transformation for
|
||||
each repeat_times category:
|
||||
Token A (id=10): only in prompt -> repeat_times = -1
|
||||
Token B (id=20): only in generated (x2) -> repeat_times = 2
|
||||
Token C (id=30): in BOTH prompt + gen -> repeat_times = 1
|
||||
Token D (id=40): nowhere -> repeat_times = 0
|
||||
"""
|
||||
|
||||
VOCAB_SIZE = 100
|
||||
MAX_MODEL_LEN = 20
|
||||
PROMPT_LEN = 5
|
||||
# Generated tokens occupy positions prompt_len .. max_model_len-1.
|
||||
# Unused slots are filled with -1 (sentinel for "no token").
|
||||
|
||||
def _build_scenario(self):
|
||||
"""
|
||||
Build a single-batch scenario:
|
||||
Prompt tokens (positions 0..4): [10, 30, 50, 51, 52]
|
||||
Generated tokens (positions 5..): [20, 20, 30, -1, -1, ...]
|
||||
|
||||
So:
|
||||
token 10: prompt only -> repeat_times = -1
|
||||
token 20: generated x2 -> repeat_times = 2
|
||||
token 30: prompt AND gen x1 -> repeat_times = 1 (the fixed case!)
|
||||
token 40: absent -> repeat_times = 0
|
||||
token 50,51,52: prompt only -> repeat_times = -1
|
||||
"""
|
||||
bs = 1
|
||||
token_ids = np.full([bs, self.MAX_MODEL_LEN], -1, dtype=np.int64)
|
||||
# Prompt region
|
||||
token_ids[0, 0] = 10
|
||||
token_ids[0, 1] = 30
|
||||
token_ids[0, 2] = 50
|
||||
token_ids[0, 3] = 51
|
||||
token_ids[0, 4] = 52
|
||||
# Generated region
|
||||
token_ids[0, 5] = 20
|
||||
token_ids[0, 6] = 20
|
||||
token_ids[0, 7] = 30
|
||||
|
||||
prompt_lens = np.array([self.PROMPT_LEN], dtype=np.int64)
|
||||
cur_dec_lens = np.array([3], dtype=np.int64) # 3 generated tokens
|
||||
|
||||
# Logits: put known positive and negative values at token positions
|
||||
# to verify the penalty formula direction.
|
||||
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
|
||||
logits[0, 10] = 2.0 # Token A (prompt only) -- positive logit
|
||||
logits[0, 20] = -1.0 # Token B (gen only x2) -- negative logit
|
||||
logits[0, 30] = 3.0 # Token C (both) -- positive logit
|
||||
logits[0, 40] = 0.5 # Token D (absent) -- positive logit
|
||||
|
||||
penalty_scores = np.array([[DEFAULT_ALPHA]], dtype=np.float32)
|
||||
frequency_scores = np.array([[DEFAULT_BETA]], dtype=np.float32)
|
||||
presence_scores = np.array([[DEFAULT_GAMMA]], dtype=np.float32)
|
||||
temperatures = np.array([[DEFAULT_TEMP]], dtype=np.float32)
|
||||
|
||||
return _make_penalty_inputs(
|
||||
token_ids,
|
||||
logits,
|
||||
prompt_lens,
|
||||
cur_dec_lens,
|
||||
penalty_scores_np=penalty_scores,
|
||||
frequency_scores_np=frequency_scores,
|
||||
presence_scores_np=presence_scores,
|
||||
temperatures_np=temperatures,
|
||||
)
|
||||
|
||||
def _expected_logit(self, raw_logit, repeat_times):
|
||||
"""
|
||||
Compute expected logit after penalty application:
|
||||
if times != 0:
|
||||
logit = logit * alpha if logit < 0 else logit / alpha
|
||||
if times > 0:
|
||||
logit = logit - times * beta - gamma
|
||||
logit = logit / temperature
|
||||
"""
|
||||
logit = raw_logit
|
||||
|
||||
if repeat_times != 0:
|
||||
if logit < 0:
|
||||
logit = logit * DEFAULT_ALPHA
|
||||
else:
|
||||
logit = logit / DEFAULT_ALPHA
|
||||
|
||||
if repeat_times > 0:
|
||||
logit = logit - repeat_times * DEFAULT_BETA - DEFAULT_GAMMA
|
||||
|
||||
logit = logit / DEFAULT_TEMP
|
||||
return logit
|
||||
|
||||
def test_penalty_correctness(self):
|
||||
"""
|
||||
Verify that each token category gets the correct penalty applied.
|
||||
|
||||
This specifically tests the case where Token C (id=30) appears in
|
||||
BOTH the prompt and generated regions. Before the __syncthreads()
|
||||
fix, this could non-deterministically produce repeat_times=0
|
||||
instead of the correct repeat_times=1.
|
||||
"""
|
||||
inputs = self._build_scenario()
|
||||
result = _run_penalty(*inputs)
|
||||
|
||||
# Token A (id=10): repeat_times = -1 (prompt only)
|
||||
# Penalty: only repetition (times != 0), no frequency/presence (times <= 0)
|
||||
expected_10 = self._expected_logit(2.0, repeat_times=-1)
|
||||
actual_10 = result[0, 10]
|
||||
assert np.isclose(actual_10, expected_10, atol=1e-5), (
|
||||
f"Token A (prompt only): expected {expected_10:.6f}, got {actual_10:.6f}. " f"repeat_times should be -1."
|
||||
)
|
||||
|
||||
# Token B (id=20): repeat_times = 2 (gen only, 2 occurrences)
|
||||
# Penalty: repetition + 2*frequency + presence
|
||||
expected_20 = self._expected_logit(-1.0, repeat_times=2)
|
||||
actual_20 = result[0, 20]
|
||||
assert np.isclose(actual_20, expected_20, atol=1e-5), (
|
||||
f"Token B (gen only x2): expected {expected_20:.6f}, got {actual_20:.6f}. " f"repeat_times should be 2."
|
||||
)
|
||||
|
||||
# Token C (id=30): repeat_times = 1 (BOTH prompt and gen, 1 gen occurrence)
|
||||
# THIS IS THE KEY TEST CASE for the race condition fix.
|
||||
# Before the fix, repeat_times could be 0 instead of 1.
|
||||
expected_30 = self._expected_logit(3.0, repeat_times=1)
|
||||
actual_30 = result[0, 30]
|
||||
assert np.isclose(actual_30, expected_30, atol=1e-5), (
|
||||
f"Token C (both prompt+gen): expected {expected_30:.6f}, got {actual_30:.6f}. "
|
||||
f"repeat_times should be 1. If this fails intermittently, the "
|
||||
f"atomicCAS/atomicMax/atomicAdd race in update_repeat_times is "
|
||||
f"not properly fixed."
|
||||
)
|
||||
|
||||
# Token D (id=40): repeat_times = 0 (absent)
|
||||
# Only temperature scaling, no penalty.
|
||||
expected_40 = self._expected_logit(0.5, repeat_times=0)
|
||||
actual_40 = result[0, 40]
|
||||
assert np.isclose(actual_40, expected_40, atol=1e-5), (
|
||||
f"Token D (absent): expected {expected_40:.6f}, got {actual_40:.6f}. " f"repeat_times should be 0."
|
||||
)
|
||||
|
||||
# Print summary for debugging
|
||||
_print_penalty_summary(actual_10, expected_10, "Token A (id=10, prompt only)", 2.0)
|
||||
_print_penalty_summary(actual_20, expected_20, "Token B (id=20, gen only x2)", -1.0)
|
||||
_print_penalty_summary(actual_30, expected_30, "Token C (id=30, both prompt+gen)", 3.0)
|
||||
_print_penalty_summary(actual_40, expected_40, "Token D (id=40, absent)", 0.5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: Determinism -- same inputs must produce bit-identical outputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPenaltyDeterminism:
|
||||
"""
|
||||
Run the penalty op multiple times with the same inputs (including tokens
|
||||
that appear in both prompt and generated regions) and verify that all
|
||||
outputs are bit-identical.
|
||||
"""
|
||||
|
||||
VOCAB_SIZE = 1000
|
||||
MAX_MODEL_LEN = 50
|
||||
PROMPT_LEN = 15
|
||||
|
||||
def _build_overlapping_scenario(self, seed=42):
|
||||
"""
|
||||
Build a scenario with multiple tokens appearing in both prompt
|
||||
and generated regions, which is the pattern that triggers the
|
||||
race condition.
|
||||
"""
|
||||
rng = np.random.RandomState(seed)
|
||||
bs = 1
|
||||
|
||||
token_ids = np.full([bs, self.MAX_MODEL_LEN], -1, dtype=np.int64)
|
||||
|
||||
# Prompt: random tokens from [0, VOCAB_SIZE)
|
||||
prompt_tokens = rng.randint(0, self.VOCAB_SIZE, size=self.PROMPT_LEN)
|
||||
token_ids[0, : self.PROMPT_LEN] = prompt_tokens
|
||||
|
||||
# Generated: include some tokens from prompt (overlap!) plus new ones
|
||||
gen_len = 20
|
||||
gen_tokens = np.concatenate(
|
||||
[
|
||||
# First few generated tokens overlap with prompt tokens
|
||||
prompt_tokens[:5],
|
||||
prompt_tokens[:3],
|
||||
# Remaining are new tokens
|
||||
rng.randint(0, self.VOCAB_SIZE, size=gen_len - 8),
|
||||
]
|
||||
)
|
||||
token_ids[0, self.PROMPT_LEN : self.PROMPT_LEN + gen_len] = gen_tokens
|
||||
|
||||
prompt_lens = np.array([self.PROMPT_LEN], dtype=np.int64)
|
||||
cur_dec_lens = np.array([gen_len], dtype=np.int64)
|
||||
|
||||
logits = rng.randn(bs, self.VOCAB_SIZE).astype(np.float32)
|
||||
|
||||
return _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
|
||||
|
||||
def test_penalty_determinism(self):
|
||||
"""
|
||||
Run the penalty op 20 times with identical inputs containing
|
||||
overlapping prompt/generated tokens. All results must be
|
||||
bit-identical (np.array_equal, not np.allclose).
|
||||
|
||||
Before the __syncthreads() fix, this would fail sporadically
|
||||
because the atomicCAS in Pass 1 could race with atomicMax/atomicAdd
|
||||
in Pass 2 for overlapping token IDs.
|
||||
"""
|
||||
inputs = self._build_overlapping_scenario()
|
||||
_assert_determinism(inputs, num_runs=20, test_name="Penalty determinism")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: Determinism stress -- large token_ids_all with many overlaps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPenaltyDeterminismStress:
|
||||
"""
|
||||
Stress test with large sequences (prompt_len=500, generated=500)
|
||||
and many overlapping tokens. Runs 50 times to detect rare races.
|
||||
"""
|
||||
|
||||
VOCAB_SIZE = 32000 # Realistic vocab size (LLaMA/Qwen range)
|
||||
MAX_MODEL_LEN = 1024
|
||||
PROMPT_LEN = 500
|
||||
GEN_LEN = 500
|
||||
NUM_RUNS = 50
|
||||
|
||||
def _build_stress_scenario(self, seed=123):
|
||||
"""
|
||||
Large scenario designed to maximize race window:
|
||||
- 500 prompt tokens, 500 generated tokens
|
||||
- ~40% of generated tokens overlap with prompt tokens
|
||||
- Multiple batch elements to increase GPU occupancy
|
||||
"""
|
||||
rng = np.random.RandomState(seed)
|
||||
bs = 4 # Multiple batch elements
|
||||
|
||||
token_ids = np.full([bs, self.MAX_MODEL_LEN], -1, dtype=np.int64)
|
||||
|
||||
for b in range(bs):
|
||||
# Prompt: random tokens
|
||||
prompt_tokens = rng.randint(0, self.VOCAB_SIZE, size=self.PROMPT_LEN)
|
||||
token_ids[b, : self.PROMPT_LEN] = prompt_tokens
|
||||
|
||||
# Generated: ~40% overlap with prompt
|
||||
num_overlap = int(self.GEN_LEN * 0.4)
|
||||
num_new = self.GEN_LEN - num_overlap
|
||||
|
||||
# Pick overlap tokens by sampling from prompt tokens with replacement
|
||||
overlap_tokens = rng.choice(prompt_tokens, size=num_overlap, replace=True)
|
||||
new_tokens = rng.randint(0, self.VOCAB_SIZE, size=num_new)
|
||||
|
||||
# Interleave overlap and new tokens (shuffle to spread races)
|
||||
gen_tokens = np.concatenate([overlap_tokens, new_tokens])
|
||||
rng.shuffle(gen_tokens)
|
||||
token_ids[b, self.PROMPT_LEN : self.PROMPT_LEN + self.GEN_LEN] = gen_tokens
|
||||
|
||||
prompt_lens = np.full([bs], self.PROMPT_LEN, dtype=np.int64)
|
||||
cur_dec_lens = np.full([bs], self.GEN_LEN, dtype=np.int64)
|
||||
|
||||
logits = rng.randn(bs, self.VOCAB_SIZE).astype(np.float32)
|
||||
|
||||
return _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
|
||||
|
||||
def test_penalty_determinism_stress(self):
|
||||
"""
|
||||
Run the penalty op 50 times with large, heavily-overlapping inputs.
|
||||
All results must be bit-identical.
|
||||
|
||||
The large prompt/generated sizes and high overlap ratio (~40%)
|
||||
maximize the chance of exposing races between the two passes.
|
||||
With 512 threads per block and 1024 tokens to process, each
|
||||
thread handles ~2 tokens, creating ample opportunity for
|
||||
interleaving between Pass 1 (atomicCAS) and Pass 2
|
||||
(atomicMax + atomicAdd) if the __syncthreads() is missing.
|
||||
"""
|
||||
inputs = self._build_stress_scenario()
|
||||
_assert_determinism(
|
||||
inputs,
|
||||
num_runs=self.NUM_RUNS,
|
||||
test_name=f"Penalty stress (bs=4, prompt_len={self.PROMPT_LEN}, gen_len={self.GEN_LEN})",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 4: Correctness of "both" case repeated many times
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPenaltyBothCaseRepeated:
|
||||
"""
|
||||
Targeted test: verify the specific race condition scenario where a token
|
||||
appears in both prompt and generated regions. Uses distinctive penalty
|
||||
values so repeat_times=0 and repeat_times=1 produce clearly different outputs.
|
||||
"""
|
||||
|
||||
VOCAB_SIZE = 100
|
||||
MAX_MODEL_LEN = 20
|
||||
PROMPT_LEN = 5
|
||||
|
||||
def _build_minimal_both_case(self):
|
||||
"""
|
||||
Minimal scenario: token 7 appears once in prompt, once in generated.
|
||||
Uses non-default penalty values to make the outcome more distinct.
|
||||
"""
|
||||
bs = 1
|
||||
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
|
||||
# Prompt: token 7 at position 0
|
||||
token_ids[0, 0] = 7
|
||||
token_ids[0, 1] = 8
|
||||
token_ids[0, 2] = 9
|
||||
token_ids[0, 3] = 11
|
||||
token_ids[0, 4] = 12
|
||||
# Generated: token 7 at position PROMPT_LEN
|
||||
token_ids[0, self.PROMPT_LEN] = 7
|
||||
token_ids[0, self.PROMPT_LEN + 1] = 13
|
||||
|
||||
prompt_lens = np.array([self.PROMPT_LEN], dtype=np.int64)
|
||||
cur_dec_lens = np.array([2], dtype=np.int64)
|
||||
|
||||
# Use a distinctive logit value for token 7
|
||||
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
|
||||
logits[0, 7] = 5.0
|
||||
|
||||
# Use non-trivial penalty values so the output differs significantly
|
||||
# between repeat_times=0 and repeat_times=1
|
||||
alpha = 1.5
|
||||
beta = 0.8
|
||||
gamma = 0.4
|
||||
penalty_scores = np.array([[alpha]], dtype=np.float32)
|
||||
frequency_scores = np.array([[beta]], dtype=np.float32)
|
||||
presence_scores = np.array([[gamma]], dtype=np.float32)
|
||||
temperatures = np.array([[1.0]], dtype=np.float32)
|
||||
|
||||
inputs = _make_penalty_inputs(
|
||||
token_ids,
|
||||
logits,
|
||||
prompt_lens,
|
||||
cur_dec_lens,
|
||||
penalty_scores_np=penalty_scores,
|
||||
frequency_scores_np=frequency_scores,
|
||||
presence_scores_np=presence_scores,
|
||||
temperatures_np=temperatures,
|
||||
)
|
||||
|
||||
# Expected: repeat_times=1 for token 7 (in both prompt and gen)
|
||||
# logit=5.0 -> positive, so divided by alpha: 5.0/1.5 = 3.333...
|
||||
# then frequency+presence: 3.333... - 1*0.8 - 0.4 = 2.133...
|
||||
# then /temperature (1.0): 2.133...
|
||||
expected = (5.0 / alpha) - 1 * beta - gamma
|
||||
|
||||
return inputs, expected
|
||||
|
||||
def test_both_case_repeated(self):
|
||||
"""
|
||||
Run 100 times, verifying that token 7 always produces consistent output.
|
||||
Uses a loose tolerance check since we're checking for consistency,
|
||||
not exact value correctness (that's covered by TestPenaltyCorrectness).
|
||||
|
||||
Before the fix, about 1-5% of runs would produce a different output
|
||||
for the overlapping token due to the race condition.
|
||||
"""
|
||||
inputs, expected_value = self._build_minimal_both_case()
|
||||
num_runs = 100
|
||||
|
||||
# First run establishes the reference
|
||||
reference = _run_penalty(*inputs)[0, 7]
|
||||
|
||||
# Verify all subsequent runs match the reference
|
||||
for i in range(1, num_runs):
|
||||
value = _run_penalty(*inputs)[0, 7]
|
||||
if not np.isclose(value, reference, atol=1e-5):
|
||||
raise AssertionError(
|
||||
f"Token 7 (in both prompt+gen) produced INCONSISTENT values: "
|
||||
f"first run={reference:.6f}, run {i}={value:.6f}.\n"
|
||||
f"This indicates the atomicCAS/atomicMax/atomicAdd race condition."
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n All {num_runs} runs produced consistent value {reference:.6f} "
|
||||
f"for the overlapping token (expected ~{expected_value:.6f})."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 5: Edge cases -- boundary conditions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPenaltyEdgeCases:
|
||||
"""Test edge cases and boundary conditions."""
|
||||
|
||||
VOCAB_SIZE = 100
|
||||
MAX_MODEL_LEN = 20
|
||||
|
||||
def test_empty_generated_tokens(self):
|
||||
"""Test with no generated tokens (cur_dec_len=0)."""
|
||||
bs = 1
|
||||
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
|
||||
# Only prompt tokens, no generated tokens
|
||||
token_ids[0, 0] = 10
|
||||
token_ids[0, 1] = 20
|
||||
token_ids[0, 2] = 30
|
||||
|
||||
prompt_lens = np.array([3], dtype=np.int64)
|
||||
cur_dec_lens = np.array([0], dtype=np.int64) # Empty generated
|
||||
|
||||
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
|
||||
logits[0, 10] = 1.0
|
||||
|
||||
inputs = _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
|
||||
|
||||
# Should not crash and should be deterministic
|
||||
_assert_determinism(inputs, num_runs=10, test_name="Empty generated tokens")
|
||||
|
||||
def test_empty_prompt(self):
|
||||
"""Test with no prompt tokens (prompt_len=0)."""
|
||||
bs = 1
|
||||
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
|
||||
# Only generated tokens
|
||||
token_ids[0, 0] = 10
|
||||
token_ids[0, 1] = 20
|
||||
token_ids[0, 2] = 10
|
||||
|
||||
prompt_lens = np.array([0], dtype=np.int64) # Empty prompt
|
||||
cur_dec_lens = np.array([3], dtype=np.int64)
|
||||
|
||||
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
|
||||
logits[0, 10] = 1.0
|
||||
|
||||
inputs = _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
|
||||
|
||||
# Should not crash and should be deterministic
|
||||
_assert_determinism(inputs, num_runs=10, test_name="Empty prompt")
|
||||
|
||||
def test_single_token(self):
|
||||
"""Test with a single generated token."""
|
||||
bs = 1
|
||||
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
|
||||
# Single prompt token
|
||||
token_ids[0, 0] = 5
|
||||
# Single generated token (different from prompt)
|
||||
token_ids[0, 1] = 10
|
||||
|
||||
prompt_lens = np.array([1], dtype=np.int64)
|
||||
cur_dec_lens = np.array([1], dtype=np.int64)
|
||||
|
||||
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
|
||||
logits[0, 10] = 1.0
|
||||
|
||||
inputs = _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
|
||||
|
||||
# Should not crash and should be deterministic
|
||||
_assert_determinism(inputs, num_runs=10, test_name="Single token")
|
||||
|
||||
def test_no_overlapping_tokens(self):
|
||||
"""Test with no overlapping tokens between prompt and generated."""
|
||||
bs = 1
|
||||
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
|
||||
# Prompt tokens
|
||||
token_ids[0, 0] = 10
|
||||
token_ids[0, 1] = 11
|
||||
# Generated tokens (no overlap with prompt)
|
||||
token_ids[0, 2] = 20
|
||||
token_ids[0, 3] = 21
|
||||
|
||||
prompt_lens = np.array([2], dtype=np.int64)
|
||||
cur_dec_lens = np.array([2], dtype=np.int64)
|
||||
|
||||
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
|
||||
logits[0, 10] = 1.0
|
||||
logits[0, 20] = -1.0
|
||||
|
||||
inputs = _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
|
||||
|
||||
# Should not crash and should be deterministic
|
||||
_assert_determinism(inputs, num_runs=10, test_name="No overlapping tokens")
|
||||
|
||||
def test_repeated_same_token_only_generated(self):
|
||||
"""Test token repeated many times in generated only (no overlap with prompt)."""
|
||||
bs = 1
|
||||
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
|
||||
# Prompt tokens
|
||||
token_ids[0, 0] = 1
|
||||
token_ids[0, 1] = 2
|
||||
token_ids[0, 2] = 3
|
||||
# Generated: same token repeated 10 times
|
||||
for i in range(10):
|
||||
token_ids[0, 3 + i] = 50
|
||||
|
||||
prompt_lens = np.array([3], dtype=np.int64)
|
||||
cur_dec_lens = np.array([10], dtype=np.int64)
|
||||
|
||||
logits = np.zeros([bs, self.VOCAB_SIZE], dtype=np.float32)
|
||||
logits[0, 50] = 2.0
|
||||
|
||||
inputs = _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
|
||||
|
||||
# Should not crash and should be deterministic
|
||||
_assert_determinism(inputs, num_runs=10, test_name="Repeated token in generated only")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 6: Multi-batch determinism
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPenaltyMultiBatch:
|
||||
"""Test determinism with multiple batch elements."""
|
||||
|
||||
VOCAB_SIZE = 500
|
||||
MAX_MODEL_LEN = 50
|
||||
BATCH_SIZE = 8
|
||||
|
||||
def test_multi_batch_determinism(self):
|
||||
"""Test with multiple batch elements, each with different patterns."""
|
||||
rng = np.random.RandomState(42)
|
||||
bs = self.BATCH_SIZE
|
||||
|
||||
token_ids = np.full([bs, self.MAX_MODEL_LEN], NO_TOKEN_SENTINEL, dtype=np.int64)
|
||||
prompt_lens = []
|
||||
cur_dec_lens = []
|
||||
|
||||
for b in range(bs):
|
||||
# Each batch element has a different pattern
|
||||
prompt_len = rng.randint(1, 20)
|
||||
gen_len = rng.randint(1, 20)
|
||||
|
||||
prompt_tokens = rng.randint(0, self.VOCAB_SIZE, size=prompt_len)
|
||||
token_ids[b, :prompt_len] = prompt_tokens
|
||||
|
||||
# Mix of overlapping and new tokens
|
||||
overlap_count = rng.randint(0, gen_len + 1)
|
||||
gen_tokens = []
|
||||
if overlap_count > 0:
|
||||
gen_tokens.extend(rng.choice(prompt_tokens, size=overlap_count, replace=True))
|
||||
gen_tokens.extend(rng.randint(0, self.VOCAB_SIZE, size=gen_len - overlap_count))
|
||||
gen_tokens = gen_tokens[:gen_len] # Ensure correct length
|
||||
|
||||
token_ids[b, prompt_len : prompt_len + gen_len] = gen_tokens
|
||||
prompt_lens.append(prompt_len)
|
||||
cur_dec_lens.append(gen_len)
|
||||
|
||||
prompt_lens = np.array(prompt_lens, dtype=np.int64)
|
||||
cur_dec_lens = np.array(cur_dec_lens, dtype=np.int64)
|
||||
logits = rng.randn(bs, self.VOCAB_SIZE).astype(np.float32)
|
||||
|
||||
inputs = _make_penalty_inputs(token_ids, logits, prompt_lens, cur_dec_lens)
|
||||
|
||||
# Should be deterministic across multiple batches
|
||||
_assert_determinism(inputs, num_runs=20, test_name="Multi-batch determinism")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-sv", __file__])
|
||||
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Unit test: isolate sampling determinism from model computation.
|
||||
|
||||
This test fixes the logits (model output) and runs only the sampling
|
||||
pipeline multiple times. If the results differ, the bug is in sampling;
|
||||
if they are always identical, the non-determinism comes from model
|
||||
computation (logits differ between runs).
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 pytest tests/deterministic/test_sampling_determinism.py -v -s
|
||||
"""
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.gpu
|
||||
|
||||
VOCAB_SIZE = 151936 # Qwen2 vocab size
|
||||
BATCH_SIZE = 1
|
||||
|
||||
|
||||
def _make_logits(seed: int = 42):
|
||||
"""Create reproducible random logits that look like real model output."""
|
||||
paddle.seed(seed)
|
||||
# Simulate logits with realistic distribution (not uniform)
|
||||
logits = paddle.randn([BATCH_SIZE, VOCAB_SIZE], dtype="float32")
|
||||
# Make it slightly peaked (a few tokens have higher logits)
|
||||
logits[0, 100] += 5.0
|
||||
logits[0, 200] += 4.5
|
||||
logits[0, 300] += 4.0
|
||||
return logits
|
||||
|
||||
|
||||
def _sample_with_top_p(logits, top_p_val, seed_val):
|
||||
"""Run the same sampling pipeline as sampler.forward_cuda (non-greedy path)."""
|
||||
probs = F.softmax(logits, axis=-1)
|
||||
top_p = paddle.to_tensor([top_p_val], dtype="float32")
|
||||
topp_seed = paddle.to_tensor([[seed_val]], dtype="int64")
|
||||
_, ids = paddle.tensor.top_p_sampling(probs, top_p, topp_seed=topp_seed, seed=-1, mode="truncated")
|
||||
return ids.item()
|
||||
|
||||
|
||||
# ---- Test 1: basic repeated sampling on identical logits ----
|
||||
|
||||
|
||||
def test_sampling_determinism_basic():
|
||||
"""Same logits + same seed -> must produce same token every time."""
|
||||
logits = _make_logits(seed=42)
|
||||
results = [_sample_with_top_p(logits, top_p_val=0.95, seed_val=200) for _ in range(20)]
|
||||
assert len(set(results)) == 1, f"Sampling non-deterministic! Got {len(set(results))} distinct values: {results}"
|
||||
|
||||
|
||||
# ---- Test 2: simulate multi-step decode (seed increments like real runner) ----
|
||||
|
||||
|
||||
def test_sampling_determinism_multistep():
|
||||
"""Simulate 100 decode steps with seed incrementing by 4 each step."""
|
||||
logits = _make_logits(seed=42)
|
||||
|
||||
def run_steps():
|
||||
tokens = []
|
||||
for step in range(100):
|
||||
seed_val = 200 + step * 4 # real runner increments seed by 4
|
||||
tok = _sample_with_top_p(logits, top_p_val=0.95, seed_val=seed_val)
|
||||
tokens.append(tok)
|
||||
return tokens
|
||||
|
||||
run1 = run_steps()
|
||||
run2 = run_steps()
|
||||
assert run1 == run2, _diff_msg(run1, run2)
|
||||
|
||||
|
||||
# ---- Test 3: interleave GPU work between sampling calls ----
|
||||
|
||||
|
||||
def test_sampling_determinism_with_gpu_noise():
|
||||
"""
|
||||
Insert GPU matmul work between sampling calls to check if
|
||||
GPU state residuals affect sampling determinism.
|
||||
"""
|
||||
logits = _make_logits(seed=42)
|
||||
|
||||
def run_steps_with_noise():
|
||||
tokens = []
|
||||
for step in range(50):
|
||||
# Simulate GPU model forward between steps
|
||||
_ = paddle.matmul(paddle.randn([256, 256]), paddle.randn([256, 256]))
|
||||
seed_val = 200 + step * 4
|
||||
tok = _sample_with_top_p(logits, top_p_val=0.95, seed_val=seed_val)
|
||||
tokens.append(tok)
|
||||
return tokens
|
||||
|
||||
run1 = run_steps_with_noise()
|
||||
run2 = run_steps_with_noise()
|
||||
assert run1 == run2, _diff_msg(run1, run2)
|
||||
|
||||
|
||||
# ---- Test 4: flat distribution (temp=1.0 scenario, hardest case) ----
|
||||
|
||||
|
||||
def test_sampling_determinism_flat_distribution():
|
||||
"""
|
||||
Flat probability distribution (simulating temp=1.0 with no dominant token).
|
||||
This is the hardest case for determinism.
|
||||
"""
|
||||
paddle.seed(99)
|
||||
# Logits close to zero -> softmax gives nearly uniform distribution
|
||||
logits = paddle.randn([BATCH_SIZE, VOCAB_SIZE], dtype="float32") * 0.1
|
||||
|
||||
results_per_seed = {}
|
||||
for seed_val in [100, 200, 300, 400, 500]:
|
||||
results = [_sample_with_top_p(logits, top_p_val=0.95, seed_val=seed_val) for _ in range(10)]
|
||||
results_per_seed[seed_val] = results
|
||||
assert len(set(results)) == 1, (
|
||||
f"seed={seed_val}: sampling non-deterministic on flat dist! "
|
||||
f"Got {len(set(results))} distinct values: {results}"
|
||||
)
|
||||
|
||||
|
||||
# ---- Test 5: different top_p values ----
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_p_val", [0.5, 0.8, 0.95, 1.0])
|
||||
def test_sampling_determinism_various_top_p(top_p_val):
|
||||
"""Determinism across different top_p values."""
|
||||
logits = _make_logits(seed=42)
|
||||
results = [_sample_with_top_p(logits, top_p_val=top_p_val, seed_val=200) for _ in range(10)]
|
||||
assert len(set(results)) == 1, (
|
||||
f"top_p={top_p_val}: non-deterministic! " f"Got {len(set(results))} distinct values: {results}"
|
||||
)
|
||||
|
||||
|
||||
# ---- Helpers ----
|
||||
|
||||
|
||||
def _diff_msg(run1, run2):
|
||||
for i, (a, b) in enumerate(zip(run1, run2)):
|
||||
if a != b:
|
||||
return f"First diff at step {i}: run1={a}, run2={b}. Total diffs: {sum(1 for x, y in zip(run1, run2) if x != y)}/{len(run1)}"
|
||||
return "Lengths differ"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-sv", __file__])
|
||||
Reference in New Issue
Block a user