mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
@@ -27,7 +27,7 @@ struct AppendAttnMetaData {
|
||||
int head_dims;
|
||||
int head_dims_v;
|
||||
int max_blocks_per_seq;
|
||||
const int *mask_offset = nullptr;
|
||||
const int* mask_offset = nullptr;
|
||||
};
|
||||
|
||||
__forceinline__ __host__ __device__ int div_up(int a, int b) {
|
||||
@@ -110,29 +110,33 @@ __device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x,
|
||||
|
||||
/******************************FASTER CAST*********************************/
|
||||
|
||||
inline __device__ static void convert_fp8(__nv_bfloat16* result, const uint32_t& source) {
|
||||
|
||||
inline __device__ static void convert_fp8(__nv_bfloat16* result,
|
||||
const uint32_t& source) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
|
||||
uint32_t dest0;
|
||||
uint32_t dest1;
|
||||
asm volatile( \
|
||||
"{\n" \
|
||||
".reg .b16 lo, hi;\n" \
|
||||
"mov.b32 {lo, hi}, %2;\n" \
|
||||
"cvt.rn.f16x2.e4m3x2 %0, lo;\n" \
|
||||
"cvt.rn.f16x2.e4m3x2 %1, hi;\n" \
|
||||
"}\n" : "=r"(dest0), "=r"(dest1) : "r"(source));
|
||||
uint32_t dest0;
|
||||
uint32_t dest1;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b16 lo, hi;\n"
|
||||
"mov.b32 {lo, hi}, %2;\n"
|
||||
"cvt.rn.f16x2.e4m3x2 %0, lo;\n"
|
||||
"cvt.rn.f16x2.e4m3x2 %1, hi;\n"
|
||||
"}\n"
|
||||
: "=r"(dest0), "=r"(dest1)
|
||||
: "r"(source));
|
||||
|
||||
((nv_bfloat162*)(result))[0] = __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0]));
|
||||
((nv_bfloat162*)(result))[1] = __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0]));
|
||||
((nv_bfloat162*)(result))[0] =
|
||||
__float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0]));
|
||||
((nv_bfloat162*)(result))[1] =
|
||||
__float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0]));
|
||||
#else
|
||||
printf("Do not support fp8 in arch < 890\n");
|
||||
asm("trap;");
|
||||
printf("Do not support fp8 in arch < 890\n");
|
||||
asm("trap;");
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
inline __device__ static void convert_fp8(half* result, const uint32_t& source) {
|
||||
inline __device__ static void convert_fp8(half* result,
|
||||
const uint32_t& source) {
|
||||
printf("Do not support fp8 to half although it's very easy.\n");
|
||||
}
|
||||
|
||||
@@ -301,8 +305,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
|
||||
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 64: { \
|
||||
constexpr size_t HEAD_DIM = 64; \
|
||||
case 64: { \
|
||||
constexpr size_t HEAD_DIM = 64; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
@@ -385,9 +389,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the cache_type: ", cache_type); \
|
||||
}
|
||||
|
||||
|
||||
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
|
||||
if (deal_each_time == 32) { \
|
||||
if (deal_each_time == 32) { \
|
||||
constexpr size_t DEAL_EACH_TIME = 32; \
|
||||
__VA_ARGS__ \
|
||||
} else if (deal_each_time == 64) { \
|
||||
@@ -404,7 +407,7 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} else if (num_threads == 256) { \
|
||||
constexpr size_t NUM_THREADS = 256; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
} else { \
|
||||
PD_THROW("not support the num_threads", num_threads); \
|
||||
}
|
||||
|
||||
@@ -456,7 +459,7 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
@@ -538,9 +541,11 @@ inline HOSTDEVICE T roundWithTiesToEven(T x) {
|
||||
: xUpper);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, bool is_need_kv_quant, bool IsFP8, int RoundType = 0>
|
||||
__host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale, const T value, const float max_bound, const float min_bound) {
|
||||
__host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale,
|
||||
const T value,
|
||||
const float max_bound,
|
||||
const float min_bound) {
|
||||
uint8_t eight_bits;
|
||||
float quant_value;
|
||||
if constexpr (is_need_kv_quant) {
|
||||
@@ -572,8 +577,8 @@ __host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale, const T val
|
||||
return eight_bits;
|
||||
}
|
||||
|
||||
|
||||
template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * result, const uint32_t& source){
|
||||
template <typename T, bool IsFP8>
|
||||
inline __device__ static void convert_c8(T* result, const uint32_t& source) {
|
||||
if constexpr (IsFP8) {
|
||||
convert_fp8(result, source);
|
||||
} else {
|
||||
@@ -583,12 +588,12 @@ template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * re
|
||||
|
||||
constexpr int kWarpSize = 32;
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
|
||||
*m2 += b_m2;
|
||||
}
|
||||
|
||||
template<typename T, int thread_group_width = kWarpSize>
|
||||
template <typename T, int thread_group_width = kWarpSize>
|
||||
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
|
||||
*m2 = thread_m2;
|
||||
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
|
||||
@@ -597,7 +602,7 @@ __inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int thread_group_width = kWarpSize>
|
||||
template <typename T, int thread_group_width = kWarpSize>
|
||||
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
|
||||
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user