[RL] change glm rope_emb calculation (#7316)

* change glm rope_emb calculation

* glm without EnforceFmulRN

* fix ci
This commit is contained in:
JYChen
2026-04-11 18:36:28 +08:00
committed by GitHub
parent fcf8b1336d
commit 076ab07528
5 changed files with 47 additions and 38 deletions
@@ -146,10 +146,10 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
rope_3d);
} else {
if (rotary_dim < dim_head) {
auto* kernelFn =
append_decode_cache_T_neox_partial_rope_kernel<T,
PackSize,
EnforceFmulRN>;
auto* kernelFn = append_decode_cache_T_neox_partial_rope_kernel<
T,
PackSize,
false>; // GLM use EnforceFmulRN=false
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
@@ -2543,10 +2543,10 @@ void gqa_rotary_qk_variable(
}
const int pack_num_new = elem_nums / PackSize;
GetNumBlocks<128>(pack_num_new, &grid_size);
auto *kernelFn =
GQANeoxVariableLengthPartialRotaryKernel<T,
PackSize,
EnforceFmulRN>;
auto *kernelFn = GQANeoxVariableLengthPartialRotaryKernel<
T,
PackSize,
false>; // GLM use EnforceFmulRN=false
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
@@ -387,30 +387,32 @@ void gqa_neox_partial_rotary_qk_split_variable(
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2;
launchWithPdlWhenEnabled(
GQAVariableLengthNeoxPartialRotarySplitKernel<T, PackSize, EnforceFmulRN>,
grid_size,
block_size,
0,
stream,
qkv_input,
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
cu_seqlens_k,
qkv_out,
q,
k,
v,
elem_nums,
num_heads,
kv_num_heads,
max_model_len,
head_dim,
rotary_dim);
launchWithPdlWhenEnabled(GQAVariableLengthNeoxPartialRotarySplitKernel<
T,
PackSize,
false>, // GLM use EnforceFmulRN=false
grid_size,
block_size,
0,
stream,
qkv_input,
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
cu_seqlens_k,
qkv_out,
q,
k,
v,
elem_nums,
num_heads,
kv_num_heads,
max_model_len,
head_dim,
rotary_dim);
}
template <typename T,
@@ -130,10 +130,11 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
GetNumBlocks(pack_num, &grid_size);
if (use_neox_style) {
if (rotary_dim < dim_head) {
append_speculate_cache_neox_partial_rope_kernel<T,
PackSize,
QKV_TYPE,
EnforceFmulRN>
append_speculate_cache_neox_partial_rope_kernel<
T,
PackSize,
QKV_TYPE,
false> // GLM use EnforceFmulRN=false
<<<grid_size, threads_per_block, 0, stream>>>(
qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size]
key_cache,