This commit is contained in:
AIbin
2026-04-08 20:21:38 +08:00
committed by GitHub
parent b262419db1
commit 48d2bbeb74
3 changed files with 55 additions and 49 deletions
@@ -2061,22 +2061,28 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
int batch_id, length;
const IdType* block_table_pre_batch;
IdType* dst;
if (seq_len_decoder != nullptr) { // decode
batch_id = batch_id_per_token[bid / q_num_heads];
// batch_id = batch_id_per_token[bid / q_num_heads];
batch_id = bid / q_num_heads;
if (batch_id == -1) return;
length = (seq_len_decoder[batch_id]); // for pack q k
if (length == 0) return;
if (block_tables != nullptr) {
block_table_pre_batch = block_tables + batch_id * max_block_num;
}
dst = output + aux_input[batch_id] * top_k;
} else { // prefill
// length = (lengths != nullptr) ? lengths[bid] : static_cast<int>(max_len);
length = (lengths != nullptr) ? lengths[bid / q_num_heads]
: static_cast<int>(max_len);
dst = output + bid * top_k;
}
const DType* score = input + bid * max_len;
IdType* dst = output + bid * top_k;
// IdType* dst = output + bid * top_k;
// Mode-specific setup
[[maybe_unused]] const IdType* src_page_entry = nullptr;
@@ -2110,8 +2116,8 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
? static_cast<IdType>(block_ids * 64 + block_offset)
: static_cast<IdType>(-1);
} else {
dst[i] =
(i < length) ? static_cast<IdType>(i) : static_cast<IdType>(-1);
dst[i] = (i < length) ? static_cast<IdType>(i) + offset_val
: static_cast<IdType>(-1);
}
} else { // Plain
if (i < length) {
@@ -2337,10 +2343,9 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
block_idx = idx / 64;
block_ids = block_table_pre_batch[block_idx];
block_offset = idx % 64;
dst[base] =
static_cast<IdType>(block_ids * 64 + block_offset); // + offset_val
dst[base] = static_cast<IdType>(block_ids * 64 + block_offset);
} else {
dst[base] = static_cast<IdType>(idx); //+ offset_val;
dst[base] = static_cast<IdType>(idx) + offset_val;
}
} else { // Plain