From c92e277cf1b8613ddc2438a09da58cfaa6ee822e Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Tue, 24 Mar 2026 21:19:53 +0800 Subject: [PATCH] [RL] RoPE without fmad opt (#6901) * env FD_ENABLE_RL=1 do fmul_rn(a*b) in rope --- custom_ops/gpu_ops/append_attention.cu | 287 ++++++------- .../decoder_write_cache_with_rope_impl.cuh | 332 ++++++++++----- .../decoder_write_cache_with_rope_kernel.cu | 313 ++++++++------ .../decoder_write_cache_with_rope_kernel.h | 2 +- .../encoder_write_cache_with_rope_impl.cuh | 386 ++++++++++-------- .../encoder_write_cache_with_rope_kernel.h | 10 +- .../append_attn/gqa_rope_write_cache.cu | 180 ++++---- custom_ops/gpu_ops/append_attn/qwen3_rope.h | 12 +- .../speculate_write_cache_with_rope_impl.cuh | 149 ++++--- .../speculate_write_cache_with_rope_kernel.cu | 92 +++-- .../speculate_write_cache_with_rope_kernel.h | 2 +- custom_ops/gpu_ops/helper.cu | 8 + custom_ops/gpu_ops/helper.h | 19 + 13 files changed, 1067 insertions(+), 725 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 4ce8d5bdc3..c1586945cc 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -108,6 +108,7 @@ void AppendAttentionKernel( static cudaEvent_t decoder_event; static cudaStream_t decoder_stream; static bool init_flag = false; + bool enforce_fmul_rn = getEnvEnableRL(); if (max_just_dec_len_this_time > 0 && max_enc_len_this_time > 0 && !init_flag) { cudaEventCreateWithFlags(&main_event, cudaEventDisableTiming); @@ -185,37 +186,41 @@ void AppendAttentionKernel( auto dispatch_EncoderWriteCacheWithRopeKernel = [&](auto temp_args) -> void { - EncoderWriteCacheWithRopeKernel( - meta_data, - qkv, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - kv_batch_ids, - kv_tile_ids_per_batch, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - kv_signal_data, - cache_quant_type_str, - kv_num_blocks_data, - max_input_length, - use_neox_rotary_style, - rope_3d, - main_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); + DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, { + EncoderWriteCacheWithRopeKernel( + meta_data, + qkv, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + kv_batch_ids, + kv_tile_ids_per_batch, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + kv_signal_data, + cache_quant_type_str, + kv_num_blocks_data, + max_input_length, + use_neox_rotary_style, + rope_3d, + main_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + }) }; if (qkv_out_scales) { @@ -284,117 +289,119 @@ void AppendAttentionKernel( } else { exec_stream = main_stream; } - if (speculate_decoder) { - if (qkv_out_scales) { - SpeculateWriteCacheWithRoPEKernel( - meta_data, - qkv, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); + DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, { + if (speculate_decoder) { + if (qkv_out_scales) { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } else { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv_out, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } } else { - SpeculateWriteCacheWithRoPEKernel( - meta_data, - qkv_out, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); + if (qkv_out_scales) { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } else { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv_out, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } } - } else { - if (qkv_out_scales) { - DecoderWriteCacheWithRoPEKernel( - meta_data, - qkv, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); - } else { - DecoderWriteCacheWithRoPEKernel( - meta_data, - qkv_out, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); - } - } + }) if (out_linear_in_scale > 0.0) { switch (fmha_out.dtype()) { diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index 5d1daed91e..7dd4612c52 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -22,7 +22,11 @@ // This function is very easy! // just make HeadDim data to be new HeadDim data! -template +template __device__ __forceinline__ void apply_rope(const T* input, const float* cos_emb, const float* sin_emb, @@ -54,15 +58,17 @@ __device__ __forceinline__ void apply_rope(const T* input, const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; out_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); out_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(out_vec, &output[head_bias]); } } -template +template __global__ void append_decode_cache_T_rope_qk_norm_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -154,8 +160,10 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( if (hi < num_heads + kv_num_heads) { const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; tmp_vec[2 * i] = tmp1; tmp_vec[2 * i + 1] = tmp2; @@ -206,7 +214,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( #endif } -template +template __global__ void append_decode_cache_T_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -289,9 +297,11 @@ __global__ void append_decode_cache_T_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; out_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); out_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { out_vec[2 * i] = src_vec[2 * i]; out_vec[2 * i + 1] = src_vec[2 * i + 1]; @@ -319,7 +329,7 @@ __global__ void append_decode_cache_T_rope_kernel( #endif } -template +template __global__ void append_decode_cache_T_quant_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -417,9 +427,11 @@ __global__ void append_decode_cache_T_quant_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec[2 * i] = static_cast(input_left); bias_vec[2 * i + 1] = static_cast(input_right); @@ -447,7 +459,7 @@ __global__ void append_decode_cache_T_quant_rope_kernel( #endif } -template +template __global__ void append_decode_cache_T_neox_partial_rope_kernel( const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -551,9 +563,11 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { left_bias_vec[i] = static_cast(input_left); right_bias_vec[i] = static_cast(input_right); @@ -591,7 +605,7 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel( #endif } -template +template __global__ void append_decode_cache_T_neox_rope_kernel( const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -676,9 +690,11 @@ __global__ void append_decode_cache_T_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { left_bias_vec[i] = static_cast(input_left); right_bias_vec[i] = static_cast(input_right); @@ -710,7 +726,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel( #endif } -template +template __global__ void append_decode_cache_T_quant_neox_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -815,9 +831,11 @@ __global__ void append_decode_cache_T_quant_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { left_bias_vec[i] = static_cast(input_left); right_bias_vec[i] = static_cast(input_right); @@ -855,7 +873,8 @@ template + bool IsDynamic = true, + bool EnforceFmulRN = false> __global__ void append_decode_cache_T_int8_neox_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -941,8 +960,10 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; out_vec[i] = static_cast(tmp1); out_vec_right[i] = static_cast(tmp2); @@ -1032,9 +1053,11 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel( float sin_tmp = sin_emb_vec1[0]; float tmp1 = 0; if (head_bias < half_head_size) { - tmp1 = input_left * cos_tmp - input_right * sin_tmp; + tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); } else { - tmp1 = input_left * cos_tmp + input_right * sin_tmp; + tmp1 = fmul_func(input_left, cos_tmp) + + fmul_func(input_right, sin_tmp); } out_vec1[0] = static_cast(tmp1); input_left = static_cast(src_vec1[1]); @@ -1042,9 +1065,11 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel( cos_tmp = cos_emb_vec1[1]; sin_tmp = sin_emb_vec1[1]; if (head_bias < half_head_size) { - tmp1 = input_left * cos_tmp - input_right * sin_tmp; + tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); } else { - tmp1 = input_left * cos_tmp + input_right * sin_tmp; + tmp1 = fmul_func(input_left, cos_tmp) + + fmul_func(input_right, sin_tmp); } out_vec1[1] = static_cast(tmp1); } else { @@ -1060,9 +1085,11 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel( float sin_tmp = sin_emb_vec2[0]; float tmp1 = 0; if (head_bias < half_head_size) { - tmp1 = input_left * cos_tmp - input_right * sin_tmp; + tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); } else { - tmp1 = input_left * cos_tmp + input_right * sin_tmp; + tmp1 = fmul_func(input_left, cos_tmp) + + fmul_func(input_right, sin_tmp); } out_vec2[0] = static_cast(tmp1); input_left = static_cast(src_vec2[1]); @@ -1070,9 +1097,11 @@ __global__ void append_decode_cache_T_int8_neox_rope_kernel( cos_tmp = cos_emb_vec2[1]; sin_tmp = sin_emb_vec2[1]; if (head_bias < half_head_size) { - tmp1 = input_left * cos_tmp - input_right * sin_tmp; + tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); } else { - tmp1 = input_left * cos_tmp + input_right * sin_tmp; + tmp1 = fmul_func(input_left, cos_tmp) + + fmul_func(input_right, sin_tmp); } out_vec2[1] = static_cast(tmp1); } else { @@ -1164,7 +1193,8 @@ template + bool IsDynamic = true, + bool EnforceFmulRN = false> __global__ void append_decode_cache_int8_rope_qk_norm_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -1250,8 +1280,10 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; out_vec[2 * i] = static_cast(tmp1); out_vec[2 * i + 1] = static_cast(tmp2); @@ -1344,8 +1376,10 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel( if (head_idx < num_heads + kv_num_heads) { float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; out_vec1[0] = static_cast(tmp1); out_vec1[1] = static_cast(tmp2); @@ -1360,8 +1394,10 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel( if (head_idx < num_heads + kv_num_heads) { float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; out_vec2[0] = static_cast(tmp1); out_vec2[1] = static_cast(tmp2); @@ -1472,7 +1508,8 @@ template + bool IsFP8 = false, + bool EnforceFmulRN = false> __global__ void append_decode_cache_int8_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -1528,11 +1565,11 @@ __global__ void append_decode_cache_int8_rope_kernel( uint32_t emb_offset = write_seq_id * half_head_size; emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0; - apply_rope(qkv_now, - cos_emb + emb_offset, - sin_emb + emb_offset, - qkv_out_now, - lane_id); + apply_rope(qkv_now, + cos_emb + emb_offset, + sin_emb + emb_offset, + qkv_out_now, + lane_id); } else if (head_idx < num_heads + 2 * kv_num_heads) { // k @@ -1619,9 +1656,11 @@ __global__ void append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; out_vec1[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); out_vec1[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { out_vec1[0] = src_vec1[0]; out_vec1[1] = src_vec1[1]; @@ -1631,10 +1670,12 @@ __global__ void append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; out_vec1[0] = - static_cast((input_left * cos_tmp - input_right * sin_tmp) * + static_cast((fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)) * float(cache_k_scale_cur[0])); out_vec1[1] = - static_cast((input_right * cos_tmp + input_left * sin_tmp) * + static_cast((fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)) * float(cache_k_scale_cur[1])); } else { out_vec1[0] = static_cast(input_left * float(cache_v_scale_cur[0])); @@ -1649,9 +1690,11 @@ __global__ void append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; out_vec2[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); out_vec2[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { out_vec2[0] = src_vec2[0]; out_vec2[1] = src_vec2[1]; @@ -1661,10 +1704,12 @@ __global__ void append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; out_vec2[0] = - static_cast((input_left * cos_tmp - input_right * sin_tmp) * + static_cast((fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)) * float(cache_k_scale_cur[8])); out_vec2[1] = - static_cast((input_right * cos_tmp + input_left * sin_tmp) * + static_cast((fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)) * float(cache_k_scale_cur[9])); } else { out_vec2[0] = static_cast(input_left * float(cache_v_scale_cur[8])); @@ -1715,7 +1760,8 @@ template + bool IsFP8 = false, + bool EnforceFmulRN = false> __global__ void int_append_decode_cache_int8_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -1815,9 +1861,11 @@ __global__ void int_append_decode_cache_int8_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(bias_vec, &qkv_out_now[bias_idx]); } @@ -1922,9 +1970,11 @@ __global__ void int_append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; bias_vec1[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec1[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec1[0] = static_cast(input_left); bias_vec1[1] = static_cast(input_right); @@ -1934,10 +1984,12 @@ __global__ void int_append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; bias_vec1[0] = - static_cast((input_left * cos_tmp - input_right * sin_tmp) * + static_cast((fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)) * float(cache_k_scale_cur[0])); bias_vec1[1] = - static_cast((input_right * cos_tmp + input_left * sin_tmp) * + static_cast((fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)) * float(cache_k_scale_cur[1])); } else { bias_vec1[0] = static_cast(input_left * float(cache_v_scale_cur[0])); @@ -1959,9 +2011,11 @@ __global__ void int_append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; bias_vec2[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec2[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec2[0] = static_cast(input_left); bias_vec2[1] = static_cast(input_right); @@ -1971,10 +2025,12 @@ __global__ void int_append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; bias_vec2[0] = - static_cast((input_left * cos_tmp - input_right * sin_tmp) * + static_cast((fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)) * float(cache_k_scale_cur[8])); bias_vec2[1] = - static_cast((input_right * cos_tmp + input_left * sin_tmp) * + static_cast((fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)) * float(cache_k_scale_cur[9])); } else { bias_vec2[0] = static_cast(input_left * float(cache_v_scale_cur[8])); @@ -2033,7 +2089,11 @@ __global__ void int_append_decode_cache_int8_rope_kernel( #endif } -template +template __global__ void append_decode_cache_int8_neox_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -2120,9 +2180,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_bias_vec, &qkv_out_now[bias_idx_left]); Store(right_bias_vec, &qkv_out_now[bias_idx_right]); @@ -2206,18 +2268,22 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( float cos_tmp = cos_emb_vec1[i]; float sin_tmp = sin_emb_vec1[i]; left_bias_vec1[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); cos_tmp = cos_emb_vec2[i]; sin_tmp = sin_emb_vec2[i]; left_bias_vec2[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec2[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); float quant_value1 = static_cast(scale * left_bias_vec1[i]); float quant_value2 = static_cast(scale * left_bias_vec2[i]); @@ -2341,7 +2407,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( #endif } -template +template __global__ void int_append_decode_cache_int8_neox_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -2451,9 +2521,11 @@ __global__ void int_append_decode_cache_int8_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_bias_vec, &qkv_out_now[bias_idx_left]); Store(right_bias_vec, &qkv_out_now[bias_idx_right]); @@ -2564,9 +2636,11 @@ __global__ void int_append_decode_cache_int8_neox_rope_kernel( float cos_tmp = cos_emb_vec1[i]; float sin_tmp = sin_emb_vec1[i]; left_bias_vec1[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); @@ -2579,9 +2653,11 @@ __global__ void int_append_decode_cache_int8_neox_rope_kernel( cos_tmp = cos_emb_vec2[i]; sin_tmp = sin_emb_vec2[i]; left_bias_vec2[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec2[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); float quant_value1 = static_cast(scale * left_bias_vec1[i]); float quant_value2 = static_cast(scale * left_bias_vec2[i]); @@ -2743,7 +2819,11 @@ __global__ void int_append_decode_cache_int8_neox_rope_kernel( #endif } -template +template __global__ void append_decode_cache_int4_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -2803,11 +2883,11 @@ __global__ void append_decode_cache_int4_rope_kernel( uint32_t emb_offset = write_seq_id * half_head_size; emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0; - apply_rope(qkv_now, - cos_emb + emb_offset, - sin_emb + emb_offset, - qkv_out_now, - lane_id); + apply_rope(qkv_now, + cos_emb + emb_offset, + sin_emb + emb_offset, + qkv_out_now, + lane_id); } else if (head_idx < num_heads + 2 * kv_num_heads) { // k @@ -2890,9 +2970,11 @@ __global__ void append_decode_cache_int4_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; out_vec1[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); out_vec1[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { out_vec1[0] = src_vec1[0]; out_vec1[1] = src_vec1[1]; @@ -2904,9 +2986,11 @@ __global__ void append_decode_cache_int4_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; out_vec2[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); out_vec2[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { out_vec2[0] = src_vec2[0]; out_vec2[1] = src_vec2[1]; @@ -3022,7 +3106,11 @@ __global__ void append_decode_cache_int4_rope_kernel( #endif } -template +template __global__ void int_append_decode_cache_int4_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -3122,9 +3210,11 @@ __global__ void int_append_decode_cache_int4_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(bias_vec, &qkv_out_now[bias_idx]); } @@ -3223,9 +3313,11 @@ __global__ void int_append_decode_cache_int4_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; bias_vec1[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec1[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec1[0] = static_cast(input_left); bias_vec1[1] = static_cast(input_right); @@ -3243,9 +3335,11 @@ __global__ void int_append_decode_cache_int4_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; bias_vec2[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec2[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec2[0] = static_cast(input_left); bias_vec2[1] = static_cast(input_right); @@ -3360,7 +3454,11 @@ __global__ void int_append_decode_cache_int4_rope_kernel( #endif } -template +template __global__ void append_decode_cache_int4_neox_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -3449,9 +3547,11 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_out_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_out_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_out_vec, &qkv_out_now[bias_idx_left]); Store(right_out_vec, &qkv_out_now[bias_idx_right]); @@ -3549,18 +3649,22 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; left_out_vec1[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_out_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); cos_tmp = cos_emb_vec2[i]; sin_tmp = sin_emb_vec2[i]; left_out_vec2[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_out_vec2[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); // quant + write k } LoadKVResT left_cache_vec, right_cache_vec; @@ -3739,7 +3843,11 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( #endif } -template +template __global__ void int_append_decode_cache_int4_neox_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -3848,9 +3956,11 @@ __global__ void int_append_decode_cache_int4_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_bias_vec, &qkv_out_now[bias_idx_left]); Store(right_bias_vec, &qkv_out_now[bias_idx_right]); @@ -3977,18 +4087,22 @@ __global__ void int_append_decode_cache_int4_neox_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; left_bias_vec1[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); cos_tmp = cos_emb_vec2[i]; sin_tmp = sin_emb_vec2[i]; left_bias_vec2[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec2[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); // quant + write k } diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index 264299e0a6..c16af564fe 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -15,7 +15,7 @@ #include "decoder_write_cache_with_rope_kernel.h" #include "utils.cuh" -template +template void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv, T* key_cache, T* value_cache, @@ -53,7 +53,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv, GetNumBlocks<128>(pack_num, &grid_size); dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1); launchWithPdlWhenEnabled( - append_decode_cache_T_rope_qk_norm_kernel, + append_decode_cache_T_rope_qk_norm_kernel, grid_size, block_dim, 0, @@ -81,7 +81,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv, rms_norm_eps); } -template +template void append_decode_cache_rope(const QKV_TYPE* qkv, T* key_cache, T* value_cache, @@ -117,7 +117,9 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, if (use_neox_style) { if (qkv_out_scales) { launchWithPdlWhenEnabled( - append_decode_cache_T_quant_neox_rope_kernel, + append_decode_cache_T_quant_neox_rope_kernel, grid_size, blocksize, 0, @@ -145,7 +147,9 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, } else { if (rotary_dim < dim_head) { auto* kernelFn = - append_decode_cache_T_neox_partial_rope_kernel; + append_decode_cache_T_neox_partial_rope_kernel; launchWithPdlWhenEnabled(kernelFn, grid_size, blocksize, @@ -171,7 +175,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, kv_num_heads, rope_3d); } else { - auto* kernelFn = append_decode_cache_T_neox_rope_kernel; + auto* kernelFn = + append_decode_cache_T_neox_rope_kernel; launchWithPdlWhenEnabled(kernelFn, grid_size, blocksize, @@ -200,7 +205,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, } else { if (qkv_out_scales) { launchWithPdlWhenEnabled( - append_decode_cache_T_quant_rope_kernel, + append_decode_cache_T_quant_rope_kernel, grid_size, blocksize, 0, @@ -226,7 +231,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, kv_num_heads, rope_3d); } else { - auto* kernelFn = append_decode_cache_T_rope_kernel; + auto* kernelFn = + append_decode_cache_T_rope_kernel; launchWithPdlWhenEnabled(kernelFn, grid_size, blocksize, @@ -257,7 +263,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, template + bool IsFP8 = false, + bool EnforceFmulRN = false> void append_decode_cache_int8_rope(const QKV_TYPE* qkv, uint8_t* key_cache, uint8_t* value_cache, @@ -289,7 +296,11 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, if (use_neox_style) { if (qkv_out_scales) { launchWithPdlWhenEnabled( - int_append_decode_cache_int8_neox_rope_kernel, + int_append_decode_cache_int8_neox_rope_kernel, grids, num_warps * 32, 0, @@ -317,31 +328,36 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, kv_num_heads, rope_3d); } else { - launchWithPdlWhenEnabled(append_decode_cache_int8_neox_rope_kernel, - grids, - num_warps * 32, - 0, - stream, - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - cache_k_scale, - cache_v_scale, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled( + append_decode_cache_int8_neox_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d); } } else { if (qkv_out_scales) { @@ -351,7 +367,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, 0, 128, is_scale_channel_wise, - IsFP8>, + IsFP8, + EnforceFmulRN>, grids, num_warps * 32, 0, @@ -385,7 +402,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, 0, 128, is_scale_channel_wise, - IsFP8>, + IsFP8, + EnforceFmulRN>, grids, num_warps * 32, 0, @@ -414,7 +432,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, } } -template +template void append_decode_cache_int4_rope(const QKV_TYPE* qkv, uint8_t* key_cache, uint8_t* value_cache, @@ -448,7 +466,11 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, if (use_neox_style) { if (qkv_out_scales) { launchWithPdlWhenEnabled( - int_append_decode_cache_int4_neox_rope_kernel, + int_append_decode_cache_int4_neox_rope_kernel, grids, num_warps * 32, 0, @@ -478,97 +500,104 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, kv_num_heads, rope_3d); } else { - launchWithPdlWhenEnabled(append_decode_cache_int4_neox_rope_kernel, - grids, - num_warps * 32, - 0, - stream, - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 7.0f, - -8.0f, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled( + append_decode_cache_int4_neox_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads, + rope_3d); } } else { if (qkv_out_scales) { - launchWithPdlWhenEnabled(int_append_decode_cache_int4_rope_kernel, - grids, - num_warps * 32, - 0, - stream, - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 7.0f, - -8.0f, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled( + int_append_decode_cache_int4_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads, + rope_3d); } else { - launchWithPdlWhenEnabled(append_decode_cache_int4_rope_kernel, - grids, - num_warps * 32, - 0, - stream, - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 7.0f, - -8.0f, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled( + append_decode_cache_int4_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads, + rope_3d); } } } -template +template void DecoderWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, @@ -632,7 +661,7 @@ void DecoderWriteCacheWithRoPEKernel( if (q_norm_weight && k_norm_weight) { if (cache_quant_type_str == "none") { - append_decode_cache_rope_qk_norm( + append_decode_cache_rope_qk_norm( reinterpret_cast(qkv_ptr), reinterpret_cast(key_cache_out->data()), reinterpret_cast(value_cache_out->data()), @@ -672,7 +701,8 @@ void DecoderWriteCacheWithRoPEKernel( 128, false, true, - true>, + true, + EnforceFmulRN>, grids, num_warps * 32, 0, @@ -714,7 +744,8 @@ void DecoderWriteCacheWithRoPEKernel( 128, false, true, - false>, + false, + EnforceFmulRN>, grids, num_warps * 32, 0, @@ -751,7 +782,7 @@ void DecoderWriteCacheWithRoPEKernel( } } else { if (cache_quant_type_str == "none") { - append_decode_cache_rope( + append_decode_cache_rope( reinterpret_cast(qkv_ptr), reinterpret_cast(key_cache_out->data()), reinterpret_cast(value_cache_out->data()), @@ -784,7 +815,11 @@ void DecoderWriteCacheWithRoPEKernel( is_scale_channel_wise = true; } if (is_scale_channel_wise) { - append_decode_cache_int8_rope( + append_decode_cache_int8_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), @@ -816,7 +851,11 @@ void DecoderWriteCacheWithRoPEKernel( use_neox_rotary_style, rope_3d); } else { - append_decode_cache_int8_rope( + append_decode_cache_int8_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), @@ -849,7 +888,11 @@ void DecoderWriteCacheWithRoPEKernel( rope_3d); } } else if (cache_quant_type_str == "cache_fp8") { - append_decode_cache_int8_rope( + append_decode_cache_int8_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), @@ -892,7 +935,9 @@ void DecoderWriteCacheWithRoPEKernel( 0, 128, false, - true>, + true, + true, + EnforceFmulRN>, grids, num_warps * 32, 0, @@ -927,7 +972,9 @@ void DecoderWriteCacheWithRoPEKernel( 0, 128, false, - true>, + true, + true, + EnforceFmulRN>, grids, num_warps * 32, 0, @@ -959,7 +1006,7 @@ void DecoderWriteCacheWithRoPEKernel( rms_norm_eps); } } else if (cache_quant_type_str == "cache_int4_zp") { - append_decode_cache_int4_rope( + append_decode_cache_int4_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h index 5095ef1062..2acb4f8293 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h @@ -15,7 +15,7 @@ #include "decoder_write_cache_with_rope_impl.cuh" -template +template void DecoderWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh index 33580edffd..0cdea53732 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -18,7 +18,7 @@ #include "mma_tensor_op.cuh" #include "utils.cuh" -template +template __global__ void IntVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] @@ -97,9 +97,11 @@ __global__ void IntVariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec[2 * i] = static_cast(input_left); bias_vec[2 * i + 1] = static_cast(input_right); @@ -112,7 +114,7 @@ __global__ void IntVariableLengthRotaryKernel( #endif } -template +template __global__ void VariableLengthRotaryKernel( const T *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] @@ -171,9 +173,11 @@ __global__ void VariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; src_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); src_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(src_vec, &qkv_out[base_idx]); } @@ -182,7 +186,7 @@ __global__ void VariableLengthRotaryKernel( #endif } -template +template __global__ void IntNeoxVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] @@ -271,9 +275,11 @@ __global__ void IntNeoxVariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { left_bias_vec[i] = static_cast(input_left); right_bias_vec[i] = static_cast(input_right); @@ -287,7 +293,7 @@ __global__ void IntNeoxVariableLengthRotaryKernel( #endif } -template +template __global__ void NeoxVariableLengthRotaryKernel( const T *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] @@ -352,9 +358,11 @@ __global__ void NeoxVariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_vec, &qkv_out[base_idx_left]); Store(right_vec, &qkv_out[base_idx_right]); @@ -364,7 +372,7 @@ __global__ void NeoxVariableLengthRotaryKernel( #endif } -template +template __global__ void IntGQAVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] @@ -442,9 +450,11 @@ __global__ void IntGQAVariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec[2 * i] = static_cast(input_left); bias_vec[2 * i + 1] = static_cast(input_right); @@ -457,7 +467,7 @@ __global__ void IntGQAVariableLengthRotaryKernel( #endif } -template +template __global__ void GQAVariableLengthRotaryQKNormKernel( const T *qkv, const float *cos_emb, @@ -525,8 +535,10 @@ __global__ void GQAVariableLengthRotaryQKNormKernel( const float input_right = static_cast(src_vec[2 * i + 1]); const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); tmp_vec[2 * i] = tmp1; tmp_vec[2 * i + 1] = tmp2; thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; @@ -554,7 +566,7 @@ __global__ void GQAVariableLengthRotaryQKNormKernel( #endif } -template +template __global__ void GQAVariableLengthRotaryKernel(const T *qkv, const float *cos_emb, const float *sin_emb, @@ -614,9 +626,11 @@ __global__ void GQAVariableLengthRotaryKernel(const T *qkv, const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; src_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); src_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(src_vec, &qkv_out[base_idx]); } @@ -625,7 +639,7 @@ __global__ void GQAVariableLengthRotaryKernel(const T *qkv, #endif } -template +template __global__ void IntGQAVariableLengthRotaryQuantKVKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] @@ -703,19 +717,23 @@ __global__ void IntGQAVariableLengthRotaryQuantKVKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else if (hi < q_num_head + kv_num_head) { int k_hi = hi - q_num_head; const int scale_idx = k_hi * last_dim + h_bias; const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast((input_left * cos_tmp - input_right * sin_tmp) * + static_cast((fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)) * float(cache_k_scales[scale_idx + 2 * i])); bias_vec[2 * i + 1] = - static_cast((input_right * cos_tmp + input_left * sin_tmp) * + static_cast((fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)) * float(cache_k_scales[scale_idx + 2 * i + 1])); } else { int v_hi = hi - q_num_head - kv_num_head; @@ -733,7 +751,7 @@ __global__ void IntGQAVariableLengthRotaryQuantKVKernel( #endif } -template +template __global__ void GQAVariableLengthRotaryQuantKVKernel( const T *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] @@ -803,26 +821,31 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel( : static_cast(src_vec[2 * i + 1]); // const float cos_tmp = cos_emb_vec[i]; // const float sin_tmp = sin_emb_vec[i]; - // src_vec[2 * i] = static_cast(input_left * cos_tmp - input_right * - // sin_tmp); src_vec[2 * i + 1] = static_cast(input_right * cos_tmp + - // input_left * sin_tmp); + // src_vec[2 * i] = static_cast(fmul_func(input_left, + // cos_tmp) - input_right * sin_tmp); src_vec[2 * i + 1] = + // static_cast(fmul_func(input_right, cos_tmp) + + // fmul_func(input_left, sin_tmp)); if (hi < q_num_head) { // qk rope const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; src_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); src_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else if (hi < q_num_head + kv_num_head) { int k_hi = hi - q_num_head; const int scale_idx = k_hi * last_dim + h_bias; const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; src_vec[2 * i] = - static_cast((input_left * cos_tmp - input_right * sin_tmp) * + static_cast((fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)) * float(cache_k_scales[scale_idx + 2 * i])); src_vec[2 * i + 1] = - static_cast((input_right * cos_tmp + input_left * sin_tmp) * + static_cast((fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)) * float(cache_k_scales[scale_idx + 2 * i + 1])); } else { int v_hi = hi - q_num_head - kv_num_head; @@ -840,7 +863,7 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel( #endif } -template +template __global__ void IntGQANeoxVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] @@ -926,9 +949,11 @@ __global__ void IntGQANeoxVariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { left_bias_vec[i] = static_cast(input_left); right_bias_vec[i] = static_cast(input_right); @@ -942,7 +967,7 @@ __global__ void IntGQANeoxVariableLengthRotaryKernel( #endif } -template +template __global__ void GQANeoxVariableLengthRotaryKernel(const T *qkv, const float *cos_emb, const float *sin_emb, @@ -1005,9 +1030,11 @@ __global__ void GQANeoxVariableLengthRotaryKernel(const T *qkv, const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_vec, &qkv_out[base_idx_left]); Store(right_vec, &qkv_out[base_idx_right]); @@ -1017,7 +1044,7 @@ __global__ void GQANeoxVariableLengthRotaryKernel(const T *qkv, #endif } -template +template __global__ void GQANeoxVariableLengthPartialRotaryKernel( const T *qkv, const float *cos_emb, @@ -1082,9 +1109,11 @@ __global__ void GQANeoxVariableLengthPartialRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_vec, &qkv_out[base_idx_left]); Store(right_vec, &qkv_out[base_idx_right]); @@ -1094,7 +1123,7 @@ __global__ void GQANeoxVariableLengthPartialRotaryKernel( #endif } -template +template __global__ void cache_kernel( const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads, // head_size] @@ -2199,7 +2228,7 @@ __global__ void append_write_cache_kv_c4_qkv( #endif } -template +template void rotary_qk_variable( T *qkv_out, // [token_num, 3, num_head, dim_head] const QKV_TYPE *qkv_input, // qkv @@ -2233,94 +2262,98 @@ void rotary_qk_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; if (qkv_out_scales) { - launchWithPdlWhenEnabled(IntVariableLengthRotaryKernel, - grid_size, - blocksize, - 0, - stream, - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - head_num, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled( + IntVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head, + rope_3d); } else { - launchWithPdlWhenEnabled(VariableLengthRotaryKernel, - grid_size, - blocksize, - 0, - stream, - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out, - elem_nums, - head_num, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled( + VariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head, + rope_3d); } } else { const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head; if (qkv_out_scales) { - launchWithPdlWhenEnabled(IntNeoxVariableLengthRotaryKernel, - grid_size, - blocksize, - 0, - stream, - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - head_num, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled( + IntNeoxVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head, + rope_3d); } else { - launchWithPdlWhenEnabled(NeoxVariableLengthRotaryKernel, - grid_size, - blocksize, - 0, - stream, - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out, - elem_nums, - head_num, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled( + NeoxVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head, + rope_3d); } } } -template +template void gqa_rotary_qk_norm_variable( T *qkv_out, // [token_num, 3, num_head, dim_head] const QKV_TYPE *qkv_input, // qkv @@ -2362,31 +2395,32 @@ void gqa_rotary_qk_norm_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; - launchWithPdlWhenEnabled(GQAVariableLengthRotaryQKNormKernel, - grid_size, - Block_Size, - 0, - stream, - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d, - q_norm_weight, - k_norm_weight, - rms_norm_eps); + launchWithPdlWhenEnabled( + GQAVariableLengthRotaryQKNormKernel, + grid_size, + Block_Size, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d, + q_norm_weight, + k_norm_weight, + rms_norm_eps); } -template +template void gqa_rotary_qk_variable( T *qkv_out, // [token_num, 3, num_head, dim_head] const QKV_TYPE *qkv_input, // qkv @@ -2425,29 +2459,31 @@ void gqa_rotary_qk_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; if (qkv_out_scales) { - launchWithPdlWhenEnabled(IntGQAVariableLengthRotaryKernel, - grid_size, - blocksize, - 0, - stream, - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled( + IntGQAVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } else { - auto *kernelFn = GQAVariableLengthRotaryKernel; + auto *kernelFn = + GQAVariableLengthRotaryKernel; launchWithPdlWhenEnabled(kernelFn, grid_size, blocksize, @@ -2473,7 +2509,7 @@ void gqa_rotary_qk_variable( const float *sin_emb = rotary_emb + input_output_len * dim_head; if (qkv_out_scales) { launchWithPdlWhenEnabled( - IntGQANeoxVariableLengthRotaryKernel, + IntGQANeoxVariableLengthRotaryKernel, grid_size, blocksize, 0, @@ -2507,7 +2543,10 @@ void gqa_rotary_qk_variable( } const int pack_num_new = elem_nums / PackSize; GetNumBlocks<128>(pack_num_new, &grid_size); - auto *kernelFn = GQANeoxVariableLengthPartialRotaryKernel; + auto *kernelFn = + GQANeoxVariableLengthPartialRotaryKernel; launchWithPdlWhenEnabled(kernelFn, grid_size, blocksize, @@ -2531,7 +2570,8 @@ void gqa_rotary_qk_variable( rotary_dim, rope_3d); } else { - auto *kernelFn = GQANeoxVariableLengthRotaryKernel; + auto *kernelFn = + GQANeoxVariableLengthRotaryKernel; launchWithPdlWhenEnabled(kernelFn, grid_size, blocksize, @@ -2558,7 +2598,7 @@ void gqa_rotary_qk_variable( } } -template +template void gqa_rotary_qk_quant_variable( T *qkv_out, // [token_num, 3, num_head, dim_head] const QKV_TYPE *qkv_input, // qkv @@ -2595,7 +2635,7 @@ void gqa_rotary_qk_quant_variable( if (!use_neox_style) { if (qkv_out_scales) { launchWithPdlWhenEnabled( - IntGQAVariableLengthRotaryQuantKVKernel, + IntGQAVariableLengthRotaryQuantKVKernel, grid_size, blocksize, 0, @@ -2620,7 +2660,7 @@ void gqa_rotary_qk_quant_variable( rope_3d); } else { launchWithPdlWhenEnabled( - GQAVariableLengthRotaryQuantKVKernel, + GQAVariableLengthRotaryQuantKVKernel, grid_size, blocksize, 0, diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h index 2349a0e5e1..23969aa429 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h @@ -16,7 +16,7 @@ #include "encoder_write_cache_with_rope_impl.cuh" #include "remote_cache_kv_ipc.h" -template +template void EncoderWriteCacheWithRopeKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& @@ -77,7 +77,7 @@ void EncoderWriteCacheWithRopeKernel( if (q_norm_weight && k_norm_weight) { if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) { - gqa_rotary_qk_norm_variable( + gqa_rotary_qk_norm_variable( qkv_out->data(), qkv.data(), qkv_out_scales ? qkv_out_scales.get().data() : nullptr, @@ -106,7 +106,7 @@ void EncoderWriteCacheWithRopeKernel( } } else { if (num_heads == kv_num_heads) { - rotary_qk_variable( + rotary_qk_variable( qkv_out->data(), qkv.data(), qkv_out_scales ? qkv_out_scales.get().data() : nullptr, @@ -126,7 +126,7 @@ void EncoderWriteCacheWithRopeKernel( rope_3d); } else { if (!is_scale_channel_wise) { - gqa_rotary_qk_variable( + gqa_rotary_qk_variable( qkv_out->data(), qkv.data(), qkv_out_scales ? qkv_out_scales.get().data() : nullptr, @@ -147,7 +147,7 @@ void EncoderWriteCacheWithRopeKernel( use_neox_style, rope_3d); } else { - gqa_rotary_qk_quant_variable( + gqa_rotary_qk_quant_variable( qkv_out->data(), qkv.data(), qkv_out_scales ? qkv_out_scales.get().data() : nullptr, diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index cf4ee1e263..e39522159f 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -20,7 +20,7 @@ #include "qwen3_rope.h" #include "remote_cache_kv_ipc.h" -template +template __global__ void GQAVariableLengthRotarySplitKernel( const T *qkv, const float *cos_emb, @@ -117,8 +117,10 @@ __global__ void GQAVariableLengthRotarySplitKernel( const float input_right = static_cast(src_vec[2 * i + 1]); const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); tmp_vec[2 * i] = tmp1; tmp_vec[2 * i + 1] = tmp2; thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; @@ -157,9 +159,11 @@ __global__ void GQAVariableLengthRotarySplitKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; src_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); src_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } } } @@ -168,7 +172,7 @@ __global__ void GQAVariableLengthRotarySplitKernel( } } -template +template void gqa_rotary_qk_split_variable( T *qkv_out, // [token_num, 3, num_head, head_dim] T *q, @@ -205,35 +209,36 @@ void gqa_rotary_qk_split_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * head_dim / 2; - launchWithPdlWhenEnabled(GQAVariableLengthRotarySplitKernel, - grid_size, - block_size, - 0, - stream, - qkv_input, - cos_emb, - sin_emb, - q_norm_weight, - k_norm_weight, - batch_id_per_token, - cu_seqlens_q, - seq_lens_encoder, - seq_lens_decoder, - cu_seqlens_k, - qkv_out, - q, - k, - v, - elem_nums, - num_heads, - kv_num_heads, - max_model_len, - head_dim, - rope_3d, - rms_norm_eps); + launchWithPdlWhenEnabled( + GQAVariableLengthRotarySplitKernel, + grid_size, + block_size, + 0, + stream, + qkv_input, + cos_emb, + sin_emb, + q_norm_weight, + k_norm_weight, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + cu_seqlens_k, + qkv_out, + q, + k, + v, + elem_nums, + num_heads, + kv_num_heads, + max_model_len, + head_dim, + rope_3d, + rms_norm_eps); } -template +template __global__ void GQAVariableLengthNeoxPartialRotarySplitKernel( const T *qkv, const float *cos_emb, @@ -329,10 +334,12 @@ __global__ void GQAVariableLengthNeoxPartialRotarySplitKernel( const float sin_tmp = sin_emb_vec[i]; if (h_bias < half_rotary_dim) { src_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); } else { src_vec[i] = - static_cast(input_left * cos_tmp + input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) + + fmul_func(input_right, sin_tmp)); } } } @@ -346,7 +353,7 @@ __global__ void GQAVariableLengthNeoxPartialRotarySplitKernel( #endif } -template +template void gqa_neox_partial_rotary_qk_split_variable( T *qkv_out, // [token_num, 3, num_head, head_dim] T *q, @@ -381,7 +388,7 @@ void gqa_neox_partial_rotary_qk_split_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2; launchWithPdlWhenEnabled( - GQAVariableLengthNeoxPartialRotarySplitKernel, + GQAVariableLengthNeoxPartialRotarySplitKernel, grid_size, block_size, 0, @@ -1376,35 +1383,60 @@ std::vector GQARopeWriteCacheKernel( paddle::Tensor v = GetEmptyTensor( {kv_token_num, kv_num_heads, head_dim}, qkv.dtype(), qkv.place()); - if (use_neox_rotary_style) { - if (rotary_dim == head_dim) { - gqa_rotary_qk_split_variable_qwen3( - qkv_out.data(), - q.data(), - k.data(), - v.data(), - qkv.data(), - rotary_embs.data(), - batch_id_per_token.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - token_num, - num_heads, - kv_num_heads, - rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2], - head_dim, - rope_3d, - stream); + bool enforce_fmul_rn = getEnvEnableRL(); + DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, { + if (use_neox_rotary_style) { + if (rotary_dim == head_dim) { + gqa_rotary_qk_split_variable_qwen3( + qkv_out.data(), + q.data(), + k.data(), + v.data(), + qkv.data(), + rotary_embs.data(), + batch_id_per_token.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + token_num, + num_heads, + kv_num_heads, + rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2], + head_dim, + rope_3d, + stream); + } else { + gqa_neox_partial_rotary_qk_split_variable( + qkv_out.data(), + q.data(), + k.data(), + v.data(), + qkv.data(), + rotary_embs.data(), + batch_id_per_token.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + head_dim, + rotary_dim, + stream); + } } else { - gqa_neox_partial_rotary_qk_split_variable( + gqa_rotary_qk_split_variable( qkv_out.data(), q.data(), k.data(), v.data(), qkv.data(), rotary_embs.data(), + q_norm_weight ? q_norm_weight.get().data() : nullptr, + k_norm_weight ? k_norm_weight.get().data() : nullptr, batch_id_per_token.data(), seq_lens_encoder.data(), seq_lens_decoder.data(), @@ -1414,35 +1446,13 @@ std::vector GQARopeWriteCacheKernel( num_heads, kv_num_heads, max_seq_len, + rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2], head_dim, - rotary_dim, + rope_3d, + rms_norm_eps, stream); } - } else { - gqa_rotary_qk_split_variable( - qkv_out.data(), - q.data(), - k.data(), - v.data(), - qkv.data(), - rotary_embs.data(), - q_norm_weight ? q_norm_weight.get().data() : nullptr, - k_norm_weight ? k_norm_weight.get().data() : nullptr, - batch_id_per_token.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - token_num, - num_heads, - kv_num_heads, - max_seq_len, - rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2], - head_dim, - rope_3d, - rms_norm_eps, - stream); - } + }) if (token_num < kv_token_num) { AppendCacheKV(key_cache, diff --git a/custom_ops/gpu_ops/append_attn/qwen3_rope.h b/custom_ops/gpu_ops/append_attn/qwen3_rope.h index 6c6325c335..42017e4215 100644 --- a/custom_ops/gpu_ops/append_attn/qwen3_rope.h +++ b/custom_ops/gpu_ops/append_attn/qwen3_rope.h @@ -5,7 +5,7 @@ #include "paddle/phi/core/memory/memcpy.h" #include "remote_cache_kv_ipc.h" -template +template __global__ void GQAVariableLengthRotarySplitKernel_Qwen3( const T *qkv, const float *cos_emb, @@ -99,9 +99,11 @@ __global__ void GQAVariableLengthRotarySplitKernel_Qwen3( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; src_vec0[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); src_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } } Store(src_vec0, &qkv_out[read_idx]); @@ -111,7 +113,7 @@ __global__ void GQAVariableLengthRotarySplitKernel_Qwen3( } } -template +template void gqa_rotary_qk_split_variable_qwen3(T *qkv_out, T *q, T *k, @@ -145,7 +147,7 @@ void gqa_rotary_qk_split_variable_qwen3(T *qkv_out, const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + max_model_len * head_dim; launchWithPdlWhenEnabled( - GQAVariableLengthRotarySplitKernel_Qwen3, + GQAVariableLengthRotarySplitKernel_Qwen3, grid_size, block_size, 0, diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index c321107237..9e63cf4e35 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -18,7 +18,10 @@ #include "mma_tensor_op.cuh" #include "utils.cuh" -template +template __global__ void append_speculate_cache_T_rope_qk_norm_kernel( const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size, // head_size] @@ -129,8 +132,10 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( if (hi < num_heads + gqa_group_size) { const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; tmp_vec[2 * i] = tmp1; tmp_vec[2 * i + 1] = tmp2; @@ -331,7 +336,10 @@ __global__ void append_clear_cache_int4_block( } } -template +template __global__ void append_speculate_cache_rope_kernel( const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size, // head_size] @@ -434,9 +442,11 @@ __global__ void append_speculate_cache_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec[2 * i] = static_cast(input_left); bias_vec[2 * i + 1] = static_cast(input_right); @@ -462,7 +472,10 @@ __global__ void append_speculate_cache_rope_kernel( } } -template +template __global__ void append_speculate_cache_neox_rope_kernel( const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size, // head_size] @@ -568,9 +581,11 @@ __global__ void append_speculate_cache_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { left_bias_vec[i] = static_cast(input_left); right_bias_vec[i] = static_cast(input_right); @@ -601,7 +616,10 @@ __global__ void append_speculate_cache_neox_rope_kernel( } } -template +template __global__ void append_speculate_cache_neox_partial_rope_kernel( const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size, // head_size] @@ -724,9 +742,11 @@ __global__ void append_speculate_cache_neox_partial_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { left_bias_vec[i] = static_cast(input_left); right_bias_vec[i] = static_cast(input_right); @@ -766,7 +786,8 @@ template + bool IsDynamic = true, + bool EnforceFmulRN = false> __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 * // gqa_group_size, head_size] @@ -864,8 +885,10 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( float input_right = static_cast(src_vec[2 * i + 1]); const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; bias_vec[2 * i] = static_cast(tmp1); bias_vec[2 * i + 1] = static_cast(tmp2); @@ -929,8 +952,10 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( if (head_idx < num_heads + gqa_group_size) { float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; bias_vec1[0] = static_cast(tmp1); bias_vec1[1] = static_cast(tmp2); @@ -944,8 +969,10 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( if (head_idx < num_heads + gqa_group_size) { float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; bias_vec2[0] = static_cast(tmp1); bias_vec2[1] = static_cast(tmp2); @@ -1045,7 +1072,8 @@ template + bool IsFP8 = false, + bool EnforceFmulRN = false> __global__ void append_speculate_cache_int8_rope_kernel( const InT* __restrict__ quant_qkv, // [num_head, num_heads + 2 * // gqa_group_size, head_size] @@ -1149,9 +1177,11 @@ __global__ void append_speculate_cache_int8_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(bias_vec, &qkv_out_now[bias_idx]); } @@ -1215,9 +1245,11 @@ __global__ void append_speculate_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; bias_vec1[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec1[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec1[0] = static_cast(input_left); bias_vec1[1] = static_cast(input_right); @@ -1238,9 +1270,11 @@ __global__ void append_speculate_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; bias_vec2[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec2[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec2[0] = static_cast(input_left); bias_vec2[1] = static_cast(input_right); @@ -1285,7 +1319,8 @@ template + typename InT = int, + bool EnforceFmulRN = false> __global__ void append_speculate_cache_int8_neox_rope_kernel( const InT* __restrict__ quant_qkv, // [num_head, num_heads + 2 * // gqa_group_size, head_size] @@ -1396,9 +1431,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_bias_vec, &qkv_out_now[bias_idx_left]); Store(right_bias_vec, &qkv_out_now[bias_idx_right]); @@ -1485,9 +1522,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( float cos_tmp = cos_emb_vec1[i]; float sin_tmp = sin_emb_vec1[i]; left_bias_vec1[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); @@ -1500,9 +1539,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( cos_tmp = cos_emb_vec2[i]; sin_tmp = sin_emb_vec2[i]; left_bias_vec2[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec2[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); float quant_value1 = static_cast(scale * left_bias_vec1[i]); float quant_value2 = static_cast(scale * left_bias_vec2[i]); @@ -1669,7 +1710,8 @@ template + typename InT = int, + bool EnforceFmulRN = false> __global__ void append_speculate_cache_int4_rope_kernel( const InT* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, // head_size] @@ -1774,9 +1816,11 @@ __global__ void append_speculate_cache_int4_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(bias_vec, &qkv_out_now[bias_idx]); } @@ -1877,9 +1921,11 @@ __global__ void append_speculate_cache_int4_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; bias_vec1[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec1[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec1[0] = static_cast(input_left); bias_vec1[1] = static_cast(input_right); @@ -1895,9 +1941,11 @@ __global__ void append_speculate_cache_int4_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; bias_vec2[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec2[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec2[0] = static_cast(input_left); bias_vec2[1] = static_cast(input_right); @@ -2017,7 +2065,8 @@ template + typename InT = int, + bool EnforceFmulRN = false> __global__ void append_speculate_cache_int4_neox_rope_kernel( const InT* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, // head_size] @@ -2133,9 +2182,11 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_bias_vec, &qkv_out_now[bias_idx_left]); Store(right_bias_vec, &qkv_out_now[bias_idx_right]); @@ -2234,18 +2285,22 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; left_bias_vec1[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); cos_tmp = cos_emb_vec2[i]; sin_tmp = sin_emb_vec2[i]; left_bias_vec2[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec2[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); // quant + write k } LoadKVResT left_cache_vec, right_cache_vec; diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu index d6897ac5ed..fdf01a1df4 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu @@ -15,7 +15,7 @@ #include "speculate_write_cache_with_rope_kernel.h" #include "utils.cuh" -template +template void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv, T* key_cache, T* value_cache, @@ -58,7 +58,10 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv, PD_THROW("append_speculate_cache_rope_qk_norm not support neox rope yet"); } else { dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1); - append_speculate_cache_T_rope_qk_norm_kernel + append_speculate_cache_T_rope_qk_norm_kernel <<>>(qkv, key_cache, value_cache, @@ -88,7 +91,7 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv, } // rope + write -template +template void append_speculate_cache_rope(const QKV_TYPE* qkv, T* key_cache, T* value_cache, @@ -127,7 +130,10 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, GetNumBlocks(pack_num, &grid_size); if (use_neox_style) { if (rotary_dim < dim_head) { - append_speculate_cache_neox_partial_rope_kernel + append_speculate_cache_neox_partial_rope_kernel <<>>( qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size] key_cache, @@ -153,7 +159,10 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, kv_num_heads, rope_3d); } else { - append_speculate_cache_neox_rope_kernel + append_speculate_cache_neox_rope_kernel <<>>( qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size] key_cache, @@ -179,7 +188,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, rope_3d); } } else { - append_speculate_cache_rope_kernel + append_speculate_cache_rope_kernel <<>>( qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size] key_cache, @@ -206,7 +215,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, } } -template +template void append_speculate_cache_fp8_rope(const T* qkv, uint8_t* key_cache, uint8_t* value_cache, @@ -238,7 +247,7 @@ void append_speculate_cache_fp8_rope(const T* qkv, ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; dim3 grids(token_num, all_warps / num_warps); - append_clear_cache_int8_block<4> + append_clear_cache_int8_block<4, 128> <<>>(key_cache, value_cache, seq_lens, @@ -256,7 +265,8 @@ void append_speculate_cache_fp8_rope(const T* qkv, 0, 128, true, - IsDynamic> + IsDynamic, + EnforceFmulRN> <<>>(qkv, key_cache, value_cache, @@ -283,7 +293,10 @@ void append_speculate_cache_fp8_rope(const T* qkv, rms_norm_eps); } -template +template void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, uint8_t* key_cache, uint8_t* value_cache, @@ -315,7 +328,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; dim3 grids(token_num, all_warps / num_warps); - append_clear_cache_int8_block<4> + append_clear_cache_int8_block<4, 128> <<>>(key_cache, value_cache, seq_lens, @@ -329,7 +342,12 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, block_size, kv_num_heads); if (use_neox_style) { - append_speculate_cache_int8_neox_rope_kernel + append_speculate_cache_int8_neox_rope_kernel <<>>(qkv, key_cache, value_cache, @@ -354,7 +372,13 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, kv_num_heads, rope_3d); } else { - append_speculate_cache_int8_rope_kernel + append_speculate_cache_int8_rope_kernel <<>>(qkv, key_cache, value_cache, @@ -381,7 +405,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, } } -template +template void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, uint8_t* key_cache, uint8_t* value_cache, @@ -415,7 +439,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; dim3 grids(token_num, all_warps / num_warps); - append_clear_cache_int4_block<4> + append_clear_cache_int4_block<4, 128> <<>>(key_cache, value_cache, seq_lens, @@ -429,7 +453,12 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, block_size, kv_num_heads); if (use_neox_style) { - append_speculate_cache_int4_neox_rope_kernel + append_speculate_cache_int4_neox_rope_kernel <<>>(qkv, key_cache, value_cache, @@ -456,7 +485,12 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, kv_num_heads, rope_3d); } else { - append_speculate_cache_int4_rope_kernel + append_speculate_cache_int4_rope_kernel <<>>(qkv, key_cache, value_cache, @@ -484,7 +518,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, rope_3d); } } -template +template void SpeculateWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, @@ -549,7 +583,7 @@ void SpeculateWriteCacheWithRoPEKernel( } if (q_norm_weight && k_norm_weight) { if (cache_quant_type_str == "none") { - append_speculate_cache_rope_qk_norm( + append_speculate_cache_rope_qk_norm( reinterpret_cast(qkv_ptr), reinterpret_cast(key_cache_out->data()), reinterpret_cast(value_cache_out->data()), @@ -580,7 +614,7 @@ void SpeculateWriteCacheWithRoPEKernel( rms_norm_eps, rope_3d); } else if (cache_quant_type_str == "block_wise_fp8") { - append_speculate_cache_fp8_rope( + append_speculate_cache_fp8_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), @@ -610,7 +644,7 @@ void SpeculateWriteCacheWithRoPEKernel( rope_3d, rms_norm_eps); } else if (cache_quant_type_str == "cache_fp8") { - append_speculate_cache_fp8_rope( + append_speculate_cache_fp8_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), @@ -648,7 +682,7 @@ void SpeculateWriteCacheWithRoPEKernel( } else { if (cache_quant_type_str == "none") { - append_speculate_cache_rope( + append_speculate_cache_rope( reinterpret_cast(qkv_ptr), reinterpret_cast(key_cache_out->data()), reinterpret_cast(value_cache_out->data()), @@ -677,7 +711,10 @@ void SpeculateWriteCacheWithRoPEKernel( use_neox_rotary_style, rope_3d); } else if (cache_quant_type_str == "cache_int8") { - append_speculate_cache_int8_rope( + append_speculate_cache_int8_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), @@ -711,7 +748,10 @@ void SpeculateWriteCacheWithRoPEKernel( use_neox_rotary_style, rope_3d); } else if (cache_quant_type_str == "cache_fp8") { - append_speculate_cache_int8_rope( + append_speculate_cache_int8_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), @@ -745,7 +785,7 @@ void SpeculateWriteCacheWithRoPEKernel( use_neox_rotary_style, rope_3d); } else if (cache_quant_type_str == "block_wise_fp8") { - append_speculate_cache_fp8_rope( + append_speculate_cache_fp8_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), @@ -775,7 +815,7 @@ void SpeculateWriteCacheWithRoPEKernel( rope_3d, rms_norm_eps); } else if (cache_quant_type_str == "cache_int4_zp") { - append_speculate_cache_int4_rope( + append_speculate_cache_int4_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h index 2db42bc26e..c9c3ff9e0b 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h @@ -15,7 +15,7 @@ #include "speculate_write_cache_with_rope_impl.cuh" -template +template void SpeculateWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& diff --git a/custom_ops/gpu_ops/helper.cu b/custom_ops/gpu_ops/helper.cu index 7144b05c23..45a3660d39 100644 --- a/custom_ops/gpu_ops/helper.cu +++ b/custom_ops/gpu_ops/helper.cu @@ -157,6 +157,14 @@ bool getEnvEnablePDL() { return enablePDL; } +bool getEnvEnableRL() { + static std::once_flag flag; + static bool enableRL = false; + + std::call_once(flag, [&]() { enableRL = getBoolEnv("FD_ENABLE_RL"); }); + return enableRL; +} + bool getEnvDeterministicMode() { static std::once_flag flag; static bool value = false; diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index dd03e5d21e..83f3ad1077 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -189,6 +189,17 @@ inline int GetGPUComputeCapability(int id) { } #endif +#ifndef DISPATCH_BOOL_DTYPE +#define DISPATCH_BOOL_DTYPE(runtime_flag, STATIC_FLAG, ...) \ + if (runtime_flag) { \ + constexpr bool STATIC_FLAG = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool STATIC_FLAG = false; \ + __VA_ARGS__ \ + } +#endif + inline constexpr uint32_t next_pow_2(uint32_t const num) { if (num <= 1) return num; return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); @@ -731,7 +742,15 @@ inline bool getBoolEnv(char const *name) { return env && env[0] == '1' && env[1] == '\0'; } +template +__device__ __forceinline__ float fmul_func(float a, float b) { + // Use __fmul_rn for IEEE-754 compliant rounding when EnforceRN is enabled + float res = EnforceRN ? __fmul_rn(a, b) : a * b; + return res; +} + bool getEnvEnablePDL(); +bool getEnvEnableRL(); bool getEnvDeterministicMode(); bool getEnvDeterministicDebug();