mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
30f9f33f34
* add fa deter * add ut * add long sentence * fix basic * fix bugs * fix adn * fix first * fix single * fix single * fix single test * refine * add more test * refine comments * add comments of bmm * fix ci * remove probe * add * remove not need * refine tests * fix comments and refine code * refine code * refine test * refine test * mv 4cards tests * fix tests * add * fix comments * fix cover * fix cover --------- Co-authored-by: gongweibao <gognweibao@baidu.com>
800 lines
26 KiB
Python
800 lines
26 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Test suite for the c16 warp1_4 decoder attention kernel determinism.
|
|
|
|
Background:
|
|
The c16 warp1_4 decoder kernel (multiquery_attention_c16_impl.cuh) had two
|
|
bugs that caused nondeterministic outputs under FD_DETERMINISTIC_MODE:
|
|
1. The warp1_4 dispatcher (lines 1164-1175) lacked the force_no_partition
|
|
check, so it still launched the multi-chunk (split) kernel even when
|
|
deterministic mode requested the single-chunk (nosplit) path.
|
|
2. The nosplit kernel template read runtime num_chunks_this_seq instead of
|
|
compile-time partition_kv, causing out-of-bounds nullptr writes to the
|
|
partial output buffer (lines 545, 748, 772, 812).
|
|
|
|
How the c16 warp1_4 path is triggered:
|
|
- dim_head=128, blocksize=64, cache_quant_type="none" -> selects c16 kernel
|
|
- decoder_block_shape_q=16 -> NUM_WARP_Q=1, i.e. warp1_4 configuration
|
|
- seq_lens_encoder=0, seq_lens_decoder>0 -> decoder mode
|
|
- FD_DETERMINISTIC_MODE=1 -> forces nosplit path
|
|
- Small decoder_max_partition_size (e.g. 64) with long prefill (e.g. 256)
|
|
ensures num_chunks > 1, which is the scenario that exposed the bugs.
|
|
|
|
Test items:
|
|
1. test_short_kv_nosplit
|
|
- Short KV (num_chunks=1): basic nosplit path with partition_kv=false.
|
|
- Verifies both correctness (vs naive reference) and determinism (10 runs).
|
|
|
|
2. test_long_kv_multi_chunk
|
|
- Long KV (num_chunks=4, prefill=256, partition=64): the exact scenario
|
|
the fix addresses. partition_kv=true template but grid_chunks=1.
|
|
- Verifies correctness and determinism.
|
|
|
|
3. test_multi_batch
|
|
- Multiple batches (batch_size=4) with multi-chunk decoder.
|
|
- Ensures the fix works across batch elements, not just single-batch.
|
|
|
|
4. test_float16
|
|
- Float16 dtype with multi-chunk decoder.
|
|
- Ensures the fix is dtype-agnostic (not only bfloat16).
|
|
|
|
5. test_unaligned_seq_len
|
|
- prefill_seq_len not divisible by blocksize (100 % 64 != 0).
|
|
- Catches off-by-one bugs in block/chunk boundary calculations.
|
|
|
|
6. test_mha_no_gqa
|
|
- MHA config: q_num_head == kv_num_head (no GQA grouping).
|
|
- Ensures the fix is not GQA-specific.
|
|
|
|
7. test_nosplit_vs_split_consistency
|
|
- Cross-path check: deterministic nosplit vs non-deterministic split.
|
|
- Both paths should produce numerically close results (rtol/atol=1e-2)
|
|
and both should match the naive attention reference.
|
|
|
|
8. test_partition_boundary
|
|
- Edge case: prefill_seq_len equals partition_size (boundary condition).
|
|
- Tests chunk calculation when num_chunks is exactly an integer.
|
|
|
|
9. test_empty_kv
|
|
- Edge case: decoder-only mode (no prefill, empty KV cache).
|
|
- Tests scenario with no encoder prefill phase.
|
|
|
|
Run:
|
|
python -m pytest tests/deterministic/test_c16_warp1_4_determinism.py -v
|
|
"""
|
|
|
|
import copy
|
|
import os
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import paddle
|
|
|
|
os.environ["FD_DETERMINISTIC_MODE"] = "1"
|
|
|
|
from fastdeploy.model_executor.layers.attention.ops import ( # noqa: E402
|
|
append_attention,
|
|
get_block_shape_and_split_kv_block,
|
|
)
|
|
|
|
SEED = 42
|
|
|
|
ENCODER_BLOCK_SHAPE_Q = 64
|
|
DECODER_BLOCK_SHAPE_Q = 16
|
|
|
|
|
|
def _assert_deterministic_and_correct(results, ref):
|
|
"""
|
|
Helper to verify determinism and correctness.
|
|
|
|
Args:
|
|
results: List of output arrays from repeated runs
|
|
ref: Reference output array from naive attention
|
|
"""
|
|
# Verify determinism: all runs should produce identical output
|
|
for i in range(1, len(results)):
|
|
np.testing.assert_array_equal(
|
|
results[0],
|
|
results[i],
|
|
err_msg=f"Determinism failure: run 0 vs run {i}",
|
|
)
|
|
# Verify correctness: output should match naive reference
|
|
np.testing.assert_allclose(results[0], ref, rtol=1e-2, atol=1e-2)
|
|
|
|
|
|
def make_rope_emb(max_seq_len, dim_head, base=10000):
|
|
pos = paddle.arange(max_seq_len).reshape((1, -1))
|
|
inv_freq = base ** (-paddle.arange(0, dim_head, 2, dtype="float32") / dim_head)
|
|
freqs = paddle.einsum("ij,k->ijk", pos.cast("float32"), inv_freq)
|
|
emb = freqs.reshape((1, max_seq_len, dim_head // 2)).unsqueeze(2)
|
|
rope_emb = paddle.zeros((2, 1, max_seq_len, 1, dim_head // 2), dtype="float32")
|
|
rope_emb[0] = paddle.cos(emb)
|
|
rope_emb[1] = paddle.sin(emb)
|
|
return rope_emb
|
|
|
|
|
|
def apply_rope(x, rope_emb, positions):
|
|
"""
|
|
Apply rotary position embedding (non-neox interleaved style).
|
|
|
|
x: (batch, heads, seq_len, dim_head)
|
|
rope_emb: (2, 1, max_seq_len, 1, dim_head//2)
|
|
positions: list of int, one position per seq index
|
|
"""
|
|
dim_head = x.shape[-1]
|
|
half = dim_head // 2
|
|
x_f32 = x.cast("float32")
|
|
out = x_f32.clone()
|
|
|
|
for seq_idx, pos in enumerate(positions):
|
|
cos_p = rope_emb[0, 0, pos, 0, :] # (dim_head//2,)
|
|
sin_p = rope_emb[1, 0, pos, 0, :]
|
|
|
|
x_slice = x_f32[:, :, seq_idx, :] # (batch, heads, dim_head)
|
|
x_pairs = x_slice.reshape(list(x_slice.shape[:-1]) + [half, 2])
|
|
x0 = x_pairs[..., 0] # (batch, heads, half)
|
|
x1 = x_pairs[..., 1]
|
|
|
|
out0 = x0 * cos_p - x1 * sin_p
|
|
out1 = x0 * sin_p + x1 * cos_p
|
|
|
|
out[:, :, seq_idx, :] = paddle.stack([out0, out1], axis=-1).reshape(x_slice.shape)
|
|
|
|
return out.cast(x.dtype)
|
|
|
|
|
|
def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
|
|
cum_offsets_now = paddle.cumsum(max_seq_len - seq_lens_this_time, dtype="int32")
|
|
cum_offsets = paddle.zeros(shape=(bsz + 1,), dtype="int32")
|
|
cum_offsets[1:] = cum_offsets_now
|
|
token_num = int(paddle.sum(seq_lens_this_time))
|
|
batch_id_per_token = paddle.zeros(shape=(token_num,), dtype="int32")
|
|
cu_seqlens_q = paddle.zeros(shape=(bsz + 1,), dtype="int32")
|
|
for i in range(bsz):
|
|
sn = int(seq_lens_this_time[i])
|
|
co = int(cum_offsets[i])
|
|
for j in range(sn):
|
|
batch_id_per_token[i * max_seq_len - co + j] = i
|
|
cu_seqlens_q[i + 1] = (i + 1) * max_seq_len - int(cum_offsets[i + 1])
|
|
return batch_id_per_token, cu_seqlens_q
|
|
|
|
|
|
def naive_attention_impl(query, key, value, cache_k, cache_v, scale):
|
|
"""Reference: Q @ K^T * scale -> softmax -> @ V, with GQA expansion."""
|
|
batch, heads, seq_len, head_dim = query.shape
|
|
kv_head = key.shape[1]
|
|
g = heads // kv_head
|
|
|
|
key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
|
|
key = paddle.tile(key, [1, 1, g, 1, 1]).reshape([batch, heads, seq_len, head_dim])
|
|
value = value.reshape([batch, kv_head, 1, seq_len, head_dim])
|
|
value = paddle.tile(value, [1, 1, g, 1, 1]).reshape([batch, heads, seq_len, head_dim])
|
|
|
|
if cache_k is not None:
|
|
ck = cache_k.reshape([batch, kv_head, 1, -1, head_dim])
|
|
ck = paddle.tile(ck, [1, 1, g, 1, 1]).reshape([batch, heads, -1, head_dim])
|
|
key = paddle.concat([ck, key], axis=2)
|
|
if cache_v is not None:
|
|
cv = cache_v.reshape([batch, kv_head, 1, -1, head_dim])
|
|
cv = paddle.tile(cv, [1, 1, g, 1, 1]).reshape([batch, heads, -1, head_dim])
|
|
value = paddle.concat([cv, value], axis=2)
|
|
|
|
qk = paddle.matmul(query, key, transpose_y=True) * scale
|
|
attn = paddle.nn.functional.softmax(qk, -1)
|
|
return paddle.matmul(attn.cast(value.dtype), value)
|
|
|
|
|
|
def block_cache_to_naive(cache_k, cache_v, bsz, block_tables, seq_len):
|
|
_, num_head, blocksize, dim_head = cache_k.shape
|
|
ok = paddle.zeros([bsz, num_head, seq_len, dim_head], dtype=cache_k.dtype)
|
|
ov = paddle.zeros([bsz, num_head, seq_len, dim_head], dtype=cache_v.dtype)
|
|
for i in range(bsz):
|
|
for j in range(seq_len):
|
|
ok[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
|
|
ov[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
|
|
return ok, ov
|
|
|
|
|
|
def run_c16_warp14_decoder_test(
|
|
batch_size,
|
|
q_num_head,
|
|
kv_num_head,
|
|
dim_head,
|
|
blocksize,
|
|
prefill_seq_len,
|
|
max_dec_len,
|
|
dtype,
|
|
decoder_max_partition_size,
|
|
num_decode_runs,
|
|
):
|
|
"""
|
|
Run encoder prefill + N decoder runs.
|
|
Returns (list of decoder output numpy arrays, naive reference numpy array).
|
|
"""
|
|
np.random.seed(SEED)
|
|
paddle.seed(SEED)
|
|
|
|
max_seq_len = prefill_seq_len + max_dec_len
|
|
block_per_seq = (max_seq_len + blocksize - 1) // blocksize
|
|
max_block_num = block_per_seq * batch_size
|
|
scale = 1.0 / np.sqrt(dim_head)
|
|
group_size = q_num_head // kv_num_head
|
|
compute_type = "bf16" if dtype == "bfloat16" else "fp16"
|
|
|
|
rope_emb = make_rope_emb(max_seq_len, dim_head)
|
|
|
|
# Block tables
|
|
free_list = list(range(max_block_num - 1, -1, -1))
|
|
block_tables = paddle.zeros((batch_size, block_per_seq), dtype="int32")
|
|
for i in range(batch_size):
|
|
for j in range(block_per_seq):
|
|
block_tables[i, j] = free_list.pop()
|
|
|
|
cache_k = paddle.zeros((max_block_num, kv_num_head, blocksize, dim_head), dtype=dtype)
|
|
cache_v = paddle.zeros((max_block_num, kv_num_head, blocksize, dim_head), dtype=dtype)
|
|
|
|
# Tile metadata buffers, sized per allocate_launch_related_buffer() formula
|
|
gqa_ratio = q_num_head // kv_num_head
|
|
decode_tile_size = int(1024 * batch_size * np.ceil((2 * gqa_ratio) / DECODER_BLOCK_SHAPE_Q))
|
|
encode_tile_size = max(batch_size, batch_size * (max_seq_len * gqa_ratio // ENCODER_BLOCK_SHAPE_Q))
|
|
kv_tile_size = max(batch_size, batch_size * (max_seq_len // blocksize))
|
|
|
|
dec_batch_ids = paddle.full([decode_tile_size], 0, dtype="int32")
|
|
dec_tile_ids = paddle.full([decode_tile_size], 0, dtype="int32")
|
|
dec_nblocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
|
|
dec_nblocks_dev = paddle.full([1], 0, dtype="int32")
|
|
dec_chunk_dev = paddle.full([1], decoder_max_partition_size, dtype="int32")
|
|
max_len_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
|
enc_batch_ids = paddle.full([encode_tile_size], 0, dtype="int32")
|
|
enc_tile_ids = paddle.full([encode_tile_size], 0, dtype="int32")
|
|
enc_nblocks_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
|
kv_batch_ids = paddle.full([kv_tile_size], 0, dtype="int32")
|
|
kv_tile_ids = paddle.full([kv_tile_size], 0, dtype="int32")
|
|
kv_nblocks_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
|
|
|
# ===== Encoder phase =====
|
|
seq_enc = paddle.full([batch_size], prefill_seq_len, dtype="int32")
|
|
seq_dec = paddle.full([batch_size], 0, dtype="int32")
|
|
seq_this = copy.deepcopy(seq_enc)
|
|
bid_enc, cu_enc = get_padding_offset(batch_size, prefill_seq_len, seq_this)
|
|
token_num = batch_size * prefill_seq_len
|
|
|
|
q_np = np.random.random([batch_size, q_num_head, prefill_seq_len, dim_head]).astype("float32") / 10
|
|
k_np = np.random.random([batch_size, kv_num_head, prefill_seq_len, dim_head]).astype("float32") / 10
|
|
v_np = np.random.random([batch_size, kv_num_head, prefill_seq_len, dim_head]).astype("float32") / 10
|
|
|
|
q = paddle.to_tensor(q_np, dtype=dtype)
|
|
k = paddle.to_tensor(k_np, dtype=dtype)
|
|
v = paddle.to_tensor(v_np, dtype=dtype)
|
|
qkv = paddle.concat(
|
|
[
|
|
q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * dim_head]),
|
|
k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
|
|
v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
|
|
],
|
|
axis=1,
|
|
)
|
|
|
|
# Use large partition size for encoder to avoid issues in prefill
|
|
encoder_partition_size = 32768
|
|
|
|
get_block_shape_and_split_kv_block(
|
|
seq_enc,
|
|
seq_dec,
|
|
seq_this,
|
|
dec_batch_ids,
|
|
dec_tile_ids,
|
|
dec_nblocks_cpu,
|
|
dec_nblocks_dev,
|
|
dec_chunk_dev,
|
|
max_len_cpu,
|
|
enc_batch_ids,
|
|
enc_tile_ids,
|
|
enc_nblocks_cpu,
|
|
kv_batch_ids,
|
|
kv_tile_ids,
|
|
kv_nblocks_cpu,
|
|
ENCODER_BLOCK_SHAPE_Q,
|
|
DECODER_BLOCK_SHAPE_Q,
|
|
group_size,
|
|
blocksize,
|
|
)
|
|
|
|
append_attention(
|
|
qkv,
|
|
cache_k,
|
|
cache_v,
|
|
seq_enc,
|
|
seq_dec,
|
|
seq_this,
|
|
bid_enc,
|
|
cu_enc,
|
|
block_tables,
|
|
enc_batch_ids,
|
|
enc_tile_ids,
|
|
enc_nblocks_cpu,
|
|
kv_batch_ids,
|
|
kv_tile_ids,
|
|
kv_nblocks_cpu,
|
|
dec_batch_ids,
|
|
dec_tile_ids,
|
|
dec_nblocks_cpu,
|
|
max_len_cpu,
|
|
rope_emb,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
1e-6,
|
|
compute_type,
|
|
"none",
|
|
False,
|
|
False,
|
|
max_seq_len,
|
|
0.0,
|
|
0.0,
|
|
-1,
|
|
ENCODER_BLOCK_SHAPE_Q,
|
|
DECODER_BLOCK_SHAPE_Q,
|
|
encoder_partition_size,
|
|
encoder_partition_size,
|
|
2,
|
|
True,
|
|
False,
|
|
0,
|
|
)
|
|
paddle.device.synchronize()
|
|
|
|
# Extract naive KV cache for reference (already has RoPE applied by kernel)
|
|
naive_ck, naive_cv = block_cache_to_naive(
|
|
cache_k,
|
|
cache_v,
|
|
batch_size,
|
|
block_tables,
|
|
prefill_seq_len,
|
|
)
|
|
|
|
# ===== Decoder phase =====
|
|
seq_enc_d = paddle.full([batch_size], 0, dtype="int32")
|
|
seq_dec_d = paddle.full([batch_size], prefill_seq_len, dtype="int32")
|
|
seq_this_d = paddle.full([batch_size], 1, dtype="int32")
|
|
bid_dec, cu_dec = get_padding_offset(batch_size, 1, seq_this_d)
|
|
|
|
dq_np = np.random.random([batch_size, q_num_head, 1, dim_head]).astype("float32") / 10
|
|
dk_np = np.random.random([batch_size, kv_num_head, 1, dim_head]).astype("float32") / 10
|
|
dv_np = np.random.random([batch_size, kv_num_head, 1, dim_head]).astype("float32") / 10
|
|
dq = paddle.to_tensor(dq_np, dtype=dtype)
|
|
dk = paddle.to_tensor(dk_np, dtype=dtype)
|
|
dv = paddle.to_tensor(dv_np, dtype=dtype)
|
|
dec_qkv = paddle.concat(
|
|
[
|
|
dq.transpose([0, 2, 1, 3]).reshape([batch_size, q_num_head * dim_head]),
|
|
dk.transpose([0, 2, 1, 3]).reshape([batch_size, kv_num_head * dim_head]),
|
|
dv.transpose([0, 2, 1, 3]).reshape([batch_size, kv_num_head * dim_head]),
|
|
],
|
|
axis=1,
|
|
)
|
|
|
|
# Warmup: first decoder call on multi-chunk path may return zeros
|
|
# due to kernel JIT compilation. Run once and discard.
|
|
get_block_shape_and_split_kv_block(
|
|
seq_enc_d,
|
|
seq_dec_d,
|
|
seq_this_d,
|
|
dec_batch_ids,
|
|
dec_tile_ids,
|
|
dec_nblocks_cpu,
|
|
dec_nblocks_dev,
|
|
dec_chunk_dev,
|
|
max_len_cpu,
|
|
enc_batch_ids,
|
|
enc_tile_ids,
|
|
enc_nblocks_cpu,
|
|
kv_batch_ids,
|
|
kv_tile_ids,
|
|
kv_nblocks_cpu,
|
|
ENCODER_BLOCK_SHAPE_Q,
|
|
DECODER_BLOCK_SHAPE_Q,
|
|
group_size,
|
|
blocksize,
|
|
)
|
|
append_attention(
|
|
dec_qkv.clone(),
|
|
cache_k.clone(),
|
|
cache_v.clone(),
|
|
seq_enc_d,
|
|
seq_dec_d,
|
|
seq_this_d,
|
|
bid_dec,
|
|
cu_dec,
|
|
block_tables,
|
|
enc_batch_ids,
|
|
enc_tile_ids,
|
|
enc_nblocks_cpu,
|
|
kv_batch_ids,
|
|
kv_tile_ids,
|
|
kv_nblocks_cpu,
|
|
dec_batch_ids,
|
|
dec_tile_ids,
|
|
dec_nblocks_cpu,
|
|
max_len_cpu,
|
|
rope_emb,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
1e-6,
|
|
compute_type,
|
|
"none",
|
|
False,
|
|
False,
|
|
max_seq_len,
|
|
0.0,
|
|
0.0,
|
|
-1,
|
|
ENCODER_BLOCK_SHAPE_Q,
|
|
DECODER_BLOCK_SHAPE_Q,
|
|
decoder_max_partition_size,
|
|
encoder_partition_size,
|
|
2,
|
|
True,
|
|
False,
|
|
0,
|
|
)
|
|
paddle.device.synchronize()
|
|
|
|
results = []
|
|
for _ in range(num_decode_runs):
|
|
cache_k_c = cache_k.clone()
|
|
cache_v_c = cache_v.clone()
|
|
qkv_c = dec_qkv.clone()
|
|
|
|
get_block_shape_and_split_kv_block(
|
|
seq_enc_d,
|
|
seq_dec_d,
|
|
seq_this_d,
|
|
dec_batch_ids,
|
|
dec_tile_ids,
|
|
dec_nblocks_cpu,
|
|
dec_nblocks_dev,
|
|
dec_chunk_dev,
|
|
max_len_cpu,
|
|
enc_batch_ids,
|
|
enc_tile_ids,
|
|
enc_nblocks_cpu,
|
|
kv_batch_ids,
|
|
kv_tile_ids,
|
|
kv_nblocks_cpu,
|
|
ENCODER_BLOCK_SHAPE_Q,
|
|
DECODER_BLOCK_SHAPE_Q,
|
|
group_size,
|
|
blocksize,
|
|
)
|
|
|
|
out = append_attention(
|
|
qkv_c,
|
|
cache_k_c,
|
|
cache_v_c,
|
|
seq_enc_d,
|
|
seq_dec_d,
|
|
seq_this_d,
|
|
bid_dec,
|
|
cu_dec,
|
|
block_tables,
|
|
enc_batch_ids,
|
|
enc_tile_ids,
|
|
enc_nblocks_cpu,
|
|
kv_batch_ids,
|
|
kv_tile_ids,
|
|
kv_nblocks_cpu,
|
|
dec_batch_ids,
|
|
dec_tile_ids,
|
|
dec_nblocks_cpu,
|
|
max_len_cpu,
|
|
rope_emb,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
1e-6,
|
|
compute_type,
|
|
"none",
|
|
False,
|
|
False,
|
|
max_seq_len,
|
|
0.0,
|
|
0.0,
|
|
-1,
|
|
ENCODER_BLOCK_SHAPE_Q,
|
|
DECODER_BLOCK_SHAPE_Q,
|
|
decoder_max_partition_size,
|
|
encoder_partition_size,
|
|
2,
|
|
True,
|
|
False,
|
|
0,
|
|
)
|
|
paddle.device.synchronize()
|
|
results.append(out.numpy().copy())
|
|
|
|
# Naive reference: apply RoPE to decoder Q/K at position prefill_seq_len
|
|
# (cached K/V already have RoPE applied by the kernel during encoder phase)
|
|
dq_rope = apply_rope(dq, rope_emb, [prefill_seq_len])
|
|
dk_rope = apply_rope(dk, rope_emb, [prefill_seq_len])
|
|
ref = naive_attention_impl(dq_rope, dk_rope, dv, naive_ck, naive_cv, scale)
|
|
ref_np = ref.transpose([0, 2, 1, 3]).reshape([batch_size, q_num_head * dim_head]).numpy()
|
|
|
|
return results, ref_np
|
|
|
|
|
|
class TestC16Warp14Determinism(unittest.TestCase):
|
|
"""
|
|
Test the c16 warp1_4 decoder kernel under FD_DETERMINISTIC_MODE=1.
|
|
|
|
Verifies:
|
|
1. Correctness: output matches naive attention reference (rtol/atol=1e-2)
|
|
2. Determinism: repeated runs with identical input -> bitwise-identical output
|
|
"""
|
|
|
|
def test_short_kv_nosplit(self):
|
|
"""num_chunks=1 (short KV): basic nosplit path, partition_kv=false template."""
|
|
results, ref = run_c16_warp14_decoder_test(
|
|
batch_size=1,
|
|
q_num_head=16,
|
|
kv_num_head=2,
|
|
dim_head=128,
|
|
blocksize=64,
|
|
prefill_seq_len=64,
|
|
max_dec_len=32,
|
|
dtype="bfloat16",
|
|
decoder_max_partition_size=32768,
|
|
num_decode_runs=10,
|
|
)
|
|
_assert_deterministic_and_correct(results, ref)
|
|
|
|
def test_long_kv_multi_chunk(self):
|
|
"""
|
|
num_chunks=4 (prefill=256, partition=64): the exact scenario the fix addresses.
|
|
partition_kv=true template but grid_chunks=1 (deterministic).
|
|
"""
|
|
results, ref = run_c16_warp14_decoder_test(
|
|
batch_size=1,
|
|
q_num_head=16,
|
|
kv_num_head=2,
|
|
dim_head=128,
|
|
blocksize=64,
|
|
prefill_seq_len=256,
|
|
max_dec_len=32,
|
|
dtype="bfloat16",
|
|
decoder_max_partition_size=64,
|
|
num_decode_runs=10,
|
|
)
|
|
_assert_deterministic_and_correct(results, ref)
|
|
|
|
def test_multi_batch(self):
|
|
"""Multiple batches with multi-chunk decoder."""
|
|
results, ref = run_c16_warp14_decoder_test(
|
|
batch_size=4,
|
|
q_num_head=8,
|
|
kv_num_head=2,
|
|
dim_head=128,
|
|
blocksize=64,
|
|
prefill_seq_len=256,
|
|
max_dec_len=32,
|
|
dtype="bfloat16",
|
|
decoder_max_partition_size=64,
|
|
num_decode_runs=10,
|
|
)
|
|
_assert_deterministic_and_correct(results, ref)
|
|
|
|
def test_float16(self):
|
|
"""Float16 dtype with multi-chunk decoder."""
|
|
results, ref = run_c16_warp14_decoder_test(
|
|
batch_size=1,
|
|
q_num_head=16,
|
|
kv_num_head=2,
|
|
dim_head=128,
|
|
blocksize=64,
|
|
prefill_seq_len=256,
|
|
max_dec_len=32,
|
|
dtype="float16",
|
|
decoder_max_partition_size=64,
|
|
num_decode_runs=10,
|
|
)
|
|
_assert_deterministic_and_correct(results, ref)
|
|
|
|
def test_unaligned_seq_len(self):
|
|
"""prefill_seq_len not divisible by blocksize (100 % 64 != 0)."""
|
|
results, ref = run_c16_warp14_decoder_test(
|
|
batch_size=1,
|
|
q_num_head=16,
|
|
kv_num_head=2,
|
|
dim_head=128,
|
|
blocksize=64,
|
|
prefill_seq_len=100,
|
|
max_dec_len=32,
|
|
dtype="bfloat16",
|
|
decoder_max_partition_size=64,
|
|
num_decode_runs=10,
|
|
)
|
|
_assert_deterministic_and_correct(results, ref)
|
|
|
|
def test_mha_no_gqa(self):
|
|
"""MHA: q_num_head == kv_num_head (no GQA grouping)."""
|
|
results, ref = run_c16_warp14_decoder_test(
|
|
batch_size=1,
|
|
q_num_head=8,
|
|
kv_num_head=8,
|
|
dim_head=128,
|
|
blocksize=64,
|
|
prefill_seq_len=128,
|
|
max_dec_len=32,
|
|
dtype="bfloat16",
|
|
decoder_max_partition_size=64,
|
|
num_decode_runs=10,
|
|
)
|
|
_assert_deterministic_and_correct(results, ref)
|
|
|
|
def test_nosplit_vs_split_consistency(self):
|
|
"""
|
|
Cross-path check: force_no_partition (deterministic) vs split (partitioned)
|
|
should produce numerically close results.
|
|
|
|
The two paths differ only in floating-point accumulation order:
|
|
- nosplit: single chunk, sequential accumulation
|
|
- split: multi-chunk parallel, then merge
|
|
Difference should be within low-precision rounding (rtol/atol=1e-2 for bf16).
|
|
|
|
Note: envs.FD_DETERMINISTIC_MODE is lazily evaluated via __getattr__,
|
|
so runtime changes to os.environ["FD_DETERMINISTIC_MODE"] take effect.
|
|
"""
|
|
# Run once in deterministic mode (already set at module level)
|
|
det_results, det_ref = run_c16_warp14_decoder_test(
|
|
batch_size=1,
|
|
q_num_head=16,
|
|
kv_num_head=2,
|
|
dim_head=128,
|
|
blocksize=64,
|
|
prefill_seq_len=256,
|
|
max_dec_len=32,
|
|
dtype="bfloat16",
|
|
decoder_max_partition_size=64,
|
|
num_decode_runs=1,
|
|
)
|
|
|
|
# Run once in non-deterministic mode (split/partitioned path)
|
|
os.environ["FD_DETERMINISTIC_MODE"] = "0"
|
|
try:
|
|
split_results, split_ref = run_c16_warp14_decoder_test(
|
|
batch_size=1,
|
|
q_num_head=16,
|
|
kv_num_head=2,
|
|
dim_head=128,
|
|
blocksize=64,
|
|
prefill_seq_len=256,
|
|
max_dec_len=32,
|
|
dtype="bfloat16",
|
|
decoder_max_partition_size=64,
|
|
num_decode_runs=1,
|
|
)
|
|
finally:
|
|
os.environ["FD_DETERMINISTIC_MODE"] = "1"
|
|
|
|
# Both paths should match the naive reference
|
|
np.testing.assert_allclose(
|
|
det_results[0], det_ref, rtol=1e-2, atol=1e-2, err_msg="Deterministic path vs naive reference"
|
|
)
|
|
np.testing.assert_allclose(
|
|
split_results[0], split_ref, rtol=1e-2, atol=1e-2, err_msg="Split path vs naive reference"
|
|
)
|
|
|
|
# Cross-path: nosplit vs split should be close
|
|
np.testing.assert_allclose(
|
|
det_results[0],
|
|
split_results[0],
|
|
rtol=1e-2,
|
|
atol=1e-2,
|
|
err_msg="Deterministic (nosplit) vs split path divergence",
|
|
)
|
|
|
|
def test_partition_boundary(self):
|
|
"""
|
|
Edge case: prefill_seq_len equals partition_size (boundary condition).
|
|
|
|
This tests the scenario where num_chunks = prefill_seq_len / partition_size
|
|
is exactly an integer, which is a boundary condition for chunk calculation.
|
|
"""
|
|
results, ref = run_c16_warp14_decoder_test(
|
|
batch_size=1,
|
|
q_num_head=16,
|
|
kv_num_head=2,
|
|
dim_head=128,
|
|
blocksize=64,
|
|
prefill_seq_len=128,
|
|
max_dec_len=32,
|
|
dtype="bfloat16",
|
|
decoder_max_partition_size=128,
|
|
num_decode_runs=10,
|
|
)
|
|
_assert_deterministic_and_correct(results, ref)
|
|
|
|
def test_empty_kv(self):
|
|
"""
|
|
Edge case: decoder-only mode (no prefill, empty KV cache).
|
|
|
|
This tests the scenario where there's no encoder prefill phase,
|
|
only decoder with empty KV cache.
|
|
"""
|
|
results, _ = run_c16_warp14_decoder_test(
|
|
batch_size=1,
|
|
q_num_head=16,
|
|
kv_num_head=2,
|
|
dim_head=128,
|
|
blocksize=64,
|
|
prefill_seq_len=0,
|
|
max_dec_len=32,
|
|
dtype="bfloat16",
|
|
decoder_max_partition_size=64,
|
|
num_decode_runs=10,
|
|
)
|
|
# For empty KV, we only check determinism as naive reference may not be valid
|
|
for i in range(1, len(results)):
|
|
np.testing.assert_array_equal(
|
|
results[0],
|
|
results[i],
|
|
err_msg=f"Determinism failure: run 0 vs run {i}",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|