[Optimization][BugFix]Optimize Deepseek networking code (#6861)

* update dsk model

* update dsk model
This commit is contained in:
AIbin
2026-03-16 16:52:43 +08:00
committed by GitHub
parent bb925c605f
commit c9f7f5234e
2 changed files with 204 additions and 45 deletions
@@ -1496,6 +1496,145 @@ cudaError_t RadixTopKMaskLogitsMultiCTA(DType* logits,
return cudaSuccess;
}
/*!
* \brief Launch multi-CTA Radix Top-K with Page Table Transform kernel.
*
* Performs top-k selection and gathers indices through a page table.
* Used for sparse attention's second stage in prefill mode.
*
* \param input Input scores tensor [num_rows, max_len]
* \param output_page_table Output page table entries [num_rows, top_k]
* \param src_page_table Source page table [batch_size, max_len]
* \param src_stride Stride of source page table (typically max_len)
* \param row_to_batch Mapping from row index to batch index [num_rows], or
* nullptr if 1:1
* \param lengths Sequence lengths per row [num_rows]
* \param num_rows Number of rows to process
* \param top_k_val Number of top elements to select
* \param max_len Maximum sequence length (input stride)
* \param row_states_buffer Buffer for inter-CTA synchronization
* \param stream CUDA stream
*/
template <typename DType, typename IdType>
cudaError_t RadixTopKPageTableTransformMultiCTA(
DType* input,
IdType* output_page_table,
const IdType* src_page_table,
int64_t src_stride,
const IdType* row_to_batch,
IdType* lengths,
uint32_t num_rows,
const IdType* seq_len_decoder,
const IdType* batch_id_per_token,
uint32_t top_k_val,
uint32_t q_num_heads,
uint32_t max_len,
RadixRowState* row_states_buffer,
cudaStream_t stream = 0) {
using OrderedType = typename RadixTopKTraits<DType>::OrderedType;
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), max_len);
int device;
FLASHINFER_CUDA_CALL(cudaGetDevice(&device));
int num_sms;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device));
int max_smem_per_block;
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 5);
constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16);
const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned;
uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType);
max_chunk_elements = round_down(max_chunk_elements, vec_size);
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);
uint32_t ctas_per_group = ceil_div(max_len, max_chunk_elements);
uint32_t chunk_size = ceil_div(max_len, ctas_per_group);
chunk_size = round_up(chunk_size, vec_size);
chunk_size = std::min(chunk_size, max_chunk_elements);
const bool single_cta = (ctas_per_group == 1);
const uint32_t smem_size =
fixed_smem_aligned + chunk_size * sizeof(OrderedType);
uint32_t num_groups =
std::min(static_cast<uint32_t>(num_sms) / ctas_per_group, num_rows);
if (num_groups == 0) num_groups = 1;
uint32_t total_ctas = num_groups * ctas_per_group;
// Unified kernel parameters
DType* output_values = nullptr; // Not used in PageTableTransform mode
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
if (single_cta) {
auto kernel = RadixTopKKernel_Unified<BLOCK_THREADS,
VEC_SIZE,
true,
RadixTopKMode::PageTableTransform,
DType,
IdType>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
dim3 nblks(total_ctas);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&input,
&output_page_table,
&output_values,
&src_page_table,
&lengths,
&row_to_batch,
&src_stride,
&seq_len_decoder,
&batch_id_per_token,
&top_k_val,
&q_num_heads,
&max_len,
&num_rows,
&row_states_buffer,
&chunk_size,
&ctas_per_group};
FLASHINFER_CUDA_CALL(cudaLaunchKernel(
(void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
auto kernel = RadixTopKKernel_Unified<BLOCK_THREADS,
VEC_SIZE,
false,
RadixTopKMode::PageTableTransform,
DType,
IdType>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
dim3 nblks(total_ctas);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&input,
&output_page_table,
&output_values,
&src_page_table,
&lengths,
&row_to_batch,
&src_stride,
&seq_len_decoder,
&batch_id_per_token,
&top_k_val,
&q_num_heads,
&max_len,
&num_rows,
&row_states_buffer,
&chunk_size,
&ctas_per_group};
FLASHINFER_CUDA_CALL(cudaLaunchKernel(
(void*)kernel, nblks, nthrs, args, smem_size, stream));
}
});
return cudaSuccess;
}
/*!
* \brief Launch multi-CTA Radix Top-K with Ragged Index Transform kernel.
*
@@ -2422,6 +2561,50 @@ inline bool ShouldUseFilteredTopK(uint32_t num_rows,
}
}
// Dispatch functions with heuristics
template <typename DType, typename IdType>
cudaError_t TopKPageTableTransformDispatch(DType* input,
IdType* output_page_table,
const IdType* src_page_table,
int64_t src_stride,
const IdType* row_to_batch,
IdType* lengths,
uint32_t num_rows,
const IdType* seq_len_decoder,
const IdType* batch_id_per_token,
uint32_t top_k_val,
uint32_t q_num_heads,
uint32_t max_len,
RadixRowState* row_states_buffer,
cudaStream_t stream = 0) {
if (ShouldUseFilteredTopK<DType>(num_rows, top_k_val, max_len)) {
return FilteredTopKPageTableTransform<DType, IdType>(input,
output_page_table,
src_page_table,
src_stride,
row_to_batch,
lengths,
num_rows,
top_k_val,
max_len,
stream);
}
return RadixTopKPageTableTransformMultiCTA<DType, IdType>(input,
output_page_table,
src_page_table,
src_stride,
row_to_batch,
lengths,
num_rows,
seq_len_decoder,
batch_id_per_token,
top_k_val,
q_num_heads,
max_len,
row_states_buffer,
stream);
}
template <typename DType, typename IdType>
cudaError_t TopKRaggedTransformDispatch(DType* input,
IdType* output_indices,