[RL] RoPE without fmad opt (#6901)

* env FD_ENABLE_RL=1 do fmul_rn(a*b) in rope
This commit is contained in:
chen
2026-03-24 21:19:53 +08:00
committed by GitHub
parent 6f5aa883f7
commit c92e277cf1
13 changed files with 1067 additions and 725 deletions
+147 -140
View File
@@ -108,6 +108,7 @@ void AppendAttentionKernel(
static cudaEvent_t decoder_event;
static cudaStream_t decoder_stream;
static bool init_flag = false;
bool enforce_fmul_rn = getEnvEnableRL();
if (max_just_dec_len_this_time > 0 && max_enc_len_this_time > 0 &&
!init_flag) {
cudaEventCreateWithFlags(&main_event, cudaEventDisableTiming);
@@ -185,37 +186,41 @@ void AppendAttentionKernel(
auto dispatch_EncoderWriteCacheWithRopeKernel =
[&](auto temp_args) -> void {
EncoderWriteCacheWithRopeKernel<data_t, decltype(temp_args)>(
meta_data,
qkv,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_batch_ids,
kv_tile_ids_per_batch,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
kv_signal_data,
cache_quant_type_str,
kv_num_blocks_data,
max_input_length,
use_neox_rotary_style,
rope_3d,
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, {
EncoderWriteCacheWithRopeKernel<data_t,
decltype(temp_args),
EnforceFmulRN>(
meta_data,
qkv,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_batch_ids,
kv_tile_ids_per_batch,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
kv_signal_data,
cache_quant_type_str,
kv_num_blocks_data,
max_input_length,
use_neox_rotary_style,
rope_3d,
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
})
};
if (qkv_out_scales) {
@@ -284,117 +289,119 @@ void AppendAttentionKernel(
} else {
exec_stream = main_stream;
}
if (speculate_decoder) {
if (qkv_out_scales) {
SpeculateWriteCacheWithRoPEKernel<data_t, int>(
meta_data,
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, {
if (speculate_decoder) {
if (qkv_out_scales) {
SpeculateWriteCacheWithRoPEKernel<data_t, int, EnforceFmulRN>(
meta_data,
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} else {
SpeculateWriteCacheWithRoPEKernel<data_t, data_t, EnforceFmulRN>(
meta_data,
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
} else {
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
if (qkv_out_scales) {
DecoderWriteCacheWithRoPEKernel<data_t, int, EnforceFmulRN>(
meta_data,
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} else {
DecoderWriteCacheWithRoPEKernel<data_t, data_t, EnforceFmulRN>(
meta_data,
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
}
} else {
if (qkv_out_scales) {
DecoderWriteCacheWithRoPEKernel<data_t, int>(
meta_data,
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} else {
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
}
})
if (out_linear_in_scale > 0.0) {
switch (fmha_out.dtype()) {
@@ -22,7 +22,11 @@
// This function is very easy!
// just make HeadDim data to be new HeadDim data!
template <typename T, int VecSize = 8, int HEAD_DIM = 128, int NUM_THREADS = 32>
template <typename T,
int VecSize = 8,
int HEAD_DIM = 128,
int NUM_THREADS = 32,
bool EnforceFmulRN = false>
__device__ __forceinline__ void apply_rope(const T* input,
const float* cos_emb,
const float* sin_emb,
@@ -54,15 +58,17 @@ __device__ __forceinline__ void apply_rope(const T* input,
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
out_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
out_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(out_vec, &output[head_bias]);
}
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -154,8 +160,10 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
if (hi < num_heads + kv_num_heads) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
@@ -206,7 +214,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void append_decode_cache_T_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -289,9 +297,11 @@ __global__ void append_decode_cache_T_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
out_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
out_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
out_vec[2 * i] = src_vec[2 * i];
out_vec[2 * i + 1] = src_vec[2 * i + 1];
@@ -319,7 +329,7 @@ __global__ void append_decode_cache_T_rope_kernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void append_decode_cache_T_quant_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -417,9 +427,11 @@ __global__ void append_decode_cache_T_quant_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
bias_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
bias_vec[2 * i] = static_cast<T>(input_left);
bias_vec[2 * i + 1] = static_cast<T>(input_right);
@@ -447,7 +459,7 @@ __global__ void append_decode_cache_T_quant_rope_kernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -551,9 +563,11 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
@@ -591,7 +605,7 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void append_decode_cache_T_neox_rope_kernel(
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -676,9 +690,11 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
@@ -710,7 +726,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void append_decode_cache_T_quant_neox_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -815,9 +831,11 @@ __global__ void append_decode_cache_T_quant_neox_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
@@ -855,7 +873,8 @@ template <typename T,
int HeadDim = 128,
bool is_scale_channel_wise = false,
bool IsFP8 = true,
bool IsDynamic = true>
bool IsDynamic = true,
bool EnforceFmulRN = false>
__global__ void append_decode_cache_T_int8_neox_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -941,8 +960,10 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec[i] = static_cast<T>(tmp1);
out_vec_right[i] = static_cast<T>(tmp2);
@@ -1032,9 +1053,11 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel(
float sin_tmp = sin_emb_vec1[0];
float tmp1 = 0;
if (head_bias < half_head_size) {
tmp1 = input_left * cos_tmp - input_right * sin_tmp;
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
} else {
tmp1 = input_left * cos_tmp + input_right * sin_tmp;
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) +
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
}
out_vec1[0] = static_cast<T>(tmp1);
input_left = static_cast<float>(src_vec1[1]);
@@ -1042,9 +1065,11 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel(
cos_tmp = cos_emb_vec1[1];
sin_tmp = sin_emb_vec1[1];
if (head_bias < half_head_size) {
tmp1 = input_left * cos_tmp - input_right * sin_tmp;
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
} else {
tmp1 = input_left * cos_tmp + input_right * sin_tmp;
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) +
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
}
out_vec1[1] = static_cast<T>(tmp1);
} else {
@@ -1060,9 +1085,11 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel(
float sin_tmp = sin_emb_vec2[0];
float tmp1 = 0;
if (head_bias < half_head_size) {
tmp1 = input_left * cos_tmp - input_right * sin_tmp;
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
} else {
tmp1 = input_left * cos_tmp + input_right * sin_tmp;
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) +
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
}
out_vec2[0] = static_cast<T>(tmp1);
input_left = static_cast<float>(src_vec2[1]);
@@ -1070,9 +1097,11 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel(
cos_tmp = cos_emb_vec2[1];
sin_tmp = sin_emb_vec2[1];
if (head_bias < half_head_size) {
tmp1 = input_left * cos_tmp - input_right * sin_tmp;
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
} else {
tmp1 = input_left * cos_tmp + input_right * sin_tmp;
tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) +
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
}
out_vec2[1] = static_cast<T>(tmp1);
} else {
@@ -1164,7 +1193,8 @@ template <typename T,
int HeadDim = 128,
bool is_scale_channel_wise = false,
bool IsFP8 = true,
bool IsDynamic = true>
bool IsDynamic = true,
bool EnforceFmulRN = false>
__global__ void append_decode_cache_int8_rope_qk_norm_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -1250,8 +1280,10 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec[2 * i] = static_cast<T>(tmp1);
out_vec[2 * i + 1] = static_cast<T>(tmp2);
@@ -1344,8 +1376,10 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
if (head_idx < num_heads + kv_num_heads) {
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec1[0] = static_cast<T>(tmp1);
out_vec1[1] = static_cast<T>(tmp2);
@@ -1360,8 +1394,10 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
if (head_idx < num_heads + kv_num_heads) {
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec2[0] = static_cast<T>(tmp1);
out_vec2[1] = static_cast<T>(tmp2);
@@ -1472,7 +1508,8 @@ template <typename T,
int RoundType = 0,
int HeadDim = 128,
bool is_scale_channel_wise = false,
bool IsFP8 = false>
bool IsFP8 = false,
bool EnforceFmulRN = false>
__global__ void append_decode_cache_int8_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -1528,11 +1565,11 @@ __global__ void append_decode_cache_int8_rope_kernel(
uint32_t emb_offset = write_seq_id * half_head_size;
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
apply_rope<T, VecSize, HeadDim, 32>(qkv_now,
cos_emb + emb_offset,
sin_emb + emb_offset,
qkv_out_now,
lane_id);
apply_rope<T, VecSize, HeadDim, 32, EnforceFmulRN>(qkv_now,
cos_emb + emb_offset,
sin_emb + emb_offset,
qkv_out_now,
lane_id);
} else if (head_idx < num_heads + 2 * kv_num_heads) {
// k
@@ -1619,9 +1656,11 @@ __global__ void append_decode_cache_int8_rope_kernel(
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
out_vec1[0] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
out_vec1[1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
out_vec1[0] = src_vec1[0];
out_vec1[1] = src_vec1[1];
@@ -1631,10 +1670,12 @@ __global__ void append_decode_cache_int8_rope_kernel(
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
out_vec1[0] =
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
static_cast<T>((fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp)) *
float(cache_k_scale_cur[0]));
out_vec1[1] =
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
static_cast<T>((fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp)) *
float(cache_k_scale_cur[1]));
} else {
out_vec1[0] = static_cast<T>(input_left * float(cache_v_scale_cur[0]));
@@ -1649,9 +1690,11 @@ __global__ void append_decode_cache_int8_rope_kernel(
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
out_vec2[0] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
out_vec2[1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
out_vec2[0] = src_vec2[0];
out_vec2[1] = src_vec2[1];
@@ -1661,10 +1704,12 @@ __global__ void append_decode_cache_int8_rope_kernel(
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
out_vec2[0] =
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
static_cast<T>((fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp)) *
float(cache_k_scale_cur[8]));
out_vec2[1] =
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
static_cast<T>((fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp)) *
float(cache_k_scale_cur[9]));
} else {
out_vec2[0] = static_cast<T>(input_left * float(cache_v_scale_cur[8]));
@@ -1715,7 +1760,8 @@ template <typename T,
int RoundType = 0,
int HeadDim = 128,
bool is_scale_channel_wise = false,
bool IsFP8 = false>
bool IsFP8 = false,
bool EnforceFmulRN = false>
__global__ void int_append_decode_cache_int8_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -1815,9 +1861,11 @@ __global__ void int_append_decode_cache_int8_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
bias_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(bias_vec, &qkv_out_now[bias_idx]);
}
@@ -1922,9 +1970,11 @@ __global__ void int_append_decode_cache_int8_rope_kernel(
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
bias_vec1[0] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec1[1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
bias_vec1[0] = static_cast<T>(input_left);
bias_vec1[1] = static_cast<T>(input_right);
@@ -1934,10 +1984,12 @@ __global__ void int_append_decode_cache_int8_rope_kernel(
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
bias_vec1[0] =
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
static_cast<T>((fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp)) *
float(cache_k_scale_cur[0]));
bias_vec1[1] =
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
static_cast<T>((fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp)) *
float(cache_k_scale_cur[1]));
} else {
bias_vec1[0] = static_cast<T>(input_left * float(cache_v_scale_cur[0]));
@@ -1959,9 +2011,11 @@ __global__ void int_append_decode_cache_int8_rope_kernel(
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
bias_vec2[0] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec2[1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
bias_vec2[0] = static_cast<T>(input_left);
bias_vec2[1] = static_cast<T>(input_right);
@@ -1971,10 +2025,12 @@ __global__ void int_append_decode_cache_int8_rope_kernel(
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
bias_vec2[0] =
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
static_cast<T>((fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp)) *
float(cache_k_scale_cur[8]));
bias_vec2[1] =
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
static_cast<T>((fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp)) *
float(cache_k_scale_cur[9]));
} else {
bias_vec2[0] = static_cast<T>(input_left * float(cache_v_scale_cur[8]));
@@ -2033,7 +2089,11 @@ __global__ void int_append_decode_cache_int8_rope_kernel(
#endif
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
bool EnforceFmulRN = false>
__global__ void append_decode_cache_int8_neox_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -2120,9 +2180,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(left_bias_vec, &qkv_out_now[bias_idx_left]);
Store<T, VecSize>(right_bias_vec, &qkv_out_now[bias_idx_right]);
@@ -2206,18 +2268,22 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
float cos_tmp = cos_emb_vec1[i];
float sin_tmp = sin_emb_vec1[i];
left_bias_vec1[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec1[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
input_left = static_cast<float>(left_src_vec2[i]);
input_right = static_cast<float>(right_src_vec2[i]);
cos_tmp = cos_emb_vec2[i];
sin_tmp = sin_emb_vec2[i];
left_bias_vec2[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec2[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
float quant_value1 = static_cast<float>(scale * left_bias_vec1[i]);
float quant_value2 = static_cast<float>(scale * left_bias_vec2[i]);
@@ -2341,7 +2407,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
#endif
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
bool EnforceFmulRN = false>
__global__ void int_append_decode_cache_int8_neox_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -2451,9 +2521,11 @@ __global__ void int_append_decode_cache_int8_neox_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(left_bias_vec, &qkv_out_now[bias_idx_left]);
Store<T, VecSize>(right_bias_vec, &qkv_out_now[bias_idx_right]);
@@ -2564,9 +2636,11 @@ __global__ void int_append_decode_cache_int8_neox_rope_kernel(
float cos_tmp = cos_emb_vec1[i];
float sin_tmp = sin_emb_vec1[i];
left_bias_vec1[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec1[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
input_left = static_cast<float>(left_src_vec2[i]);
input_right = static_cast<float>(right_src_vec2[i]);
@@ -2579,9 +2653,11 @@ __global__ void int_append_decode_cache_int8_neox_rope_kernel(
cos_tmp = cos_emb_vec2[i];
sin_tmp = sin_emb_vec2[i];
left_bias_vec2[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec2[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
float quant_value1 = static_cast<float>(scale * left_bias_vec1[i]);
float quant_value2 = static_cast<float>(scale * left_bias_vec2[i]);
@@ -2743,7 +2819,11 @@ __global__ void int_append_decode_cache_int8_neox_rope_kernel(
#endif
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
bool EnforceFmulRN = false>
__global__ void append_decode_cache_int4_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -2803,11 +2883,11 @@ __global__ void append_decode_cache_int4_rope_kernel(
uint32_t emb_offset = write_seq_id * half_head_size;
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
apply_rope<T, VecSize, HeadDim, 32>(qkv_now,
cos_emb + emb_offset,
sin_emb + emb_offset,
qkv_out_now,
lane_id);
apply_rope<T, VecSize, HeadDim, 32, EnforceFmulRN>(qkv_now,
cos_emb + emb_offset,
sin_emb + emb_offset,
qkv_out_now,
lane_id);
} else if (head_idx < num_heads + 2 * kv_num_heads) {
// k
@@ -2890,9 +2970,11 @@ __global__ void append_decode_cache_int4_rope_kernel(
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
out_vec1[0] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
out_vec1[1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
out_vec1[0] = src_vec1[0];
out_vec1[1] = src_vec1[1];
@@ -2904,9 +2986,11 @@ __global__ void append_decode_cache_int4_rope_kernel(
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
out_vec2[0] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
out_vec2[1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
out_vec2[0] = src_vec2[0];
out_vec2[1] = src_vec2[1];
@@ -3022,7 +3106,11 @@ __global__ void append_decode_cache_int4_rope_kernel(
#endif
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
bool EnforceFmulRN = false>
__global__ void int_append_decode_cache_int4_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -3122,9 +3210,11 @@ __global__ void int_append_decode_cache_int4_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
bias_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(bias_vec, &qkv_out_now[bias_idx]);
}
@@ -3223,9 +3313,11 @@ __global__ void int_append_decode_cache_int4_rope_kernel(
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
bias_vec1[0] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec1[1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
bias_vec1[0] = static_cast<T>(input_left);
bias_vec1[1] = static_cast<T>(input_right);
@@ -3243,9 +3335,11 @@ __global__ void int_append_decode_cache_int4_rope_kernel(
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
bias_vec2[0] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec2[1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
bias_vec2[0] = static_cast<T>(input_left);
bias_vec2[1] = static_cast<T>(input_right);
@@ -3360,7 +3454,11 @@ __global__ void int_append_decode_cache_int4_rope_kernel(
#endif
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
bool EnforceFmulRN = false>
__global__ void append_decode_cache_int4_neox_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -3449,9 +3547,11 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_out_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_out_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(left_out_vec, &qkv_out_now[bias_idx_left]);
Store<T, VecSize>(right_out_vec, &qkv_out_now[bias_idx_right]);
@@ -3549,18 +3649,22 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
left_out_vec1[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_out_vec1[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
input_left = static_cast<float>(left_src_vec2[i]);
input_right = static_cast<float>(right_src_vec2[i]);
cos_tmp = cos_emb_vec2[i];
sin_tmp = sin_emb_vec2[i];
left_out_vec2[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_out_vec2[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
// quant + write k
}
LoadKVResT left_cache_vec, right_cache_vec;
@@ -3739,7 +3843,11 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
#endif
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
bool EnforceFmulRN = false>
__global__ void int_append_decode_cache_int4_neox_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
@@ -3848,9 +3956,11 @@ __global__ void int_append_decode_cache_int4_neox_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(left_bias_vec, &qkv_out_now[bias_idx_left]);
Store<T, VecSize>(right_bias_vec, &qkv_out_now[bias_idx_right]);
@@ -3977,18 +4087,22 @@ __global__ void int_append_decode_cache_int4_neox_rope_kernel(
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
left_bias_vec1[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec1[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
input_left = static_cast<float>(left_src_vec2[i]);
input_right = static_cast<float>(right_src_vec2[i]);
cos_tmp = cos_emb_vec2[i];
sin_tmp = sin_emb_vec2[i];
left_bias_vec2[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec2[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
// quant + write k
}
@@ -15,7 +15,7 @@
#include "decoder_write_cache_with_rope_kernel.h"
#include "utils.cuh"
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
@@ -53,7 +53,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
GetNumBlocks<128>(pack_num, &grid_size);
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
launchWithPdlWhenEnabled(
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>,
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize, EnforceFmulRN>,
grid_size,
block_dim,
0,
@@ -81,7 +81,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
void append_decode_cache_rope(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
@@ -117,7 +117,9 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
if (use_neox_style) {
if (qkv_out_scales) {
launchWithPdlWhenEnabled(
append_decode_cache_T_quant_neox_rope_kernel<T, PackSize>,
append_decode_cache_T_quant_neox_rope_kernel<T,
PackSize,
EnforceFmulRN>,
grid_size,
blocksize,
0,
@@ -145,7 +147,9 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
} else {
if (rotary_dim < dim_head) {
auto* kernelFn =
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>;
append_decode_cache_T_neox_partial_rope_kernel<T,
PackSize,
EnforceFmulRN>;
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
@@ -171,7 +175,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
kv_num_heads,
rope_3d);
} else {
auto* kernelFn = append_decode_cache_T_neox_rope_kernel<T, PackSize>;
auto* kernelFn =
append_decode_cache_T_neox_rope_kernel<T, PackSize, EnforceFmulRN>;
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
@@ -200,7 +205,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
} else {
if (qkv_out_scales) {
launchWithPdlWhenEnabled(
append_decode_cache_T_quant_rope_kernel<T, PackSize>,
append_decode_cache_T_quant_rope_kernel<T, PackSize, EnforceFmulRN>,
grid_size,
blocksize,
0,
@@ -226,7 +231,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
kv_num_heads,
rope_3d);
} else {
auto* kernelFn = append_decode_cache_T_rope_kernel<T, PackSize>;
auto* kernelFn =
append_decode_cache_T_rope_kernel<T, PackSize, EnforceFmulRN>;
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
@@ -257,7 +263,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
template <typename T,
typename QKV_TYPE,
bool is_scale_channel_wise = false,
bool IsFP8 = false>
bool IsFP8 = false,
bool EnforceFmulRN = false>
void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
uint8_t* key_cache,
uint8_t* value_cache,
@@ -289,7 +296,11 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
if (use_neox_style) {
if (qkv_out_scales) {
launchWithPdlWhenEnabled(
int_append_decode_cache_int8_neox_rope_kernel<T, 4>,
int_append_decode_cache_int8_neox_rope_kernel<T,
4,
0,
128,
EnforceFmulRN>,
grids,
num_warps * 32,
0,
@@ -317,31 +328,36 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
kv_num_heads,
rope_3d);
} else {
launchWithPdlWhenEnabled(append_decode_cache_int8_neox_rope_kernel<T, 4>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(
append_decode_cache_int8_neox_rope_kernel<T,
4,
0,
128,
EnforceFmulRN>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
}
} else {
if (qkv_out_scales) {
@@ -351,7 +367,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
0,
128,
is_scale_channel_wise,
IsFP8>,
IsFP8,
EnforceFmulRN>,
grids,
num_warps * 32,
0,
@@ -385,7 +402,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
0,
128,
is_scale_channel_wise,
IsFP8>,
IsFP8,
EnforceFmulRN>,
grids,
num_warps * 32,
0,
@@ -414,7 +432,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
}
}
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
uint8_t* key_cache,
uint8_t* value_cache,
@@ -448,7 +466,11 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
if (use_neox_style) {
if (qkv_out_scales) {
launchWithPdlWhenEnabled(
int_append_decode_cache_int4_neox_rope_kernel<T, 4>,
int_append_decode_cache_int4_neox_rope_kernel<T,
4,
0,
128,
EnforceFmulRN>,
grids,
num_warps * 32,
0,
@@ -478,97 +500,104 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
kv_num_heads,
rope_3d);
} else {
launchWithPdlWhenEnabled(append_decode_cache_int4_neox_rope_kernel<T, 4>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(
append_decode_cache_int4_neox_rope_kernel<T,
4,
0,
128,
EnforceFmulRN>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
}
} else {
if (qkv_out_scales) {
launchWithPdlWhenEnabled(int_append_decode_cache_int4_rope_kernel<T, 4>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(
int_append_decode_cache_int4_rope_kernel<T, 4, 0, 128, EnforceFmulRN>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
} else {
launchWithPdlWhenEnabled(append_decode_cache_int4_rope_kernel<T, 4>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(
append_decode_cache_int4_rope_kernel<T, 4, 0, 128, EnforceFmulRN>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
}
}
}
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN>
void DecoderWriteCacheWithRoPEKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& qkv,
@@ -632,7 +661,7 @@ void DecoderWriteCacheWithRoPEKernel(
if (q_norm_weight && k_norm_weight) {
if (cache_quant_type_str == "none") {
append_decode_cache_rope_qk_norm(
append_decode_cache_rope_qk_norm<DataType_, QKV_TYPE, EnforceFmulRN>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
@@ -672,7 +701,8 @@ void DecoderWriteCacheWithRoPEKernel(
128,
false,
true,
true>,
true,
EnforceFmulRN>,
grids,
num_warps * 32,
0,
@@ -714,7 +744,8 @@ void DecoderWriteCacheWithRoPEKernel(
128,
false,
true,
false>,
false,
EnforceFmulRN>,
grids,
num_warps * 32,
0,
@@ -751,7 +782,7 @@ void DecoderWriteCacheWithRoPEKernel(
}
} else {
if (cache_quant_type_str == "none") {
append_decode_cache_rope(
append_decode_cache_rope<DataType_, QKV_TYPE, EnforceFmulRN>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
@@ -784,7 +815,11 @@ void DecoderWriteCacheWithRoPEKernel(
is_scale_channel_wise = true;
}
if (is_scale_channel_wise) {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
append_decode_cache_int8_rope<DataType_,
QKV_TYPE,
true,
false,
EnforceFmulRN>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
@@ -816,7 +851,11 @@ void DecoderWriteCacheWithRoPEKernel(
use_neox_rotary_style,
rope_3d);
} else {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
append_decode_cache_int8_rope<DataType_,
QKV_TYPE,
false,
false,
EnforceFmulRN>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
@@ -849,7 +888,11 @@ void DecoderWriteCacheWithRoPEKernel(
rope_3d);
}
} else if (cache_quant_type_str == "cache_fp8") {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
append_decode_cache_int8_rope<DataType_,
QKV_TYPE,
false,
true,
EnforceFmulRN>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
@@ -892,7 +935,9 @@ void DecoderWriteCacheWithRoPEKernel(
0,
128,
false,
true>,
true,
true,
EnforceFmulRN>,
grids,
num_warps * 32,
0,
@@ -927,7 +972,9 @@ void DecoderWriteCacheWithRoPEKernel(
0,
128,
false,
true>,
true,
true,
EnforceFmulRN>,
grids,
num_warps * 32,
0,
@@ -959,7 +1006,7 @@ void DecoderWriteCacheWithRoPEKernel(
rms_norm_eps);
}
} else if (cache_quant_type_str == "cache_int4_zp") {
append_decode_cache_int4_rope(
append_decode_cache_int4_rope<DataType_, QKV_TYPE, EnforceFmulRN>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
@@ -15,7 +15,7 @@
#include "decoder_write_cache_with_rope_impl.cuh"
template <typename T, typename QKV_TYPE = int>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
void DecoderWriteCacheWithRoPEKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor&
@@ -18,7 +18,7 @@
#include "mma_tensor_op.cuh"
#include "utils.cuh"
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void IntVariableLengthRotaryKernel(
const int *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
@@ -97,9 +97,11 @@ __global__ void IntVariableLengthRotaryKernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
bias_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
bias_vec[2 * i] = static_cast<T>(input_left);
bias_vec[2 * i + 1] = static_cast<T>(input_right);
@@ -112,7 +114,7 @@ __global__ void IntVariableLengthRotaryKernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void VariableLengthRotaryKernel(
const T *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
@@ -171,9 +173,11 @@ __global__ void VariableLengthRotaryKernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
src_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
src_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
}
@@ -182,7 +186,7 @@ __global__ void VariableLengthRotaryKernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void IntNeoxVariableLengthRotaryKernel(
const int *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
@@ -271,9 +275,11 @@ __global__ void IntNeoxVariableLengthRotaryKernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
@@ -287,7 +293,7 @@ __global__ void IntNeoxVariableLengthRotaryKernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void NeoxVariableLengthRotaryKernel(
const T *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
@@ -352,9 +358,11 @@ __global__ void NeoxVariableLengthRotaryKernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
@@ -364,7 +372,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void IntGQAVariableLengthRotaryKernel(
const int *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
@@ -442,9 +450,11 @@ __global__ void IntGQAVariableLengthRotaryKernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
bias_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
bias_vec[2 * i] = static_cast<T>(input_left);
bias_vec[2 * i + 1] = static_cast<T>(input_right);
@@ -457,7 +467,7 @@ __global__ void IntGQAVariableLengthRotaryKernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void GQAVariableLengthRotaryQKNormKernel(
const T *qkv,
const float *cos_emb,
@@ -525,8 +535,10 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
@@ -554,7 +566,7 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void GQAVariableLengthRotaryKernel(const T *qkv,
const float *cos_emb,
const float *sin_emb,
@@ -614,9 +626,11 @@ __global__ void GQAVariableLengthRotaryKernel(const T *qkv,
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
src_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
src_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
}
@@ -625,7 +639,7 @@ __global__ void GQAVariableLengthRotaryKernel(const T *qkv,
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void IntGQAVariableLengthRotaryQuantKVKernel(
const int *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
@@ -703,19 +717,23 @@ __global__ void IntGQAVariableLengthRotaryQuantKVKernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
bias_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else if (hi < q_num_head + kv_num_head) {
int k_hi = hi - q_num_head;
const int scale_idx = k_hi * last_dim + h_bias;
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
bias_vec[2 * i] =
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
static_cast<T>((fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp)) *
float(cache_k_scales[scale_idx + 2 * i]));
bias_vec[2 * i + 1] =
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
static_cast<T>((fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp)) *
float(cache_k_scales[scale_idx + 2 * i + 1]));
} else {
int v_hi = hi - q_num_head - kv_num_head;
@@ -733,7 +751,7 @@ __global__ void IntGQAVariableLengthRotaryQuantKVKernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void GQAVariableLengthRotaryQuantKVKernel(
const T *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
@@ -803,26 +821,31 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(
: static_cast<float>(src_vec[2 * i + 1]);
// const float cos_tmp = cos_emb_vec[i];
// const float sin_tmp = sin_emb_vec[i];
// src_vec[2 * i] = static_cast<T>(input_left * cos_tmp - input_right *
// sin_tmp); src_vec[2 * i + 1] = static_cast<T>(input_right * cos_tmp +
// input_left * sin_tmp);
// src_vec[2 * i] = static_cast<T>(fmul_func<EnforceFmulRN>(input_left,
// cos_tmp) - input_right * sin_tmp); src_vec[2 * i + 1] =
// static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
// fmul_func<EnforceFmulRN>(input_left, sin_tmp));
if (hi < q_num_head) { // qk rope
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
src_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
src_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else if (hi < q_num_head + kv_num_head) {
int k_hi = hi - q_num_head;
const int scale_idx = k_hi * last_dim + h_bias;
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
src_vec[2 * i] =
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
static_cast<T>((fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp)) *
float(cache_k_scales[scale_idx + 2 * i]));
src_vec[2 * i + 1] =
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
static_cast<T>((fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp)) *
float(cache_k_scales[scale_idx + 2 * i + 1]));
} else {
int v_hi = hi - q_num_head - kv_num_head;
@@ -840,7 +863,7 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void IntGQANeoxVariableLengthRotaryKernel(
const int *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
@@ -926,9 +949,11 @@ __global__ void IntGQANeoxVariableLengthRotaryKernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
@@ -942,7 +967,7 @@ __global__ void IntGQANeoxVariableLengthRotaryKernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void GQANeoxVariableLengthRotaryKernel(const T *qkv,
const float *cos_emb,
const float *sin_emb,
@@ -1005,9 +1030,11 @@ __global__ void GQANeoxVariableLengthRotaryKernel(const T *qkv,
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
@@ -1017,7 +1044,7 @@ __global__ void GQANeoxVariableLengthRotaryKernel(const T *qkv,
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
const T *qkv,
const float *cos_emb,
@@ -1082,9 +1109,11 @@ __global__ void GQANeoxVariableLengthPartialRotaryKernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
@@ -1094,7 +1123,7 @@ __global__ void GQANeoxVariableLengthPartialRotaryKernel(
#endif
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void cache_kernel(
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
// head_size]
@@ -2199,7 +2228,7 @@ __global__ void append_write_cache_kv_c4_qkv(
#endif
}
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
void rotary_qk_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]
const QKV_TYPE *qkv_input, // qkv
@@ -2233,94 +2262,98 @@ void rotary_qk_variable(
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
if (qkv_out_scales) {
launchWithPdlWhenEnabled(IntVariableLengthRotaryKernel<T, PackSize>,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const int *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
qkv_bias,
qkv_out,
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
launchWithPdlWhenEnabled(
IntVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const int *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
qkv_bias,
qkv_out,
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
} else {
launchWithPdlWhenEnabled(VariableLengthRotaryKernel<T, PackSize>,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
launchWithPdlWhenEnabled(
VariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
}
} else {
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head;
if (qkv_out_scales) {
launchWithPdlWhenEnabled(IntNeoxVariableLengthRotaryKernel<T, PackSize>,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const int *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
qkv_bias,
qkv_out,
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
launchWithPdlWhenEnabled(
IntNeoxVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const int *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
qkv_bias,
qkv_out,
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
} else {
launchWithPdlWhenEnabled(NeoxVariableLengthRotaryKernel<T, PackSize>,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
launchWithPdlWhenEnabled(
NeoxVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
}
}
}
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
void gqa_rotary_qk_norm_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]
const QKV_TYPE *qkv_input, // qkv
@@ -2362,31 +2395,32 @@ void gqa_rotary_qk_norm_variable(
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
launchWithPdlWhenEnabled(GQAVariableLengthRotaryQKNormKernel<T, PackSize>,
grid_size,
Block_Size,
0,
stream,
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
launchWithPdlWhenEnabled(
GQAVariableLengthRotaryQKNormKernel<T, PackSize, EnforceFmulRN>,
grid_size,
Block_Size,
0,
stream,
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
void gqa_rotary_qk_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]
const QKV_TYPE *qkv_input, // qkv
@@ -2425,29 +2459,31 @@ void gqa_rotary_qk_variable(
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
if (qkv_out_scales) {
launchWithPdlWhenEnabled(IntGQAVariableLengthRotaryKernel<T, PackSize>,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const int *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
qkv_bias,
qkv_out,
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d);
launchWithPdlWhenEnabled(
IntGQAVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const int *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
qkv_bias,
qkv_out,
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d);
} else {
auto *kernelFn = GQAVariableLengthRotaryKernel<T, PackSize>;
auto *kernelFn =
GQAVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>;
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
@@ -2473,7 +2509,7 @@ void gqa_rotary_qk_variable(
const float *sin_emb = rotary_emb + input_output_len * dim_head;
if (qkv_out_scales) {
launchWithPdlWhenEnabled(
IntGQANeoxVariableLengthRotaryKernel<T, PackSize>,
IntGQANeoxVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>,
grid_size,
blocksize,
0,
@@ -2507,7 +2543,10 @@ void gqa_rotary_qk_variable(
}
const int pack_num_new = elem_nums / PackSize;
GetNumBlocks<128>(pack_num_new, &grid_size);
auto *kernelFn = GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>;
auto *kernelFn =
GQANeoxVariableLengthPartialRotaryKernel<T,
PackSize,
EnforceFmulRN>;
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
@@ -2531,7 +2570,8 @@ void gqa_rotary_qk_variable(
rotary_dim,
rope_3d);
} else {
auto *kernelFn = GQANeoxVariableLengthRotaryKernel<T, PackSize>;
auto *kernelFn =
GQANeoxVariableLengthRotaryKernel<T, PackSize, EnforceFmulRN>;
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
@@ -2558,7 +2598,7 @@ void gqa_rotary_qk_variable(
}
}
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
void gqa_rotary_qk_quant_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]
const QKV_TYPE *qkv_input, // qkv
@@ -2595,7 +2635,7 @@ void gqa_rotary_qk_quant_variable(
if (!use_neox_style) {
if (qkv_out_scales) {
launchWithPdlWhenEnabled(
IntGQAVariableLengthRotaryQuantKVKernel<T, PackSize>,
IntGQAVariableLengthRotaryQuantKVKernel<T, PackSize, EnforceFmulRN>,
grid_size,
blocksize,
0,
@@ -2620,7 +2660,7 @@ void gqa_rotary_qk_quant_variable(
rope_3d);
} else {
launchWithPdlWhenEnabled(
GQAVariableLengthRotaryQuantKVKernel<T, PackSize>,
GQAVariableLengthRotaryQuantKVKernel<T, PackSize, EnforceFmulRN>,
grid_size,
blocksize,
0,
@@ -16,7 +16,7 @@
#include "encoder_write_cache_with_rope_impl.cuh"
#include "remote_cache_kv_ipc.h"
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
void EncoderWriteCacheWithRopeKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor&
@@ -77,7 +77,7 @@ void EncoderWriteCacheWithRopeKernel(
if (q_norm_weight && k_norm_weight) {
if (num_heads != kv_num_heads && !is_scale_channel_wise &&
!use_neox_style) {
gqa_rotary_qk_norm_variable(
gqa_rotary_qk_norm_variable<T, QKV_TYPE, EnforceFmulRN>(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
@@ -106,7 +106,7 @@ void EncoderWriteCacheWithRopeKernel(
}
} else {
if (num_heads == kv_num_heads) {
rotary_qk_variable(
rotary_qk_variable<T, QKV_TYPE, EnforceFmulRN>(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
@@ -126,7 +126,7 @@ void EncoderWriteCacheWithRopeKernel(
rope_3d);
} else {
if (!is_scale_channel_wise) {
gqa_rotary_qk_variable(
gqa_rotary_qk_variable<T, QKV_TYPE, EnforceFmulRN>(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
@@ -147,7 +147,7 @@ void EncoderWriteCacheWithRopeKernel(
use_neox_style,
rope_3d);
} else {
gqa_rotary_qk_quant_variable(
gqa_rotary_qk_quant_variable<T, QKV_TYPE, EnforceFmulRN>(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
@@ -20,7 +20,7 @@
#include "qwen3_rope.h"
#include "remote_cache_kv_ipc.h"
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void GQAVariableLengthRotarySplitKernel(
const T *qkv,
const float *cos_emb,
@@ -117,8 +117,10 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
@@ -157,9 +159,11 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
src_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
src_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
}
}
@@ -168,7 +172,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
}
}
template <typename T>
template <typename T, bool EnforceFmulRN = false>
void gqa_rotary_qk_split_variable(
T *qkv_out, // [token_num, 3, num_head, head_dim]
T *q,
@@ -205,35 +209,36 @@ void gqa_rotary_qk_split_variable(
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * head_dim / 2;
launchWithPdlWhenEnabled(GQAVariableLengthRotarySplitKernel<T, PackSize>,
grid_size,
block_size,
0,
stream,
qkv_input,
cos_emb,
sin_emb,
q_norm_weight,
k_norm_weight,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
cu_seqlens_k,
qkv_out,
q,
k,
v,
elem_nums,
num_heads,
kv_num_heads,
max_model_len,
head_dim,
rope_3d,
rms_norm_eps);
launchWithPdlWhenEnabled(
GQAVariableLengthRotarySplitKernel<T, PackSize, EnforceFmulRN>,
grid_size,
block_size,
0,
stream,
qkv_input,
cos_emb,
sin_emb,
q_norm_weight,
k_norm_weight,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
cu_seqlens_k,
qkv_out,
q,
k,
v,
elem_nums,
num_heads,
kv_num_heads,
max_model_len,
head_dim,
rope_3d,
rms_norm_eps);
}
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void GQAVariableLengthNeoxPartialRotarySplitKernel(
const T *qkv,
const float *cos_emb,
@@ -329,10 +334,12 @@ __global__ void GQAVariableLengthNeoxPartialRotarySplitKernel(
const float sin_tmp = sin_emb_vec[i];
if (h_bias < half_rotary_dim) {
src_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
} else {
src_vec[i] =
static_cast<T>(input_left * cos_tmp + input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) +
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
}
}
}
@@ -346,7 +353,7 @@ __global__ void GQAVariableLengthNeoxPartialRotarySplitKernel(
#endif
}
template <typename T>
template <typename T, bool EnforceFmulRN = false>
void gqa_neox_partial_rotary_qk_split_variable(
T *qkv_out, // [token_num, 3, num_head, head_dim]
T *q,
@@ -381,7 +388,7 @@ void gqa_neox_partial_rotary_qk_split_variable(
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2;
launchWithPdlWhenEnabled(
GQAVariableLengthNeoxPartialRotarySplitKernel<T, PackSize>,
GQAVariableLengthNeoxPartialRotarySplitKernel<T, PackSize, EnforceFmulRN>,
grid_size,
block_size,
0,
@@ -1376,35 +1383,60 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
paddle::Tensor v = GetEmptyTensor(
{kv_token_num, kv_num_heads, head_dim}, qkv.dtype(), qkv.place());
if (use_neox_rotary_style) {
if (rotary_dim == head_dim) {
gqa_rotary_qk_split_variable_qwen3<data_t>(
qkv_out.data<data_t>(),
q.data<data_t>(),
k.data<data_t>(),
v.data<data_t>(),
qkv.data<data_t>(),
rotary_embs.data<float>(),
batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
token_num,
num_heads,
kv_num_heads,
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
head_dim,
rope_3d,
stream);
bool enforce_fmul_rn = getEnvEnableRL();
DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, {
if (use_neox_rotary_style) {
if (rotary_dim == head_dim) {
gqa_rotary_qk_split_variable_qwen3<data_t, EnforceFmulRN>(
qkv_out.data<data_t>(),
q.data<data_t>(),
k.data<data_t>(),
v.data<data_t>(),
qkv.data<data_t>(),
rotary_embs.data<float>(),
batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
token_num,
num_heads,
kv_num_heads,
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
head_dim,
rope_3d,
stream);
} else {
gqa_neox_partial_rotary_qk_split_variable<data_t, EnforceFmulRN>(
qkv_out.data<data_t>(),
q.data<data_t>(),
k.data<data_t>(),
v.data<data_t>(),
qkv.data<data_t>(),
rotary_embs.data<float>(),
batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
head_dim,
rotary_dim,
stream);
}
} else {
gqa_neox_partial_rotary_qk_split_variable<data_t>(
gqa_rotary_qk_split_variable<data_t, EnforceFmulRN>(
qkv_out.data<data_t>(),
q.data<data_t>(),
k.data<data_t>(),
v.data<data_t>(),
qkv.data<data_t>(),
rotary_embs.data<float>(),
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
@@ -1414,35 +1446,13 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
num_heads,
kv_num_heads,
max_seq_len,
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
head_dim,
rotary_dim,
rope_3d,
rms_norm_eps,
stream);
}
} else {
gqa_rotary_qk_split_variable<data_t>(
qkv_out.data<data_t>(),
q.data<data_t>(),
k.data<data_t>(),
v.data<data_t>(),
qkv.data<data_t>(),
rotary_embs.data<float>(),
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
head_dim,
rope_3d,
rms_norm_eps,
stream);
}
})
if (token_num < kv_token_num) {
AppendCacheKV<data_t, 128, 64>(key_cache,
+7 -5
View File
@@ -5,7 +5,7 @@
#include "paddle/phi/core/memory/memcpy.h"
#include "remote_cache_kv_ipc.h"
template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, bool EnforceFmulRN = false>
__global__ void GQAVariableLengthRotarySplitKernel_Qwen3(
const T *qkv,
const float *cos_emb,
@@ -99,9 +99,11 @@ __global__ void GQAVariableLengthRotarySplitKernel_Qwen3(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
src_vec0[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
src_vec1[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
}
Store<T, VecSize>(src_vec0, &qkv_out[read_idx]);
@@ -111,7 +113,7 @@ __global__ void GQAVariableLengthRotarySplitKernel_Qwen3(
}
}
template <typename T>
template <typename T, bool EnforceFmulRN = false>
void gqa_rotary_qk_split_variable_qwen3(T *qkv_out,
T *q,
T *k,
@@ -145,7 +147,7 @@ void gqa_rotary_qk_split_variable_qwen3(T *qkv_out,
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + max_model_len * head_dim;
launchWithPdlWhenEnabled(
GQAVariableLengthRotarySplitKernel_Qwen3<T, PackSize>,
GQAVariableLengthRotarySplitKernel_Qwen3<T, PackSize, EnforceFmulRN>,
grid_size,
block_size,
0,
@@ -18,7 +18,10 @@
#include "mma_tensor_op.cuh"
#include "utils.cuh"
template <typename T, int VecSize = 1, typename InT = T>
template <typename T,
int VecSize = 1,
typename InT = T,
bool EnforceFmulRN = false>
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
// head_size]
@@ -129,8 +132,10 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
if (hi < num_heads + gqa_group_size) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
@@ -331,7 +336,10 @@ __global__ void append_clear_cache_int4_block(
}
}
template <typename T, int VecSize = 1, typename InT = T>
template <typename T,
int VecSize = 1,
typename InT = T,
bool EnforceFmulRN = false>
__global__ void append_speculate_cache_rope_kernel(
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
// head_size]
@@ -434,9 +442,11 @@ __global__ void append_speculate_cache_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
bias_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
bias_vec[2 * i] = static_cast<T>(input_left);
bias_vec[2 * i + 1] = static_cast<T>(input_right);
@@ -462,7 +472,10 @@ __global__ void append_speculate_cache_rope_kernel(
}
}
template <typename T, int VecSize = 1, typename InT = T>
template <typename T,
int VecSize = 1,
typename InT = T,
bool EnforceFmulRN = false>
__global__ void append_speculate_cache_neox_rope_kernel(
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
// head_size]
@@ -568,9 +581,11 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
@@ -601,7 +616,10 @@ __global__ void append_speculate_cache_neox_rope_kernel(
}
}
template <typename T, int VecSize = 1, typename InT = T>
template <typename T,
int VecSize = 1,
typename InT = T,
bool EnforceFmulRN = false>
__global__ void append_speculate_cache_neox_partial_rope_kernel(
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
// head_size]
@@ -724,9 +742,11 @@ __global__ void append_speculate_cache_neox_partial_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
@@ -766,7 +786,8 @@ template <typename T,
int RoundType = 0,
int HeadDim = 128,
bool IsFP8 = false,
bool IsDynamic = true>
bool IsDynamic = true,
bool EnforceFmulRN = false>
__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
// gqa_group_size, head_size]
@@ -864,8 +885,10 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
bias_vec[2 * i] = static_cast<T>(tmp1);
bias_vec[2 * i + 1] = static_cast<T>(tmp2);
@@ -929,8 +952,10 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
if (head_idx < num_heads + gqa_group_size) {
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
bias_vec1[0] = static_cast<T>(tmp1);
bias_vec1[1] = static_cast<T>(tmp2);
@@ -944,8 +969,10 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
if (head_idx < num_heads + gqa_group_size) {
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
float tmp1 = fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp);
float tmp2 = fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp);
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
bias_vec2[0] = static_cast<T>(tmp1);
bias_vec2[1] = static_cast<T>(tmp2);
@@ -1045,7 +1072,8 @@ template <typename T,
int RoundType = 0,
int HeadDim = 128,
typename InT = int,
bool IsFP8 = false>
bool IsFP8 = false,
bool EnforceFmulRN = false>
__global__ void append_speculate_cache_int8_rope_kernel(
const InT* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
// gqa_group_size, head_size]
@@ -1149,9 +1177,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
bias_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(bias_vec, &qkv_out_now[bias_idx]);
}
@@ -1215,9 +1245,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
bias_vec1[0] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec1[1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
bias_vec1[0] = static_cast<T>(input_left);
bias_vec1[1] = static_cast<T>(input_right);
@@ -1238,9 +1270,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
bias_vec2[0] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec2[1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
bias_vec2[0] = static_cast<T>(input_left);
bias_vec2[1] = static_cast<T>(input_right);
@@ -1285,7 +1319,8 @@ template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
typename InT = int>
typename InT = int,
bool EnforceFmulRN = false>
__global__ void append_speculate_cache_int8_neox_rope_kernel(
const InT* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
// gqa_group_size, head_size]
@@ -1396,9 +1431,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(left_bias_vec, &qkv_out_now[bias_idx_left]);
Store<T, VecSize>(right_bias_vec, &qkv_out_now[bias_idx_right]);
@@ -1485,9 +1522,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
float cos_tmp = cos_emb_vec1[i];
float sin_tmp = sin_emb_vec1[i];
left_bias_vec1[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec1[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
input_left = static_cast<float>(left_src_vec2[i]);
input_right = static_cast<float>(right_src_vec2[i]);
@@ -1500,9 +1539,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
cos_tmp = cos_emb_vec2[i];
sin_tmp = sin_emb_vec2[i];
left_bias_vec2[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec2[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
float quant_value1 = static_cast<float>(scale * left_bias_vec1[i]);
float quant_value2 = static_cast<float>(scale * left_bias_vec2[i]);
@@ -1669,7 +1710,8 @@ template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
typename InT = int>
typename InT = int,
bool EnforceFmulRN = false>
__global__ void append_speculate_cache_int4_rope_kernel(
const InT* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size,
// head_size]
@@ -1774,9 +1816,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
bias_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(bias_vec, &qkv_out_now[bias_idx]);
}
@@ -1877,9 +1921,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
bias_vec1[0] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec1[1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
bias_vec1[0] = static_cast<T>(input_left);
bias_vec1[1] = static_cast<T>(input_right);
@@ -1895,9 +1941,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
bias_vec2[0] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
bias_vec2[1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
} else {
bias_vec2[0] = static_cast<T>(input_left);
bias_vec2[1] = static_cast<T>(input_right);
@@ -2017,7 +2065,8 @@ template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
typename InT = int>
typename InT = int,
bool EnforceFmulRN = false>
__global__ void append_speculate_cache_int4_neox_rope_kernel(
const InT* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size,
// head_size]
@@ -2133,9 +2182,11 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
}
Store<T, VecSize>(left_bias_vec, &qkv_out_now[bias_idx_left]);
Store<T, VecSize>(right_bias_vec, &qkv_out_now[bias_idx_right]);
@@ -2234,18 +2285,22 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
left_bias_vec1[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec1[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
input_left = static_cast<float>(left_src_vec2[i]);
input_right = static_cast<float>(right_src_vec2[i]);
cos_tmp = cos_emb_vec2[i];
sin_tmp = sin_emb_vec2[i];
left_bias_vec2[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_left, cos_tmp) -
fmul_func<EnforceFmulRN>(input_right, sin_tmp));
right_bias_vec2[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
static_cast<T>(fmul_func<EnforceFmulRN>(input_right, cos_tmp) +
fmul_func<EnforceFmulRN>(input_left, sin_tmp));
// quant + write k
}
LoadKVResT left_cache_vec, right_cache_vec;
@@ -15,7 +15,7 @@
#include "speculate_write_cache_with_rope_kernel.h"
#include "utils.cuh"
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
@@ -58,7 +58,10 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
PD_THROW("append_speculate_cache_rope_qk_norm not support neox rope yet");
} else {
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
append_speculate_cache_T_rope_qk_norm_kernel<T,
PackSize,
QKV_TYPE,
EnforceFmulRN>
<<<grid_size, block_dim, 0, stream>>>(qkv,
key_cache,
value_cache,
@@ -88,7 +91,7 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
}
// rope + write
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
void append_speculate_cache_rope(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
@@ -127,7 +130,10 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
GetNumBlocks(pack_num, &grid_size);
if (use_neox_style) {
if (rotary_dim < dim_head) {
append_speculate_cache_neox_partial_rope_kernel<T, PackSize>
append_speculate_cache_neox_partial_rope_kernel<T,
PackSize,
QKV_TYPE,
EnforceFmulRN>
<<<grid_size, threads_per_block, 0, stream>>>(
qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size]
key_cache,
@@ -153,7 +159,10 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
kv_num_heads,
rope_3d);
} else {
append_speculate_cache_neox_rope_kernel<T, PackSize>
append_speculate_cache_neox_rope_kernel<T,
PackSize,
QKV_TYPE,
EnforceFmulRN>
<<<grid_size, threads_per_block, 0, stream>>>(
qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size]
key_cache,
@@ -179,7 +188,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
rope_3d);
}
} else {
append_speculate_cache_rope_kernel<T, PackSize>
append_speculate_cache_rope_kernel<T, PackSize, QKV_TYPE, EnforceFmulRN>
<<<grid_size, threads_per_block, 0, stream>>>(
qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size]
key_cache,
@@ -206,7 +215,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
}
}
template <typename T, bool IsDynamic = true>
template <typename T, bool IsDynamic = true, bool EnforceFmulRN = false>
void append_speculate_cache_fp8_rope(const T* qkv,
uint8_t* key_cache,
uint8_t* value_cache,
@@ -238,7 +247,7 @@ void append_speculate_cache_fp8_rope(const T* qkv,
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
dim3 grids(token_num, all_warps / num_warps);
append_clear_cache_int8_block<4>
append_clear_cache_int8_block<4, 128>
<<<grids, num_warps * 32, 0, stream>>>(key_cache,
value_cache,
seq_lens,
@@ -256,7 +265,8 @@ void append_speculate_cache_fp8_rope(const T* qkv,
0,
128,
true,
IsDynamic>
IsDynamic,
EnforceFmulRN>
<<<grids, num_warps * 32, 0, stream>>>(qkv,
key_cache,
value_cache,
@@ -283,7 +293,10 @@ void append_speculate_cache_fp8_rope(const T* qkv,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE, bool IsFP8 = false>
template <typename T,
typename QKV_TYPE,
bool IsFP8 = false,
bool EnforceFmulRN = false>
void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
uint8_t* key_cache,
uint8_t* value_cache,
@@ -315,7 +328,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
dim3 grids(token_num, all_warps / num_warps);
append_clear_cache_int8_block<4>
append_clear_cache_int8_block<4, 128>
<<<grids, num_warps * 32, 0, stream>>>(key_cache,
value_cache,
seq_lens,
@@ -329,7 +342,12 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
kv_num_heads);
if (use_neox_style) {
append_speculate_cache_int8_neox_rope_kernel<T, 4>
append_speculate_cache_int8_neox_rope_kernel<T,
4,
0,
128,
QKV_TYPE,
EnforceFmulRN>
<<<grids, num_warps * 32, 0, stream>>>(qkv,
key_cache,
value_cache,
@@ -354,7 +372,13 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
kv_num_heads,
rope_3d);
} else {
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
append_speculate_cache_int8_rope_kernel<T,
4,
0,
128,
QKV_TYPE,
IsFP8,
EnforceFmulRN>
<<<grids, num_warps * 32, 0, stream>>>(qkv,
key_cache,
value_cache,
@@ -381,7 +405,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
}
}
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN = false>
void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
uint8_t* key_cache,
uint8_t* value_cache,
@@ -415,7 +439,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
dim3 grids(token_num, all_warps / num_warps);
append_clear_cache_int4_block<4>
append_clear_cache_int4_block<4, 128>
<<<grids, num_warps * 32, 0, stream>>>(key_cache,
value_cache,
seq_lens,
@@ -429,7 +453,12 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
kv_num_heads);
if (use_neox_style) {
append_speculate_cache_int4_neox_rope_kernel<T, 4>
append_speculate_cache_int4_neox_rope_kernel<T,
4,
0,
128,
QKV_TYPE,
EnforceFmulRN>
<<<grids, num_warps * 32, 0, stream>>>(qkv,
key_cache,
value_cache,
@@ -456,7 +485,12 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
kv_num_heads,
rope_3d);
} else {
append_speculate_cache_int4_rope_kernel<T, 4>
append_speculate_cache_int4_rope_kernel<T,
4,
0,
128,
QKV_TYPE,
EnforceFmulRN>
<<<grids, num_warps * 32, 0, stream>>>(qkv,
key_cache,
value_cache,
@@ -484,7 +518,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
rope_3d);
}
}
template <typename T, typename QKV_TYPE>
template <typename T, typename QKV_TYPE, bool EnforceFmulRN>
void SpeculateWriteCacheWithRoPEKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& qkv,
@@ -549,7 +583,7 @@ void SpeculateWriteCacheWithRoPEKernel(
}
if (q_norm_weight && k_norm_weight) {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope_qk_norm(
append_speculate_cache_rope_qk_norm<DataType_, QKV_TYPE, EnforceFmulRN>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
@@ -580,7 +614,7 @@ void SpeculateWriteCacheWithRoPEKernel(
rms_norm_eps,
rope_3d);
} else if (cache_quant_type_str == "block_wise_fp8") {
append_speculate_cache_fp8_rope<DataType_, true>(
append_speculate_cache_fp8_rope<DataType_, true, EnforceFmulRN>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
@@ -610,7 +644,7 @@ void SpeculateWriteCacheWithRoPEKernel(
rope_3d,
rms_norm_eps);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_fp8_rope<DataType_, false>(
append_speculate_cache_fp8_rope<DataType_, false, EnforceFmulRN>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
@@ -648,7 +682,7 @@ void SpeculateWriteCacheWithRoPEKernel(
} else {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope(
append_speculate_cache_rope<DataType_, QKV_TYPE, EnforceFmulRN>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
@@ -677,7 +711,10 @@ void SpeculateWriteCacheWithRoPEKernel(
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
append_speculate_cache_int8_rope<DataType_,
QKV_TYPE,
false,
EnforceFmulRN>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
@@ -711,7 +748,10 @@ void SpeculateWriteCacheWithRoPEKernel(
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
append_speculate_cache_int8_rope<DataType_,
QKV_TYPE,
true,
EnforceFmulRN>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
@@ -745,7 +785,7 @@ void SpeculateWriteCacheWithRoPEKernel(
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "block_wise_fp8") {
append_speculate_cache_fp8_rope(
append_speculate_cache_fp8_rope<DataType_, true, EnforceFmulRN>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
@@ -775,7 +815,7 @@ void SpeculateWriteCacheWithRoPEKernel(
rope_3d,
rms_norm_eps);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
append_speculate_cache_int4_rope<DataType_, QKV_TYPE, EnforceFmulRN>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
@@ -15,7 +15,7 @@
#include "speculate_write_cache_with_rope_impl.cuh"
template <typename T, typename QKV_TYPE = int>
template <typename T, typename QKV_TYPE = int, bool EnforceFmulRN = false>
void SpeculateWriteCacheWithRoPEKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor&
+8
View File
@@ -157,6 +157,14 @@ bool getEnvEnablePDL() {
return enablePDL;
}
bool getEnvEnableRL() {
static std::once_flag flag;
static bool enableRL = false;
std::call_once(flag, [&]() { enableRL = getBoolEnv("FD_ENABLE_RL"); });
return enableRL;
}
bool getEnvDeterministicMode() {
static std::once_flag flag;
static bool value = false;
+19
View File
@@ -189,6 +189,17 @@ inline int GetGPUComputeCapability(int id) {
}
#endif
#ifndef DISPATCH_BOOL_DTYPE
#define DISPATCH_BOOL_DTYPE(runtime_flag, STATIC_FLAG, ...) \
if (runtime_flag) { \
constexpr bool STATIC_FLAG = true; \
__VA_ARGS__ \
} else { \
constexpr bool STATIC_FLAG = false; \
__VA_ARGS__ \
}
#endif
inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
@@ -731,7 +742,15 @@ inline bool getBoolEnv(char const *name) {
return env && env[0] == '1' && env[1] == '\0';
}
template <const bool EnforceRN = false>
__device__ __forceinline__ float fmul_func(float a, float b) {
// Use __fmul_rn for IEEE-754 compliant rounding when EnforceRN is enabled
float res = EnforceRN ? __fmul_rn(a, b) : a * b;
return res;
}
bool getEnvEnablePDL();
bool getEnvEnableRL();
bool getEnvDeterministicMode();
bool getEnvDeterministicDebug();