support fa3 qwen-vl rope (#5869)

This commit is contained in:
chen
2026-01-05 15:29:34 +08:00
committed by GitHub
parent adb91dcacc
commit ac39c0f887
2 changed files with 26 additions and 20 deletions
@@ -1377,23 +1377,25 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
if (use_neox_rotary_style) {
if (rotary_dim == head_dim) {
gqa_rotary_qk_split_variable_qwen3<data_t>(qkv_out.data<data_t>(),
q.data<data_t>(),
k.data<data_t>(),
v.data<data_t>(),
qkv.data<data_t>(),
rotary_embs.data<float>(),
batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
head_dim,
stream);
gqa_rotary_qk_split_variable_qwen3<data_t>(
qkv_out.data<data_t>(),
q.data<data_t>(),
k.data<data_t>(),
v.data<data_t>(),
qkv.data<data_t>(),
rotary_embs.data<float>(),
batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
token_num,
num_heads,
kv_num_heads,
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
head_dim,
rope_3d,
stream);
} else {
gqa_neox_partial_rotary_qk_split_variable<data_t>(
qkv_out.data<data_t>(),
+7 -3
View File
@@ -23,7 +23,8 @@ __global__ void GQAVariableLengthRotarySplitKernel_Qwen3(
const int q_num_head,
const int kv_num_head,
const int max_model_len,
const int head_dim) {
const int head_dim,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadEmbT cos_emb_vec;
@@ -84,7 +85,8 @@ __global__ void GQAVariableLengthRotarySplitKernel_Qwen3(
}
// TODO check this correct or not
int64_t new_emb_idx = emb_idx;
int64_t new_emb_idx =
rope_3d ? emb_idx + ori_bi * 2 * max_model_len * head_dim : emb_idx;
if (hi < q_num_head + kv_num_head) {
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
@@ -126,6 +128,7 @@ void gqa_rotary_qk_split_variable_qwen3(T *qkv_out,
const int kv_num_heads,
const int max_model_len,
const int head_dim,
const bool rope_3d,
const cudaStream_t &stream) {
assert(head_dim == 128 && "head_dim must be 128");
@@ -163,5 +166,6 @@ void gqa_rotary_qk_split_variable_qwen3(T *qkv_out,
num_heads,
kv_num_heads,
max_model_len,
head_dim);
head_dim,
rope_3d);
}