update indexer model (#6791)

This commit is contained in:
AIbin
2026-03-13 14:11:39 +08:00
committed by GitHub
parent d935752be7
commit 2b8a5b0d81
@@ -606,8 +606,9 @@ class Indexer(nn.Layer):
# ===================================== cache =============================================
# ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous()
ks = paddle.zeros(forward_meta.seq_lens_encoder, dtype=paddle.int32)
ke = paddle.arange(forward_meta.seq_lens_encoder, dtype=paddle.int32) + 1 # + (seq_len_kv - seq_len)
num_tokens = q_fp8.shape[0]
ks = paddle.zeros(num_tokens, dtype=paddle.int32)
ke = paddle.arange(num_tokens, dtype=paddle.int32) + 1 # + (seq_len_kv - seq_len)
max_seqlen_k = (ke - ks).max().item()
logits = deep_gemm.fp8_mqa_logits(
@@ -615,12 +616,12 @@ class Indexer(nn.Layer):
)
# To save GPU global memory usage
assert logits.size() == (forward_meta.seq_lens_encoder, max_seqlen_k)
assert logits.size() == (num_tokens, max_seqlen_k)
tmp = paddle.full(
(forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder),
(num_tokens, num_tokens),
float("-inf"),
)
for i in range(forward_meta.seq_lens_encoder):
for i in range(num_tokens):
tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]]
logits = tmp