cuda13.0, implement changes to CCCL (#6751)

This commit is contained in:
wangyifei
2026-03-10 16:47:02 +08:00
committed by GitHub
parent 54581b8653
commit b57c960837
13 changed files with 211 additions and 27 deletions
+10 -5
View File
@@ -27,11 +27,16 @@
#include <numeric>
#include "sample_kernels/utils.cuh"
#include "../cccl_compat.h" // CCCL 3.0 compatibility
namespace sampling {
using namespace cub;
// Use fd_cub_compat for functors removed in CCCL 3.0
using fd_cub_compat::Max;
using fd_cub_compat::Min;
#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \
if (compute_capacity.first >= 8) { \
constexpr uint32_t BLOCK_THREADS = 1024; \
@@ -317,7 +322,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
}
int max_valid_index = BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage->block_prim.reduce_int)
.Reduce(valid_index, cub::Max());
.Reduce(valid_index, Max());
if (tx == 0 && max_valid_index != -1) {
temp_storage->last_valid_id = max_valid_index;
}
@@ -636,12 +641,12 @@ __device__ __forceinline__ float GetMaxValue(float* in_data,
max_val = max(max_val,
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce)
.Reduce(in_data_, cub::Max()));
.Reduce(in_data_, Max()));
#else
max_val = max(max_val,
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
.Reduce<VEC_SIZE>(in_data_, Max()));
#endif
__syncthreads();
}
@@ -837,11 +842,11 @@ __global__ void TopKRenormProbKernel(DType* probs,
}
min_gt_low = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce)
.Reduce(min_gt_low, cub::Min());
.Reduce(min_gt_low, Min());
__syncthreads();
max_le_high = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce)
.Reduce(max_le_high, cub::Max());
.Reduce(max_le_high, Max());
if (tx == 0) {
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;