mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
3fabba0dc7
* [Feature] Add Triton unified attention kernel for deterministic inference Add a Triton-based unified extend attention kernel that processes both prefix (cached) and extend (new) KV tokens through a single kernel with unified kv_indices, ensuring identical accumulation order regardless of cache hit/miss patterns. Key components: - _fwd_kernel_unified: Triton JIT kernel with online softmax, paged KV cache support, and causal masking for prefix+extend - Index building utilities: triton_cumsum_with_zero_prefix, build_kv_indices_from_block_tables, build_unified_kv_indices, _scatter_extend_kv_indices_kernel (all CUDA Graph compatible) - pre_cache_len_concat_triton: GPU-only replacement for C++ op - Reference implementations (_ref variants) for correctness validation - Comprehensive tests: kernel correctness, split invariance, determinism, production-scale, cross-validation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Vectorize causal mask in test references for ~26x speedup Replace triple Python for-loop with paddle.where vectorized mask in naive_attention and _build_causal_mask. seq4096 test: 2m39s -> 6s. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix cover --------- Co-authored-by: gongweibao <gognweibao@baidu.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
276 lines
11 KiB
Python
276 lines
11 KiB
Python
"""
|
|
Tests for CUDA-Graph-compatible Triton index building kernels.
|
|
|
|
Compares the new Triton-based implementations (no .item() calls) against
|
|
the reference Python for-loop implementations to verify correctness.
|
|
"""
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import pytest
|
|
|
|
from fastdeploy.model_executor.layers.attention.triton_ops.unified_extend_attention import (
|
|
_scatter_extend_kv_indices_kernel,
|
|
build_kv_indices_from_block_tables,
|
|
build_kv_indices_from_block_tables_ref,
|
|
pre_cache_len_concat_ref,
|
|
pre_cache_len_concat_triton,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test: build_kv_indices_from_block_tables (Triton) vs _ref (Python for-loop)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBuildKvIndicesFromBlockTables:
|
|
"""Compare Triton kernel vs reference Python loop for building KV indices."""
|
|
|
|
@staticmethod
|
|
def _make_block_tables(bs, max_blocks_per_seq):
|
|
"""Create random block tables with unique physical block IDs."""
|
|
# Use unique block IDs (0..bs*max_blocks-1) to avoid collisions
|
|
total = bs * max_blocks_per_seq
|
|
ids = np.random.permutation(total).reshape(bs, max_blocks_per_seq)
|
|
return paddle.to_tensor(ids, dtype="int32")
|
|
|
|
@pytest.mark.parametrize("block_size", [16, 64])
|
|
@pytest.mark.parametrize(
|
|
"bs, seq_lens_list",
|
|
[
|
|
(1, [10]), # single sequence
|
|
(1, [64]), # exactly one block
|
|
(3, [10, 20, 30]), # multiple sequences
|
|
(4, [0, 15, 0, 7]), # sequences with zero length
|
|
(2, [128, 256]), # multi-block sequences
|
|
(1, [1]), # single token
|
|
(5, [1, 1, 1, 1, 1]), # all single-token (decode)
|
|
],
|
|
)
|
|
def test_matches_ref(self, block_size, bs, seq_lens_list):
|
|
"""Triton kernel output must exactly match the reference implementation."""
|
|
max_blocks_per_seq = max((s + block_size - 1) // block_size for s in seq_lens_list)
|
|
max_blocks_per_seq = max(max_blocks_per_seq, 1)
|
|
block_tables = self._make_block_tables(bs, max_blocks_per_seq)
|
|
seq_lens = paddle.to_tensor(seq_lens_list, dtype="int32")
|
|
total_kv_len = sum(seq_lens_list)
|
|
|
|
# Reference
|
|
indptr_ref, indices_ref = build_kv_indices_from_block_tables_ref(block_tables, seq_lens, block_size, bs)
|
|
|
|
# Triton (with pre-computed total_kv_len — the CUDA Graph path)
|
|
indptr_new, indices_new = build_kv_indices_from_block_tables(
|
|
block_tables, seq_lens, block_size, bs, total_kv_len=total_kv_len
|
|
)
|
|
|
|
np.testing.assert_array_equal(indptr_new.numpy(), indptr_ref.numpy(), err_msg="kv_indptr mismatch")
|
|
if total_kv_len > 0:
|
|
np.testing.assert_array_equal(
|
|
indices_new[:total_kv_len].numpy(),
|
|
indices_ref[:total_kv_len].numpy(),
|
|
err_msg="kv_indices mismatch",
|
|
)
|
|
|
|
@pytest.mark.parametrize("block_size", [16, 64])
|
|
def test_auto_total_kv_len(self, block_size):
|
|
"""When total_kv_len is None, the function falls back to .item() (non-graph path)."""
|
|
bs = 3
|
|
seq_lens_list = [10, 20, 30]
|
|
max_blocks_per_seq = max((s + block_size - 1) // block_size for s in seq_lens_list)
|
|
block_tables = self._make_block_tables(bs, max_blocks_per_seq)
|
|
seq_lens = paddle.to_tensor(seq_lens_list, dtype="int32")
|
|
|
|
indptr_ref, indices_ref = build_kv_indices_from_block_tables_ref(block_tables, seq_lens, block_size, bs)
|
|
indptr_new, indices_new = build_kv_indices_from_block_tables(
|
|
block_tables, seq_lens, block_size, bs, total_kv_len=None
|
|
)
|
|
|
|
total = sum(seq_lens_list)
|
|
np.testing.assert_array_equal(indptr_new.numpy(), indptr_ref.numpy())
|
|
np.testing.assert_array_equal(indices_new[:total].numpy(), indices_ref[:total].numpy())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test: _scatter_extend_kv_indices_kernel
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestScatterExtendKvIndices:
|
|
"""Verify the Triton scatter kernel against a Python reference."""
|
|
|
|
@staticmethod
|
|
def _scatter_ref(all_kv_indices, all_kv_indptr, prefix_lens, extend_start_loc, extend_seq_lens, bs):
|
|
"""Python reference for the scatter operation."""
|
|
total_extend = int(paddle.sum(extend_seq_lens).item())
|
|
out = paddle.empty([max(total_extend, 1)], dtype="int32")
|
|
for s in range(bs):
|
|
plen = int(prefix_lens[s].item())
|
|
elen = int(extend_seq_lens[s].item())
|
|
if elen == 0:
|
|
continue
|
|
src_start = int(all_kv_indptr[s].item()) + plen
|
|
dst_start = int(extend_start_loc[s].item())
|
|
out[dst_start : dst_start + elen] = all_kv_indices[src_start : src_start + elen]
|
|
return out
|
|
|
|
@pytest.mark.parametrize(
|
|
"bs, prefix_list, extend_list",
|
|
[
|
|
(1, [10], [5]), # single seq
|
|
(3, [10, 20, 30], [5, 3, 8]), # multi seq
|
|
(4, [0, 15, 0, 7], [3, 2, 0, 1]), # mixed zero/non-zero
|
|
(2, [100, 200], [1, 1]), # decode-like (extend=1)
|
|
(5, [0, 0, 0, 0, 0], [10, 20, 30, 40, 50]), # all prefill, no prefix
|
|
],
|
|
)
|
|
def test_matches_ref(self, bs, prefix_list, extend_list):
|
|
"""Triton scatter kernel output must exactly match Python reference."""
|
|
prefix_lens = paddle.to_tensor(prefix_list, dtype="int32")
|
|
extend_seq_lens = paddle.to_tensor(extend_list, dtype="int32")
|
|
total_seq_lens = prefix_lens + extend_seq_lens
|
|
|
|
# Build all_kv_indptr and all_kv_indices (fake monotonic indices)
|
|
all_kv_indptr = paddle.concat(
|
|
[
|
|
paddle.zeros([1], dtype="int32"),
|
|
paddle.cumsum(total_seq_lens).astype("int32"),
|
|
]
|
|
)
|
|
total_all = int(paddle.sum(total_seq_lens).item())
|
|
all_kv_indices = paddle.arange(total_all, dtype="int32") # 0, 1, 2, ...
|
|
|
|
extend_start_loc = (
|
|
paddle.concat(
|
|
[
|
|
paddle.zeros([1], dtype="int32"),
|
|
paddle.cumsum(extend_seq_lens[:-1]).astype("int32"),
|
|
]
|
|
)
|
|
if bs > 1
|
|
else paddle.zeros([1], dtype="int32")
|
|
)
|
|
|
|
total_extend = sum(extend_list)
|
|
|
|
# Reference
|
|
ref = self._scatter_ref(all_kv_indices, all_kv_indptr, prefix_lens, extend_start_loc, extend_seq_lens, bs)
|
|
|
|
# Triton
|
|
out = paddle.empty([max(total_extend, 1)], dtype="int32")
|
|
if bs > 0 and total_extend > 0:
|
|
_scatter_extend_kv_indices_kernel[(bs,)](
|
|
all_kv_indices,
|
|
all_kv_indptr,
|
|
prefix_lens,
|
|
extend_start_loc,
|
|
extend_seq_lens,
|
|
out,
|
|
BLOCK=128,
|
|
)
|
|
|
|
if total_extend > 0:
|
|
np.testing.assert_array_equal(
|
|
out[:total_extend].numpy(),
|
|
ref[:total_extend].numpy(),
|
|
err_msg="scatter extend kv indices mismatch",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "-x"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test: pre_cache_len_concat_triton vs pre_cache_len_concat_ref
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPreCacheLenConcat:
|
|
"""Compare Triton GPU-only pre_cache_len_concat vs Python reference."""
|
|
|
|
@pytest.mark.parametrize("block_size", [16, 64, 128])
|
|
@pytest.mark.parametrize(
|
|
"bsz, enc_list, dec_list, qlen_list",
|
|
[
|
|
# Pure decode: all enc=0, so cache_len=0 for all
|
|
(3, [0, 0, 0], [50, 100, 200], [1, 1, 1]),
|
|
# Pure prefill: enc>0, cache_len = dec (chunked prefill)
|
|
(2, [10, 20], [0, 0], [10, 20]),
|
|
# Mixed: some prefill, some decode
|
|
(4, [10, 0, 5, 0], [32, 100, 64, 200], [10, 1, 5, 1]),
|
|
# Single batch
|
|
(1, [1], [128], [1]),
|
|
# All zero enc/dec (edge case)
|
|
(3, [0, 0, 0], [0, 0, 0], [5, 3, 7]),
|
|
# Large cache_len spanning many blocks
|
|
(2, [1, 1], [512, 1024], [32, 64]),
|
|
# Single token decode
|
|
(5, [0, 0, 0, 0, 0], [10, 20, 30, 40, 50], [1, 1, 1, 1, 1]),
|
|
# Mixed zero and non-zero enc
|
|
(4, [5, 0, 0, 10], [100, 0, 50, 200], [5, 1, 1, 10]),
|
|
],
|
|
)
|
|
def test_matches_ref(self, block_size, bsz, enc_list, dec_list, qlen_list):
|
|
"""Triton pre_cache_len_concat must exactly match the reference."""
|
|
seq_lens_encoder = paddle.to_tensor(enc_list, dtype="int32")
|
|
seq_lens_decoder = paddle.to_tensor(dec_list, dtype="int32")
|
|
seq_lens_this_time = paddle.to_tensor(qlen_list, dtype="int32")
|
|
|
|
max_dec = max(dec_list) if dec_list else 0
|
|
max_tile_per_bs = max((max_dec + block_size - 1) // block_size, 1)
|
|
|
|
# Reference
|
|
cu_ref, batch_ids_ref, tile_ids_ref = pre_cache_len_concat_ref(
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
seq_lens_this_time,
|
|
bsz,
|
|
block_size,
|
|
max_tile_per_bs,
|
|
)
|
|
|
|
# Triton
|
|
cu_tri, batch_ids_tri, tile_ids_tri = pre_cache_len_concat_triton(
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
seq_lens_this_time,
|
|
bsz,
|
|
block_size,
|
|
max_tile_per_bs,
|
|
)
|
|
|
|
# cu_seqlens_k must match exactly
|
|
np.testing.assert_array_equal(cu_tri.numpy(), cu_ref.numpy(), err_msg="cu_seqlens_k mismatch")
|
|
|
|
# Compute total gridx from reference to know valid range of batch_ids/tile_ids
|
|
gridx = 0
|
|
for bid in range(bsz):
|
|
enc = enc_list[bid]
|
|
dec = dec_list[bid]
|
|
cache_len = dec if enc > 0 else 0
|
|
gridx += (cache_len + block_size - 1) // block_size
|
|
|
|
if gridx > 0:
|
|
np.testing.assert_array_equal(
|
|
batch_ids_tri[:gridx].numpy(),
|
|
batch_ids_ref[:gridx].numpy(),
|
|
err_msg="batch_ids mismatch",
|
|
)
|
|
np.testing.assert_array_equal(
|
|
tile_ids_tri[:gridx].numpy(),
|
|
tile_ids_ref[:gridx].numpy(),
|
|
err_msg="tile_ids mismatch",
|
|
)
|
|
|
|
def test_empty_batch(self):
|
|
"""Edge case: bsz=0 should produce cu_seqlens_k=[0]."""
|
|
cu, _, _ = pre_cache_len_concat_triton(
|
|
paddle.to_tensor([], dtype="int32"),
|
|
paddle.to_tensor([], dtype="int32"),
|
|
paddle.to_tensor([], dtype="int32"),
|
|
0,
|
|
64,
|
|
1,
|
|
)
|
|
assert cu.shape == [1]
|
|
assert int(cu[0].item()) == 0
|