Files
FastDeploy/custom_ops/gpu_ops/sparse_indexer/utils.cuh
T
AIbin c3aceb6bdc [Models][OP][Optimization] Support DeepSeek-v3.2 model, integrate DSA & Indexer architecture with FlashMLA/DeepGEMM (#6689)
* Support DeepSeek-v3.2 model, integrate DSA & Indexer architecture with FlashMLA/DeepGEMM
2026-03-10 15:05:14 +08:00

535 lines
24 KiB
Plaintext

/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_UTILS_CUH_
#define FLASHINFER_UTILS_CUH_
#include <cuda_bf16.h>
#include <cuda_device_runtime_api.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <atomic>
#include <cstdint>
#include <iostream>
#include <type_traits>
#include <vector>
#include "exception.h"
#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)
// macro to turn off fp16 qk reduction to reduce binary
#ifndef FLASHINFER_ALWAYS_DISUSE_FP16_QK_REDUCTION
#define FLASHINFER_ALWAYS_DISUSE_FP16_QK_REDUCTION 0
#endif
#ifndef NDEBUG
#define FLASHINFER_CUDA_CALL(func, ...) \
{ \
cudaError_t e = (func); \
if (e != cudaSuccess) { \
std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
<< ") " << __FILE__ << ": line " << __LINE__ \
<< " at function " << STR(func) << std::endl; \
return e; \
} \
}
#else
#define FLASHINFER_CUDA_CALL(func, ...) \
{ \
cudaError_t e = (func); \
if (e != cudaSuccess) { \
return e; \
} \
}
#endif
#define DISPATCH_USE_FP16_QK_REDUCTION( \
use_fp16_qk_reduction, USE_FP16_QK_REDUCTION, ...) \
if (use_fp16_qk_reduction) { \
FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \
} else { \
constexpr bool USE_FP16_QK_REDUCTION = false; \
__VA_ARGS__ \
}
#define DISPATCH_NUM_MMA_Q(num_mma_q, NUM_MMA_Q, ...) \
if (num_mma_q == 1) { \
constexpr size_t NUM_MMA_Q = 1; \
__VA_ARGS__ \
} else if (num_mma_q == 2) { \
constexpr size_t NUM_MMA_Q = 2; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported num_mma_q: " << num_mma_q; \
FLASHINFER_ERROR(err_msg.str()); \
}
#define DISPATCH_NUM_MMA_KV(max_mma_kv, NUM_MMA_KV, ...) \
if (max_mma_kv >= 8) { \
constexpr size_t NUM_MMA_KV = 8; \
__VA_ARGS__ \
} else if (max_mma_kv >= 4) { \
constexpr size_t NUM_MMA_KV = 4; \
__VA_ARGS__ \
} else if (max_mma_kv >= 2) { \
constexpr size_t NUM_MMA_KV = 2; \
__VA_ARGS__ \
} else if (max_mma_kv >= 1) { \
constexpr size_t NUM_MMA_KV = 1; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported max_mma_kv: " << max_mma_kv; \
FLASHINFER_ERROR(err_msg.str()); \
}
#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \
switch (cta_tile_q) { \
case 128: { \
constexpr uint32_t CTA_TILE_Q = 128; \
__VA_ARGS__ \
break; \
} \
case 64: { \
constexpr uint32_t CTA_TILE_Q = 64; \
__VA_ARGS__ \
break; \
} \
case 16: { \
constexpr uint32_t CTA_TILE_Q = 16; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 1) { \
constexpr size_t GROUP_SIZE = 1; \
__VA_ARGS__ \
} else if (group_size == 2) { \
constexpr size_t GROUP_SIZE = 2; \
__VA_ARGS__ \
} else if (group_size == 3) { \
constexpr size_t GROUP_SIZE = 3; \
__VA_ARGS__ \
} else if (group_size == 4) { \
constexpr size_t GROUP_SIZE = 4; \
__VA_ARGS__ \
} else if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported group_size: " << group_size; \
FLASHINFER_ERROR(err_msg.str()); \
}
#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \
switch (mask_mode) { \
case MaskMode::kNone: { \
constexpr MaskMode MASK_MODE = MaskMode::kNone; \
__VA_ARGS__ \
break; \
} \
case MaskMode::kCausal: { \
constexpr MaskMode MASK_MODE = MaskMode::kCausal; \
__VA_ARGS__ \
break; \
} \
case MaskMode::kCustom: { \
constexpr MaskMode MASK_MODE = MaskMode::kCustom; \
__VA_ARGS__ \
break; \
} \
case MaskMode::kMultiItemScoring: { \
constexpr MaskMode MASK_MODE = MaskMode::kMultiItemScoring; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported mask_mode: " << int(mask_mode); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
// convert head_dim to compile-time constant
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 64: { \
constexpr size_t HEAD_DIM = 64; \
__VA_ARGS__ \
break; \
} \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 256: { \
constexpr size_t HEAD_DIM = 256; \
__VA_ARGS__ \
break; \
} \
case 512: { \
constexpr size_t HEAD_DIM = 512; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported head_dim: " << head_dim; \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
// convert interleave to compile-time constant
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
constexpr bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
constexpr bool INTERLEAVE = false; \
__VA_ARGS__ \
}
#define DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, ...) \
switch (rope_dim) { \
case 16: { \
constexpr uint32_t ROPE_DIM = 16; \
__VA_ARGS__ \
break; \
} \
case 32: { \
constexpr uint32_t ROPE_DIM = 32; \
__VA_ARGS__ \
break; \
} \
case 64: { \
constexpr uint32_t ROPE_DIM = 64; \
__VA_ARGS__ \
break; \
} \
case 128: { \
constexpr uint32_t ROPE_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 256: { \
constexpr uint32_t ROPE_DIM = 256; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported ROPE_DIM: " << rope_dim; \
err_msg << ". Supported values: 16, 32, 64, 128, 256"; \
err_msg << " in DISPATCH_ROPE_DIM"; \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \
switch (pos_encoding_mode) { \
case PosEncodingMode::kNone: { \
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \
__VA_ARGS__ \
break; \
} \
case PosEncodingMode::kRoPELlama: { \
constexpr PosEncodingMode POS_ENCODING_MODE = \
PosEncodingMode::kRoPELlama; \
__VA_ARGS__ \
break; \
} \
case PosEncodingMode::kALiBi: { \
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
switch (aligned_vec_size) { \
case 16: { \
constexpr size_t ALIGNED_VEC_SIZE = 16; \
__VA_ARGS__ \
break; \
} \
case 8: { \
constexpr size_t ALIGNED_VEC_SIZE = 8; \
__VA_ARGS__ \
break; \
} \
case 4: { \
constexpr size_t ALIGNED_VEC_SIZE = 4; \
__VA_ARGS__ \
break; \
} \
case 2: { \
constexpr size_t ALIGNED_VEC_SIZE = 2; \
__VA_ARGS__ \
break; \
} \
case 1: { \
constexpr size_t ALIGNED_VEC_SIZE = 1; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
#define DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM( \
compute_capacity, NUM_STAGES_SMEM, ...) \
if (compute_capacity.first >= 8) { \
constexpr uint32_t NUM_STAGES_SMEM = 2; \
__VA_ARGS__ \
} else { \
constexpr uint32_t NUM_STAGES_SMEM = 1; \
__VA_ARGS__ \
}
namespace flashinfer {
template <typename T1, typename T2>
__forceinline__ __device__ __host__ constexpr T1 ceil_div(const T1 x,
const T2 y) noexcept {
return (x + y - 1) / y;
}
template <typename T1, typename T2>
__forceinline__ __device__ __host__ constexpr T1 round_up(const T1 x,
const T2 y) noexcept {
return ceil_div(x, y) * y;
}
template <typename T1, typename T2>
__forceinline__ __device__ __host__ constexpr T1 round_down(
const T1 x, const T2 y) noexcept {
return (x / y) * y;
}
inline std::pair<int, int> GetCudaComputeCapability() {
int device_id = 0;
cudaGetDevice(&device_id);
int major = 0, minor = 0;
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id);
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id);
return std::make_pair(major, minor);
}
// This function is thread-safe and cached the sm_count.
// But it will only check the current CUDA device, thus assuming each process
// handles single GPU.
inline int GetCudaMultiProcessorCount() {
static std::atomic<int> sm_count{0};
int cached = sm_count.load(std::memory_order_relaxed);
if (cached == 0) {
int device_id;
cudaGetDevice(&device_id);
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
cached = device_prop.multiProcessorCount;
sm_count.store(cached, std::memory_order_relaxed);
}
return cached;
}
template <typename T>
inline void DebugPrintCUDAArray(T* device_ptr,
size_t size,
std::string prefix = "") {
std::vector<T> host_array(size);
std::cout << prefix;
cudaMemcpy(
host_array.data(), device_ptr, size * sizeof(T), cudaMemcpyDeviceToHost);
for (size_t i = 0; i < size; ++i) {
std::cout << host_array[i] << " ";
}
std::cout << std::endl;
}
inline uint32_t FA2DetermineCtaTileQ(int64_t avg_packed_qo_len,
uint32_t head_dim) {
if (avg_packed_qo_len > 64 && head_dim < 256) {
return 128;
} else {
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first >= 8) {
// Ampere or newer
if (avg_packed_qo_len > 16) {
// avg_packed_qo_len <= 64
return 64;
} else {
// avg_packed_qo_len <= 16
return 16;
}
} else {
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
return 64;
}
}
}
inline int UpPowerOfTwo(int x) {
// Returns the smallest power of two greater than or equal to x
if (x <= 0) return 1;
--x;
x |= x >> 1;
x |= x >> 2;
x |= x >> 4;
x |= x >> 8;
x |= x >> 16;
return x + 1;
}
#define LOOP_SPLIT_MASK(iter, COND1, COND2, ...) \
{ \
_Pragma("unroll 1") for (; (COND1); (iter) -= 1) { \
constexpr bool WITH_MASK = true; \
__VA_ARGS__ \
} \
_Pragma("unroll 1") for (; (COND2); (iter) -= 1) { \
constexpr bool WITH_MASK = false; \
__VA_ARGS__ \
} \
}
/*!
* \brief Return x - y if x > y, otherwise return 0.
*/
__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x,
uint32_t y) {
return (x > y) ? x - y : 0U;
}
// ======================= PTX Memory Utility Functions =======================
// Non-atomic global memory access with cache streaming hint (cs)
// These are useful for streaming memory access patterns where data is used once
/*!
* \brief Get the lane ID within a warp (0-31)
*/
__forceinline__ __device__ int get_lane_id() {
int lane_id;
asm("mov.u32 %0, %%laneid;" : "=r"(lane_id));
return lane_id;
}
/*!
* \brief Non-atomic global load for int (4 bytes) with cache streaming hint
*/
__forceinline__ __device__ int ld_na_global_v1(const int* addr) {
int val;
asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr));
return val;
}
/*!
* \brief Non-atomic global load for int2 (8 bytes) with cache streaming hint
*/
__forceinline__ __device__ int2 ld_na_global_v2(const int2* addr) {
int2 val;
asm volatile("ld.global.cs.v2.b32 {%0, %1}, [%2];"
: "=r"(val.x), "=r"(val.y)
: "l"(addr));
return val;
}
/*!
* \brief Non-atomic global store for int (4 bytes) with cache streaming hint
*/
__forceinline__ __device__ void st_na_global_v1(int* addr, int val) {
asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val));
}
/*!
* \brief Non-atomic global store for int2 (8 bytes) with cache streaming hint
*/
__forceinline__ __device__ void st_na_global_v2(int2* addr, int2 val) {
asm volatile("st.global.cs.v2.b32 [%0], {%1, %2};" ::"l"(addr),
"r"(val.x),
"r"(val.y));
}
/*!
* \brief Prefetch data to L2 cache
*/
template <typename T>
__forceinline__ __device__ void prefetch_L2(const T* addr) {
asm volatile("prefetch.global.L2 [%0];" ::"l"(addr));
}
__device__ __forceinline__ void swap(uint32_t& a, uint32_t& b) {
uint32_t tmp = a;
a = b;
b = tmp;
}
__device__ __forceinline__ uint32_t dim2_offset(const uint32_t& dim_a,
const uint32_t& idx_b,
const uint32_t& idx_a) {
return idx_b * dim_a + idx_a;
}
__device__ __forceinline__ uint32_t dim3_offset(const uint32_t& dim_b,
const uint32_t& dim_a,
const uint32_t& idx_c,
const uint32_t& idx_b,
const uint32_t& idx_a) {
return (idx_c * dim_b + idx_b) * dim_a + idx_a;
}
__device__ __forceinline__ uint32_t dim4_offset(const uint32_t& dim_c,
const uint32_t& dim_b,
const uint32_t& dim_a,
const uint32_t& idx_d,
const uint32_t& idx_c,
const uint32_t& idx_b,
const uint32_t& idx_a) {
return ((idx_d * dim_c + idx_c) * dim_b + idx_b) * dim_a + idx_a;
}
#define DEFINE_HAS_MEMBER(member) \
template <typename T, typename = void> \
struct has_##member : std::false_type {}; \
template <typename T> \
struct has_##member<T, std::void_t<decltype(std::declval<T>().member)>> \
: std::true_type {}; \
template <typename T> \
inline constexpr bool has_##member##_v = has_##member<T>::value;
} // namespace flashinfer
#endif // FLASHINFER_UTILS_CUH_