Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
+36 -31
View File
@@ -27,7 +27,7 @@ struct AppendAttnMetaData {
int head_dims;
int head_dims_v;
int max_blocks_per_seq;
const int *mask_offset = nullptr;
const int* mask_offset = nullptr;
};
__forceinline__ __host__ __device__ int div_up(int a, int b) {
@@ -110,29 +110,33 @@ __device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x,
/******************************FASTER CAST*********************************/
inline __device__ static void convert_fp8(__nv_bfloat16* result, const uint32_t& source) {
inline __device__ static void convert_fp8(__nv_bfloat16* result,
const uint32_t& source) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
uint32_t dest0;
uint32_t dest1;
asm volatile( \
"{\n" \
".reg .b16 lo, hi;\n" \
"mov.b32 {lo, hi}, %2;\n" \
"cvt.rn.f16x2.e4m3x2 %0, lo;\n" \
"cvt.rn.f16x2.e4m3x2 %1, hi;\n" \
"}\n" : "=r"(dest0), "=r"(dest1) : "r"(source));
uint32_t dest0;
uint32_t dest1;
asm volatile(
"{\n"
".reg .b16 lo, hi;\n"
"mov.b32 {lo, hi}, %2;\n"
"cvt.rn.f16x2.e4m3x2 %0, lo;\n"
"cvt.rn.f16x2.e4m3x2 %1, hi;\n"
"}\n"
: "=r"(dest0), "=r"(dest1)
: "r"(source));
((nv_bfloat162*)(result))[0] = __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0]));
((nv_bfloat162*)(result))[1] = __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0]));
((nv_bfloat162*)(result))[0] =
__float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0]));
((nv_bfloat162*)(result))[1] =
__float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0]));
#else
printf("Do not support fp8 in arch < 890\n");
asm("trap;");
printf("Do not support fp8 in arch < 890\n");
asm("trap;");
#endif
}
inline __device__ static void convert_fp8(half* result, const uint32_t& source) {
inline __device__ static void convert_fp8(half* result,
const uint32_t& source) {
printf("Do not support fp8 to half although it's very easy.\n");
}
@@ -301,8 +305,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 64: { \
constexpr size_t HEAD_DIM = 64; \
case 64: { \
constexpr size_t HEAD_DIM = 64; \
__VA_ARGS__ \
break; \
} \
@@ -385,9 +389,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
PD_THROW("not support the cache_type: ", cache_type); \
}
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
if (deal_each_time == 32) { \
if (deal_each_time == 32) { \
constexpr size_t DEAL_EACH_TIME = 32; \
__VA_ARGS__ \
} else if (deal_each_time == 64) { \
@@ -404,7 +407,7 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (num_threads == 256) { \
constexpr size_t NUM_THREADS = 256; \
__VA_ARGS__ \
} else { \
} else { \
PD_THROW("not support the num_threads", num_threads); \
}
@@ -456,7 +459,7 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
}
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 16) { \
@@ -538,9 +541,11 @@ inline HOSTDEVICE T roundWithTiesToEven(T x) {
: xUpper);
}
template <typename T, bool is_need_kv_quant, bool IsFP8, int RoundType = 0>
__host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale, const T value, const float max_bound, const float min_bound) {
__host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale,
const T value,
const float max_bound,
const float min_bound) {
uint8_t eight_bits;
float quant_value;
if constexpr (is_need_kv_quant) {
@@ -572,8 +577,8 @@ __host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale, const T val
return eight_bits;
}
template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * result, const uint32_t& source){
template <typename T, bool IsFP8>
inline __device__ static void convert_c8(T* result, const uint32_t& source) {
if constexpr (IsFP8) {
convert_fp8(result, source);
} else {
@@ -583,12 +588,12 @@ template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * re
constexpr int kWarpSize = 32;
template<typename T>
template <typename T>
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
*m2 += b_m2;
}
template<typename T, int thread_group_width = kWarpSize>
template <typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
*m2 = thread_m2;
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
@@ -597,7 +602,7 @@ __inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
}
}
template<typename T, int thread_group_width = kWarpSize>
template <typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
}