mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
cuda13.0, implement changes to CCCL (#6751)
This commit is contained in:
@@ -27,11 +27,16 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "sample_kernels/utils.cuh"
|
||||
#include "../cccl_compat.h" // CCCL 3.0 compatibility
|
||||
|
||||
namespace sampling {
|
||||
|
||||
using namespace cub;
|
||||
|
||||
// Use fd_cub_compat for functors removed in CCCL 3.0
|
||||
using fd_cub_compat::Max;
|
||||
using fd_cub_compat::Min;
|
||||
|
||||
#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \
|
||||
if (compute_capacity.first >= 8) { \
|
||||
constexpr uint32_t BLOCK_THREADS = 1024; \
|
||||
@@ -317,7 +322,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
}
|
||||
int max_valid_index = BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage->block_prim.reduce_int)
|
||||
.Reduce(valid_index, cub::Max());
|
||||
.Reduce(valid_index, Max());
|
||||
if (tx == 0 && max_valid_index != -1) {
|
||||
temp_storage->last_valid_id = max_valid_index;
|
||||
}
|
||||
@@ -636,12 +641,12 @@ __device__ __forceinline__ float GetMaxValue(float* in_data,
|
||||
max_val = max(max_val,
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce)
|
||||
.Reduce(in_data_, cub::Max()));
|
||||
.Reduce(in_data_, Max()));
|
||||
#else
|
||||
max_val = max(max_val,
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce)
|
||||
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
|
||||
.Reduce<VEC_SIZE>(in_data_, Max()));
|
||||
#endif
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -837,11 +842,11 @@ __global__ void TopKRenormProbKernel(DType* probs,
|
||||
}
|
||||
min_gt_low = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce)
|
||||
.Reduce(min_gt_low, cub::Min());
|
||||
.Reduce(min_gt_low, Min());
|
||||
__syncthreads();
|
||||
max_le_high = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce)
|
||||
.Reduce(max_le_high, cub::Max());
|
||||
.Reduce(max_le_high, Max());
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
|
||||
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
|
||||
|
||||
Reference in New Issue
Block a user