mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
support fa3 qwen-vl rope (#5869)
This commit is contained in:
@@ -1377,23 +1377,25 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
|
||||
if (use_neox_rotary_style) {
|
||||
if (rotary_dim == head_dim) {
|
||||
gqa_rotary_qk_split_variable_qwen3<data_t>(qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
head_dim,
|
||||
stream);
|
||||
gqa_rotary_qk_split_variable_qwen3<data_t>(
|
||||
qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
rope_3d,
|
||||
stream);
|
||||
} else {
|
||||
gqa_neox_partial_rotary_qk_split_variable<data_t>(
|
||||
qkv_out.data<data_t>(),
|
||||
|
||||
@@ -23,7 +23,8 @@ __global__ void GQAVariableLengthRotarySplitKernel_Qwen3(
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int max_model_len,
|
||||
const int head_dim) {
|
||||
const int head_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadEmbT cos_emb_vec;
|
||||
@@ -84,7 +85,8 @@ __global__ void GQAVariableLengthRotarySplitKernel_Qwen3(
|
||||
}
|
||||
|
||||
// TODO check this correct or not
|
||||
int64_t new_emb_idx = emb_idx;
|
||||
int64_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * 2 * max_model_len * head_dim : emb_idx;
|
||||
|
||||
if (hi < q_num_head + kv_num_head) {
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
@@ -126,6 +128,7 @@ void gqa_rotary_qk_split_variable_qwen3(T *qkv_out,
|
||||
const int kv_num_heads,
|
||||
const int max_model_len,
|
||||
const int head_dim,
|
||||
const bool rope_3d,
|
||||
const cudaStream_t &stream) {
|
||||
assert(head_dim == 128 && "head_dim must be 128");
|
||||
|
||||
@@ -163,5 +166,6 @@ void gqa_rotary_qk_split_variable_qwen3(T *qkv_out,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_model_len,
|
||||
head_dim);
|
||||
head_dim,
|
||||
rope_3d);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user