diff --git a/custom_ops/xpu_ops/src/ops/block_attn_spliced.cc b/custom_ops/xpu_ops/src/ops/block_attn_spliced.cc index beb950ec31..9bc69dd27f 100644 --- a/custom_ops/xpu_ops/src/ops/block_attn_spliced.cc +++ b/custom_ops/xpu_ops/src/ops/block_attn_spliced.cc @@ -239,11 +239,7 @@ void split_kvcache_encoder(api::Context* xpu_ctx, bool use_neox_rotary_style) { int ret; int64_t real_kv_num_heads = (kv_num_heads == -1) ? q_num_heads : kv_num_heads; - // TODO: spliced split kvcache should support rope3d - if (FLAGS_encoder_splice && !rope_3d) { - if (rope_3d) { - PD_THROW("split_kvcache_encoder does not support rope_3d == true!"); - } + if (FLAGS_encoder_splice) { paddle::Place place = qkv.place(); xftblock::DataType KV_BUF_TYPE = std::is_same::value ? xftblock::DataType::DT_BFLOAT16 @@ -274,6 +270,47 @@ void split_kvcache_encoder(api::Context* xpu_ctx, head_dim); PD_CHECK(ret == api::SUCCESS, "split_qkv_block failed."); + paddle::Tensor pos_emb_offset_xpu; + paddle::Tensor pos_emb_offset_cpu; + api::VectorParam pos_emb_offset; + pos_emb_offset.len = start_tokens.len; + if (rope_3d) { + pos_emb_offset_xpu = + paddle::empty({batch_size}, paddle::DataType::INT32, place); + pos_emb_offset_cpu = paddle::empty( + {batch_size}, paddle::DataType::INT32, paddle::CPUPlace()); + auto pos_emb_offset_xpu_ptr = + const_cast(pos_emb_offset_xpu.data()); + auto pos_emb_offset_cpu_ptr = + const_cast(pos_emb_offset_cpu.data()); + // bs_offset = real_batch.cpu[bs] * 2 * rope_max_seqlen + + // start_tokens.cpu[bs]; + ret = api::scale(xpu_ctx, + real_batch.xpu, + pos_emb_offset_xpu_ptr, + batch_size, + true, + 2 * static_cast(rope_max_seqlen), + 0); + PD_CHECK(ret == api::SUCCESS, "api::scale failed."); + ret = api::broadcast_add(xpu_ctx, + pos_emb_offset_xpu_ptr, + start_tokens.xpu, + pos_emb_offset_xpu_ptr, + {pos_emb_offset.len}, + {pos_emb_offset.len}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_add failed."); + for (int i = 0; i < batch_size; i++) { + pos_emb_offset_cpu_ptr[i] = + real_batch.cpu[i] * 2 * static_cast(rope_max_seqlen) + + start_tokens.cpu[i]; + } + pos_emb_offset.cpu = pos_emb_offset_cpu_ptr; + pos_emb_offset.xpu = pos_emb_offset_xpu_ptr; + } else { + pos_emb_offset.cpu = start_tokens.cpu; + pos_emb_offset.xpu = start_tokens.xpu; + } if (!use_neox_rotary_style) { ret = infer_ops::vsl_rotary_embedding_gptj( xpu_ctx, @@ -288,8 +325,8 @@ void split_kvcache_encoder(api::Context* xpu_ctx, q_num_heads, head_dim, "BLHD", - start_tokens, - "NORMAL", + pos_emb_offset, + pos_emb_type, real_kv_num_heads, false); PD_CHECK(ret == api::SUCCESS, "vsl_rotary_embedding_gptj failed."); @@ -308,7 +345,7 @@ void split_kvcache_encoder(api::Context* xpu_ctx, head_dim, rope_head_dim, "BLHD", - start_tokens, + pos_emb_offset, "NORMAL", real_kv_num_heads, false); @@ -546,12 +583,7 @@ void split_kvcache_decoder(api::Context* xpu_ctx, bool use_neox_rotary_style) { int64_t real_kv_num_heads = (kv_num_heads == -1) ? q_num_heads : kv_num_heads; int ret; - // TODO: spliced split kvcache should support rope3d - if (FLAGS_decoder_splice && !rope_3d) { - // not yet supported - if (rope_3d) { - PD_THROW("split_kvcache_decoder does not support rope_3d == true!"); - } + if (FLAGS_decoder_splice) { if (std::is_same::value && (k_cache_scale_inv == nullptr || v_cache_scale_inv == nullptr)) { PD_THROW( @@ -591,6 +623,47 @@ void split_kvcache_decoder(api::Context* xpu_ctx, head_dim); PD_CHECK(ret == api::SUCCESS, "split_qkv_block failed."); + paddle::Tensor pos_emb_offset_xpu; + paddle::Tensor pos_emb_offset_cpu; + api::VectorParam pos_emb_offset; + pos_emb_offset.len = start_tokens.len; + if (rope_3d) { + pos_emb_offset_xpu = + paddle::empty({batch_size}, paddle::DataType::INT32, place); + pos_emb_offset_cpu = paddle::empty( + {batch_size}, paddle::DataType::INT32, paddle::CPUPlace()); + auto pos_emb_offset_xpu_ptr = + const_cast(pos_emb_offset_xpu.data()); + auto pos_emb_offset_cpu_ptr = + const_cast(pos_emb_offset_cpu.data()); + // bs_offset = real_batch.cpu[bs] * 2 * rope_max_seqlen + + // start_tokens.cpu[bs]; + ret = api::scale(xpu_ctx, + real_batch.xpu, + pos_emb_offset_xpu_ptr, + batch_size, + true, + 2 * static_cast(rope_max_seqlen), + 0); + PD_CHECK(ret == api::SUCCESS, "api::scale failed."); + ret = api::broadcast_add(xpu_ctx, + pos_emb_offset_xpu_ptr, + start_tokens.xpu, + pos_emb_offset_xpu_ptr, + {pos_emb_offset.len}, + {pos_emb_offset.len}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_add failed."); + for (int i = 0; i < batch_size; i++) { + pos_emb_offset_cpu_ptr[i] = + real_batch.cpu[i] * 2 * static_cast(rope_max_seqlen) + + start_tokens.cpu[i]; + } + pos_emb_offset.cpu = pos_emb_offset_cpu_ptr; + pos_emb_offset.xpu = pos_emb_offset_xpu_ptr; + } else { + pos_emb_offset.cpu = start_tokens.cpu; + pos_emb_offset.xpu = start_tokens.xpu; + } if (!use_neox_rotary_style) { ret = infer_ops::vsl_rotary_embedding_gptj( xpu_ctx, @@ -605,8 +678,8 @@ void split_kvcache_decoder(api::Context* xpu_ctx, q_num_heads, head_dim, "BLHD", - start_tokens, - "NORMAL", + pos_emb_offset, + pos_emb_type, real_kv_num_heads, false); PD_CHECK(ret == api::SUCCESS, "vsl_rotary_embedding_gptj failed."); @@ -625,7 +698,7 @@ void split_kvcache_decoder(api::Context* xpu_ctx, head_dim, rope_head_dim, "BLHD", - start_tokens, + pos_emb_offset, "NORMAL", real_kv_num_heads, false);