mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization][BugFix]Optimize Deepseek networking code (#6861)
* update dsk model * update dsk model
This commit is contained in:
@@ -1496,6 +1496,145 @@ cudaError_t RadixTopKMaskLogitsMultiCTA(DType* logits,
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Launch multi-CTA Radix Top-K with Page Table Transform kernel.
|
||||
*
|
||||
* Performs top-k selection and gathers indices through a page table.
|
||||
* Used for sparse attention's second stage in prefill mode.
|
||||
*
|
||||
* \param input Input scores tensor [num_rows, max_len]
|
||||
* \param output_page_table Output page table entries [num_rows, top_k]
|
||||
* \param src_page_table Source page table [batch_size, max_len]
|
||||
* \param src_stride Stride of source page table (typically max_len)
|
||||
* \param row_to_batch Mapping from row index to batch index [num_rows], or
|
||||
* nullptr if 1:1
|
||||
* \param lengths Sequence lengths per row [num_rows]
|
||||
* \param num_rows Number of rows to process
|
||||
* \param top_k_val Number of top elements to select
|
||||
* \param max_len Maximum sequence length (input stride)
|
||||
* \param row_states_buffer Buffer for inter-CTA synchronization
|
||||
* \param stream CUDA stream
|
||||
*/
|
||||
template <typename DType, typename IdType>
|
||||
cudaError_t RadixTopKPageTableTransformMultiCTA(
|
||||
DType* input,
|
||||
IdType* output_page_table,
|
||||
const IdType* src_page_table,
|
||||
int64_t src_stride,
|
||||
const IdType* row_to_batch,
|
||||
IdType* lengths,
|
||||
uint32_t num_rows,
|
||||
const IdType* seq_len_decoder,
|
||||
const IdType* batch_id_per_token,
|
||||
uint32_t top_k_val,
|
||||
uint32_t q_num_heads,
|
||||
uint32_t max_len,
|
||||
RadixRowState* row_states_buffer,
|
||||
cudaStream_t stream = 0) {
|
||||
using OrderedType = typename RadixTopKTraits<DType>::OrderedType;
|
||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), max_len);
|
||||
|
||||
int device;
|
||||
FLASHINFER_CUDA_CALL(cudaGetDevice(&device));
|
||||
int num_sms;
|
||||
FLASHINFER_CUDA_CALL(
|
||||
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device));
|
||||
int max_smem_per_block;
|
||||
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
|
||||
|
||||
constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 5);
|
||||
constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16);
|
||||
|
||||
const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned;
|
||||
uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType);
|
||||
max_chunk_elements = round_down(max_chunk_elements, vec_size);
|
||||
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
|
||||
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);
|
||||
|
||||
uint32_t ctas_per_group = ceil_div(max_len, max_chunk_elements);
|
||||
uint32_t chunk_size = ceil_div(max_len, ctas_per_group);
|
||||
chunk_size = round_up(chunk_size, vec_size);
|
||||
chunk_size = std::min(chunk_size, max_chunk_elements);
|
||||
|
||||
const bool single_cta = (ctas_per_group == 1);
|
||||
const uint32_t smem_size =
|
||||
fixed_smem_aligned + chunk_size * sizeof(OrderedType);
|
||||
|
||||
uint32_t num_groups =
|
||||
std::min(static_cast<uint32_t>(num_sms) / ctas_per_group, num_rows);
|
||||
if (num_groups == 0) num_groups = 1;
|
||||
uint32_t total_ctas = num_groups * ctas_per_group;
|
||||
|
||||
// Unified kernel parameters
|
||||
DType* output_values = nullptr; // Not used in PageTableTransform mode
|
||||
|
||||
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
||||
if (single_cta) {
|
||||
auto kernel = RadixTopKKernel_Unified<BLOCK_THREADS,
|
||||
VEC_SIZE,
|
||||
true,
|
||||
RadixTopKMode::PageTableTransform,
|
||||
DType,
|
||||
IdType>;
|
||||
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
dim3 nblks(total_ctas);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&input,
|
||||
&output_page_table,
|
||||
&output_values,
|
||||
&src_page_table,
|
||||
&lengths,
|
||||
&row_to_batch,
|
||||
&src_stride,
|
||||
&seq_len_decoder,
|
||||
&batch_id_per_token,
|
||||
&top_k_val,
|
||||
&q_num_heads,
|
||||
&max_len,
|
||||
&num_rows,
|
||||
&row_states_buffer,
|
||||
&chunk_size,
|
||||
&ctas_per_group};
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel(
|
||||
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
} else {
|
||||
auto kernel = RadixTopKKernel_Unified<BLOCK_THREADS,
|
||||
VEC_SIZE,
|
||||
false,
|
||||
RadixTopKMode::PageTableTransform,
|
||||
DType,
|
||||
IdType>;
|
||||
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
dim3 nblks(total_ctas);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&input,
|
||||
&output_page_table,
|
||||
&output_values,
|
||||
&src_page_table,
|
||||
&lengths,
|
||||
&row_to_batch,
|
||||
&src_stride,
|
||||
&seq_len_decoder,
|
||||
&batch_id_per_token,
|
||||
&top_k_val,
|
||||
&q_num_heads,
|
||||
&max_len,
|
||||
&num_rows,
|
||||
&row_states_buffer,
|
||||
&chunk_size,
|
||||
&ctas_per_group};
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel(
|
||||
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
}
|
||||
});
|
||||
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Launch multi-CTA Radix Top-K with Ragged Index Transform kernel.
|
||||
*
|
||||
@@ -2422,6 +2561,50 @@ inline bool ShouldUseFilteredTopK(uint32_t num_rows,
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch functions with heuristics
|
||||
template <typename DType, typename IdType>
|
||||
cudaError_t TopKPageTableTransformDispatch(DType* input,
|
||||
IdType* output_page_table,
|
||||
const IdType* src_page_table,
|
||||
int64_t src_stride,
|
||||
const IdType* row_to_batch,
|
||||
IdType* lengths,
|
||||
uint32_t num_rows,
|
||||
const IdType* seq_len_decoder,
|
||||
const IdType* batch_id_per_token,
|
||||
uint32_t top_k_val,
|
||||
uint32_t q_num_heads,
|
||||
uint32_t max_len,
|
||||
RadixRowState* row_states_buffer,
|
||||
cudaStream_t stream = 0) {
|
||||
if (ShouldUseFilteredTopK<DType>(num_rows, top_k_val, max_len)) {
|
||||
return FilteredTopKPageTableTransform<DType, IdType>(input,
|
||||
output_page_table,
|
||||
src_page_table,
|
||||
src_stride,
|
||||
row_to_batch,
|
||||
lengths,
|
||||
num_rows,
|
||||
top_k_val,
|
||||
max_len,
|
||||
stream);
|
||||
}
|
||||
return RadixTopKPageTableTransformMultiCTA<DType, IdType>(input,
|
||||
output_page_table,
|
||||
src_page_table,
|
||||
src_stride,
|
||||
row_to_batch,
|
||||
lengths,
|
||||
num_rows,
|
||||
seq_len_decoder,
|
||||
batch_id_per_token,
|
||||
top_k_val,
|
||||
q_num_heads,
|
||||
max_len,
|
||||
row_states_buffer,
|
||||
stream);
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType>
|
||||
cudaError_t TopKRaggedTransformDispatch(DType* input,
|
||||
IdType* output_indices,
|
||||
|
||||
Reference in New Issue
Block a user