[XPU] add support for rope3d (#7518)

* [XPU] add support for rope3d

* support decoder

---------

Co-authored-by: yinwei <yinwei_hust@163.com>
This commit is contained in:
RuohengMa
2026-04-21 13:39:00 +08:00
committed by GitHub
parent 609f649dd7
commit 9d3551cfbb
@@ -239,11 +239,7 @@ void split_kvcache_encoder(api::Context* xpu_ctx,
bool use_neox_rotary_style) {
int ret;
int64_t real_kv_num_heads = (kv_num_heads == -1) ? q_num_heads : kv_num_heads;
// TODO: spliced split kvcache should support rope3d
if (FLAGS_encoder_splice && !rope_3d) {
if (rope_3d) {
PD_THROW("split_kvcache_encoder does not support rope_3d == true!");
}
if (FLAGS_encoder_splice) {
paddle::Place place = qkv.place();
xftblock::DataType KV_BUF_TYPE = std::is_same<bfloat16, TQKV>::value
? xftblock::DataType::DT_BFLOAT16
@@ -274,6 +270,47 @@ void split_kvcache_encoder(api::Context* xpu_ctx,
head_dim);
PD_CHECK(ret == api::SUCCESS, "split_qkv_block failed.");
paddle::Tensor pos_emb_offset_xpu;
paddle::Tensor pos_emb_offset_cpu;
api::VectorParam<int32_t> pos_emb_offset;
pos_emb_offset.len = start_tokens.len;
if (rope_3d) {
pos_emb_offset_xpu =
paddle::empty({batch_size}, paddle::DataType::INT32, place);
pos_emb_offset_cpu = paddle::empty(
{batch_size}, paddle::DataType::INT32, paddle::CPUPlace());
auto pos_emb_offset_xpu_ptr =
const_cast<int32_t*>(pos_emb_offset_xpu.data<int32_t>());
auto pos_emb_offset_cpu_ptr =
const_cast<int32_t*>(pos_emb_offset_cpu.data<int32_t>());
// bs_offset = real_batch.cpu[bs] * 2 * rope_max_seqlen +
// start_tokens.cpu[bs];
ret = api::scale(xpu_ctx,
real_batch.xpu,
pos_emb_offset_xpu_ptr,
batch_size,
true,
2 * static_cast<int32_t>(rope_max_seqlen),
0);
PD_CHECK(ret == api::SUCCESS, "api::scale failed.");
ret = api::broadcast_add<int32_t>(xpu_ctx,
pos_emb_offset_xpu_ptr,
start_tokens.xpu,
pos_emb_offset_xpu_ptr,
{pos_emb_offset.len},
{pos_emb_offset.len});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_add failed.");
for (int i = 0; i < batch_size; i++) {
pos_emb_offset_cpu_ptr[i] =
real_batch.cpu[i] * 2 * static_cast<int32_t>(rope_max_seqlen) +
start_tokens.cpu[i];
}
pos_emb_offset.cpu = pos_emb_offset_cpu_ptr;
pos_emb_offset.xpu = pos_emb_offset_xpu_ptr;
} else {
pos_emb_offset.cpu = start_tokens.cpu;
pos_emb_offset.xpu = start_tokens.xpu;
}
if (!use_neox_rotary_style) {
ret = infer_ops::vsl_rotary_embedding_gptj<TQKV, TR, TID>(
xpu_ctx,
@@ -288,8 +325,8 @@ void split_kvcache_encoder(api::Context* xpu_ctx,
q_num_heads,
head_dim,
"BLHD",
start_tokens,
"NORMAL",
pos_emb_offset,
pos_emb_type,
real_kv_num_heads,
false);
PD_CHECK(ret == api::SUCCESS, "vsl_rotary_embedding_gptj failed.");
@@ -308,7 +345,7 @@ void split_kvcache_encoder(api::Context* xpu_ctx,
head_dim,
rope_head_dim,
"BLHD",
start_tokens,
pos_emb_offset,
"NORMAL",
real_kv_num_heads,
false);
@@ -546,12 +583,7 @@ void split_kvcache_decoder(api::Context* xpu_ctx,
bool use_neox_rotary_style) {
int64_t real_kv_num_heads = (kv_num_heads == -1) ? q_num_heads : kv_num_heads;
int ret;
// TODO: spliced split kvcache should support rope3d
if (FLAGS_decoder_splice && !rope_3d) {
// not yet supported
if (rope_3d) {
PD_THROW("split_kvcache_decoder does not support rope_3d == true!");
}
if (FLAGS_decoder_splice) {
if (std::is_same<TKV_CACHE, int8_t>::value &&
(k_cache_scale_inv == nullptr || v_cache_scale_inv == nullptr)) {
PD_THROW(
@@ -591,6 +623,47 @@ void split_kvcache_decoder(api::Context* xpu_ctx,
head_dim);
PD_CHECK(ret == api::SUCCESS, "split_qkv_block failed.");
paddle::Tensor pos_emb_offset_xpu;
paddle::Tensor pos_emb_offset_cpu;
api::VectorParam<int32_t> pos_emb_offset;
pos_emb_offset.len = start_tokens.len;
if (rope_3d) {
pos_emb_offset_xpu =
paddle::empty({batch_size}, paddle::DataType::INT32, place);
pos_emb_offset_cpu = paddle::empty(
{batch_size}, paddle::DataType::INT32, paddle::CPUPlace());
auto pos_emb_offset_xpu_ptr =
const_cast<int32_t*>(pos_emb_offset_xpu.data<int32_t>());
auto pos_emb_offset_cpu_ptr =
const_cast<int32_t*>(pos_emb_offset_cpu.data<int32_t>());
// bs_offset = real_batch.cpu[bs] * 2 * rope_max_seqlen +
// start_tokens.cpu[bs];
ret = api::scale(xpu_ctx,
real_batch.xpu,
pos_emb_offset_xpu_ptr,
batch_size,
true,
2 * static_cast<int32_t>(rope_max_seqlen),
0);
PD_CHECK(ret == api::SUCCESS, "api::scale failed.");
ret = api::broadcast_add<int32_t>(xpu_ctx,
pos_emb_offset_xpu_ptr,
start_tokens.xpu,
pos_emb_offset_xpu_ptr,
{pos_emb_offset.len},
{pos_emb_offset.len});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_add failed.");
for (int i = 0; i < batch_size; i++) {
pos_emb_offset_cpu_ptr[i] =
real_batch.cpu[i] * 2 * static_cast<int32_t>(rope_max_seqlen) +
start_tokens.cpu[i];
}
pos_emb_offset.cpu = pos_emb_offset_cpu_ptr;
pos_emb_offset.xpu = pos_emb_offset_xpu_ptr;
} else {
pos_emb_offset.cpu = start_tokens.cpu;
pos_emb_offset.xpu = start_tokens.xpu;
}
if (!use_neox_rotary_style) {
ret = infer_ops::vsl_rotary_embedding_gptj<TQKV, TR, TID>(
xpu_ctx,
@@ -605,8 +678,8 @@ void split_kvcache_decoder(api::Context* xpu_ctx,
q_num_heads,
head_dim,
"BLHD",
start_tokens,
"NORMAL",
pos_emb_offset,
pos_emb_type,
real_kv_num_heads,
false);
PD_CHECK(ret == api::SUCCESS, "vsl_rotary_embedding_gptj failed.");
@@ -625,7 +698,7 @@ void split_kvcache_decoder(api::Context* xpu_ctx,
head_dim,
rope_head_dim,
"BLHD",
start_tokens,
pos_emb_offset,
"NORMAL",
real_kv_num_heads,
false);