mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[Optimization][Feature]Supports multiple batches of DSK-DSA. (#6930)
* support DSA_MUTI_BATCH * update test topk * update dsk-dsa
This commit is contained in:
@@ -26,6 +26,8 @@ cudaError_t DispatchTopK(paddle::Tensor& input,
|
||||
uint32_t num_rows,
|
||||
const int32_t* seq_len_decoder,
|
||||
const int32_t* batch_id_per_token,
|
||||
const int32_t* block_tables,
|
||||
uint32_t max_block_num,
|
||||
uint32_t top_k,
|
||||
uint32_t q_num_heads,
|
||||
uint32_t max_len,
|
||||
@@ -45,6 +47,8 @@ cudaError_t DispatchTopK(paddle::Tensor& input,
|
||||
num_rows,
|
||||
seq_len_decoder,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
static_cast<uint32_t>(max_block_num),
|
||||
static_cast<uint32_t>(top_k),
|
||||
static_cast<uint32_t>(q_num_heads),
|
||||
max_len,
|
||||
@@ -60,7 +64,9 @@ void RadixTopkRaggedTransform(
|
||||
paddle::Tensor& lengths,
|
||||
paddle::optional<paddle::Tensor>& seq_len_decoder,
|
||||
paddle::optional<paddle::Tensor>& batch_id_per_token,
|
||||
paddle::optional<paddle::Tensor>& block_tables,
|
||||
paddle::optional<paddle::Tensor>& maybe_row_states_buffer,
|
||||
int max_block_num,
|
||||
int top_k,
|
||||
int q_num_heads = 0) {
|
||||
// CHECK_INPUT(input);
|
||||
@@ -102,6 +108,11 @@ void RadixTopkRaggedTransform(
|
||||
batch_id_per_token_ptr =
|
||||
static_cast<const int32_t*>(tensor_ptr.data<int32_t>());
|
||||
}
|
||||
const int32_t* block_tables_ptr = nullptr;
|
||||
if (block_tables) {
|
||||
auto& tensor_ptr = block_tables.get();
|
||||
block_tables_ptr = static_cast<const int32_t*>(tensor_ptr.data<int32_t>());
|
||||
}
|
||||
|
||||
if (input_dtype == paddle::DataType::BFLOAT16) {
|
||||
status = DispatchTopK<paddle::DataType::BFLOAT16>(input,
|
||||
@@ -111,6 +122,8 @@ void RadixTopkRaggedTransform(
|
||||
num_rows,
|
||||
seq_len_ptr,
|
||||
batch_id_per_token_ptr,
|
||||
block_tables_ptr,
|
||||
max_block_num,
|
||||
top_k,
|
||||
q_num_heads,
|
||||
max_len,
|
||||
@@ -124,6 +137,8 @@ void RadixTopkRaggedTransform(
|
||||
num_rows,
|
||||
seq_len_ptr,
|
||||
batch_id_per_token_ptr,
|
||||
block_tables_ptr,
|
||||
max_block_num,
|
||||
top_k,
|
||||
q_num_heads,
|
||||
max_len,
|
||||
@@ -141,6 +156,7 @@ PD_BUILD_STATIC_OP(radix_topk_ragged_transform)
|
||||
"lengths",
|
||||
paddle::Optional("seq_len_decoder"),
|
||||
paddle::Optional("batch_id_per_token"),
|
||||
paddle::Optional("block_tables"),
|
||||
paddle::Optional("maybe_row_states_buffer")})
|
||||
.Attrs({"top_k : int", "q_num_heads : int"})
|
||||
.Attrs({"top_k : int", "q_num_heads : int", "max_block_num : int"})
|
||||
.SetKernelFn(PD_KERNEL(RadixTopkRaggedTransform));
|
||||
|
||||
@@ -935,6 +935,10 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified(
|
||||
seq_len_decoder, // NOTE (changwenbin) Support FD P/D indexer topk
|
||||
const IdType*
|
||||
batch_id_per_token, // NOTE (changwenbin) Support FD P/D indexer topk
|
||||
const IdType*
|
||||
block_tables, // NOTE (changwenbin) Support FD sparse indexer topk
|
||||
uint32_t
|
||||
max_block_num, // NOTE (changwenbin) Support FD sparse indexer topk
|
||||
uint32_t top_k_val,
|
||||
uint32_t q_num_heads,
|
||||
uint32_t stride,
|
||||
@@ -985,6 +989,7 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified(
|
||||
|
||||
// NOTE (changwenbin) Support FD Metadata
|
||||
int batch_id;
|
||||
const IdType* block_table_pre_batch;
|
||||
if (batch_id_per_token != nullptr) {
|
||||
batch_id = batch_id_per_token[row_idx / 4];
|
||||
if (batch_id == -1) continue;
|
||||
@@ -999,7 +1004,10 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified(
|
||||
} else {
|
||||
// NOTE (changwenbin) decode
|
||||
if (seq_len_decoder != nullptr && batch_id_per_token != nullptr) {
|
||||
length = (seq_len_decoder[batch_id] + 1); // for pack q k
|
||||
length = (seq_len_decoder[batch_id]); // for pack q k
|
||||
if (block_tables != nullptr) {
|
||||
block_table_pre_batch = block_tables + batch_id * max_block_num;
|
||||
}
|
||||
} else {
|
||||
// NOTE (changwenbin) prefill for pack q k
|
||||
// length = lengths[row_idx]; // Per-row length
|
||||
@@ -1065,8 +1073,21 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified(
|
||||
IdType offset = aux_data[row_idx];
|
||||
if (length <= top_k_val) {
|
||||
for (uint32_t i = tx; i < top_k_val; i += BLOCK_THREADS) {
|
||||
row_output[i] = (i < length) ? static_cast<IdType>(i) + offset
|
||||
: static_cast<IdType>(-1);
|
||||
if (seq_len_decoder != nullptr && block_tables != nullptr) {
|
||||
int block_idx, block_ids, block_offset;
|
||||
if (i < length) {
|
||||
block_idx = i / 64;
|
||||
block_ids = block_table_pre_batch[block_idx];
|
||||
block_offset = i % 64;
|
||||
}
|
||||
row_output[i] =
|
||||
(i < length)
|
||||
? static_cast<IdType>(block_ids * 64 + block_offset)
|
||||
: static_cast<IdType>(-1);
|
||||
} else {
|
||||
row_output[i] =
|
||||
(i < length) ? static_cast<IdType>(i) : static_cast<IdType>(-1);
|
||||
}
|
||||
}
|
||||
// Clear histogram for next iteration
|
||||
if constexpr (!SINGLE_CTA) {
|
||||
@@ -1589,8 +1610,9 @@ cudaError_t RadixTopKPageTableTransformMultiCTA(
|
||||
&lengths,
|
||||
&row_to_batch,
|
||||
&src_stride,
|
||||
&seq_len_decoder,
|
||||
&batch_id_per_token,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&top_k_val,
|
||||
&q_num_heads,
|
||||
&max_len,
|
||||
@@ -1618,8 +1640,9 @@ cudaError_t RadixTopKPageTableTransformMultiCTA(
|
||||
&lengths,
|
||||
&row_to_batch,
|
||||
&src_stride,
|
||||
&seq_len_decoder,
|
||||
&batch_id_per_token,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&top_k_val,
|
||||
&q_num_heads,
|
||||
&max_len,
|
||||
@@ -1659,6 +1682,8 @@ cudaError_t RadixTopKRaggedTransformMultiCTA(DType* input,
|
||||
uint32_t num_rows,
|
||||
const IdType* seq_len_decoder,
|
||||
const IdType* batch_id_per_token,
|
||||
const IdType* block_tables,
|
||||
uint32_t max_block_num,
|
||||
uint32_t top_k_val,
|
||||
uint32_t q_num_heads,
|
||||
uint32_t max_len,
|
||||
@@ -1727,6 +1752,8 @@ cudaError_t RadixTopKRaggedTransformMultiCTA(DType* input,
|
||||
&aux_stride,
|
||||
&seq_len_decoder,
|
||||
&batch_id_per_token,
|
||||
&block_tables,
|
||||
&max_block_num,
|
||||
&top_k_val,
|
||||
&q_num_heads,
|
||||
&max_len,
|
||||
@@ -1756,6 +1783,8 @@ cudaError_t RadixTopKRaggedTransformMultiCTA(DType* input,
|
||||
&aux_stride,
|
||||
&seq_len_decoder,
|
||||
&batch_id_per_token,
|
||||
&block_tables,
|
||||
&max_block_num,
|
||||
&top_k_val,
|
||||
&q_num_heads,
|
||||
&max_len,
|
||||
@@ -2014,6 +2043,8 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
|
||||
uint32_t num_rows,
|
||||
const IdType* __restrict__ seq_len_decoder,
|
||||
const IdType* __restrict__ batch_id_per_token,
|
||||
const IdType* __restrict__ block_tables,
|
||||
uint32_t max_block_num,
|
||||
uint32_t top_k,
|
||||
uint32_t q_num_heads,
|
||||
uint32_t max_len) {
|
||||
@@ -2028,12 +2059,16 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
|
||||
|
||||
// NOTE:(changwenbin) Support FD Metadata
|
||||
int batch_id, length;
|
||||
const IdType* block_table_pre_batch;
|
||||
|
||||
if (seq_len_decoder != nullptr) { // decode
|
||||
batch_id = batch_id_per_token[bid / q_num_heads];
|
||||
if (batch_id == -1) return;
|
||||
length = (seq_len_decoder[batch_id] + 1); // for pack q k
|
||||
length = (seq_len_decoder[batch_id]); // for pack q k
|
||||
if (length == 0) return;
|
||||
if (block_tables != nullptr) {
|
||||
block_table_pre_batch = block_tables + batch_id * max_block_num;
|
||||
}
|
||||
} else { // prefill
|
||||
// length = (lengths != nullptr) ? lengths[bid] : static_cast<int>(max_len);
|
||||
length = (lengths != nullptr) ? lengths[bid / q_num_heads]
|
||||
@@ -2064,8 +2099,20 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
|
||||
if constexpr (MODE == FilteredTopKMode::PageTable) {
|
||||
dst[i] = (i < length) ? src_page_entry[i] : static_cast<IdType>(-1);
|
||||
} else if constexpr (MODE == FilteredTopKMode::Ragged) {
|
||||
dst[i] = (i < length) ? static_cast<IdType>(i) + offset_val
|
||||
: static_cast<IdType>(-1);
|
||||
if (seq_len_decoder != nullptr && block_tables != nullptr) {
|
||||
int block_idx, block_ids, block_offset;
|
||||
if (i < length) {
|
||||
block_idx = i / 64;
|
||||
block_ids = block_table_pre_batch[block_idx];
|
||||
block_offset = i % 64;
|
||||
}
|
||||
dst[i] = (i < length)
|
||||
? static_cast<IdType>(block_ids * 64 + block_offset)
|
||||
: static_cast<IdType>(-1);
|
||||
} else {
|
||||
dst[i] =
|
||||
(i < length) ? static_cast<IdType>(i) : static_cast<IdType>(-1);
|
||||
}
|
||||
} else { // Plain
|
||||
if (i < length) {
|
||||
dst[i] = static_cast<IdType>(i);
|
||||
@@ -2284,7 +2331,18 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
|
||||
if constexpr (MODE == FilteredTopKMode::PageTable) {
|
||||
dst[base] = src_page_entry[idx];
|
||||
} else if constexpr (MODE == FilteredTopKMode::Ragged) {
|
||||
dst[base] = static_cast<IdType>(idx) + offset_val;
|
||||
// NOTE(changwenbin) support decode paged indexer in Ranged mode.
|
||||
if (seq_len_decoder != nullptr && block_tables != nullptr) {
|
||||
int block_idx, block_ids, block_offset;
|
||||
block_idx = idx / 64;
|
||||
block_ids = block_table_pre_batch[block_idx];
|
||||
block_offset = idx % 64;
|
||||
dst[base] =
|
||||
static_cast<IdType>(block_ids * 64 + block_offset); // + offset_val
|
||||
} else {
|
||||
dst[base] = static_cast<IdType>(idx); //+ offset_val;
|
||||
}
|
||||
|
||||
} else { // Plain
|
||||
dst[base] = static_cast<IdType>(idx);
|
||||
dst_values[base] = score[idx];
|
||||
@@ -2375,6 +2433,8 @@ cudaError_t FilteredTopKRaggedTransform(DType* input,
|
||||
uint32_t num_rows,
|
||||
const IdType* seq_len_decoder,
|
||||
const IdType* batch_id_per_token,
|
||||
const IdType* block_tables,
|
||||
uint32_t max_block_num,
|
||||
uint32_t top_k_val,
|
||||
uint32_t q_num_heads,
|
||||
uint32_t max_len,
|
||||
@@ -2397,6 +2457,8 @@ cudaError_t FilteredTopKRaggedTransform(DType* input,
|
||||
&num_rows,
|
||||
&seq_len_decoder,
|
||||
&batch_id_per_token,
|
||||
&block_tables,
|
||||
&max_block_num,
|
||||
&top_k_val,
|
||||
&q_num_heads,
|
||||
&max_len};
|
||||
@@ -2613,36 +2675,42 @@ cudaError_t TopKRaggedTransformDispatch(DType* input,
|
||||
uint32_t num_rows,
|
||||
const IdType* seq_len_decoder,
|
||||
const IdType* batch_id_per_token,
|
||||
const IdType* block_tables,
|
||||
uint32_t max_block_num,
|
||||
uint32_t top_k_val,
|
||||
uint32_t q_num_heads,
|
||||
uint32_t max_len,
|
||||
RadixRowState* row_states_buffer,
|
||||
cudaStream_t stream = 0) {
|
||||
if (ShouldUseFilteredTopK<DType>(num_rows, top_k_val, max_len)) {
|
||||
return FilteredTopKRaggedTransform<DType, IdType>(input,
|
||||
output_indices,
|
||||
offsets,
|
||||
lengths,
|
||||
num_rows,
|
||||
seq_len_decoder,
|
||||
batch_id_per_token,
|
||||
top_k_val,
|
||||
q_num_heads,
|
||||
max_len,
|
||||
stream);
|
||||
}
|
||||
return RadixTopKRaggedTransformMultiCTA<DType, IdType>(input,
|
||||
output_indices,
|
||||
offsets,
|
||||
lengths,
|
||||
num_rows,
|
||||
seq_len_decoder,
|
||||
batch_id_per_token,
|
||||
top_k_val,
|
||||
q_num_heads,
|
||||
max_len,
|
||||
row_states_buffer,
|
||||
stream);
|
||||
// if (ShouldUseFilteredTopK<DType>(num_rows, top_k_val, max_len)) {
|
||||
return FilteredTopKRaggedTransform<DType, IdType>(input,
|
||||
output_indices,
|
||||
offsets,
|
||||
lengths,
|
||||
num_rows,
|
||||
seq_len_decoder,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
max_block_num,
|
||||
top_k_val,
|
||||
q_num_heads,
|
||||
max_len,
|
||||
stream);
|
||||
// }
|
||||
// return RadixTopKRaggedTransformMultiCTA<DType, IdType>(input,
|
||||
// output_indices,
|
||||
// offsets,
|
||||
// lengths,
|
||||
// num_rows,
|
||||
// seq_len_decoder,
|
||||
// batch_id_per_token,
|
||||
// block_tables,
|
||||
// max_block_num,
|
||||
// top_k_val,
|
||||
// q_num_heads,
|
||||
// max_len,
|
||||
// row_states_buffer,
|
||||
// stream);
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType>
|
||||
|
||||
Reference in New Issue
Block a user