mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
fix dsa (#7252)
This commit is contained in:
@@ -36,24 +36,26 @@ class BaseTestRadixTopk(unittest.TestCase):
|
||||
def get_reference_topk(self, input_pd, lengths_pd, offsets_pd, top_k, q_num_heads):
|
||||
"""
|
||||
使用 paddle.topk 生成参考结果
|
||||
注意:算子输出的索引是 0-based 相对索引(不包含 offset)
|
||||
注意:算子输出的索引是相对于 offsets 的偏移量(0-based 相对索引)
|
||||
|
||||
Args:
|
||||
input_pd: (num_rows, max_len)
|
||||
lengths_pd: (batch_size,) - 每个batch的长度
|
||||
offsets_pd: (num_rows,) - 每一行的偏移基点(未使用,仅保留参数兼容性)
|
||||
offsets_pd: (num_rows,) - 每一行的偏移基点
|
||||
top_k: k值
|
||||
q_num_heads: query head数量
|
||||
|
||||
Returns:
|
||||
ref_indices: (num_rows, top_k) - 参考索引(0-based 相对索引),长度不足的部分用-1填充
|
||||
ref_indices: (num_rows, top_k) - 参考索引(相对于 offset 的偏移),长度不足的部分用-1填充
|
||||
"""
|
||||
num_rows = input_pd.shape[0]
|
||||
ref_indices = paddle.full([num_rows, top_k], -1, dtype="int32")
|
||||
offsets = offsets_pd.numpy()
|
||||
|
||||
for row_idx in range(num_rows):
|
||||
batch_idx = row_idx // q_num_heads
|
||||
length = lengths_pd[batch_idx].item()
|
||||
offset = offsets[row_idx]
|
||||
|
||||
if length == 0:
|
||||
continue
|
||||
@@ -61,13 +63,13 @@ class BaseTestRadixTopk(unittest.TestCase):
|
||||
row_data = input_pd[row_idx, :length]
|
||||
|
||||
if length <= top_k:
|
||||
# 长度不足top_k,按顺序返回所有索引(0-based)
|
||||
ref_indices[row_idx, :length] = paddle.arange(0, length, dtype="int32")
|
||||
# 长度不足top_k,按顺序返回所有索引(相对于 offset)
|
||||
ref_indices[row_idx, :length] = paddle.arange(offset, offset + length, dtype="int32")
|
||||
else:
|
||||
# 长度足够,使用 paddle.topk 获取最大的top_k个值的索引
|
||||
topk_vals, topk_inds = paddle.topk(row_data, top_k)
|
||||
# 直接使用 topk 返回的索引(0-based)
|
||||
ref_indices[row_idx, :top_k] = topk_inds
|
||||
# 加上 offset 作为基点
|
||||
ref_indices[row_idx, :top_k] = topk_inds + offset
|
||||
|
||||
return ref_indices
|
||||
|
||||
@@ -171,41 +173,48 @@ class TestDecodeMode(BaseTestRadixTopk):
|
||||
paddle.seed(2025)
|
||||
|
||||
batch_size = 2
|
||||
q_num_heads = 4
|
||||
num_rows = batch_size * q_num_heads
|
||||
kv_head = 1 # decode 模式下,每个 batch 只有一个新 token
|
||||
num_rows = batch_size * kv_head # = batch_size
|
||||
max_len = 1024
|
||||
top_k = 8
|
||||
|
||||
# 使用 paddle 构造数据
|
||||
input_pd = paddle.randn([num_rows, max_len], dtype="float32")
|
||||
offsets_pd = paddle.arange(num_rows, dtype="int32")
|
||||
lengths_pd = paddle.full([num_rows], 0, dtype="int32")
|
||||
|
||||
# 生成 cu_seqlens_q: 每个 batch 在打平的 query 中的偏移量
|
||||
# 在 decode 模式下,每个 batch 只有一个新 token,所以 cu_seqlens_q = [0, 1, 2, ..., batch_size]
|
||||
cu_seqlens_q_pd = paddle.concat(
|
||||
[
|
||||
paddle.zeros([1], dtype="int32"),
|
||||
paddle.cumsum(paddle.ones([batch_size], dtype="int32")).astype("int32"),
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
lengths_pd = paddle.full([num_rows], 0, dtype="int32") # unused
|
||||
seq_len_decoder_pd = paddle.randint(16, 128, [batch_size], dtype="int32")
|
||||
|
||||
# 生成 batch_id_per_token
|
||||
batch_id_per_token_pd = paddle.arange(num_rows, dtype="int32") // q_num_heads
|
||||
|
||||
# 调用算子
|
||||
# 调用算子(不使用 block_tables,让它按照 prefill 模式类似的逻辑工作)
|
||||
output_indices = paddle.full([num_rows, top_k], -1, dtype="int32")
|
||||
radix_topk_ragged_transform(
|
||||
input_pd,
|
||||
output_indices,
|
||||
offsets_pd,
|
||||
lengths_pd,
|
||||
cu_seqlens_q_pd,
|
||||
lengths_pd, # unused
|
||||
seq_len_decoder_pd,
|
||||
batch_id_per_token_pd,
|
||||
None,
|
||||
None,
|
||||
0,
|
||||
None, # batch_id_per_token
|
||||
None, # block_tables
|
||||
None, # buffer
|
||||
0, # max_block_per_seq
|
||||
top_k,
|
||||
q_num_heads,
|
||||
kv_head,
|
||||
)
|
||||
|
||||
# Decode 模式下,长度 = seq_len_decoder + 1
|
||||
decode_lengths = seq_len_decoder_pd + 1
|
||||
|
||||
# 获取参考结果
|
||||
ref_indices = self.get_reference_topk(input_pd, decode_lengths, offsets_pd, top_k, q_num_heads)
|
||||
# 获取参考结果(注意:num_rows = batch_size * kv_head)
|
||||
ref_indices = self.get_reference_topk(input_pd, decode_lengths, cu_seqlens_q_pd, top_k, kv_head)
|
||||
|
||||
# 对比结果
|
||||
result = self.compare_indices(output_indices, ref_indices)
|
||||
|
||||
Reference in New Issue
Block a user