[Feature] Add Triton unified attention kernel for deterministic inference (#6795)

* [Feature] Add Triton unified attention kernel for deterministic inference

Add a Triton-based unified extend attention kernel that processes both
prefix (cached) and extend (new) KV tokens through a single kernel with
unified kv_indices, ensuring identical accumulation order regardless of
cache hit/miss patterns.

Key components:
- _fwd_kernel_unified: Triton JIT kernel with online softmax, paged KV
  cache support, and causal masking for prefix+extend
- Index building utilities: triton_cumsum_with_zero_prefix,
  build_kv_indices_from_block_tables, build_unified_kv_indices,
  _scatter_extend_kv_indices_kernel (all CUDA Graph compatible)
- pre_cache_len_concat_triton: GPU-only replacement for C++ op
- Reference implementations (_ref variants) for correctness validation
- Comprehensive tests: kernel correctness, split invariance,
  determinism, production-scale, cross-validation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Vectorize causal mask in test references for ~26x speedup

Replace triple Python for-loop with paddle.where vectorized mask in
naive_attention and _build_causal_mask. seq4096 test: 2m39s -> 6s.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix cover

---------

Co-authored-by: gongweibao <gognweibao@baidu.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
gongweibao
2026-03-16 14:29:45 +08:00
committed by GitHub
parent 7c8c0a3c02
commit 3fabba0dc7
4 changed files with 2075 additions and 0 deletions
@@ -0,0 +1,19 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapt from
# https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/triton_ops/extend_attention.py
# Licensed under Apache License 2.0
"""
@@ -0,0 +1,726 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from
# https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/triton_ops/extend_attention.py
# Licensed under Apache License 2.0
#
# Modified by FastDeploy team for deterministic mode support with prefix caching.
# Key adaptation: FastDeploy uses paged KV cache [num_blocks, kv_heads, block_size, head_dim],
# while SGLang uses flat KV cache [total_tokens, kv_heads, head_dim].
"""
import paddle
import triton
import triton.language as tl
# ---------------------------------------------------------------------------
# Triton cumsum (CUDA Graph compatible, replaces paddle.cumsum / thrust)
# ---------------------------------------------------------------------------
@triton.jit # pragma: no cover
def _cumsum_with_zero_prefix_kernel(input_ptr, output_ptr, n, BLOCK: tl.constexpr):
"""
output[0] = 0, output[1:n+1] = cumsum(input[0:n]).
Single program, handles n <= BLOCK.
"""
tl.store(output_ptr, 0)
idx = tl.arange(0, BLOCK)
mask = idx < n
val = tl.load(input_ptr + idx, mask=mask, other=0).to(tl.int32)
cumval = tl.cumsum(val, axis=0)
tl.store(output_ptr + 1 + idx, cumval, mask=mask)
def triton_cumsum_with_zero_prefix(x, n=None, out_buf=None):
"""
Triton replacement for: paddle.concat([zeros([1], int32), cumsum(x).astype(int32)])
Returns int32 tensor of shape [n+1]. CUDA Graph capture compatible.
Args:
out_buf: Optional pre-allocated buffer of size >= [n+1].
If provided, writes into out_buf[:n+1] instead of allocating.
"""
if n is None:
n = x.shape[0]
out = out_buf[: n + 1] if out_buf is not None else paddle.empty([n + 1], dtype="int32")
if n == 0:
out[0:1] = paddle.zeros([1], dtype="int32")
return out
BLOCK = triton.next_power_of_2(n)
_cumsum_with_zero_prefix_kernel[(1,)](x, out, n, BLOCK=BLOCK)
return out
# ---------------------------------------------------------------------------
# Triton helper kernels (allocation-free elementwise ops for CUDA Graph)
# ---------------------------------------------------------------------------
@triton.jit # pragma: no cover
def _indptr_to_lens_kernel(indptr_ptr, lens_ptr, n, BLOCK: tl.constexpr):
"""Compute lens[i] = indptr[i+1] - indptr[i] for i in [0, n)."""
idx = tl.arange(0, BLOCK)
mask = idx < n
a = tl.load(indptr_ptr + idx + 1, mask=mask, other=0)
b = tl.load(indptr_ptr + idx, mask=mask, other=0)
tl.store(lens_ptr + idx, a - b, mask=mask)
@triton.jit # pragma: no cover
def _elementwise_add_kernel(a_ptr, b_ptr, out_ptr, n, BLOCK: tl.constexpr):
"""Compute out[i] = a[i] + b[i] for i in [0, n)."""
idx = tl.arange(0, BLOCK)
mask = idx < n
a = tl.load(a_ptr + idx, mask=mask, other=0)
b = tl.load(b_ptr + idx, mask=mask, other=0)
tl.store(out_ptr + idx, a + b, mask=mask)
# ---------------------------------------------------------------------------
# Index building utilities
# ---------------------------------------------------------------------------
@triton.jit # pragma: no cover
def _copy_unified_indices_kernel(
prefix_kv_indptr,
prefix_kv_indices,
extend_start_loc,
extend_seq_lens,
extend_kv_indices,
unified_kv_indptr,
unified_kv_indices,
bs,
):
"""
Copy prefix and extend KV indices into a unified buffer.
One program per sequence, internal loops for vectorized copy.
"""
pid = tl.program_id(0)
if pid >= bs:
return
prefix_start = tl.load(prefix_kv_indptr + pid)
prefix_end = tl.load(prefix_kv_indptr + pid + 1)
extend_start = tl.load(extend_start_loc + pid)
extend_len = tl.load(extend_seq_lens + pid)
prefix_len = prefix_end - prefix_start
unified_start = tl.load(unified_kv_indptr + pid)
BLOCK_SIZE: tl.constexpr = 128
# Copy prefix indices
for block_start in range(0, prefix_len, BLOCK_SIZE):
offs = block_start + tl.arange(0, BLOCK_SIZE)
mask = offs < prefix_len
vals = tl.load(prefix_kv_indices + prefix_start + offs, mask=mask, other=0)
tl.store(unified_kv_indices + unified_start + offs, vals, mask=mask)
# Copy extend indices
for block_start in range(0, extend_len, BLOCK_SIZE):
offs = block_start + tl.arange(0, BLOCK_SIZE)
mask = offs < extend_len
vals = tl.load(extend_kv_indices + extend_start + offs, mask=mask, other=0)
tl.store(unified_kv_indices + unified_start + prefix_len + offs, vals, mask=mask)
def build_unified_kv_indices(
prefix_kv_indptr,
prefix_kv_indices,
extend_start_loc,
extend_seq_lens,
extend_kv_indices,
bs,
unified_kv_indptr_buf=None,
unified_kv_indices_buf=None,
prefix_lens_buf=None,
unified_lens_buf=None,
):
"""
Build unified KV indices from prefix and extend parts.
Uses Triton cumsum (CUDA Graph compatible) for indptr, Triton kernel for index copy.
Optional *_buf args: pre-allocated buffers for CUDA Graph compatibility.
When None (default), tensors are allocated dynamically (backward-compatible).
Returns:
(unified_kv_indptr, unified_kv_indices, prefix_lens)
"""
if prefix_lens_buf is not None and bs > 0:
prefix_lens = prefix_lens_buf[:bs]
BLOCK = triton.next_power_of_2(bs)
_indptr_to_lens_kernel[(1,)](prefix_kv_indptr, prefix_lens, bs, BLOCK=BLOCK)
else:
prefix_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs]
if unified_lens_buf is not None and bs > 0:
unified_lens = unified_lens_buf[:bs]
BLOCK = triton.next_power_of_2(bs)
_elementwise_add_kernel[(1,)](prefix_lens, extend_seq_lens, unified_lens, bs, BLOCK=BLOCK)
else:
unified_lens = prefix_lens + extend_seq_lens[:bs]
unified_kv_indptr = triton_cumsum_with_zero_prefix(unified_lens, bs, out_buf=unified_kv_indptr_buf)
total_len = prefix_kv_indices.shape[0] + extend_kv_indices.shape[0]
if unified_kv_indices_buf is not None:
unified_kv_indices = unified_kv_indices_buf[:total_len]
else:
unified_kv_indices = paddle.empty([total_len], dtype="int32")
_copy_unified_indices_kernel[(bs,)](
prefix_kv_indptr,
prefix_kv_indices,
extend_start_loc,
extend_seq_lens,
extend_kv_indices,
unified_kv_indptr,
unified_kv_indices,
bs,
)
return unified_kv_indptr, unified_kv_indices, prefix_lens
@triton.jit # pragma: no cover
def _build_kv_indices_kernel(
block_tables_ptr,
seq_lens_ptr,
kv_indptr_ptr,
kv_indices_ptr,
block_size: tl.constexpr,
max_blocks_per_seq,
BLOCK: tl.constexpr,
):
"""
Build flat token-level KV indices from block_tables.
One program per sequence.
For token at position t in sequence s:
physical_index = block_tables[s, t // block_size] * block_size + t % block_size
"""
seq_id = tl.program_id(0)
slen = tl.load(seq_lens_ptr + seq_id)
dst_start = tl.load(kv_indptr_ptr + seq_id)
row_base = seq_id * max_blocks_per_seq
for off in range(0, slen, BLOCK):
t = off + tl.arange(0, BLOCK)
mask = t < slen
block_idx = t // block_size
in_block_off = t % block_size
blk_id = tl.load(block_tables_ptr + row_base + block_idx, mask=mask, other=0)
idx = blk_id * block_size + in_block_off
tl.store(kv_indices_ptr + dst_start + t, idx, mask=mask)
@triton.jit # pragma: no cover
def _scatter_extend_kv_indices_kernel(
all_kv_indices_ptr,
all_kv_indptr_ptr,
prefix_lens_ptr,
extend_start_loc_ptr,
extend_seq_lens_ptr,
out_ptr,
BLOCK: tl.constexpr,
):
"""
Scatter extend KV indices from all_kv_indices into a contiguous extend buffer.
One program per sequence. Copies the extend portion (skipping prefix) of each sequence.
"""
seq_id = tl.program_id(0)
elen = tl.load(extend_seq_lens_ptr + seq_id)
plen = tl.load(prefix_lens_ptr + seq_id)
src_start = tl.load(all_kv_indptr_ptr + seq_id) + plen
dst_start = tl.load(extend_start_loc_ptr + seq_id)
for off in range(0, elen, BLOCK):
idx = off + tl.arange(0, BLOCK)
mask = idx < elen
val = tl.load(all_kv_indices_ptr + src_start + idx, mask=mask, other=0)
tl.store(out_ptr + dst_start + idx, val, mask=mask)
def build_kv_indices_from_block_tables(
block_tables, seq_lens, block_size, bs, total_kv_len=None, kv_indptr_buf=None, kv_indices_buf=None
):
"""
Convert FastDeploy's block_tables to flat token-level KV indices.
CUDA Graph capture compatible (uses Triton cumsum, no thrust).
Optional *_buf args: pre-allocated buffers for CUDA Graph compatibility.
"""
kv_indptr = triton_cumsum_with_zero_prefix(seq_lens[:bs], bs, out_buf=kv_indptr_buf)
if total_kv_len is None:
total_kv_len = int(paddle.sum(seq_lens[:bs]).item())
if kv_indices_buf is not None:
kv_indices = kv_indices_buf[: max(total_kv_len, 1)]
else:
kv_indices = paddle.empty([max(total_kv_len, 1)], dtype="int32")
if bs > 0 and total_kv_len > 0:
max_blocks_per_seq = block_tables.shape[1]
_build_kv_indices_kernel[(bs,)](
block_tables,
seq_lens,
kv_indptr,
kv_indices,
block_size,
max_blocks_per_seq,
BLOCK=128,
)
return kv_indptr, kv_indices
def build_kv_indices_from_block_tables_ref(block_tables, seq_lens, block_size, bs):
"""
Reference (Python for-loop) implementation of build_kv_indices_from_block_tables.
Uses .item() for GPUCPU transfers NOT compatible with CUDA Graph capture.
Kept for correctness validation against the Triton version.
"""
kv_indptr = paddle.concat(
[
paddle.zeros([1], dtype="int32"),
paddle.cumsum(seq_lens[:bs]).astype("int32"),
]
)
total_kv_len = int(paddle.sum(seq_lens[:bs]).item())
kv_indices = paddle.empty([max(total_kv_len, 1)], dtype="int32")
for s in range(bs):
slen = int(seq_lens[s].item())
if slen == 0:
continue
start = int(kv_indptr[s].item())
positions = paddle.arange(slen, dtype="int32")
block_ids = block_tables[s, positions // block_size]
offsets = positions % block_size
kv_indices[start : start + slen] = block_ids * block_size + offsets
return kv_indptr, kv_indices
# ---------------------------------------------------------------------------
# Triton pre_cache_len_concat (CUDA Graph compatible, GPU-only)
# ---------------------------------------------------------------------------
@triton.jit # pragma: no cover
def _pre_cache_cu_seqlens_kernel(
seq_lens_encoder_ptr,
seq_lens_decoder_ptr,
seq_lens_this_time_ptr,
cu_seqlens_k_ptr,
cache_len_ptr,
loop_times_ptr,
bsz,
block_size: tl.constexpr,
BLOCK_BSZ: tl.constexpr,
):
"""
Compute cu_seqlens_k, cache_len, and loop_times per batch.
Single program, vectorized over bsz.
"""
tl.store(cu_seqlens_k_ptr, 0)
bid = tl.arange(0, BLOCK_BSZ)
mask = bid < bsz
enc = tl.load(seq_lens_encoder_ptr + bid, mask=mask, other=0)
dec = tl.load(seq_lens_decoder_ptr + bid, mask=mask, other=0)
qlen = tl.load(seq_lens_this_time_ptr + bid, mask=mask, other=0)
cache_len = tl.where(enc > 0, dec, 0)
total_per_batch = cache_len + qlen
cu_cumsum = tl.cumsum(total_per_batch, axis=0)
tl.store(cu_seqlens_k_ptr + 1 + bid, cu_cumsum, mask=mask)
tl.store(cache_len_ptr + bid, cache_len, mask=mask)
lt = (cache_len + block_size - 1) // block_size
tl.store(loop_times_ptr + bid, lt, mask=mask)
@triton.jit # pragma: no cover
def _pre_cache_scatter_kernel(
loop_times_ptr,
gridx_offset_ptr,
batch_ids_ptr,
tile_ids_per_batch_ptr,
BLOCK: tl.constexpr,
):
"""
Scatter batch_ids and tile_ids for one batch.
One program per batch (grid = bsz).
"""
bid = tl.program_id(0)
lt = tl.load(loop_times_ptr + bid)
offset = tl.load(gridx_offset_ptr + bid)
for off in range(0, lt, BLOCK):
idx = off + tl.arange(0, BLOCK)
m = idx < lt
tl.store(batch_ids_ptr + offset + idx, bid, mask=m)
tl.store(tile_ids_per_batch_ptr + offset + idx, idx, mask=m)
def pre_cache_len_concat_triton(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
bsz,
block_size,
max_tile_size_per_bs,
cu_seqlens_k_buf=None,
batch_ids_buf=None,
tile_ids_buf=None,
cache_len_buf=None,
loop_times_buf=None,
gridx_offset_buf=None,
):
"""
GPU-only Triton replacement for pre_cache_len_concat C++ op.
No D2H copy CUDA Graph capture compatible.
Two-phase approach:
Phase 1: Vectorized kernel computes cu_seqlens_k, cache_len, loop_times
Phase 2: Per-batch scatter kernel writes batch_ids and tile_ids
Optional *_buf args: pre-allocated buffers for CUDA Graph compatibility.
Returns:
cu_seqlens_k: [bsz+1] int32, GPU
batch_ids: [bsz * max_tile_size_per_bs] int32, GPU
tile_ids_per_batch: [bsz * max_tile_size_per_bs] int32, GPU
"""
cu_seqlens_k = (
cu_seqlens_k_buf[: bsz + 1] if cu_seqlens_k_buf is not None else paddle.empty([bsz + 1], dtype="int32")
)
out_size = max(bsz * max_tile_size_per_bs, 1)
batch_ids = batch_ids_buf[:out_size] if batch_ids_buf is not None else paddle.empty([out_size], dtype="int32")
tile_ids = tile_ids_buf[:out_size] if tile_ids_buf is not None else paddle.empty([out_size], dtype="int32")
if bsz > 0:
BLOCK_BSZ = triton.next_power_of_2(bsz)
_cache_len_buf = cache_len_buf[:bsz] if cache_len_buf is not None else paddle.empty([bsz], dtype="int32")
_loop_times_buf = loop_times_buf[:bsz] if loop_times_buf is not None else paddle.empty([bsz], dtype="int32")
# Phase 1: compute cu_seqlens_k, cache_len, loop_times
_pre_cache_cu_seqlens_kernel[(1,)](
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
cu_seqlens_k,
_cache_len_buf,
_loop_times_buf,
bsz=bsz,
block_size=block_size,
BLOCK_BSZ=BLOCK_BSZ,
)
# Phase 2: compute gridx_offset (exclusive prefix sum) and scatter
gridx_offset = triton_cumsum_with_zero_prefix(_loop_times_buf, bsz, out_buf=gridx_offset_buf)
_pre_cache_scatter_kernel[(bsz,)](
_loop_times_buf,
gridx_offset,
batch_ids,
tile_ids,
BLOCK=128,
)
else:
cu_seqlens_k[0:1] = paddle.zeros([1], dtype="int32")
return cu_seqlens_k, batch_ids, tile_ids
def pre_cache_len_concat_ref(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, bsz, block_size, max_tile_size_per_bs
):
"""
Reference (Python for-loop) implementation of pre_cache_len_concat.
Uses .item() NOT CUDA Graph compatible. Kept for validation.
"""
cu_seqlens_k = paddle.zeros([bsz + 1], dtype="int32")
batch_ids = paddle.empty([max(bsz * max_tile_size_per_bs, 1)], dtype="int32")
tile_ids = paddle.empty([max(bsz * max_tile_size_per_bs, 1)], dtype="int32")
gridx = 0
total_tokens = 0
for bid in range(bsz):
enc = int(seq_lens_encoder[bid].item())
dec = int(seq_lens_decoder[bid].item())
qlen = int(seq_lens_this_time[bid].item())
cache_len = dec if enc > 0 else 0
loop_times = (cache_len + block_size - 1) // block_size
for tile_id in range(loop_times):
batch_ids[gridx] = bid
tile_ids[gridx] = tile_id
gridx += 1
total_tokens += cache_len + qlen
cu_seqlens_k[bid + 1] = total_tokens
return cu_seqlens_k, batch_ids, tile_ids
# ---------------------------------------------------------------------------
# Triton attention kernel (unified, deterministic)
# ---------------------------------------------------------------------------
@triton.jit # pragma: no cover
def _fwd_kernel_unified(
Q,
O,
K_Buffer,
V_Buffer,
qo_indptr,
kv_indptr,
kv_indices,
prefix_lens,
sm_scale,
kv_group_num,
stride_qbs,
stride_qh,
stride_obs,
stride_oh,
# K_Buffer strides: [num_blocks, kv_heads, block_size, head_dim]
stride_kb, # dim0: block
stride_kh, # dim1: head
stride_kt, # dim2: token offset in block
# V_Buffer strides: [num_blocks, kv_heads, block_size, head_dim]
stride_vb,
stride_vh,
stride_vt,
Lq: tl.constexpr,
Lv: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
KV_BLOCK_SIZE: tl.constexpr,
IS_CAUSAL: tl.constexpr,
):
"""
Unified 1-stage extend attention kernel for deterministic inference.
Both prefix and extend KV are accessed through unified kv_indices,
ensuring identical accumulation order regardless of cache hit/miss.
"""
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
cur_block_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
# Load sequence metadata
cur_seq_q_start = tl.load(qo_indptr + cur_seq)
cur_seq_q_len = tl.load(qo_indptr + cur_seq + 1) - cur_seq_q_start
cur_seq_kv_start = tl.load(kv_indptr + cur_seq)
cur_seq_kv_len = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start
cur_seq_prefix_len = tl.load(prefix_lens + cur_seq)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_q_len
mask_d = offs_d < Lq
mask_dv = offs_dv < Lv
# Load Q block: Q shape is [num_tokens, num_heads, head_dim]
offs_q = (
(cur_seq_q_start + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :]
)
q = tl.load(Q + offs_q, mask=mask_m[:, None] & mask_d[None, :], other=0.0)
# Initialize online softmax accumulators
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
# Unified loop over all KV (prefix + extend)
for start_n in range(0, cur_seq_kv_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_seq_kv_len
# Build mask: bounds + causal
final_mask = mask_m[:, None] & mask_n[None, :]
if IS_CAUSAL:
# Prefix tokens: always visible (no causal mask)
# Extend tokens: apply standard causal mask
q_idx = cur_block_m * BLOCK_M + offs_m[:, None]
k_idx_in_total = start_n + offs_n[None, :]
k_is_extend = k_idx_in_total >= cur_seq_prefix_len
k_idx_in_extend = k_idx_in_total - cur_seq_prefix_len
causal_mask = tl.where(k_is_extend, q_idx >= k_idx_in_extend, True)
final_mask &= causal_mask
# Load KV indices (flat token indices: block_id * block_size + offset)
offs_kv_loc = tl.load(
kv_indices + cur_seq_kv_start + start_n + offs_n,
mask=mask_n,
other=0,
)
# Decompose flat index into (block_id, offset_in_block)
kv_block_ids = offs_kv_loc // KV_BLOCK_SIZE
kv_offsets = offs_kv_loc % KV_BLOCK_SIZE
# Load K: cache shape [num_blocks, kv_heads, block_size, head_dim]
# addr = block_id * stride_kb + head * stride_kh + offset * stride_kt + d
offs_buf_k = (
kv_block_ids[None, :] * stride_kb
+ cur_kv_head * stride_kh
+ kv_offsets[None, :] * stride_kt
+ offs_d[:, None]
)
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :] & mask_d[:, None], other=0.0)
# QK = Q @ K^T, shape [BLOCK_M, BLOCK_N]
qk = tl.dot(q.to(k.dtype), k) * sm_scale
qk = tl.where(final_mask, qk, float("-inf"))
# Online softmax update
row_max = tl.max(qk, 1)
# Avoid -inf in exp: clamp to a large negative value
row_max_safe = tl.where(row_max == float("-inf"), -1e20, row_max)
m_new = tl.maximum(m_i, row_max_safe)
re_scale = tl.exp(m_i - m_new)
p = tl.exp(qk - m_new[:, None])
l_i = l_i * re_scale + tl.sum(p, 1)
# Load V: same 4D layout as K
offs_buf_v = (
kv_block_ids[:, None] * stride_vb
+ cur_kv_head * stride_vh
+ kv_offsets[:, None] * stride_vt
+ offs_dv[None, :]
)
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0)
# Accumulate: rescale old acc, add new P @ V
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
m_i = m_new
# Final output = acc / l_i
offs_o = (
(cur_seq_q_start + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_dv[None, :]
)
# Avoid division by zero for fully masked rows
safe_l = tl.where(l_i == 0.0, 1.0, l_i)
tl.store(O + offs_o, acc / safe_l[:, None], mask=mask_m[:, None] & mask_dv[None, :])
def extend_attention_fwd_unified(
q,
o,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
prefix_lens,
num_q_heads,
num_kv_heads,
head_dim,
max_len_extend,
is_causal=True,
sm_scale=None,
):
"""
Launch the unified extend attention kernel.
Args:
q: [num_tokens, num_q_heads, head_dim]
o: [num_tokens, num_q_heads, head_dim] (output, will be written)
k_buffer: KV cache key buffer [num_blocks, kv_heads, block_size, head_dim]
v_buffer: KV cache value buffer [num_blocks, kv_heads, block_size, head_dim]
qo_indptr: [bs+1] query/output CSR indptr
kv_indptr: [bs+1] unified KV CSR indptr
kv_indices: [total_kv_len] flat token indices into paged cache
prefix_lens: [bs] prefix length per sequence
num_q_heads: number of query heads
num_kv_heads: number of KV heads
head_dim: head dimension
max_len_extend: max extend length (for grid sizing)
is_causal: whether to apply causal mask
sm_scale: softmax scale, defaults to 1/sqrt(head_dim)
"""
Lq = head_dim
Lv = head_dim
BLOCK_DMODEL = triton.next_power_of_2(Lq)
BLOCK_DV = triton.next_power_of_2(Lv)
# Choose block sizes based on head_dim
if Lq <= 128:
BLOCK_M, BLOCK_N = 64, 128
elif Lq <= 256:
BLOCK_M, BLOCK_N = 64, 64
else:
BLOCK_M, BLOCK_N = 32, 32
num_warps = 4 if Lq <= 64 else 8
sm_scale = sm_scale or (1.0 / (Lq**0.5))
batch_size = qo_indptr.shape[0] - 1
kv_group_num = num_q_heads // num_kv_heads
# KV cache block_size: k_buffer shape is [num_blocks, kv_heads, block_size, head_dim]
kv_block_size = k_buffer.shape[2]
grid = (batch_size, num_q_heads, triton.cdiv(max_len_extend, BLOCK_M))
_fwd_kernel_unified[grid](
q,
o,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
prefix_lens,
sm_scale,
kv_group_num,
q.strides[0],
q.strides[1],
o.strides[0],
o.strides[1],
# K strides: [num_blocks, kv_heads, block_size, head_dim]
k_buffer.strides[0],
k_buffer.strides[1],
k_buffer.strides[2],
# V strides: [num_blocks, kv_heads, block_size, head_dim]
v_buffer.strides[0],
v_buffer.strides[1],
v_buffer.strides[2],
Lq=Lq,
Lv=Lv,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DV=BLOCK_DV,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
KV_BLOCK_SIZE=kv_block_size,
IS_CAUSAL=is_causal,
num_warps=num_warps,
num_stages=1,
)
return o