This commit is contained in:
AIbin
2026-04-08 20:21:38 +08:00
committed by GitHub
parent b262419db1
commit 48d2bbeb74
3 changed files with 55 additions and 49 deletions
+10 -18
View File
@@ -674,14 +674,13 @@ class Indexer(nn.Layer):
self.indexer_cache, k_fp8_cache, k_scale_cache, forward_meta.block_tables, forward_meta.cu_seqlens_k
)
k_scale_cache = k_scale_cache.flatten()[: k.shape[0]]
k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache
k_scale_cache_real = k_scale_cache.flatten()[: k.shape[0]].contiguous()
k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache_real
# TODO(changwenbin): Constructed using maskoffset
# ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous()
num_tokens = q_fp8.shape[0]
ks = paddle.zeros(num_tokens, dtype=paddle.int32)
ks_topk = paddle.zeros(num_tokens, dtype=paddle.int32)
ke = paddle.zeros(num_tokens, dtype=paddle.int32)
bsz = forward_meta.seq_lens_this_time.shape[0]
@@ -696,20 +695,13 @@ class Indexer(nn.Layer):
logits = deep_gemm.fp8_mqa_logits(
q_fp8, k_cache, weights, ks, ke, max_seqlen_k=max_seqlen_k, clean_logits=False
)
# To save GPU global memory usage
assert logits.size() == (num_tokens, max_seqlen_k)
tmp = paddle.full((num_tokens, num_tokens), float("-inf"))
for i in range(num_tokens):
tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]]
logits = tmp
).contiguous()
radix_topk_ragged_transform(
logits.contiguous(),
logits,
indexer_top_k,
ks_topk, # self.offsets,
ke - ks + 1, # mask.contiguous(),#self.lengths,
ks, # self.offsets,# 初始K方向偏移,
ke - ks, # self.lengths,# 表明当前q 关注的k有多长;
None, # forward_meta.seq_lens_decoder,
None, # forward_meta.batch_id_per_token,
None,
@@ -740,12 +732,12 @@ class Indexer(nn.Layer):
schedule_metadata,
self.max_model_len,
clean_logits=True,
)
).contiguous()
radix_topk_ragged_transform(
logits.contiguous(),
logits,
indexer_top_k,
self.offsets, # unused
forward_meta.cu_seqlens_q,
self.lengths, # unused
cache_seqlens,
forward_meta.batch_id_per_token,
@@ -753,7 +745,7 @@ class Indexer(nn.Layer):
None, # self.buffer
forward_meta.block_tables.shape[1],
self.index_topk,
1, # q_head
1, # kv_head
)
return indexer_top_k