mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
update indexer model (#6791)
This commit is contained in:
@@ -606,8 +606,9 @@ class Indexer(nn.Layer):
|
||||
# ===================================== cache =============================================
|
||||
|
||||
# ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous()
|
||||
ks = paddle.zeros(forward_meta.seq_lens_encoder, dtype=paddle.int32)
|
||||
ke = paddle.arange(forward_meta.seq_lens_encoder, dtype=paddle.int32) + 1 # + (seq_len_kv - seq_len)
|
||||
num_tokens = q_fp8.shape[0]
|
||||
ks = paddle.zeros(num_tokens, dtype=paddle.int32)
|
||||
ke = paddle.arange(num_tokens, dtype=paddle.int32) + 1 # + (seq_len_kv - seq_len)
|
||||
max_seqlen_k = (ke - ks).max().item()
|
||||
|
||||
logits = deep_gemm.fp8_mqa_logits(
|
||||
@@ -615,12 +616,12 @@ class Indexer(nn.Layer):
|
||||
)
|
||||
|
||||
# To save GPU global memory usage
|
||||
assert logits.size() == (forward_meta.seq_lens_encoder, max_seqlen_k)
|
||||
assert logits.size() == (num_tokens, max_seqlen_k)
|
||||
tmp = paddle.full(
|
||||
(forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder),
|
||||
(num_tokens, num_tokens),
|
||||
float("-inf"),
|
||||
)
|
||||
for i in range(forward_meta.seq_lens_encoder):
|
||||
for i in range(num_tokens):
|
||||
tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]]
|
||||
logits = tmp
|
||||
|
||||
|
||||
Reference in New Issue
Block a user