mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] add support for rope3d (#7518)
* [XPU] add support for rope3d * support decoder --------- Co-authored-by: yinwei <yinwei_hust@163.com>
This commit is contained in:
@@ -239,11 +239,7 @@ void split_kvcache_encoder(api::Context* xpu_ctx,
|
||||
bool use_neox_rotary_style) {
|
||||
int ret;
|
||||
int64_t real_kv_num_heads = (kv_num_heads == -1) ? q_num_heads : kv_num_heads;
|
||||
// TODO: spliced split kvcache should support rope3d
|
||||
if (FLAGS_encoder_splice && !rope_3d) {
|
||||
if (rope_3d) {
|
||||
PD_THROW("split_kvcache_encoder does not support rope_3d == true!");
|
||||
}
|
||||
if (FLAGS_encoder_splice) {
|
||||
paddle::Place place = qkv.place();
|
||||
xftblock::DataType KV_BUF_TYPE = std::is_same<bfloat16, TQKV>::value
|
||||
? xftblock::DataType::DT_BFLOAT16
|
||||
@@ -274,6 +270,47 @@ void split_kvcache_encoder(api::Context* xpu_ctx,
|
||||
head_dim);
|
||||
PD_CHECK(ret == api::SUCCESS, "split_qkv_block failed.");
|
||||
|
||||
paddle::Tensor pos_emb_offset_xpu;
|
||||
paddle::Tensor pos_emb_offset_cpu;
|
||||
api::VectorParam<int32_t> pos_emb_offset;
|
||||
pos_emb_offset.len = start_tokens.len;
|
||||
if (rope_3d) {
|
||||
pos_emb_offset_xpu =
|
||||
paddle::empty({batch_size}, paddle::DataType::INT32, place);
|
||||
pos_emb_offset_cpu = paddle::empty(
|
||||
{batch_size}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
auto pos_emb_offset_xpu_ptr =
|
||||
const_cast<int32_t*>(pos_emb_offset_xpu.data<int32_t>());
|
||||
auto pos_emb_offset_cpu_ptr =
|
||||
const_cast<int32_t*>(pos_emb_offset_cpu.data<int32_t>());
|
||||
// bs_offset = real_batch.cpu[bs] * 2 * rope_max_seqlen +
|
||||
// start_tokens.cpu[bs];
|
||||
ret = api::scale(xpu_ctx,
|
||||
real_batch.xpu,
|
||||
pos_emb_offset_xpu_ptr,
|
||||
batch_size,
|
||||
true,
|
||||
2 * static_cast<int32_t>(rope_max_seqlen),
|
||||
0);
|
||||
PD_CHECK(ret == api::SUCCESS, "api::scale failed.");
|
||||
ret = api::broadcast_add<int32_t>(xpu_ctx,
|
||||
pos_emb_offset_xpu_ptr,
|
||||
start_tokens.xpu,
|
||||
pos_emb_offset_xpu_ptr,
|
||||
{pos_emb_offset.len},
|
||||
{pos_emb_offset.len});
|
||||
PD_CHECK(ret == api::SUCCESS, "api::broadcast_add failed.");
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
pos_emb_offset_cpu_ptr[i] =
|
||||
real_batch.cpu[i] * 2 * static_cast<int32_t>(rope_max_seqlen) +
|
||||
start_tokens.cpu[i];
|
||||
}
|
||||
pos_emb_offset.cpu = pos_emb_offset_cpu_ptr;
|
||||
pos_emb_offset.xpu = pos_emb_offset_xpu_ptr;
|
||||
} else {
|
||||
pos_emb_offset.cpu = start_tokens.cpu;
|
||||
pos_emb_offset.xpu = start_tokens.xpu;
|
||||
}
|
||||
if (!use_neox_rotary_style) {
|
||||
ret = infer_ops::vsl_rotary_embedding_gptj<TQKV, TR, TID>(
|
||||
xpu_ctx,
|
||||
@@ -288,8 +325,8 @@ void split_kvcache_encoder(api::Context* xpu_ctx,
|
||||
q_num_heads,
|
||||
head_dim,
|
||||
"BLHD",
|
||||
start_tokens,
|
||||
"NORMAL",
|
||||
pos_emb_offset,
|
||||
pos_emb_type,
|
||||
real_kv_num_heads,
|
||||
false);
|
||||
PD_CHECK(ret == api::SUCCESS, "vsl_rotary_embedding_gptj failed.");
|
||||
@@ -308,7 +345,7 @@ void split_kvcache_encoder(api::Context* xpu_ctx,
|
||||
head_dim,
|
||||
rope_head_dim,
|
||||
"BLHD",
|
||||
start_tokens,
|
||||
pos_emb_offset,
|
||||
"NORMAL",
|
||||
real_kv_num_heads,
|
||||
false);
|
||||
@@ -546,12 +583,7 @@ void split_kvcache_decoder(api::Context* xpu_ctx,
|
||||
bool use_neox_rotary_style) {
|
||||
int64_t real_kv_num_heads = (kv_num_heads == -1) ? q_num_heads : kv_num_heads;
|
||||
int ret;
|
||||
// TODO: spliced split kvcache should support rope3d
|
||||
if (FLAGS_decoder_splice && !rope_3d) {
|
||||
// not yet supported
|
||||
if (rope_3d) {
|
||||
PD_THROW("split_kvcache_decoder does not support rope_3d == true!");
|
||||
}
|
||||
if (FLAGS_decoder_splice) {
|
||||
if (std::is_same<TKV_CACHE, int8_t>::value &&
|
||||
(k_cache_scale_inv == nullptr || v_cache_scale_inv == nullptr)) {
|
||||
PD_THROW(
|
||||
@@ -591,6 +623,47 @@ void split_kvcache_decoder(api::Context* xpu_ctx,
|
||||
head_dim);
|
||||
PD_CHECK(ret == api::SUCCESS, "split_qkv_block failed.");
|
||||
|
||||
paddle::Tensor pos_emb_offset_xpu;
|
||||
paddle::Tensor pos_emb_offset_cpu;
|
||||
api::VectorParam<int32_t> pos_emb_offset;
|
||||
pos_emb_offset.len = start_tokens.len;
|
||||
if (rope_3d) {
|
||||
pos_emb_offset_xpu =
|
||||
paddle::empty({batch_size}, paddle::DataType::INT32, place);
|
||||
pos_emb_offset_cpu = paddle::empty(
|
||||
{batch_size}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
auto pos_emb_offset_xpu_ptr =
|
||||
const_cast<int32_t*>(pos_emb_offset_xpu.data<int32_t>());
|
||||
auto pos_emb_offset_cpu_ptr =
|
||||
const_cast<int32_t*>(pos_emb_offset_cpu.data<int32_t>());
|
||||
// bs_offset = real_batch.cpu[bs] * 2 * rope_max_seqlen +
|
||||
// start_tokens.cpu[bs];
|
||||
ret = api::scale(xpu_ctx,
|
||||
real_batch.xpu,
|
||||
pos_emb_offset_xpu_ptr,
|
||||
batch_size,
|
||||
true,
|
||||
2 * static_cast<int32_t>(rope_max_seqlen),
|
||||
0);
|
||||
PD_CHECK(ret == api::SUCCESS, "api::scale failed.");
|
||||
ret = api::broadcast_add<int32_t>(xpu_ctx,
|
||||
pos_emb_offset_xpu_ptr,
|
||||
start_tokens.xpu,
|
||||
pos_emb_offset_xpu_ptr,
|
||||
{pos_emb_offset.len},
|
||||
{pos_emb_offset.len});
|
||||
PD_CHECK(ret == api::SUCCESS, "api::broadcast_add failed.");
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
pos_emb_offset_cpu_ptr[i] =
|
||||
real_batch.cpu[i] * 2 * static_cast<int32_t>(rope_max_seqlen) +
|
||||
start_tokens.cpu[i];
|
||||
}
|
||||
pos_emb_offset.cpu = pos_emb_offset_cpu_ptr;
|
||||
pos_emb_offset.xpu = pos_emb_offset_xpu_ptr;
|
||||
} else {
|
||||
pos_emb_offset.cpu = start_tokens.cpu;
|
||||
pos_emb_offset.xpu = start_tokens.xpu;
|
||||
}
|
||||
if (!use_neox_rotary_style) {
|
||||
ret = infer_ops::vsl_rotary_embedding_gptj<TQKV, TR, TID>(
|
||||
xpu_ctx,
|
||||
@@ -605,8 +678,8 @@ void split_kvcache_decoder(api::Context* xpu_ctx,
|
||||
q_num_heads,
|
||||
head_dim,
|
||||
"BLHD",
|
||||
start_tokens,
|
||||
"NORMAL",
|
||||
pos_emb_offset,
|
||||
pos_emb_type,
|
||||
real_kv_num_heads,
|
||||
false);
|
||||
PD_CHECK(ret == api::SUCCESS, "vsl_rotary_embedding_gptj failed.");
|
||||
@@ -625,7 +698,7 @@ void split_kvcache_decoder(api::Context* xpu_ctx,
|
||||
head_dim,
|
||||
rope_head_dim,
|
||||
"BLHD",
|
||||
start_tokens,
|
||||
pos_emb_offset,
|
||||
"NORMAL",
|
||||
real_kv_num_heads,
|
||||
false);
|
||||
|
||||
Reference in New Issue
Block a user