This commit is contained in:
AIbin
2026-04-08 20:21:38 +08:00
committed by GitHub
parent b262419db1
commit 48d2bbeb74
3 changed files with 55 additions and 49 deletions
+33 -24
View File
@@ -36,24 +36,26 @@ class BaseTestRadixTopk(unittest.TestCase):
def get_reference_topk(self, input_pd, lengths_pd, offsets_pd, top_k, q_num_heads):
"""
使用 paddle.topk 生成参考结果
注意:算子输出的索引是 0-based 相对索引(不包含 offset
注意:算子输出的索引是相对于 offsets 的偏移量(0-based 相对索引)
Args:
input_pd: (num_rows, max_len)
lengths_pd: (batch_size,) - 每个batch的长度
offsets_pd: (num_rows,) - 每一行的偏移基点(未使用,仅保留参数兼容性)
offsets_pd: (num_rows,) - 每一行的偏移基点
top_k: k值
q_num_heads: query head数量
Returns:
ref_indices: (num_rows, top_k) - 参考索引(0-based 相对索引),长度不足的部分用-1填充
ref_indices: (num_rows, top_k) - 参考索引(相对于 offset 的偏移),长度不足的部分用-1填充
"""
num_rows = input_pd.shape[0]
ref_indices = paddle.full([num_rows, top_k], -1, dtype="int32")
offsets = offsets_pd.numpy()
for row_idx in range(num_rows):
batch_idx = row_idx // q_num_heads
length = lengths_pd[batch_idx].item()
offset = offsets[row_idx]
if length == 0:
continue
@@ -61,13 +63,13 @@ class BaseTestRadixTopk(unittest.TestCase):
row_data = input_pd[row_idx, :length]
if length <= top_k:
# 长度不足top_k,按顺序返回所有索引(0-based
ref_indices[row_idx, :length] = paddle.arange(0, length, dtype="int32")
# 长度不足top_k,按顺序返回所有索引(相对于 offset
ref_indices[row_idx, :length] = paddle.arange(offset, offset + length, dtype="int32")
else:
# 长度足够,使用 paddle.topk 获取最大的top_k个值的索引
topk_vals, topk_inds = paddle.topk(row_data, top_k)
# 直接使用 topk 返回的索引(0-based
ref_indices[row_idx, :top_k] = topk_inds
# 加上 offset 作为基点
ref_indices[row_idx, :top_k] = topk_inds + offset
return ref_indices
@@ -171,41 +173,48 @@ class TestDecodeMode(BaseTestRadixTopk):
paddle.seed(2025)
batch_size = 2
q_num_heads = 4
num_rows = batch_size * q_num_heads
kv_head = 1 # decode 模式下,每个 batch 只有一个新 token
num_rows = batch_size * kv_head # = batch_size
max_len = 1024
top_k = 8
# 使用 paddle 构造数据
input_pd = paddle.randn([num_rows, max_len], dtype="float32")
offsets_pd = paddle.arange(num_rows, dtype="int32")
lengths_pd = paddle.full([num_rows], 0, dtype="int32")
# 生成 cu_seqlens_q: 每个 batch 在打平的 query 中的偏移量
# 在 decode 模式下,每个 batch 只有一个新 token,所以 cu_seqlens_q = [0, 1, 2, ..., batch_size]
cu_seqlens_q_pd = paddle.concat(
[
paddle.zeros([1], dtype="int32"),
paddle.cumsum(paddle.ones([batch_size], dtype="int32")).astype("int32"),
],
axis=0,
)
lengths_pd = paddle.full([num_rows], 0, dtype="int32") # unused
seq_len_decoder_pd = paddle.randint(16, 128, [batch_size], dtype="int32")
# 生成 batch_id_per_token
batch_id_per_token_pd = paddle.arange(num_rows, dtype="int32") // q_num_heads
# 调用算子
# 调用算子(不使用 block_tables,让它按照 prefill 模式类似的逻辑工作)
output_indices = paddle.full([num_rows, top_k], -1, dtype="int32")
radix_topk_ragged_transform(
input_pd,
output_indices,
offsets_pd,
lengths_pd,
cu_seqlens_q_pd,
lengths_pd, # unused
seq_len_decoder_pd,
batch_id_per_token_pd,
None,
None,
0,
None, # batch_id_per_token
None, # block_tables
None, # buffer
0, # max_block_per_seq
top_k,
q_num_heads,
kv_head,
)
# Decode 模式下,长度 = seq_len_decoder + 1
decode_lengths = seq_len_decoder_pd + 1
# 获取参考结果
ref_indices = self.get_reference_topk(input_pd, decode_lengths, offsets_pd, top_k, q_num_heads)
# 获取参考结果(注意:num_rows = batch_size * kv_head
ref_indices = self.get_reference_topk(input_pd, decode_lengths, cu_seqlens_q_pd, top_k, kv_head)
# 对比结果
result = self.compare_indices(output_indices, ref_indices)