mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
fix dsa (#7252)
This commit is contained in:
@@ -674,14 +674,13 @@ class Indexer(nn.Layer):
|
||||
self.indexer_cache, k_fp8_cache, k_scale_cache, forward_meta.block_tables, forward_meta.cu_seqlens_k
|
||||
)
|
||||
|
||||
k_scale_cache = k_scale_cache.flatten()[: k.shape[0]]
|
||||
k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache
|
||||
k_scale_cache_real = k_scale_cache.flatten()[: k.shape[0]].contiguous()
|
||||
k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache_real
|
||||
|
||||
# TODO(changwenbin): Constructed using maskoffset
|
||||
# ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous()
|
||||
num_tokens = q_fp8.shape[0]
|
||||
ks = paddle.zeros(num_tokens, dtype=paddle.int32)
|
||||
ks_topk = paddle.zeros(num_tokens, dtype=paddle.int32)
|
||||
ke = paddle.zeros(num_tokens, dtype=paddle.int32)
|
||||
|
||||
bsz = forward_meta.seq_lens_this_time.shape[0]
|
||||
@@ -696,20 +695,13 @@ class Indexer(nn.Layer):
|
||||
|
||||
logits = deep_gemm.fp8_mqa_logits(
|
||||
q_fp8, k_cache, weights, ks, ke, max_seqlen_k=max_seqlen_k, clean_logits=False
|
||||
)
|
||||
|
||||
# To save GPU global memory usage
|
||||
assert logits.size() == (num_tokens, max_seqlen_k)
|
||||
tmp = paddle.full((num_tokens, num_tokens), float("-inf"))
|
||||
for i in range(num_tokens):
|
||||
tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]]
|
||||
logits = tmp
|
||||
).contiguous()
|
||||
|
||||
radix_topk_ragged_transform(
|
||||
logits.contiguous(),
|
||||
logits,
|
||||
indexer_top_k,
|
||||
ks_topk, # self.offsets,
|
||||
ke - ks + 1, # mask.contiguous(),#self.lengths,
|
||||
ks, # self.offsets,# 初始K方向偏移,
|
||||
ke - ks, # self.lengths,# 表明当前q 关注的k有多长;
|
||||
None, # forward_meta.seq_lens_decoder,
|
||||
None, # forward_meta.batch_id_per_token,
|
||||
None,
|
||||
@@ -740,12 +732,12 @@ class Indexer(nn.Layer):
|
||||
schedule_metadata,
|
||||
self.max_model_len,
|
||||
clean_logits=True,
|
||||
)
|
||||
).contiguous()
|
||||
|
||||
radix_topk_ragged_transform(
|
||||
logits.contiguous(),
|
||||
logits,
|
||||
indexer_top_k,
|
||||
self.offsets, # unused
|
||||
forward_meta.cu_seqlens_q,
|
||||
self.lengths, # unused
|
||||
cache_seqlens,
|
||||
forward_meta.batch_id_per_token,
|
||||
@@ -753,7 +745,7 @@ class Indexer(nn.Layer):
|
||||
None, # self.buffer
|
||||
forward_meta.block_tables.shape[1],
|
||||
self.index_topk,
|
||||
1, # q_head
|
||||
1, # kv_head
|
||||
)
|
||||
|
||||
return indexer_top_k
|
||||
|
||||
Reference in New Issue
Block a user