mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
fix dsa (#7252)
This commit is contained in:
@@ -2061,22 +2061,28 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
|
||||
int batch_id, length;
|
||||
const IdType* block_table_pre_batch;
|
||||
|
||||
IdType* dst;
|
||||
|
||||
if (seq_len_decoder != nullptr) { // decode
|
||||
batch_id = batch_id_per_token[bid / q_num_heads];
|
||||
// batch_id = batch_id_per_token[bid / q_num_heads];
|
||||
batch_id = bid / q_num_heads;
|
||||
if (batch_id == -1) return;
|
||||
length = (seq_len_decoder[batch_id]); // for pack q k
|
||||
if (length == 0) return;
|
||||
if (block_tables != nullptr) {
|
||||
block_table_pre_batch = block_tables + batch_id * max_block_num;
|
||||
}
|
||||
dst = output + aux_input[batch_id] * top_k;
|
||||
|
||||
} else { // prefill
|
||||
// length = (lengths != nullptr) ? lengths[bid] : static_cast<int>(max_len);
|
||||
length = (lengths != nullptr) ? lengths[bid / q_num_heads]
|
||||
: static_cast<int>(max_len);
|
||||
dst = output + bid * top_k;
|
||||
}
|
||||
|
||||
const DType* score = input + bid * max_len;
|
||||
IdType* dst = output + bid * top_k;
|
||||
// IdType* dst = output + bid * top_k;
|
||||
|
||||
// Mode-specific setup
|
||||
[[maybe_unused]] const IdType* src_page_entry = nullptr;
|
||||
@@ -2110,8 +2116,8 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
|
||||
? static_cast<IdType>(block_ids * 64 + block_offset)
|
||||
: static_cast<IdType>(-1);
|
||||
} else {
|
||||
dst[i] =
|
||||
(i < length) ? static_cast<IdType>(i) : static_cast<IdType>(-1);
|
||||
dst[i] = (i < length) ? static_cast<IdType>(i) + offset_val
|
||||
: static_cast<IdType>(-1);
|
||||
}
|
||||
} else { // Plain
|
||||
if (i < length) {
|
||||
@@ -2337,10 +2343,9 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
|
||||
block_idx = idx / 64;
|
||||
block_ids = block_table_pre_batch[block_idx];
|
||||
block_offset = idx % 64;
|
||||
dst[base] =
|
||||
static_cast<IdType>(block_ids * 64 + block_offset); // + offset_val
|
||||
dst[base] = static_cast<IdType>(block_ids * 64 + block_offset);
|
||||
} else {
|
||||
dst[base] = static_cast<IdType>(idx); //+ offset_val;
|
||||
dst[base] = static_cast<IdType>(idx) + offset_val;
|
||||
}
|
||||
|
||||
} else { // Plain
|
||||
|
||||
Reference in New Issue
Block a user