[Optimization][Feature]Supports multiple batches of DSK-DSA. (#6930)

* support DSA_MUTI_BATCH

* update test topk

* update dsk-dsa
This commit is contained in:
AIbin
2026-03-20 15:59:22 +08:00
committed by GitHub
parent 1c38da2118
commit bf7e2424d0
5 changed files with 262 additions and 61 deletions
@@ -26,6 +26,8 @@ cudaError_t DispatchTopK(paddle::Tensor& input,
uint32_t num_rows,
const int32_t* seq_len_decoder,
const int32_t* batch_id_per_token,
const int32_t* block_tables,
uint32_t max_block_num,
uint32_t top_k,
uint32_t q_num_heads,
uint32_t max_len,
@@ -45,6 +47,8 @@ cudaError_t DispatchTopK(paddle::Tensor& input,
num_rows,
seq_len_decoder,
batch_id_per_token,
block_tables,
static_cast<uint32_t>(max_block_num),
static_cast<uint32_t>(top_k),
static_cast<uint32_t>(q_num_heads),
max_len,
@@ -60,7 +64,9 @@ void RadixTopkRaggedTransform(
paddle::Tensor& lengths,
paddle::optional<paddle::Tensor>& seq_len_decoder,
paddle::optional<paddle::Tensor>& batch_id_per_token,
paddle::optional<paddle::Tensor>& block_tables,
paddle::optional<paddle::Tensor>& maybe_row_states_buffer,
int max_block_num,
int top_k,
int q_num_heads = 0) {
// CHECK_INPUT(input);
@@ -102,6 +108,11 @@ void RadixTopkRaggedTransform(
batch_id_per_token_ptr =
static_cast<const int32_t*>(tensor_ptr.data<int32_t>());
}
const int32_t* block_tables_ptr = nullptr;
if (block_tables) {
auto& tensor_ptr = block_tables.get();
block_tables_ptr = static_cast<const int32_t*>(tensor_ptr.data<int32_t>());
}
if (input_dtype == paddle::DataType::BFLOAT16) {
status = DispatchTopK<paddle::DataType::BFLOAT16>(input,
@@ -111,6 +122,8 @@ void RadixTopkRaggedTransform(
num_rows,
seq_len_ptr,
batch_id_per_token_ptr,
block_tables_ptr,
max_block_num,
top_k,
q_num_heads,
max_len,
@@ -124,6 +137,8 @@ void RadixTopkRaggedTransform(
num_rows,
seq_len_ptr,
batch_id_per_token_ptr,
block_tables_ptr,
max_block_num,
top_k,
q_num_heads,
max_len,
@@ -141,6 +156,7 @@ PD_BUILD_STATIC_OP(radix_topk_ragged_transform)
"lengths",
paddle::Optional("seq_len_decoder"),
paddle::Optional("batch_id_per_token"),
paddle::Optional("block_tables"),
paddle::Optional("maybe_row_states_buffer")})
.Attrs({"top_k : int", "q_num_heads : int"})
.Attrs({"top_k : int", "q_num_heads : int", "max_block_num : int"})
.SetKernelFn(PD_KERNEL(RadixTopkRaggedTransform));
@@ -935,6 +935,10 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified(
seq_len_decoder, // NOTE (changwenbin) Support FD P/D indexer topk
const IdType*
batch_id_per_token, // NOTE (changwenbin) Support FD P/D indexer topk
const IdType*
block_tables, // NOTE (changwenbin) Support FD sparse indexer topk
uint32_t
max_block_num, // NOTE (changwenbin) Support FD sparse indexer topk
uint32_t top_k_val,
uint32_t q_num_heads,
uint32_t stride,
@@ -985,6 +989,7 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified(
// NOTE (changwenbin) Support FD Metadata
int batch_id;
const IdType* block_table_pre_batch;
if (batch_id_per_token != nullptr) {
batch_id = batch_id_per_token[row_idx / 4];
if (batch_id == -1) continue;
@@ -999,7 +1004,10 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified(
} else {
// NOTE (changwenbin) decode
if (seq_len_decoder != nullptr && batch_id_per_token != nullptr) {
length = (seq_len_decoder[batch_id] + 1); // for pack q k
length = (seq_len_decoder[batch_id]); // for pack q k
if (block_tables != nullptr) {
block_table_pre_batch = block_tables + batch_id * max_block_num;
}
} else {
// NOTE (changwenbin) prefill for pack q k
// length = lengths[row_idx]; // Per-row length
@@ -1065,8 +1073,21 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified(
IdType offset = aux_data[row_idx];
if (length <= top_k_val) {
for (uint32_t i = tx; i < top_k_val; i += BLOCK_THREADS) {
row_output[i] = (i < length) ? static_cast<IdType>(i) + offset
: static_cast<IdType>(-1);
if (seq_len_decoder != nullptr && block_tables != nullptr) {
int block_idx, block_ids, block_offset;
if (i < length) {
block_idx = i / 64;
block_ids = block_table_pre_batch[block_idx];
block_offset = i % 64;
}
row_output[i] =
(i < length)
? static_cast<IdType>(block_ids * 64 + block_offset)
: static_cast<IdType>(-1);
} else {
row_output[i] =
(i < length) ? static_cast<IdType>(i) : static_cast<IdType>(-1);
}
}
// Clear histogram for next iteration
if constexpr (!SINGLE_CTA) {
@@ -1589,8 +1610,9 @@ cudaError_t RadixTopKPageTableTransformMultiCTA(
&lengths,
&row_to_batch,
&src_stride,
&seq_len_decoder,
&batch_id_per_token,
nullptr,
nullptr,
nullptr,
&top_k_val,
&q_num_heads,
&max_len,
@@ -1618,8 +1640,9 @@ cudaError_t RadixTopKPageTableTransformMultiCTA(
&lengths,
&row_to_batch,
&src_stride,
&seq_len_decoder,
&batch_id_per_token,
nullptr,
nullptr,
nullptr,
&top_k_val,
&q_num_heads,
&max_len,
@@ -1659,6 +1682,8 @@ cudaError_t RadixTopKRaggedTransformMultiCTA(DType* input,
uint32_t num_rows,
const IdType* seq_len_decoder,
const IdType* batch_id_per_token,
const IdType* block_tables,
uint32_t max_block_num,
uint32_t top_k_val,
uint32_t q_num_heads,
uint32_t max_len,
@@ -1727,6 +1752,8 @@ cudaError_t RadixTopKRaggedTransformMultiCTA(DType* input,
&aux_stride,
&seq_len_decoder,
&batch_id_per_token,
&block_tables,
&max_block_num,
&top_k_val,
&q_num_heads,
&max_len,
@@ -1756,6 +1783,8 @@ cudaError_t RadixTopKRaggedTransformMultiCTA(DType* input,
&aux_stride,
&seq_len_decoder,
&batch_id_per_token,
&block_tables,
&max_block_num,
&top_k_val,
&q_num_heads,
&max_len,
@@ -2014,6 +2043,8 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
uint32_t num_rows,
const IdType* __restrict__ seq_len_decoder,
const IdType* __restrict__ batch_id_per_token,
const IdType* __restrict__ block_tables,
uint32_t max_block_num,
uint32_t top_k,
uint32_t q_num_heads,
uint32_t max_len) {
@@ -2028,12 +2059,16 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
// NOTE:(changwenbin) Support FD Metadata
int batch_id, length;
const IdType* block_table_pre_batch;
if (seq_len_decoder != nullptr) { // decode
batch_id = batch_id_per_token[bid / q_num_heads];
if (batch_id == -1) return;
length = (seq_len_decoder[batch_id] + 1); // for pack q k
length = (seq_len_decoder[batch_id]); // for pack q k
if (length == 0) return;
if (block_tables != nullptr) {
block_table_pre_batch = block_tables + batch_id * max_block_num;
}
} else { // prefill
// length = (lengths != nullptr) ? lengths[bid] : static_cast<int>(max_len);
length = (lengths != nullptr) ? lengths[bid / q_num_heads]
@@ -2064,8 +2099,20 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
if constexpr (MODE == FilteredTopKMode::PageTable) {
dst[i] = (i < length) ? src_page_entry[i] : static_cast<IdType>(-1);
} else if constexpr (MODE == FilteredTopKMode::Ragged) {
dst[i] = (i < length) ? static_cast<IdType>(i) + offset_val
: static_cast<IdType>(-1);
if (seq_len_decoder != nullptr && block_tables != nullptr) {
int block_idx, block_ids, block_offset;
if (i < length) {
block_idx = i / 64;
block_ids = block_table_pre_batch[block_idx];
block_offset = i % 64;
}
dst[i] = (i < length)
? static_cast<IdType>(block_ids * 64 + block_offset)
: static_cast<IdType>(-1);
} else {
dst[i] =
(i < length) ? static_cast<IdType>(i) : static_cast<IdType>(-1);
}
} else { // Plain
if (i < length) {
dst[i] = static_cast<IdType>(i);
@@ -2284,7 +2331,18 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
if constexpr (MODE == FilteredTopKMode::PageTable) {
dst[base] = src_page_entry[idx];
} else if constexpr (MODE == FilteredTopKMode::Ragged) {
dst[base] = static_cast<IdType>(idx) + offset_val;
// NOTE(changwenbin) support decode paged indexer in Ranged mode.
if (seq_len_decoder != nullptr && block_tables != nullptr) {
int block_idx, block_ids, block_offset;
block_idx = idx / 64;
block_ids = block_table_pre_batch[block_idx];
block_offset = idx % 64;
dst[base] =
static_cast<IdType>(block_ids * 64 + block_offset); // + offset_val
} else {
dst[base] = static_cast<IdType>(idx); //+ offset_val;
}
} else { // Plain
dst[base] = static_cast<IdType>(idx);
dst_values[base] = score[idx];
@@ -2375,6 +2433,8 @@ cudaError_t FilteredTopKRaggedTransform(DType* input,
uint32_t num_rows,
const IdType* seq_len_decoder,
const IdType* batch_id_per_token,
const IdType* block_tables,
uint32_t max_block_num,
uint32_t top_k_val,
uint32_t q_num_heads,
uint32_t max_len,
@@ -2397,6 +2457,8 @@ cudaError_t FilteredTopKRaggedTransform(DType* input,
&num_rows,
&seq_len_decoder,
&batch_id_per_token,
&block_tables,
&max_block_num,
&top_k_val,
&q_num_heads,
&max_len};
@@ -2613,36 +2675,42 @@ cudaError_t TopKRaggedTransformDispatch(DType* input,
uint32_t num_rows,
const IdType* seq_len_decoder,
const IdType* batch_id_per_token,
const IdType* block_tables,
uint32_t max_block_num,
uint32_t top_k_val,
uint32_t q_num_heads,
uint32_t max_len,
RadixRowState* row_states_buffer,
cudaStream_t stream = 0) {
if (ShouldUseFilteredTopK<DType>(num_rows, top_k_val, max_len)) {
return FilteredTopKRaggedTransform<DType, IdType>(input,
output_indices,
offsets,
lengths,
num_rows,
seq_len_decoder,
batch_id_per_token,
top_k_val,
q_num_heads,
max_len,
stream);
}
return RadixTopKRaggedTransformMultiCTA<DType, IdType>(input,
output_indices,
offsets,
lengths,
num_rows,
seq_len_decoder,
batch_id_per_token,
top_k_val,
q_num_heads,
max_len,
row_states_buffer,
stream);
// if (ShouldUseFilteredTopK<DType>(num_rows, top_k_val, max_len)) {
return FilteredTopKRaggedTransform<DType, IdType>(input,
output_indices,
offsets,
lengths,
num_rows,
seq_len_decoder,
batch_id_per_token,
block_tables,
max_block_num,
top_k_val,
q_num_heads,
max_len,
stream);
// }
// return RadixTopKRaggedTransformMultiCTA<DType, IdType>(input,
// output_indices,
// offsets,
// lengths,
// num_rows,
// seq_len_decoder,
// batch_id_per_token,
// block_tables,
// max_block_num,
// top_k_val,
// q_num_heads,
// max_len,
// row_states_buffer,
// stream);
}
template <typename DType, typename IdType>