mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] change glm rope_emb calculation (#7316)
* change glm rope_emb calculation * glm without EnforceFmulRN * fix ci
This commit is contained in:
@@ -146,10 +146,10 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (rotary_dim < dim_head) {
|
||||
auto* kernelFn =
|
||||
append_decode_cache_T_neox_partial_rope_kernel<T,
|
||||
PackSize,
|
||||
EnforceFmulRN>;
|
||||
auto* kernelFn = append_decode_cache_T_neox_partial_rope_kernel<
|
||||
T,
|
||||
PackSize,
|
||||
false>; // GLM use EnforceFmulRN=false
|
||||
launchWithPdlWhenEnabled(kernelFn,
|
||||
grid_size,
|
||||
blocksize,
|
||||
|
||||
@@ -2543,10 +2543,10 @@ void gqa_rotary_qk_variable(
|
||||
}
|
||||
const int pack_num_new = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num_new, &grid_size);
|
||||
auto *kernelFn =
|
||||
GQANeoxVariableLengthPartialRotaryKernel<T,
|
||||
PackSize,
|
||||
EnforceFmulRN>;
|
||||
auto *kernelFn = GQANeoxVariableLengthPartialRotaryKernel<
|
||||
T,
|
||||
PackSize,
|
||||
false>; // GLM use EnforceFmulRN=false
|
||||
launchWithPdlWhenEnabled(kernelFn,
|
||||
grid_size,
|
||||
blocksize,
|
||||
|
||||
@@ -387,30 +387,32 @@ void gqa_neox_partial_rotary_qk_split_variable(
|
||||
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2;
|
||||
launchWithPdlWhenEnabled(
|
||||
GQAVariableLengthNeoxPartialRotarySplitKernel<T, PackSize, EnforceFmulRN>,
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
stream,
|
||||
qkv_input,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
qkv_out,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_model_len,
|
||||
head_dim,
|
||||
rotary_dim);
|
||||
launchWithPdlWhenEnabled(GQAVariableLengthNeoxPartialRotarySplitKernel<
|
||||
T,
|
||||
PackSize,
|
||||
false>, // GLM use EnforceFmulRN=false
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
stream,
|
||||
qkv_input,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
qkv_out,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_model_len,
|
||||
head_dim,
|
||||
rotary_dim);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
|
||||
@@ -130,10 +130,11 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
GetNumBlocks(pack_num, &grid_size);
|
||||
if (use_neox_style) {
|
||||
if (rotary_dim < dim_head) {
|
||||
append_speculate_cache_neox_partial_rope_kernel<T,
|
||||
PackSize,
|
||||
QKV_TYPE,
|
||||
EnforceFmulRN>
|
||||
append_speculate_cache_neox_partial_rope_kernel<
|
||||
T,
|
||||
PackSize,
|
||||
QKV_TYPE,
|
||||
false> // GLM use EnforceFmulRN=false
|
||||
<<<grid_size, threads_per_block, 0, stream>>>(
|
||||
qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size]
|
||||
key_cache,
|
||||
|
||||
Reference in New Issue
Block a user