mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] RoPE without fmad opt (#6901)
* env FD_ENABLE_RL=1 do fmul_rn(a*b) in rope
This commit is contained in:
@@ -108,6 +108,7 @@ void AppendAttentionKernel(
|
||||
static cudaEvent_t decoder_event;
|
||||
static cudaStream_t decoder_stream;
|
||||
static bool init_flag = false;
|
||||
bool enforce_fmul_rn = getEnvEnableRL();
|
||||
if (max_just_dec_len_this_time > 0 && max_enc_len_this_time > 0 &&
|
||||
!init_flag) {
|
||||
cudaEventCreateWithFlags(&main_event, cudaEventDisableTiming);
|
||||
@@ -185,37 +186,41 @@ void AppendAttentionKernel(
|
||||
|
||||
auto dispatch_EncoderWriteCacheWithRopeKernel =
|
||||
[&](auto temp_args) -> void {
|
||||
EncoderWriteCacheWithRopeKernel<data_t, decltype(temp_args)>(
|
||||
meta_data,
|
||||
qkv,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
kv_signal_data,
|
||||
cache_quant_type_str,
|
||||
kv_num_blocks_data,
|
||||
max_input_length,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
main_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, {
|
||||
EncoderWriteCacheWithRopeKernel<data_t,
|
||||
decltype(temp_args),
|
||||
EnforceFmulRN>(
|
||||
meta_data,
|
||||
qkv,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
kv_signal_data,
|
||||
cache_quant_type_str,
|
||||
kv_num_blocks_data,
|
||||
max_input_length,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
main_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
})
|
||||
};
|
||||
|
||||
if (qkv_out_scales) {
|
||||
@@ -284,117 +289,119 @@ void AppendAttentionKernel(
|
||||
} else {
|
||||
exec_stream = main_stream;
|
||||
}
|
||||
if (speculate_decoder) {
|
||||
if (qkv_out_scales) {
|
||||
SpeculateWriteCacheWithRoPEKernel<data_t, int>(
|
||||
meta_data,
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, {
|
||||
if (speculate_decoder) {
|
||||
if (qkv_out_scales) {
|
||||
SpeculateWriteCacheWithRoPEKernel<data_t, int, EnforceFmulRN>(
|
||||
meta_data,
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
SpeculateWriteCacheWithRoPEKernel<data_t, data_t, EnforceFmulRN>(
|
||||
meta_data,
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
} else {
|
||||
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
meta_data,
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
if (qkv_out_scales) {
|
||||
DecoderWriteCacheWithRoPEKernel<data_t, int, EnforceFmulRN>(
|
||||
meta_data,
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
DecoderWriteCacheWithRoPEKernel<data_t, data_t, EnforceFmulRN>(
|
||||
meta_data,
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
DecoderWriteCacheWithRoPEKernel<data_t, int>(
|
||||
meta_data,
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
meta_data,
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
switch (fmha_out.dtype()) {
|
||||
|
||||
@@ -22,7 +22,11 @@
|
||||
// This function is very easy!
|
||||
// just make HeadDim data to be new HeadDim data!
|
||||
|
||||
template <typename T, int VecSize = 8, int HEAD_DIM = 128, int NUM_THREADS = 32>
|
||||
template <typename T,
|
||||
int VecSize = 8,
|
||||
int HEAD_DIM = 128,
|
||||
int NUM_THREADS = 32,
|
||||
bool EnforceFmulRN = false>
|
||||
__device__ __forceinline__ void apply_rope(const T* input,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
@@ -54,15 +58,17 @@ __device__ __forceinline__ void apply_rope(const T* input,
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &output[head_bias]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -154,8 +160,10 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
@@ -206,7 +214,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void append_decode_cache_T_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -289,9 +297,11 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
out_vec[2 * i] = src_vec[2 * i];
|
||||
out_vec[2 * i + 1] = src_vec[2 * i + 1];
|
||||
@@ -319,7 +329,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void append_decode_cache_T_quant_rope_kernel(
|
||||
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -417,9 +427,11 @@ __global__ void append_decode_cache_T_quant_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
bias_vec[2 * i] = static_cast<T>(input_left);
|
||||
bias_vec[2 * i + 1] = static_cast<T>(input_right);
|
||||
@@ -447,7 +459,7 @@ __global__ void append_decode_cache_T_quant_rope_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -551,9 +563,11 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
@@ -591,7 +605,7 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -676,9 +690,11 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
@@ -710,7 +726,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void append_decode_cache_T_quant_neox_rope_kernel(
|
||||
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -815,9 +831,11 @@ __global__ void append_decode_cache_T_quant_neox_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
@@ -855,7 +873,8 @@ template <typename T,
|
||||
int HeadDim = 128,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = true,
|
||||
bool IsDynamic = true>
|
||||
bool IsDynamic = true,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_decode_cache_T_int8_neox_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -941,8 +960,10 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel(
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec[i] = static_cast<T>(tmp1);
|
||||
out_vec_right[i] = static_cast<T>(tmp2);
|
||||
@@ -1032,9 +1053,11 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel(
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
float tmp1 = 0;
|
||||
if (head_bias < half_head_size) {
|
||||
tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
} else {
|
||||
tmp1 = input_left * cos_tmp + input_right * sin_tmp;
|
||||
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
}
|
||||
out_vec1[0] = static_cast<T>(tmp1);
|
||||
input_left = static_cast<float>(src_vec1[1]);
|
||||
@@ -1042,9 +1065,11 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel(
|
||||
cos_tmp = cos_emb_vec1[1];
|
||||
sin_tmp = sin_emb_vec1[1];
|
||||
if (head_bias < half_head_size) {
|
||||
tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
} else {
|
||||
tmp1 = input_left * cos_tmp + input_right * sin_tmp;
|
||||
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
}
|
||||
out_vec1[1] = static_cast<T>(tmp1);
|
||||
} else {
|
||||
@@ -1060,9 +1085,11 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel(
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
float tmp1 = 0;
|
||||
if (head_bias < half_head_size) {
|
||||
tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
} else {
|
||||
tmp1 = input_left * cos_tmp + input_right * sin_tmp;
|
||||
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
}
|
||||
out_vec2[0] = static_cast<T>(tmp1);
|
||||
input_left = static_cast<float>(src_vec2[1]);
|
||||
@@ -1070,9 +1097,11 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel(
|
||||
cos_tmp = cos_emb_vec2[1];
|
||||
sin_tmp = sin_emb_vec2[1];
|
||||
if (head_bias < half_head_size) {
|
||||
tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
} else {
|
||||
tmp1 = input_left * cos_tmp + input_right * sin_tmp;
|
||||
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
}
|
||||
out_vec2[1] = static_cast<T>(tmp1);
|
||||
} else {
|
||||
@@ -1164,7 +1193,8 @@ template <typename T,
|
||||
int HeadDim = 128,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = true,
|
||||
bool IsDynamic = true>
|
||||
bool IsDynamic = true,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -1250,8 +1280,10 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec[2 * i] = static_cast<T>(tmp1);
|
||||
out_vec[2 * i + 1] = static_cast<T>(tmp2);
|
||||
@@ -1344,8 +1376,10 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec1[0] = static_cast<T>(tmp1);
|
||||
out_vec1[1] = static_cast<T>(tmp2);
|
||||
@@ -1360,8 +1394,10 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec2[0] = static_cast<T>(tmp1);
|
||||
out_vec2[1] = static_cast<T>(tmp2);
|
||||
@@ -1472,7 +1508,8 @@ template <typename T,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false>
|
||||
bool IsFP8 = false,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_decode_cache_int8_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -1528,11 +1565,11 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
|
||||
uint32_t emb_offset = write_seq_id * half_head_size;
|
||||
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
|
||||
apply_rope<T, VecSize, HeadDim, 32>(qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
apply_rope<T, VecSize, HeadDim, 32, EnforceFmulRN>(qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
@@ -1619,9 +1656,11 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
out_vec1[0] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
out_vec1[1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
out_vec1[0] = src_vec1[0];
|
||||
out_vec1[1] = src_vec1[1];
|
||||
@@ -1631,10 +1670,12 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
out_vec1[0] =
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
|
||||
static_cast<T>((fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp)) *
|
||||
float(cache_k_scale_cur[0]));
|
||||
out_vec1[1] =
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
|
||||
static_cast<T>((fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp)) *
|
||||
float(cache_k_scale_cur[1]));
|
||||
} else {
|
||||
out_vec1[0] = static_cast<T>(input_left * float(cache_v_scale_cur[0]));
|
||||
@@ -1649,9 +1690,11 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
out_vec2[0] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
out_vec2[1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
out_vec2[0] = src_vec2[0];
|
||||
out_vec2[1] = src_vec2[1];
|
||||
@@ -1661,10 +1704,12 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
out_vec2[0] =
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
|
||||
static_cast<T>((fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp)) *
|
||||
float(cache_k_scale_cur[8]));
|
||||
out_vec2[1] =
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
|
||||
static_cast<T>((fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp)) *
|
||||
float(cache_k_scale_cur[9]));
|
||||
} else {
|
||||
out_vec2[0] = static_cast<T>(input_left * float(cache_v_scale_cur[8]));
|
||||
@@ -1715,7 +1760,8 @@ template <typename T,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false>
|
||||
bool IsFP8 = false,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void int_append_decode_cache_int8_rope_kernel(
|
||||
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -1815,9 +1861,11 @@ __global__ void int_append_decode_cache_int8_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(bias_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
@@ -1922,9 +1970,11 @@ __global__ void int_append_decode_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
bias_vec1[0] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec1[1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
bias_vec1[0] = static_cast<T>(input_left);
|
||||
bias_vec1[1] = static_cast<T>(input_right);
|
||||
@@ -1934,10 +1984,12 @@ __global__ void int_append_decode_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
bias_vec1[0] =
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
|
||||
static_cast<T>((fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp)) *
|
||||
float(cache_k_scale_cur[0]));
|
||||
bias_vec1[1] =
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
|
||||
static_cast<T>((fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp)) *
|
||||
float(cache_k_scale_cur[1]));
|
||||
} else {
|
||||
bias_vec1[0] = static_cast<T>(input_left * float(cache_v_scale_cur[0]));
|
||||
@@ -1959,9 +2011,11 @@ __global__ void int_append_decode_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
bias_vec2[0] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec2[1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
bias_vec2[0] = static_cast<T>(input_left);
|
||||
bias_vec2[1] = static_cast<T>(input_right);
|
||||
@@ -1971,10 +2025,12 @@ __global__ void int_append_decode_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
bias_vec2[0] =
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
|
||||
static_cast<T>((fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp)) *
|
||||
float(cache_k_scale_cur[8]));
|
||||
bias_vec2[1] =
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
|
||||
static_cast<T>((fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp)) *
|
||||
float(cache_k_scale_cur[9]));
|
||||
} else {
|
||||
bias_vec2[0] = static_cast<T>(input_left * float(cache_v_scale_cur[8]));
|
||||
@@ -2033,7 +2089,11 @@ __global__ void int_append_decode_cache_int8_rope_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
|
||||
template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -2120,9 +2180,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(left_bias_vec, &qkv_out_now[bias_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &qkv_out_now[bias_idx_right]);
|
||||
@@ -2206,18 +2268,22 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[i];
|
||||
float sin_tmp = sin_emb_vec1[i];
|
||||
left_bias_vec1[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec1[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
|
||||
input_left = static_cast<float>(left_src_vec2[i]);
|
||||
input_right = static_cast<float>(right_src_vec2[i]);
|
||||
cos_tmp = cos_emb_vec2[i];
|
||||
sin_tmp = sin_emb_vec2[i];
|
||||
left_bias_vec2[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec2[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
|
||||
float quant_value1 = static_cast<float>(scale * left_bias_vec1[i]);
|
||||
float quant_value2 = static_cast<float>(scale * left_bias_vec2[i]);
|
||||
@@ -2341,7 +2407,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
|
||||
template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void int_append_decode_cache_int8_neox_rope_kernel(
|
||||
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -2451,9 +2521,11 @@ __global__ void int_append_decode_cache_int8_neox_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(left_bias_vec, &qkv_out_now[bias_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &qkv_out_now[bias_idx_right]);
|
||||
@@ -2564,9 +2636,11 @@ __global__ void int_append_decode_cache_int8_neox_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[i];
|
||||
float sin_tmp = sin_emb_vec1[i];
|
||||
left_bias_vec1[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec1[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
|
||||
input_left = static_cast<float>(left_src_vec2[i]);
|
||||
input_right = static_cast<float>(right_src_vec2[i]);
|
||||
@@ -2579,9 +2653,11 @@ __global__ void int_append_decode_cache_int8_neox_rope_kernel(
|
||||
cos_tmp = cos_emb_vec2[i];
|
||||
sin_tmp = sin_emb_vec2[i];
|
||||
left_bias_vec2[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec2[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
|
||||
float quant_value1 = static_cast<float>(scale * left_bias_vec1[i]);
|
||||
float quant_value2 = static_cast<float>(scale * left_bias_vec2[i]);
|
||||
@@ -2743,7 +2819,11 @@ __global__ void int_append_decode_cache_int8_neox_rope_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
|
||||
template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_decode_cache_int4_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -2803,11 +2883,11 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
|
||||
uint32_t emb_offset = write_seq_id * half_head_size;
|
||||
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
|
||||
apply_rope<T, VecSize, HeadDim, 32>(qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
apply_rope<T, VecSize, HeadDim, 32, EnforceFmulRN>(qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
@@ -2890,9 +2970,11 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
out_vec1[0] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
out_vec1[1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
out_vec1[0] = src_vec1[0];
|
||||
out_vec1[1] = src_vec1[1];
|
||||
@@ -2904,9 +2986,11 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
out_vec2[0] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
out_vec2[1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
out_vec2[0] = src_vec2[0];
|
||||
out_vec2[1] = src_vec2[1];
|
||||
@@ -3022,7 +3106,11 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
|
||||
template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void int_append_decode_cache_int4_rope_kernel(
|
||||
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -3122,9 +3210,11 @@ __global__ void int_append_decode_cache_int4_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(bias_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
@@ -3223,9 +3313,11 @@ __global__ void int_append_decode_cache_int4_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
bias_vec1[0] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec1[1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
bias_vec1[0] = static_cast<T>(input_left);
|
||||
bias_vec1[1] = static_cast<T>(input_right);
|
||||
@@ -3243,9 +3335,11 @@ __global__ void int_append_decode_cache_int4_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
bias_vec2[0] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec2[1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
bias_vec2[0] = static_cast<T>(input_left);
|
||||
bias_vec2[1] = static_cast<T>(input_right);
|
||||
@@ -3360,7 +3454,11 @@ __global__ void int_append_decode_cache_int4_rope_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
|
||||
template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -3449,9 +3547,11 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_out_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_out_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(left_out_vec, &qkv_out_now[bias_idx_left]);
|
||||
Store<T, VecSize>(right_out_vec, &qkv_out_now[bias_idx_right]);
|
||||
@@ -3549,18 +3649,22 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
left_out_vec1[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_out_vec1[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
|
||||
input_left = static_cast<float>(left_src_vec2[i]);
|
||||
input_right = static_cast<float>(right_src_vec2[i]);
|
||||
cos_tmp = cos_emb_vec2[i];
|
||||
sin_tmp = sin_emb_vec2[i];
|
||||
left_out_vec2[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_out_vec2[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
// quant + write k
|
||||
}
|
||||
LoadKVResT left_cache_vec, right_cache_vec;
|
||||
@@ -3739,7 +3843,11 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
|
||||
template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void int_append_decode_cache_int4_neox_rope_kernel(
|
||||
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -3848,9 +3956,11 @@ __global__ void int_append_decode_cache_int4_neox_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(left_bias_vec, &qkv_out_now[bias_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &qkv_out_now[bias_idx_right]);
|
||||
@@ -3977,18 +4087,22 @@ __global__ void int_append_decode_cache_int4_neox_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
left_bias_vec1[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec1[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
|
||||
input_left = static_cast<float>(left_src_vec2[i]);
|
||||
input_right = static_cast<float>(right_src_vec2[i]);
|
||||
cos_tmp = cos_emb_vec2[i];
|
||||
sin_tmp = sin_emb_vec2[i];
|
||||
left_bias_vec2[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec2[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
// quant + write k
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "decoder_write_cache_with_rope_kernel.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
|
||||
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
@@ -53,7 +53,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||
launchWithPdlWhenEnabled(
|
||||
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>,
|
||||
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
block_dim,
|
||||
0,
|
||||
@@ -81,7 +81,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
rms_norm_eps);
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
|
||||
void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
@@ -117,7 +117,9 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
if (use_neox_style) {
|
||||
if (qkv_out_scales) {
|
||||
launchWithPdlWhenEnabled(
|
||||
append_decode_cache_T_quant_neox_rope_kernel<T, PackSize>,
|
||||
append_decode_cache_T_quant_neox_rope_kernel<T,
|
||||
PackSize,
|
||||
EnforceFmulRN>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
@@ -145,7 +147,9 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
} else {
|
||||
if (rotary_dim < dim_head) {
|
||||
auto* kernelFn =
|
||||
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>;
|
||||
append_decode_cache_T_neox_partial_rope_kernel<T,
|
||||
PackSize,
|
||||
EnforceFmulRN>;
|
||||
launchWithPdlWhenEnabled(kernelFn,
|
||||
grid_size,
|
||||
blocksize,
|
||||
@@ -171,7 +175,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
auto* kernelFn = append_decode_cache_T_neox_rope_kernel<T, PackSize>;
|
||||
auto* kernelFn =
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize, EnforceFmulRN>;
|
||||
launchWithPdlWhenEnabled(kernelFn,
|
||||
grid_size,
|
||||
blocksize,
|
||||
@@ -200,7 +205,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
launchWithPdlWhenEnabled(
|
||||
append_decode_cache_T_quant_rope_kernel<T, PackSize>,
|
||||
append_decode_cache_T_quant_rope_kernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
@@ -226,7 +231,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
auto* kernelFn = append_decode_cache_T_rope_kernel<T, PackSize>;
|
||||
auto* kernelFn =
|
||||
append_decode_cache_T_rope_kernel<T, PackSize, EnforceFmulRN>;
|
||||
launchWithPdlWhenEnabled(kernelFn,
|
||||
grid_size,
|
||||
blocksize,
|
||||
@@ -257,7 +263,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
template <typename T,
|
||||
typename QKV_TYPE,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false>
|
||||
bool IsFP8 = false,
|
||||
bool EnforceFmulRN = false>
|
||||
void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* key_cache,
|
||||
uint8_t* value_cache,
|
||||
@@ -289,7 +296,11 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
if (use_neox_style) {
|
||||
if (qkv_out_scales) {
|
||||
launchWithPdlWhenEnabled(
|
||||
int_append_decode_cache_int8_neox_rope_kernel<T, 4>,
|
||||
int_append_decode_cache_int8_neox_rope_kernel<T,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
EnforceFmulRN>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
@@ -317,31 +328,36 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
launchWithPdlWhenEnabled(append_decode_cache_int8_neox_rope_kernel<T, 4>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
launchWithPdlWhenEnabled(
|
||||
append_decode_cache_int8_neox_rope_kernel<T,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
EnforceFmulRN>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -351,7 +367,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
0,
|
||||
128,
|
||||
is_scale_channel_wise,
|
||||
IsFP8>,
|
||||
IsFP8,
|
||||
EnforceFmulRN>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
@@ -385,7 +402,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
0,
|
||||
128,
|
||||
is_scale_channel_wise,
|
||||
IsFP8>,
|
||||
IsFP8,
|
||||
EnforceFmulRN>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
@@ -414,7 +432,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
|
||||
void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* key_cache,
|
||||
uint8_t* value_cache,
|
||||
@@ -448,7 +466,11 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
if (use_neox_style) {
|
||||
if (qkv_out_scales) {
|
||||
launchWithPdlWhenEnabled(
|
||||
int_append_decode_cache_int4_neox_rope_kernel<T, 4>,
|
||||
int_append_decode_cache_int4_neox_rope_kernel<T,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
EnforceFmulRN>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
@@ -478,97 +500,104 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
launchWithPdlWhenEnabled(append_decode_cache_int4_neox_rope_kernel<T, 4>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
launchWithPdlWhenEnabled(
|
||||
append_decode_cache_int4_neox_rope_kernel<T,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
EnforceFmulRN>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
launchWithPdlWhenEnabled(int_append_decode_cache_int4_rope_kernel<T, 4>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const int*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales,
|
||||
qkv_biases,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
launchWithPdlWhenEnabled(
|
||||
int_append_decode_cache_int4_rope_kernel<T, 4, 0, 128, EnforceFmulRN>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const int*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales,
|
||||
qkv_biases,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
launchWithPdlWhenEnabled(append_decode_cache_int4_rope_kernel<T, 4>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
launchWithPdlWhenEnabled(
|
||||
append_decode_cache_int4_rope_kernel<T, 4, 0, 128, EnforceFmulRN>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN>
|
||||
void DecoderWriteCacheWithRoPEKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& qkv,
|
||||
@@ -632,7 +661,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_decode_cache_rope_qk_norm(
|
||||
append_decode_cache_rope_qk_norm<DataType_, QKV_TYPE, EnforceFmulRN>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
@@ -672,7 +701,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
128,
|
||||
false,
|
||||
true,
|
||||
true>,
|
||||
true,
|
||||
EnforceFmulRN>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
@@ -714,7 +744,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
128,
|
||||
false,
|
||||
true,
|
||||
false>,
|
||||
false,
|
||||
EnforceFmulRN>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
@@ -751,7 +782,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
}
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_decode_cache_rope(
|
||||
append_decode_cache_rope<DataType_, QKV_TYPE, EnforceFmulRN>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
@@ -784,7 +815,11 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
append_decode_cache_int8_rope<DataType_,
|
||||
QKV_TYPE,
|
||||
true,
|
||||
false,
|
||||
EnforceFmulRN>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
@@ -816,7 +851,11 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
|
||||
append_decode_cache_int8_rope<DataType_,
|
||||
QKV_TYPE,
|
||||
false,
|
||||
false,
|
||||
EnforceFmulRN>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
@@ -849,7 +888,11 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
rope_3d);
|
||||
}
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
|
||||
append_decode_cache_int8_rope<DataType_,
|
||||
QKV_TYPE,
|
||||
false,
|
||||
true,
|
||||
EnforceFmulRN>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
@@ -892,7 +935,9 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
0,
|
||||
128,
|
||||
false,
|
||||
true>,
|
||||
true,
|
||||
true,
|
||||
EnforceFmulRN>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
@@ -927,7 +972,9 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
0,
|
||||
128,
|
||||
false,
|
||||
true>,
|
||||
true,
|
||||
true,
|
||||
EnforceFmulRN>,
|
||||
grids,
|
||||
num_warps * 32,
|
||||
0,
|
||||
@@ -959,7 +1006,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
rms_norm_eps);
|
||||
}
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_decode_cache_int4_rope(
|
||||
append_decode_cache_int4_rope<DataType_, QKV_TYPE, EnforceFmulRN>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
#include "decoder_write_cache_with_rope_impl.cuh"
|
||||
|
||||
template <typename T, typename QKV_TYPE = int>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
|
||||
void DecoderWriteCacheWithRoPEKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor&
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void IntVariableLengthRotaryKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
@@ -97,9 +97,11 @@ __global__ void IntVariableLengthRotaryKernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
bias_vec[2 * i] = static_cast<T>(input_left);
|
||||
bias_vec[2 * i + 1] = static_cast<T>(input_right);
|
||||
@@ -112,7 +114,7 @@ __global__ void IntVariableLengthRotaryKernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void VariableLengthRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
@@ -171,9 +173,11 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
src_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
src_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
|
||||
}
|
||||
@@ -182,7 +186,7 @@ __global__ void VariableLengthRotaryKernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void IntNeoxVariableLengthRotaryKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
@@ -271,9 +275,11 @@ __global__ void IntNeoxVariableLengthRotaryKernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
@@ -287,7 +293,7 @@ __global__ void IntNeoxVariableLengthRotaryKernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void NeoxVariableLengthRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
@@ -352,9 +358,11 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
|
||||
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
|
||||
@@ -364,7 +372,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void IntGQAVariableLengthRotaryKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
@@ -442,9 +450,11 @@ __global__ void IntGQAVariableLengthRotaryKernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
bias_vec[2 * i] = static_cast<T>(input_left);
|
||||
bias_vec[2 * i + 1] = static_cast<T>(input_right);
|
||||
@@ -457,7 +467,7 @@ __global__ void IntGQAVariableLengthRotaryKernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void GQAVariableLengthRotaryQKNormKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
@@ -525,8 +535,10 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
||||
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
@@ -554,7 +566,7 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void GQAVariableLengthRotaryKernel(const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
@@ -614,9 +626,11 @@ __global__ void GQAVariableLengthRotaryKernel(const T *qkv,
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
src_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
src_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
|
||||
}
|
||||
@@ -625,7 +639,7 @@ __global__ void GQAVariableLengthRotaryKernel(const T *qkv,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void IntGQAVariableLengthRotaryQuantKVKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
@@ -703,19 +717,23 @@ __global__ void IntGQAVariableLengthRotaryQuantKVKernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else if (hi < q_num_head + kv_num_head) {
|
||||
int k_hi = hi - q_num_head;
|
||||
const int scale_idx = k_hi * last_dim + h_bias;
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
|
||||
static_cast<T>((fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp)) *
|
||||
float(cache_k_scales[scale_idx + 2 * i]));
|
||||
bias_vec[2 * i + 1] =
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
|
||||
static_cast<T>((fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp)) *
|
||||
float(cache_k_scales[scale_idx + 2 * i + 1]));
|
||||
} else {
|
||||
int v_hi = hi - q_num_head - kv_num_head;
|
||||
@@ -733,7 +751,7 @@ __global__ void IntGQAVariableLengthRotaryQuantKVKernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void GQAVariableLengthRotaryQuantKVKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
@@ -803,26 +821,31 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(
|
||||
: static_cast<float>(src_vec[2 * i + 1]);
|
||||
// const float cos_tmp = cos_emb_vec[i];
|
||||
// const float sin_tmp = sin_emb_vec[i];
|
||||
// src_vec[2 * i] = static_cast<T>(input_left * cos_tmp - input_right *
|
||||
// sin_tmp); src_vec[2 * i + 1] = static_cast<T>(input_right * cos_tmp +
|
||||
// input_left * sin_tmp);
|
||||
// src_vec[2 * i] = static_cast<T>(fmul_func<EnforceFmulRN>(input_left,
|
||||
// cos_tmp) - input_right * sin_tmp); src_vec[2 * i + 1] =
|
||||
// static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
// fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
if (hi < q_num_head) { // qk rope
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
src_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
src_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else if (hi < q_num_head + kv_num_head) {
|
||||
int k_hi = hi - q_num_head;
|
||||
const int scale_idx = k_hi * last_dim + h_bias;
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
src_vec[2 * i] =
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
|
||||
static_cast<T>((fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp)) *
|
||||
float(cache_k_scales[scale_idx + 2 * i]));
|
||||
src_vec[2 * i + 1] =
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
|
||||
static_cast<T>((fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp)) *
|
||||
float(cache_k_scales[scale_idx + 2 * i + 1]));
|
||||
} else {
|
||||
int v_hi = hi - q_num_head - kv_num_head;
|
||||
@@ -840,7 +863,7 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void IntGQANeoxVariableLengthRotaryKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
@@ -926,9 +949,11 @@ __global__ void IntGQANeoxVariableLengthRotaryKernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
@@ -942,7 +967,7 @@ __global__ void IntGQANeoxVariableLengthRotaryKernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void GQANeoxVariableLengthRotaryKernel(const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
@@ -1005,9 +1030,11 @@ __global__ void GQANeoxVariableLengthRotaryKernel(const T *qkv,
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
|
||||
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
|
||||
@@ -1017,7 +1044,7 @@ __global__ void GQANeoxVariableLengthRotaryKernel(const T *qkv,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
@@ -1082,9 +1109,11 @@ __global__ void GQANeoxVariableLengthPartialRotaryKernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
|
||||
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
|
||||
@@ -1094,7 +1123,7 @@ __global__ void GQANeoxVariableLengthPartialRotaryKernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void cache_kernel(
|
||||
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -2199,7 +2228,7 @@ __global__ void append_write_cache_kv_c4_qkv(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
|
||||
void rotary_qk_variable(
|
||||
T *qkv_out, // [token_num, 3, num_head, dim_head]
|
||||
const QKV_TYPE *qkv_input, // qkv
|
||||
@@ -2233,94 +2262,98 @@ void rotary_qk_variable(
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
|
||||
if (qkv_out_scales) {
|
||||
launchWithPdlWhenEnabled(IntVariableLengthRotaryKernel<T, PackSize>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
head_num,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
launchWithPdlWhenEnabled(
|
||||
IntVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
head_num,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
} else {
|
||||
launchWithPdlWhenEnabled(VariableLengthRotaryKernel<T, PackSize>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
head_num,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
launchWithPdlWhenEnabled(
|
||||
VariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
head_num,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
} else {
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + input_output_len * dim_head;
|
||||
if (qkv_out_scales) {
|
||||
launchWithPdlWhenEnabled(IntNeoxVariableLengthRotaryKernel<T, PackSize>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
head_num,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
launchWithPdlWhenEnabled(
|
||||
IntNeoxVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
head_num,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
} else {
|
||||
launchWithPdlWhenEnabled(NeoxVariableLengthRotaryKernel<T, PackSize>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
head_num,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
launchWithPdlWhenEnabled(
|
||||
NeoxVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
head_num,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
|
||||
void gqa_rotary_qk_norm_variable(
|
||||
T *qkv_out, // [token_num, 3, num_head, dim_head]
|
||||
const QKV_TYPE *qkv_input, // qkv
|
||||
@@ -2362,31 +2395,32 @@ void gqa_rotary_qk_norm_variable(
|
||||
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
|
||||
launchWithPdlWhenEnabled(GQAVariableLengthRotaryQKNormKernel<T, PackSize>,
|
||||
grid_size,
|
||||
Block_Size,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
launchWithPdlWhenEnabled(
|
||||
GQAVariableLengthRotaryQKNormKernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
Block_Size,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
|
||||
void gqa_rotary_qk_variable(
|
||||
T *qkv_out, // [token_num, 3, num_head, dim_head]
|
||||
const QKV_TYPE *qkv_input, // qkv
|
||||
@@ -2425,29 +2459,31 @@ void gqa_rotary_qk_variable(
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
|
||||
if (qkv_out_scales) {
|
||||
launchWithPdlWhenEnabled(IntGQAVariableLengthRotaryKernel<T, PackSize>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
launchWithPdlWhenEnabled(
|
||||
IntGQAVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
stream,
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
} else {
|
||||
auto *kernelFn = GQAVariableLengthRotaryKernel<T, PackSize>;
|
||||
auto *kernelFn =
|
||||
GQAVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>;
|
||||
launchWithPdlWhenEnabled(kernelFn,
|
||||
grid_size,
|
||||
blocksize,
|
||||
@@ -2473,7 +2509,7 @@ void gqa_rotary_qk_variable(
|
||||
const float *sin_emb = rotary_emb + input_output_len * dim_head;
|
||||
if (qkv_out_scales) {
|
||||
launchWithPdlWhenEnabled(
|
||||
IntGQANeoxVariableLengthRotaryKernel<T, PackSize>,
|
||||
IntGQANeoxVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
@@ -2507,7 +2543,10 @@ void gqa_rotary_qk_variable(
|
||||
}
|
||||
const int pack_num_new = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num_new, &grid_size);
|
||||
auto *kernelFn = GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>;
|
||||
auto *kernelFn =
|
||||
GQANeoxVariableLengthPartialRotaryKernel<T,
|
||||
PackSize,
|
||||
EnforceFmulRN>;
|
||||
launchWithPdlWhenEnabled(kernelFn,
|
||||
grid_size,
|
||||
blocksize,
|
||||
@@ -2531,7 +2570,8 @@ void gqa_rotary_qk_variable(
|
||||
rotary_dim,
|
||||
rope_3d);
|
||||
} else {
|
||||
auto *kernelFn = GQANeoxVariableLengthRotaryKernel<T, PackSize>;
|
||||
auto *kernelFn =
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>;
|
||||
launchWithPdlWhenEnabled(kernelFn,
|
||||
grid_size,
|
||||
blocksize,
|
||||
@@ -2558,7 +2598,7 @@ void gqa_rotary_qk_variable(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
|
||||
void gqa_rotary_qk_quant_variable(
|
||||
T *qkv_out, // [token_num, 3, num_head, dim_head]
|
||||
const QKV_TYPE *qkv_input, // qkv
|
||||
@@ -2595,7 +2635,7 @@ void gqa_rotary_qk_quant_variable(
|
||||
if (!use_neox_style) {
|
||||
if (qkv_out_scales) {
|
||||
launchWithPdlWhenEnabled(
|
||||
IntGQAVariableLengthRotaryQuantKVKernel<T, PackSize>,
|
||||
IntGQAVariableLengthRotaryQuantKVKernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
@@ -2620,7 +2660,7 @@ void gqa_rotary_qk_quant_variable(
|
||||
rope_3d);
|
||||
} else {
|
||||
launchWithPdlWhenEnabled(
|
||||
GQAVariableLengthRotaryQuantKVKernel<T, PackSize>,
|
||||
GQAVariableLengthRotaryQuantKVKernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
0,
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
#include "encoder_write_cache_with_rope_impl.cuh"
|
||||
#include "remote_cache_kv_ipc.h"
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
|
||||
void EncoderWriteCacheWithRopeKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor&
|
||||
@@ -77,7 +77,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (num_heads != kv_num_heads && !is_scale_channel_wise &&
|
||||
!use_neox_style) {
|
||||
gqa_rotary_qk_norm_variable(
|
||||
gqa_rotary_qk_norm_variable<T, QKV_TYPE, EnforceFmulRN>(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
@@ -106,7 +106,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
}
|
||||
} else {
|
||||
if (num_heads == kv_num_heads) {
|
||||
rotary_qk_variable(
|
||||
rotary_qk_variable<T, QKV_TYPE, EnforceFmulRN>(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
@@ -126,7 +126,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
rope_3d);
|
||||
} else {
|
||||
if (!is_scale_channel_wise) {
|
||||
gqa_rotary_qk_variable(
|
||||
gqa_rotary_qk_variable<T, QKV_TYPE, EnforceFmulRN>(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
@@ -147,7 +147,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
gqa_rotary_qk_quant_variable(
|
||||
gqa_rotary_qk_quant_variable<T, QKV_TYPE, EnforceFmulRN>(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
#include "qwen3_rope.h"
|
||||
#include "remote_cache_kv_ipc.h"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
@@ -117,8 +117,10 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
@@ -157,9 +159,11 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
src_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
src_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -168,7 +172,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool EnforceFmulRN = false>
|
||||
void gqa_rotary_qk_split_variable(
|
||||
T *qkv_out, // [token_num, 3, num_head, head_dim]
|
||||
T *q,
|
||||
@@ -205,35 +209,36 @@ void gqa_rotary_qk_split_variable(
|
||||
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + input_output_len * head_dim / 2;
|
||||
launchWithPdlWhenEnabled(GQAVariableLengthRotarySplitKernel<T, PackSize>,
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
stream,
|
||||
qkv_input,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
qkv_out,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_model_len,
|
||||
head_dim,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
launchWithPdlWhenEnabled(
|
||||
GQAVariableLengthRotarySplitKernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
stream,
|
||||
qkv_input,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
qkv_out,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_model_len,
|
||||
head_dim,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void GQAVariableLengthNeoxPartialRotarySplitKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
@@ -329,10 +334,12 @@ __global__ void GQAVariableLengthNeoxPartialRotarySplitKernel(
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
if (h_bias < half_rotary_dim) {
|
||||
src_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
} else {
|
||||
src_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp + input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -346,7 +353,7 @@ __global__ void GQAVariableLengthNeoxPartialRotarySplitKernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool EnforceFmulRN = false>
|
||||
void gqa_neox_partial_rotary_qk_split_variable(
|
||||
T *qkv_out, // [token_num, 3, num_head, head_dim]
|
||||
T *q,
|
||||
@@ -381,7 +388,7 @@ void gqa_neox_partial_rotary_qk_split_variable(
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2;
|
||||
launchWithPdlWhenEnabled(
|
||||
GQAVariableLengthNeoxPartialRotarySplitKernel<T, PackSize>,
|
||||
GQAVariableLengthNeoxPartialRotarySplitKernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
@@ -1376,35 +1383,60 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
paddle::Tensor v = GetEmptyTensor(
|
||||
{kv_token_num, kv_num_heads, head_dim}, qkv.dtype(), qkv.place());
|
||||
|
||||
if (use_neox_rotary_style) {
|
||||
if (rotary_dim == head_dim) {
|
||||
gqa_rotary_qk_split_variable_qwen3<data_t>(
|
||||
qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
rope_3d,
|
||||
stream);
|
||||
bool enforce_fmul_rn = getEnvEnableRL();
|
||||
DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, {
|
||||
if (use_neox_rotary_style) {
|
||||
if (rotary_dim == head_dim) {
|
||||
gqa_rotary_qk_split_variable_qwen3<data_t, EnforceFmulRN>(
|
||||
qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
rope_3d,
|
||||
stream);
|
||||
} else {
|
||||
gqa_neox_partial_rotary_qk_split_variable<data_t, EnforceFmulRN>(
|
||||
qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
stream);
|
||||
}
|
||||
} else {
|
||||
gqa_neox_partial_rotary_qk_split_variable<data_t>(
|
||||
gqa_rotary_qk_split_variable<data_t, EnforceFmulRN>(
|
||||
qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
@@ -1414,35 +1446,13 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
rope_3d,
|
||||
rms_norm_eps,
|
||||
stream);
|
||||
}
|
||||
} else {
|
||||
gqa_rotary_qk_split_variable<data_t>(
|
||||
qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
rope_3d,
|
||||
rms_norm_eps,
|
||||
stream);
|
||||
}
|
||||
})
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
AppendCacheKV<data_t, 128, 64>(key_cache,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
#include "remote_cache_kv_ipc.h"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
|
||||
__global__ void GQAVariableLengthRotarySplitKernel_Qwen3(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
@@ -99,9 +99,11 @@ __global__ void GQAVariableLengthRotarySplitKernel_Qwen3(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
src_vec0[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
src_vec1[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(src_vec0, &qkv_out[read_idx]);
|
||||
@@ -111,7 +113,7 @@ __global__ void GQAVariableLengthRotarySplitKernel_Qwen3(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool EnforceFmulRN = false>
|
||||
void gqa_rotary_qk_split_variable_qwen3(T *qkv_out,
|
||||
T *q,
|
||||
T *k,
|
||||
@@ -145,7 +147,7 @@ void gqa_rotary_qk_split_variable_qwen3(T *qkv_out,
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + max_model_len * head_dim;
|
||||
launchWithPdlWhenEnabled(
|
||||
GQAVariableLengthRotarySplitKernel_Qwen3<T, PackSize>,
|
||||
GQAVariableLengthRotarySplitKernel_Qwen3<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
|
||||
@@ -18,7 +18,10 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1, typename InT = T>
|
||||
template <typename T,
|
||||
int VecSize = 1,
|
||||
typename InT = T,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
|
||||
// head_size]
|
||||
@@ -129,8 +132,10 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
@@ -331,7 +336,10 @@ __global__ void append_clear_cache_int4_block(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1, typename InT = T>
|
||||
template <typename T,
|
||||
int VecSize = 1,
|
||||
typename InT = T,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_speculate_cache_rope_kernel(
|
||||
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
|
||||
// head_size]
|
||||
@@ -434,9 +442,11 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
bias_vec[2 * i] = static_cast<T>(input_left);
|
||||
bias_vec[2 * i + 1] = static_cast<T>(input_right);
|
||||
@@ -462,7 +472,10 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1, typename InT = T>
|
||||
template <typename T,
|
||||
int VecSize = 1,
|
||||
typename InT = T,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_speculate_cache_neox_rope_kernel(
|
||||
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
|
||||
// head_size]
|
||||
@@ -568,9 +581,11 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
@@ -601,7 +616,10 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1, typename InT = T>
|
||||
template <typename T,
|
||||
int VecSize = 1,
|
||||
typename InT = T,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_speculate_cache_neox_partial_rope_kernel(
|
||||
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
|
||||
// head_size]
|
||||
@@ -724,9 +742,11 @@ __global__ void append_speculate_cache_neox_partial_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
@@ -766,7 +786,8 @@ template <typename T,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamic = true>
|
||||
bool IsDynamic = true,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
|
||||
// gqa_group_size, head_size]
|
||||
@@ -864,8 +885,10 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
bias_vec[2 * i] = static_cast<T>(tmp1);
|
||||
bias_vec[2 * i + 1] = static_cast<T>(tmp2);
|
||||
@@ -929,8 +952,10 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
bias_vec1[0] = static_cast<T>(tmp1);
|
||||
bias_vec1[1] = static_cast<T>(tmp2);
|
||||
@@ -944,8 +969,10 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
|
||||
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
bias_vec2[0] = static_cast<T>(tmp1);
|
||||
bias_vec2[1] = static_cast<T>(tmp2);
|
||||
@@ -1045,7 +1072,8 @@ template <typename T,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
typename InT = int,
|
||||
bool IsFP8 = false>
|
||||
bool IsFP8 = false,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const InT* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
|
||||
// gqa_group_size, head_size]
|
||||
@@ -1149,9 +1177,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(bias_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
@@ -1215,9 +1245,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
bias_vec1[0] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec1[1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
bias_vec1[0] = static_cast<T>(input_left);
|
||||
bias_vec1[1] = static_cast<T>(input_right);
|
||||
@@ -1238,9 +1270,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
bias_vec2[0] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec2[1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
bias_vec2[0] = static_cast<T>(input_left);
|
||||
bias_vec2[1] = static_cast<T>(input_right);
|
||||
@@ -1285,7 +1319,8 @@ template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
typename InT = int>
|
||||
typename InT = int,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const InT* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
|
||||
// gqa_group_size, head_size]
|
||||
@@ -1396,9 +1431,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(left_bias_vec, &qkv_out_now[bias_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &qkv_out_now[bias_idx_right]);
|
||||
@@ -1485,9 +1522,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[i];
|
||||
float sin_tmp = sin_emb_vec1[i];
|
||||
left_bias_vec1[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec1[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
|
||||
input_left = static_cast<float>(left_src_vec2[i]);
|
||||
input_right = static_cast<float>(right_src_vec2[i]);
|
||||
@@ -1500,9 +1539,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
cos_tmp = cos_emb_vec2[i];
|
||||
sin_tmp = sin_emb_vec2[i];
|
||||
left_bias_vec2[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec2[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
|
||||
float quant_value1 = static_cast<float>(scale * left_bias_vec1[i]);
|
||||
float quant_value2 = static_cast<float>(scale * left_bias_vec2[i]);
|
||||
@@ -1669,7 +1710,8 @@ template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
typename InT = int>
|
||||
typename InT = int,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const InT* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size,
|
||||
// head_size]
|
||||
@@ -1774,9 +1816,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(bias_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
@@ -1877,9 +1921,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
bias_vec1[0] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec1[1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
bias_vec1[0] = static_cast<T>(input_left);
|
||||
bias_vec1[1] = static_cast<T>(input_right);
|
||||
@@ -1895,9 +1941,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
bias_vec2[0] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
bias_vec2[1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
} else {
|
||||
bias_vec2[0] = static_cast<T>(input_left);
|
||||
bias_vec2[1] = static_cast<T>(input_right);
|
||||
@@ -2017,7 +2065,8 @@ template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
typename InT = int>
|
||||
typename InT = int,
|
||||
bool EnforceFmulRN = false>
|
||||
__global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const InT* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size,
|
||||
// head_size]
|
||||
@@ -2133,9 +2182,11 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
}
|
||||
Store<T, VecSize>(left_bias_vec, &qkv_out_now[bias_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &qkv_out_now[bias_idx_right]);
|
||||
@@ -2234,18 +2285,22 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
left_bias_vec1[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec1[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
|
||||
input_left = static_cast<float>(left_src_vec2[i]);
|
||||
input_right = static_cast<float>(right_src_vec2[i]);
|
||||
cos_tmp = cos_emb_vec2[i];
|
||||
sin_tmp = sin_emb_vec2[i];
|
||||
left_bias_vec2[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
|
||||
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
|
||||
right_bias_vec2[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
|
||||
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
|
||||
// quant + write k
|
||||
}
|
||||
LoadKVResT left_cache_vec, right_cache_vec;
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "speculate_write_cache_with_rope_kernel.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
|
||||
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
@@ -58,7 +58,10 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
PD_THROW("append_speculate_cache_rope_qk_norm not support neox rope yet");
|
||||
} else {
|
||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||
append_speculate_cache_T_rope_qk_norm_kernel<T,
|
||||
PackSize,
|
||||
QKV_TYPE,
|
||||
EnforceFmulRN>
|
||||
<<<grid_size, block_dim, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -88,7 +91,7 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
}
|
||||
|
||||
// rope + write
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
|
||||
void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
@@ -127,7 +130,10 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
GetNumBlocks(pack_num, &grid_size);
|
||||
if (use_neox_style) {
|
||||
if (rotary_dim < dim_head) {
|
||||
append_speculate_cache_neox_partial_rope_kernel<T, PackSize>
|
||||
append_speculate_cache_neox_partial_rope_kernel<T,
|
||||
PackSize,
|
||||
QKV_TYPE,
|
||||
EnforceFmulRN>
|
||||
<<<grid_size, threads_per_block, 0, stream>>>(
|
||||
qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size]
|
||||
key_cache,
|
||||
@@ -153,7 +159,10 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_neox_rope_kernel<T, PackSize>
|
||||
append_speculate_cache_neox_rope_kernel<T,
|
||||
PackSize,
|
||||
QKV_TYPE,
|
||||
EnforceFmulRN>
|
||||
<<<grid_size, threads_per_block, 0, stream>>>(
|
||||
qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size]
|
||||
key_cache,
|
||||
@@ -179,7 +188,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
rope_3d);
|
||||
}
|
||||
} else {
|
||||
append_speculate_cache_rope_kernel<T, PackSize>
|
||||
append_speculate_cache_rope_kernel<T, PackSize, QKV_TYPE, EnforceFmulRN>
|
||||
<<<grid_size, threads_per_block, 0, stream>>>(
|
||||
qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size]
|
||||
key_cache,
|
||||
@@ -206,7 +215,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool IsDynamic = true>
|
||||
template <typename T, bool IsDynamic = true, bool EnforceFmulRN = false>
|
||||
void append_speculate_cache_fp8_rope(const T* qkv,
|
||||
uint8_t* key_cache,
|
||||
uint8_t* value_cache,
|
||||
@@ -238,7 +247,7 @@ void append_speculate_cache_fp8_rope(const T* qkv,
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(token_num, all_warps / num_warps);
|
||||
|
||||
append_clear_cache_int8_block<4>
|
||||
append_clear_cache_int8_block<4, 128>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(key_cache,
|
||||
value_cache,
|
||||
seq_lens,
|
||||
@@ -256,7 +265,8 @@ void append_speculate_cache_fp8_rope(const T* qkv,
|
||||
0,
|
||||
128,
|
||||
true,
|
||||
IsDynamic>
|
||||
IsDynamic,
|
||||
EnforceFmulRN>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -283,7 +293,10 @@ void append_speculate_cache_fp8_rope(const T* qkv,
|
||||
rms_norm_eps);
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE, bool IsFP8 = false>
|
||||
template <typename T,
|
||||
typename QKV_TYPE,
|
||||
bool IsFP8 = false,
|
||||
bool EnforceFmulRN = false>
|
||||
void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* key_cache,
|
||||
uint8_t* value_cache,
|
||||
@@ -315,7 +328,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(token_num, all_warps / num_warps);
|
||||
|
||||
append_clear_cache_int8_block<4>
|
||||
append_clear_cache_int8_block<4, 128>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(key_cache,
|
||||
value_cache,
|
||||
seq_lens,
|
||||
@@ -329,7 +342,12 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
kv_num_heads);
|
||||
if (use_neox_style) {
|
||||
append_speculate_cache_int8_neox_rope_kernel<T, 4>
|
||||
append_speculate_cache_int8_neox_rope_kernel<T,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
QKV_TYPE,
|
||||
EnforceFmulRN>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -354,7 +372,13 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
|
||||
append_speculate_cache_int8_rope_kernel<T,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
QKV_TYPE,
|
||||
IsFP8,
|
||||
EnforceFmulRN>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -381,7 +405,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
|
||||
void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* key_cache,
|
||||
uint8_t* value_cache,
|
||||
@@ -415,7 +439,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(token_num, all_warps / num_warps);
|
||||
|
||||
append_clear_cache_int4_block<4>
|
||||
append_clear_cache_int4_block<4, 128>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(key_cache,
|
||||
value_cache,
|
||||
seq_lens,
|
||||
@@ -429,7 +453,12 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
kv_num_heads);
|
||||
if (use_neox_style) {
|
||||
append_speculate_cache_int4_neox_rope_kernel<T, 4>
|
||||
append_speculate_cache_int4_neox_rope_kernel<T,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
QKV_TYPE,
|
||||
EnforceFmulRN>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -456,7 +485,12 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_int4_rope_kernel<T, 4>
|
||||
append_speculate_cache_int4_rope_kernel<T,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
QKV_TYPE,
|
||||
EnforceFmulRN>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -484,7 +518,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
template <typename T, typename QKV_TYPE>
|
||||
template <typename T, typename QKV_TYPE, bool EnforceFmulRN>
|
||||
void SpeculateWriteCacheWithRoPEKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& qkv,
|
||||
@@ -549,7 +583,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
}
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope_qk_norm(
|
||||
append_speculate_cache_rope_qk_norm<DataType_, QKV_TYPE, EnforceFmulRN>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
@@ -580,7 +614,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
append_speculate_cache_fp8_rope<DataType_, true>(
|
||||
append_speculate_cache_fp8_rope<DataType_, true, EnforceFmulRN>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
@@ -610,7 +644,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_fp8_rope<DataType_, false>(
|
||||
append_speculate_cache_fp8_rope<DataType_, false, EnforceFmulRN>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
@@ -648,7 +682,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
append_speculate_cache_rope<DataType_, QKV_TYPE, EnforceFmulRN>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
@@ -677,7 +711,10 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
append_speculate_cache_int8_rope<DataType_,
|
||||
QKV_TYPE,
|
||||
false,
|
||||
EnforceFmulRN>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
@@ -711,7 +748,10 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
append_speculate_cache_int8_rope<DataType_,
|
||||
QKV_TYPE,
|
||||
true,
|
||||
EnforceFmulRN>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
@@ -745,7 +785,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
append_speculate_cache_fp8_rope(
|
||||
append_speculate_cache_fp8_rope<DataType_, true, EnforceFmulRN>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
@@ -775,7 +815,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
append_speculate_cache_int4_rope<DataType_, QKV_TYPE, EnforceFmulRN>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
#include "speculate_write_cache_with_rope_impl.cuh"
|
||||
|
||||
template <typename T, typename QKV_TYPE = int>
|
||||
template <typename T, typename QKV_TYPE = int, bool EnforceFmulRN = false>
|
||||
void SpeculateWriteCacheWithRoPEKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor&
|
||||
|
||||
@@ -157,6 +157,14 @@ bool getEnvEnablePDL() {
|
||||
return enablePDL;
|
||||
}
|
||||
|
||||
bool getEnvEnableRL() {
|
||||
static std::once_flag flag;
|
||||
static bool enableRL = false;
|
||||
|
||||
std::call_once(flag, [&]() { enableRL = getBoolEnv("FD_ENABLE_RL"); });
|
||||
return enableRL;
|
||||
}
|
||||
|
||||
bool getEnvDeterministicMode() {
|
||||
static std::once_flag flag;
|
||||
static bool value = false;
|
||||
|
||||
@@ -189,6 +189,17 @@ inline int GetGPUComputeCapability(int id) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef DISPATCH_BOOL_DTYPE
|
||||
#define DISPATCH_BOOL_DTYPE(runtime_flag, STATIC_FLAG, ...) \
|
||||
if (runtime_flag) { \
|
||||
constexpr bool STATIC_FLAG = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr bool STATIC_FLAG = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
#endif
|
||||
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
@@ -731,7 +742,15 @@ inline bool getBoolEnv(char const *name) {
|
||||
return env && env[0] == '1' && env[1] == '\0';
|
||||
}
|
||||
|
||||
template <const bool EnforceRN = false>
|
||||
__device__ __forceinline__ float fmul_func(float a, float b) {
|
||||
// Use __fmul_rn for IEEE-754 compliant rounding when EnforceRN is enabled
|
||||
float res = EnforceRN ? __fmul_rn(a, b) : a * b;
|
||||
return res;
|
||||
}
|
||||
|
||||
bool getEnvEnablePDL();
|
||||
bool getEnvEnableRL();
|
||||
bool getEnvDeterministicMode();
|
||||
bool getEnvDeterministicDebug();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user