[XPU] get_infer_param use inplace copy, remove block_tables abundant d2h copy (#7431)

* inplace_copy: encoder_batch_idx/decoder_batch_idx bs == 9 ok

* inplace_copy: encoder_seq_lod/decoder_seq_lod bs == 9 ok

* inplace_copy: all bs == 9 ok

* inplace_copy: all cpu bs == 9 ok

* inplace_copy: len_info_cpu bs == 9 ok

* finished and rm unused code

* prefix_block_tables reuse

* refine

* improve performance

* remove block_table copy to cpu

* fix unit test

* fix

* resolve conflict

* refine code

* fix

* fix

* fix

* fix

* fix

* try fix unit tests

* fix

* tmp save

* fix unit test

* get_infer_param try less return values

* add yinwei fix

---------

Co-authored-by: yinwei <yinwei_hust@163.com>
This commit is contained in:
RuohengMa
2026-04-22 11:01:32 +08:00
committed by GitHub
parent 2edb30c2d0
commit 36d47aa23e
10 changed files with 628 additions and 302 deletions
+101 -159
View File
@@ -25,7 +25,7 @@ namespace api = baidu::xpu::api;
void lod_to_slot_mapping(api::Context* xpu_ctx, void lod_to_slot_mapping(api::Context* xpu_ctx,
paddle::Place place, paddle::Place place,
const std::vector<int32_t>& block_table, const paddle::Tensor& block_table_xpu,
const std::vector<int32_t>& kv_seq_lod, const std::vector<int32_t>& kv_seq_lod,
const std::vector<int32_t>& start_tokens, const std::vector<int32_t>& start_tokens,
const std::vector<int32_t>& real_batch, const std::vector<int32_t>& real_batch,
@@ -35,10 +35,16 @@ void lod_to_slot_mapping(api::Context* xpu_ctx,
int32_t batch_size, int32_t batch_size,
int32_t max_num_blocks_per_seq, int32_t max_num_blocks_per_seq,
int32_t num_speculative_tokens) { int32_t num_speculative_tokens) {
if (token_num <= 0) { int32_t actual_token_num = kv_seq_lod[batch_size];
if (token_num <= 0 || actual_token_num <= 0) {
return; return;
} }
std::vector<int32_t> slot_mapping_vec(token_num, -1);
int ret;
std::vector<int32_t> block_table_idx_vec(actual_token_num);
std::vector<int32_t> seq_offset_vec(actual_token_num);
int32_t idx = 0; int32_t idx = 0;
// For each Batch // For each Batch
for (auto batch_ = 0; batch_ < batch_size; batch_++) { for (auto batch_ = 0; batch_ < batch_size; batch_++) {
@@ -47,20 +53,56 @@ void lod_to_slot_mapping(api::Context* xpu_ctx,
int32_t dst_batch_id = real_batch[batch_]; int32_t dst_batch_id = real_batch[batch_];
// for each token // for each token
for (auto seq_ = seq_start; seq_ < seq_start + seq_len; seq_++) { for (auto seq_ = seq_start; seq_ < seq_start + seq_len; seq_++) {
int32_t table_id = seq_ / block_size; block_table_idx_vec[idx] =
int32_t block_id = seq_ / block_size + dst_batch_id * max_num_blocks_per_seq;
block_table[dst_batch_id * max_num_blocks_per_seq + table_id]; seq_offset_vec[idx] = seq_ % block_size;
int32_t seq_offset = seq_ % block_size;
int32_t dst_token_offset = block_id * block_size + seq_offset;
slot_mapping_vec[idx] = dst_token_offset;
idx++; idx++;
} }
} }
int ret = api::do_host2device(xpu_ctx, auto block_table_idx =
slot_mapping_vec.data(), paddle::empty({actual_token_num}, paddle::DataType::INT32, place);
slot_mapping, auto seq_offset =
token_num * sizeof(int32_t)); paddle::empty({actual_token_num}, paddle::DataType::INT32, place);
ret = api::do_host2device(xpu_ctx,
block_table_idx_vec.data(),
block_table_idx.data<int32_t>(),
actual_token_num * sizeof(int32_t));
PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed."); PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed.");
ret = api::do_host2device(xpu_ctx,
seq_offset_vec.data(),
seq_offset.data<int32_t>(),
actual_token_num * sizeof(int32_t));
PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed.");
// int32_t block_id =
// block_table[dst_batch_id * max_num_blocks_per_seq + table_id];
auto block_id =
paddle::empty({actual_token_num}, paddle::DataType::INT32, place);
auto block_size_tensor =
paddle::full({1}, block_size, paddle::DataType::INT32, place);
ret = api::index_select<int32_t, int32_t>(xpu_ctx,
block_table_xpu.data<int32_t>(),
block_table_idx.data<int32_t>(),
block_id.data<int32_t>(),
{block_table_xpu.numel()},
actual_token_num,
0);
PD_CHECK(ret == api::SUCCESS, "api::index_select failed.");
// int32_t dst_token_offset = block_id * block_size + seq_offset;
ret = api::broadcast_mul<int32_t>(xpu_ctx,
block_id.data<int32_t>(),
block_size_tensor.data<int32_t>(),
block_id.data<int32_t>(),
{actual_token_num},
{1});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed.");
ret = api::broadcast_add<int32_t>(xpu_ctx,
block_id.data<int32_t>(),
seq_offset.data<int32_t>(),
slot_mapping,
{actual_token_num},
{actual_token_num});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_add failed.");
} }
std::vector<paddle::Tensor> GetInferParam( std::vector<paddle::Tensor> GetInferParam(
@@ -68,6 +110,28 @@ std::vector<paddle::Tensor> GetInferParam(
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
paddle::Tensor& encoder_batch_map,
paddle::Tensor& decoder_batch_map,
paddle::Tensor& encoder_batch_idx,
paddle::Tensor& decoder_batch_idx,
paddle::Tensor& encoder_seq_lod,
paddle::Tensor& decoder_seq_lod,
paddle::Tensor& encoder_kv_lod,
paddle::Tensor& prefix_len,
paddle::Tensor& decoder_context_len,
paddle::Tensor& decoder_context_len_cache,
paddle::Tensor& prefix_block_tables,
paddle::Tensor& encoder_batch_map_cpu,
paddle::Tensor& decoder_batch_map_cpu,
paddle::Tensor& encoder_batch_idx_cpu,
paddle::Tensor& decoder_batch_idx_cpu,
paddle::Tensor& encoder_seq_lod_cpu,
paddle::Tensor& decoder_seq_lod_cpu,
paddle::Tensor& encoder_kv_lod_cpu,
paddle::Tensor& prefix_len_cpu,
paddle::Tensor& decoder_context_len_cpu,
paddle::Tensor& decoder_context_len_cache_cpu,
paddle::Tensor& len_info_cpu,
int block_size, int block_size,
int num_speculative_tokens) { int num_speculative_tokens) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
@@ -128,18 +192,18 @@ std::vector<paddle::Tensor> GetInferParam(
if (seq_lens_encoder_vec[i] > 0) { if (seq_lens_encoder_vec[i] > 0) {
enc_batch++; enc_batch++;
int seq_len = seq_lens_encoder_vec[i]; int seq_len = seq_lens_encoder_vec[i];
int prefix_len = seq_lens_decoder_vec[i]; int prefix_len_int = seq_lens_decoder_vec[i];
total_enc_len += seq_len; total_enc_len += seq_len;
max_seq_len = std::max(max_seq_len, seq_len); max_seq_len = std::max(max_seq_len, seq_len);
max_prefix_len = std::max(max_prefix_len, prefix_len); max_prefix_len = std::max(max_prefix_len, prefix_len_int);
max_kv_len = std::max(max_kv_len, seq_len + prefix_len); max_kv_len = std::max(max_kv_len, seq_len + prefix_len_int);
encoder_batch_map_vec[enc_batch - 1] = i; encoder_batch_map_vec[enc_batch - 1] = i;
encoder_batch_idx_vec[enc_batch - 1] = i - batch_offset; encoder_batch_idx_vec[enc_batch - 1] = i - batch_offset;
encoder_seq_lod_vec[enc_batch] = encoder_seq_lod_vec[enc_batch] =
seq_len + encoder_seq_lod_vec[enc_batch - 1]; seq_len + encoder_seq_lod_vec[enc_batch - 1];
encoder_kv_lod_vec[enc_batch] = encoder_kv_lod_vec[enc_batch] =
seq_len + prefix_len + encoder_kv_lod_vec[enc_batch - 1]; seq_len + prefix_len_int + encoder_kv_lod_vec[enc_batch - 1];
prefix_len_vec[enc_batch - 1] = prefix_len; prefix_len_vec[enc_batch - 1] = prefix_len_int;
} else if (seq_lens_decoder_vec[i] > 0 && seq_lens_this_time_vec[i] > 0) { } else if (seq_lens_decoder_vec[i] > 0 && seq_lens_this_time_vec[i] > 0) {
dec_batch++; dec_batch++;
max_dec_len = std::max(max_dec_len, seq_lens_this_time_vec[i]); max_dec_len = std::max(max_dec_len, seq_lens_this_time_vec[i]);
@@ -186,42 +250,6 @@ std::vector<paddle::Tensor> GetInferParam(
prefix_block_num_per_seq = -1; prefix_block_num_per_seq = -1;
} }
auto encoder_batch_map = paddle::empty({encoder_batch_map_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_batch_map = paddle::empty({decoder_batch_map_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto encoder_batch_idx = paddle::empty({encoder_batch_idx_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_batch_idx = paddle::empty({decoder_batch_idx_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto encoder_seq_lod = paddle::empty({encoder_seq_lod_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_seq_lod = paddle::empty({decoder_seq_lod_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto encoder_kv_lod = paddle::empty({encoder_kv_lod_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto prefix_len = paddle::empty({prefix_len_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_context_len = paddle::empty({decoder_context_len_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_context_len_cache =
paddle::empty({decoder_context_len_cache_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto prefix_block_tables =
paddle::empty({block_bs, block_num_per_seq}, // full size
seq_lens_encoder.type(),
seq_lens_encoder.place());
// for store_paged_kv_cache of cudagraph mode // for store_paged_kv_cache of cudagraph mode
// if slot_mapping is -1, store_paged_kv_cache will not write to kv cache // if slot_mapping is -1, store_paged_kv_cache will not write to kv cache
paddle::Tensor slot_mapping_enc = paddle::full( paddle::Tensor slot_mapping_enc = paddle::full(
@@ -232,21 +260,15 @@ std::vector<paddle::Tensor> GetInferParam(
-1, -1,
paddle::DataType::INT32, paddle::DataType::INT32,
seq_lens_decoder.place()); seq_lens_decoder.place());
if (FLAGS_encoder_splice || FLAGS_decoder_splice) {
std::vector<int32_t> block_tables_vec(block_bs * block_num_per_seq);
r = xpu_memcpy(block_tables_vec.data(),
block_tables.data<int32_t>(),
sizeof(int32_t) * block_bs * block_num_per_seq,
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
if (FLAGS_encoder_splice) { if (FLAGS_encoder_splice) {
lod_to_slot_mapping(xpu_ctx->x_context(), lod_to_slot_mapping(xpu_ctx->x_context(),
seq_lens_encoder.place(), seq_lens_encoder.place(),
block_tables_vec, block_tables,
encoder_seq_lod_vec, encoder_seq_lod_vec,
prefix_len_vec, prefix_len_vec,
encoder_batch_map_vec, encoder_batch_map_vec,
slot_mapping_enc.data<int32_t>(), slot_mapping_enc.data<int32_t>(),
total_enc_len, slot_mapping_enc.numel(),
block_size, block_size,
enc_batch, enc_batch,
block_num_per_seq, block_num_per_seq,
@@ -255,49 +277,17 @@ std::vector<paddle::Tensor> GetInferParam(
if (FLAGS_decoder_splice) { if (FLAGS_decoder_splice) {
lod_to_slot_mapping(xpu_ctx->x_context(), lod_to_slot_mapping(xpu_ctx->x_context(),
seq_lens_decoder.place(), seq_lens_decoder.place(),
block_tables_vec, block_tables,
decoder_seq_lod_vec, decoder_seq_lod_vec,
decoder_context_len_cache_vec, decoder_context_len_cache_vec,
decoder_batch_map_vec, decoder_batch_map_vec,
slot_mapping_dec.data<int32_t>(), slot_mapping_dec.data<int32_t>(),
bsz * (1 + num_speculative_tokens), slot_mapping_dec.numel(),
block_size, block_size,
dec_batch, dec_batch,
block_num_per_seq, block_num_per_seq,
num_speculative_tokens); num_speculative_tokens);
} }
}
auto encoder_batch_map_cpu = paddle::empty({encoder_batch_map_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_batch_map_cpu = paddle::empty({decoder_batch_map_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto encoder_batch_idx_cpu = paddle::empty({encoder_batch_idx_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_batch_idx_cpu = paddle::empty({decoder_batch_idx_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto encoder_seq_lod_cpu = paddle::empty({encoder_seq_lod_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_seq_lod_cpu = paddle::empty({decoder_seq_lod_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto encoder_kv_lod_cpu = paddle::empty(
{encoder_kv_lod_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace());
auto prefix_len_cpu = paddle::empty(
{prefix_len_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace());
auto decoder_context_len_cpu = paddle::empty({decoder_context_len_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_context_len_cache_cpu =
paddle::empty({decoder_context_len_cache_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
ret = api::do_host2device( ret = api::do_host2device(
xpu_ctx->x_context(), xpu_ctx->x_context(),
@@ -400,65 +390,25 @@ std::vector<paddle::Tensor> GetInferParam(
max_kv_len, max_kv_len,
prefix_block_num_per_seq, prefix_block_num_per_seq,
max_dec_len}; max_dec_len};
auto len_info_cpu =
paddle::empty({7}, seq_lens_encoder.type(), paddle::CPUPlace());
std::memcpy(len_info_cpu.data<int32_t>(), std::memcpy(len_info_cpu.data<int32_t>(),
len_info_vec.data(), len_info_vec.data(),
sizeof(int32_t) * len_info_vec.size()); sizeof(int32_t) * len_info_vec.size());
return {encoder_batch_map, return {slot_mapping_enc, slot_mapping_dec};
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
slot_mapping_enc,
slot_mapping_dec};
} }
std::vector<std::vector<int64_t>> GetInferParamInferShape( std::vector<std::vector<int64_t>> GetInferParamInferShape(
const std::vector<int64_t>& seq_lens_encoder_shape, const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape, const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape, const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& block_tables_shape) { const std::vector<int64_t>& block_tables_shape,
return {seq_lens_encoder_shape, int num_speculative_tokens) {
seq_lens_encoder_shape, // Return shapes for slot_mapping_enc and slot_mapping_dec
seq_lens_encoder_shape, // slot_mapping_enc shape depends on encoder token count (unknown at shape
seq_lens_encoder_shape, // inference time) slot_mapping_dec shape depends on batch size and
{seq_lens_encoder_shape[0] + 1}, // speculative token count
{seq_lens_encoder_shape[0] + 1}, return {{-1}, {seq_lens_encoder_shape[0] * (1 + num_speculative_tokens)}};
{seq_lens_encoder_shape[0] + 1},
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
block_tables_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
{seq_lens_encoder_shape[0] + 1},
{seq_lens_encoder_shape[0] + 1},
{seq_lens_encoder_shape[0] + 1},
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
{7}};
} }
std::vector<paddle::DataType> GetInferParamInferDtype( std::vector<paddle::DataType> GetInferParamInferDtype(
@@ -466,23 +416,16 @@ std::vector<paddle::DataType> GetInferParamInferDtype(
const paddle::DataType& seq_lens_decoder_dtype, const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype, const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& block_tables_dtype) { const paddle::DataType& block_tables_dtype) {
return { // Return dtypes for slot_mapping_enc and slot_mapping_dec (both INT32)
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype, return {paddle::DataType::INT32, paddle::DataType::INT32};
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, block_tables_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype};
} }
PD_BUILD_OP(get_infer_param) PD_BUILD_OP(get_infer_param)
.Inputs({"seq_lens_encoder", .Inputs({"seq_lens_encoder",
"seq_lens_decoder", "seq_lens_decoder",
"seq_lens_this_time", "seq_lens_this_time",
"block_tables"}) "block_tables",
.Outputs({"encoder_batch_map", "encoder_batch_map",
"decoder_batch_map", "decoder_batch_map",
"encoder_batch_idx", "encoder_batch_idx",
"decoder_batch_idx", "decoder_batch_idx",
@@ -503,9 +446,8 @@ PD_BUILD_OP(get_infer_param)
"prefix_len_cpu", "prefix_len_cpu",
"decoder_context_len_cpu", "decoder_context_len_cpu",
"decoder_context_len_cache_cpu", "decoder_context_len_cache_cpu",
"len_info_cpu", "len_info_cpu"})
"slot_mapping_enc", .Outputs({"slot_mapping_enc", "slot_mapping_dec"})
"slot_mapping_dec"})
.SetKernelFn(PD_KERNEL(GetInferParam)) .SetKernelFn(PD_KERNEL(GetInferParam))
.Attrs({"block_size: int", "num_speculative_tokens: int"}) .Attrs({"block_size: int", "num_speculative_tokens: int"})
.SetInferShapeFn(PD_INFER_SHAPE(GetInferParamInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(GetInferParamInferShape))
@@ -478,6 +478,28 @@ std::vector<paddle::Tensor> GetInferParam(
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
paddle::Tensor& encoder_batch_map,
paddle::Tensor& decoder_batch_map,
paddle::Tensor& encoder_batch_idx,
paddle::Tensor& decoder_batch_idx,
paddle::Tensor& encoder_seq_lod,
paddle::Tensor& decoder_seq_lod,
paddle::Tensor& encoder_kv_lod,
paddle::Tensor& prefix_len,
paddle::Tensor& decoder_context_len,
paddle::Tensor& decoder_context_len_cache,
paddle::Tensor& prefix_block_tables,
paddle::Tensor& encoder_batch_map_cpu,
paddle::Tensor& decoder_batch_map_cpu,
paddle::Tensor& encoder_batch_idx_cpu,
paddle::Tensor& decoder_batch_idx_cpu,
paddle::Tensor& encoder_seq_lod_cpu,
paddle::Tensor& decoder_seq_lod_cpu,
paddle::Tensor& encoder_kv_lod_cpu,
paddle::Tensor& prefix_len_cpu,
paddle::Tensor& decoder_context_len_cpu,
paddle::Tensor& decoder_context_len_cache_cpu,
paddle::Tensor& len_info_cpu,
int block_size, int block_size,
int num_speculative_tokens); int num_speculative_tokens);
@@ -1052,6 +1074,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("seq_lens_decoder"), py::arg("seq_lens_decoder"),
py::arg("seq_lens_this_time"), py::arg("seq_lens_this_time"),
py::arg("block_tables"), py::arg("block_tables"),
py::arg("encoder_batch_map"),
py::arg("decoder_batch_map"),
py::arg("encoder_batch_idx"),
py::arg("decoder_batch_idx"),
py::arg("encoder_seq_lod"),
py::arg("decoder_seq_lod"),
py::arg("encoder_kv_lod"),
py::arg("prefix_len"),
py::arg("decoder_context_len"),
py::arg("decoder_context_len_cache"),
py::arg("prefix_block_tables"),
py::arg("encoder_batch_map_cpu"),
py::arg("decoder_batch_map_cpu"),
py::arg("encoder_batch_idx_cpu"),
py::arg("decoder_batch_idx_cpu"),
py::arg("encoder_seq_lod_cpu"),
py::arg("decoder_seq_lod_cpu"),
py::arg("encoder_kv_lod_cpu"),
py::arg("prefix_len_cpu"),
py::arg("decoder_context_len_cpu"),
py::arg("decoder_context_len_cache_cpu"),
py::arg("len_info_cpu"),
py::arg("block_size"), py::arg("block_size"),
py::arg("num_speculative_tokens"), py::arg("num_speculative_tokens"),
"Get infer parameters for block attention in XPU"); "Get infer parameters for block attention in XPU");
@@ -16,6 +16,7 @@ import unittest # 导入 unittest
import numpy as np import numpy as np
import paddle import paddle
from utils import init_inplace_tensor
from fastdeploy.model_executor.ops.xpu import ( from fastdeploy.model_executor.ops.xpu import (
adjust_batch, adjust_batch,
@@ -33,11 +34,8 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
seq_lens_this_time = paddle.to_tensor(seq_lens_this_time_data, dtype="int32") seq_lens_this_time = paddle.to_tensor(seq_lens_this_time_data, dtype="int32")
bsz = seq_lens_this_time.shape[0] bsz = seq_lens_this_time.shape[0]
cum_offsets = paddle.zeros(bsz, dtype="int32")
block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8)) block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8))
infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64)
( (
encoder_batch_map, encoder_batch_map,
decoder_batch_map, decoder_batch_map,
@@ -45,23 +43,56 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
decoder_batch_idx, decoder_batch_idx,
encoder_seq_lod, encoder_seq_lod,
decoder_seq_lod, decoder_seq_lod,
_, encoder_kv_lod,
_, prefix_len,
_, decoder_context_len,
_, decoder_context_len_cache,
_, prefix_block_tables,
encoder_batch_map_cpu, encoder_batch_map_cpu,
decoder_batch_map_cpu, decoder_batch_map_cpu,
encoder_batch_idx_cpu, encoder_batch_idx_cpu,
decoder_batch_idx_cpu, decoder_batch_idx_cpu,
encoder_seq_lod_cpu, encoder_seq_lod_cpu,
decoder_seq_lod_cpu, decoder_seq_lod_cpu,
_, encoder_kv_lod_cpu,
_, prefix_len_cpu,
_, decoder_context_len_cpu,
_, decoder_context_len_cache_cpu,
len_info_cpu, len_info_cpu,
) = infer_params ) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape)
(
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
block_table,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
0,
)
token_num = seq_lens_this_time.sum().cpu().item() token_num = seq_lens_this_time.sum().cpu().item()
hidden_dim = 8192 hidden_dim = 8192
@@ -72,7 +103,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
# 测试 adjust_batch # 测试 adjust_batch
adjusted_output = adjust_batch( adjusted_output = adjust_batch(
input_tensor, input_tensor,
cum_offsets,
encoder_seq_lod, encoder_seq_lod,
decoder_seq_lod, decoder_seq_lod,
encoder_batch_idx, encoder_batch_idx,
@@ -88,7 +118,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
adjusted_output_cpu = adjust_batch( adjusted_output_cpu = adjust_batch(
input_tensor.cpu(), input_tensor.cpu(),
cum_offsets,
encoder_seq_lod, encoder_seq_lod,
decoder_seq_lod, decoder_seq_lod,
encoder_batch_idx, encoder_batch_idx,
@@ -110,7 +139,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
# 测试 gather_next_token # 测试 gather_next_token
gather_out = gather_next_token( gather_out = gather_next_token(
adjusted_output, adjusted_output,
cum_offsets,
encoder_seq_lod, encoder_seq_lod,
decoder_seq_lod, decoder_seq_lod,
encoder_batch_map, encoder_batch_map,
@@ -126,7 +154,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
gather_out_cpu = gather_next_token( gather_out_cpu = gather_next_token(
adjusted_output.cpu(), adjusted_output.cpu(),
cum_offsets,
encoder_seq_lod, encoder_seq_lod,
decoder_seq_lod, decoder_seq_lod,
encoder_batch_map, encoder_batch_map,
@@ -16,6 +16,7 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from utils import init_inplace_tensor
from fastdeploy.model_executor.ops.xpu import ( from fastdeploy.model_executor.ops.xpu import (
adjust_batch, adjust_batch,
@@ -33,8 +34,6 @@ def _run_test_base(seq_lens_this_time_data):
cum_offsets = paddle.zeros(bsz, dtype="int32") cum_offsets = paddle.zeros(bsz, dtype="int32")
block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8)) block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8))
infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64)
( (
encoder_batch_map, encoder_batch_map,
decoder_batch_map, decoder_batch_map,
@@ -42,23 +41,56 @@ def _run_test_base(seq_lens_this_time_data):
decoder_batch_idx, decoder_batch_idx,
encoder_seq_lod, encoder_seq_lod,
decoder_seq_lod, decoder_seq_lod,
_, encoder_kv_lod,
_, prefix_len,
_, decoder_context_len,
_, decoder_context_len_cache,
_, prefix_block_tables,
encoder_batch_map_cpu, encoder_batch_map_cpu,
decoder_batch_map_cpu, decoder_batch_map_cpu,
encoder_batch_idx_cpu, encoder_batch_idx_cpu,
decoder_batch_idx_cpu, decoder_batch_idx_cpu,
encoder_seq_lod_cpu, encoder_seq_lod_cpu,
decoder_seq_lod_cpu, decoder_seq_lod_cpu,
_, encoder_kv_lod_cpu,
_, prefix_len_cpu,
_, decoder_context_len_cpu,
_, decoder_context_len_cache_cpu,
len_info_cpu, len_info_cpu,
) = infer_params ) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape)
(
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
block_table,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
0,
)
token_num = seq_lens_this_time.sum().cpu().item() token_num = seq_lens_this_time.sum().cpu().item()
hidden_dim = 8192 hidden_dim = 8192
@@ -68,7 +100,6 @@ def _run_test_base(seq_lens_this_time_data):
# test adjust_batch # test adjust_batch
adjusted_output = adjust_batch( adjusted_output = adjust_batch(
input_tensor, input_tensor,
cum_offsets,
encoder_seq_lod, encoder_seq_lod,
decoder_seq_lod, decoder_seq_lod,
encoder_batch_idx, encoder_batch_idx,
@@ -84,7 +115,6 @@ def _run_test_base(seq_lens_this_time_data):
adjusted_output_cpu = adjust_batch( adjusted_output_cpu = adjust_batch(
input_tensor.cpu(), input_tensor.cpu(),
cum_offsets,
encoder_seq_lod, encoder_seq_lod,
decoder_seq_lod, decoder_seq_lod,
encoder_batch_idx, encoder_batch_idx,
+64 -3
View File
@@ -16,6 +16,7 @@ import random
import numpy as np import numpy as np
import paddle import paddle
from utils import init_inplace_tensor
# block_attn_fused is deprecated and should be removed in the future # block_attn_fused is deprecated and should be removed in the future
from fastdeploy.model_executor.ops.xpu import ( from fastdeploy.model_executor.ops.xpu import (
@@ -76,6 +77,7 @@ def run_prefix_cache_block_attn(
# prefix cache block attn # prefix cache block attn
seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32") seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32") seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
( (
encoder_batch_map, encoder_batch_map,
decoder_batch_map, decoder_batch_map,
@@ -99,11 +101,40 @@ def run_prefix_cache_block_attn(
decoder_context_len_cpu, decoder_context_len_cpu,
decoder_context_len_cache_cpu, decoder_context_len_cache_cpu,
len_info_cpu, len_info_cpu,
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
(
slot_mapping_enc, slot_mapping_enc,
slot_mapping_dec, slot_mapping_dec,
) = get_infer_param( ) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens seq_lens_encoder,
) # block_size seq_lens_decoder,
seq_lens_this_time,
block_tables,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
num_speculative_tokens,
)
qkv_prefix = qkv[hit_prefix_len:] qkv_prefix = qkv[hit_prefix_len:]
attn_out_prefix_cache = block_attn_func( attn_out_prefix_cache = block_attn_func(
qkv_prefix, qkv_prefix,
@@ -194,6 +225,7 @@ def run_block_attn(
seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32") seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32")
block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32") block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32")
block_tables = block_tables.reshape((block_batch, max_block_per_seq)) block_tables = block_tables.reshape((block_batch, max_block_per_seq))
( (
encoder_batch_map, encoder_batch_map,
decoder_batch_map, decoder_batch_map,
@@ -217,10 +249,39 @@ def run_block_attn(
decoder_context_len_cpu, decoder_context_len_cpu,
decoder_context_len_cache_cpu, decoder_context_len_cache_cpu,
len_info_cpu, len_info_cpu,
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
(
slot_mapping_enc, slot_mapping_enc,
slot_mapping_dec, slot_mapping_dec,
) = get_infer_param( ) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
block_tables,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
num_speculative_tokens,
) )
qkv = paddle.uniform( qkv = paddle.uniform(
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed
@@ -14,6 +14,7 @@
import numpy as np import numpy as np
import paddle import paddle
from utils import init_inplace_tensor
from fastdeploy.model_executor.ops.xpu import block_attn_fused, get_infer_param from fastdeploy.model_executor.ops.xpu import block_attn_fused, get_infer_param
@@ -24,6 +25,7 @@ seq_len = 128
block_batch = 5 block_batch = 5
max_block_per_seq = 128 max_block_per_seq = 128
block_size = 64 block_size = 64
num_speculative_tokens = 0
seq_lens_encoder = paddle.to_tensor([128, 0, 0, 0, 0], dtype="int32") seq_lens_encoder = paddle.to_tensor([128, 0, 0, 0, 0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32") seq_lens_decoder = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32")
@@ -53,11 +55,40 @@ block_tables = block_tables.reshape((block_batch, max_block_per_seq))
decoder_context_len_cpu, decoder_context_len_cpu,
decoder_context_len_cache_cpu, decoder_context_len_cache_cpu,
len_info_cpu, len_info_cpu,
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
(
slot_mapping_enc, slot_mapping_enc,
slot_mapping_dec, slot_mapping_dec,
) = get_infer_param( ) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0 seq_lens_encoder,
) # block_size seq_lens_decoder,
seq_lens_this_time,
block_tables,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
num_speculative_tokens,
)
qkv = paddle.uniform( qkv = paddle.uniform(
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim],
@@ -247,11 +278,40 @@ seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
decoder_context_len_cpu, decoder_context_len_cpu,
decoder_context_len_cache_cpu, decoder_context_len_cache_cpu,
len_info_cpu, len_info_cpu,
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
(
slot_mapping_enc, slot_mapping_enc,
slot_mapping_dec, slot_mapping_dec,
) = get_infer_param( ) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0 seq_lens_encoder,
) # block_size seq_lens_decoder,
seq_lens_this_time,
block_tables,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
num_speculative_tokens,
)
qkv_prefix = qkv[hit_prefix_len:] qkv_prefix = qkv[hit_prefix_len:]
attn_out_prefix_cache = block_attn_fused( attn_out_prefix_cache = block_attn_fused(
@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from utils import init_inplace_tensor
from fastdeploy.model_executor.ops.xpu import get_infer_param from fastdeploy.model_executor.ops.xpu import get_infer_param
@@ -21,6 +22,7 @@ seq_lens_decoder = paddle.to_tensor([0, 5, 0, 25, 64], dtype="int32")
seq_lens_this_time = paddle.to_tensor([100, 1, 0, 1, 300], dtype="int32") seq_lens_this_time = paddle.to_tensor([100, 1, 0, 1, 300], dtype="int32")
block_table = paddle.arange(0, 40, dtype="int32") block_table = paddle.arange(0, 40, dtype="int32")
block_table = block_table.reshape((5, 8)) block_table = block_table.reshape((5, 8))
( (
encoder_batch_map, encoder_batch_map,
decoder_batch_map, decoder_batch_map,
@@ -44,9 +46,40 @@ block_table = block_table.reshape((5, 8))
decoder_context_len_cpu, decoder_context_len_cpu,
decoder_context_len_cache_cpu, decoder_context_len_cache_cpu,
len_info_cpu, len_info_cpu,
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape)
(
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param( ) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64 seq_lens_encoder,
) # block_size seq_lens_decoder,
seq_lens_this_time,
block_table,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
0,
)
print("block_table", block_table) print("block_table", block_table)
print("encoder_batch_map", encoder_batch_map) # [0, 4, 0, 0, 0] print("encoder_batch_map", encoder_batch_map) # [0, 4, 0, 0, 0]
+54
View File
@@ -0,0 +1,54 @@
import paddle
def init_inplace_tensor(bsz, block_tables_shape):
encoder_batch_map = paddle.empty(bsz, dtype="int32")
decoder_batch_map = paddle.empty(bsz, dtype="int32")
encoder_batch_idx = paddle.empty(bsz, dtype="int32")
decoder_batch_idx = paddle.empty(bsz, dtype="int32")
encoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
decoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
encoder_kv_lod = paddle.empty(bsz + 1, dtype="int32")
prefix_len = paddle.empty(bsz, dtype="int32")
decoder_context_len = paddle.empty(bsz, dtype="int32")
decoder_context_len_cache = paddle.empty(bsz, dtype="int32")
prefix_block_tables = paddle.empty(block_tables_shape, dtype="int32")
encoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
decoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
encoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
decoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
encoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
decoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
encoder_kv_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
prefix_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
decoder_context_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
decoder_context_len_cache_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
len_info_cpu = paddle.empty(7, dtype="int32", device="cpu")
return (
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
)
+27
View File
@@ -289,6 +289,33 @@ class XPUForwardMeta(ForwardMeta):
# #
slot_mapping_dec: Optional[paddle.Tensor] = None slot_mapping_dec: Optional[paddle.Tensor] = None
def init_inplace_tensor(self, bsz, block_tables_shape):
self.encoder_batch_map = paddle.empty(bsz, dtype="int32")
self.decoder_batch_map = paddle.empty(bsz, dtype="int32")
self.encoder_batch_idx = paddle.empty(bsz, dtype="int32")
self.decoder_batch_idx = paddle.empty(bsz, dtype="int32")
self.encoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
self.decoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
self.encoder_kv_lod = paddle.empty(bsz + 1, dtype="int32")
self.prefix_len = paddle.empty(bsz, dtype="int32")
self.decoder_context_len = paddle.empty(bsz, dtype="int32")
self.decoder_context_len_cache = paddle.empty(bsz, dtype="int32")
self.prefix_block_tables = paddle.empty(block_tables_shape, dtype="int32")
self.encoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
self.decoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
self.encoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
self.decoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
self.encoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
self.decoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
self.encoder_kv_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
self.prefix_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
self.decoder_context_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
self.decoder_context_len_cache_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
self.len_info_cpu = paddle.empty(7, dtype="int32", device="cpu")
def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None): def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None):
""" """
Synchronize attributes from another XPUForwardMeta object Synchronize attributes from another XPUForwardMeta object
@@ -158,6 +158,23 @@ def xpu_pre_process(
share_inputs["cu_seqlens_q"] = cu_seqlens_q share_inputs["cu_seqlens_q"] = cu_seqlens_q
share_inputs["cu_seqlens_k"] = cu_seqlens_k share_inputs["cu_seqlens_k"] = cu_seqlens_k
if use_cudagraph and forward_meta is not None:
forward_meta.ids_remove_padding.copy_(share_inputs["ids_remove_padding"], False)
forward_meta.rotary_embs.copy_(share_inputs["rope_emb"], False)
forward_meta.attn_backend = None
forward_meta.seq_lens_encoder.copy_(share_inputs["seq_lens_encoder"], False)
forward_meta.seq_lens_decoder.copy_(share_inputs["seq_lens_decoder"], False)
forward_meta.seq_lens_this_time.copy_(share_inputs["seq_lens_this_time"], False)
forward_meta.batch_id_per_token.copy_(share_inputs["batch_id_per_token"], False)
forward_meta.cu_seqlens_q.copy_(share_inputs["cu_seqlens_q"], False)
forward_meta.cu_seqlens_k.copy_(share_inputs["cu_seqlens_k"], False)
forward_meta.block_tables.copy_(share_inputs["block_tables"], False)
forward_meta.caches = share_inputs["caches"]
forward_meta.max_num_seqs = share_inputs["seq_lens_this_time"].shape[0]
forward_meta.is_speculative = use_speculate_method
xpu_forward_meta = forward_meta
else:
xpu_forward_meta = XPUForwardMeta( xpu_forward_meta = XPUForwardMeta(
ids_remove_padding=share_inputs["ids_remove_padding"], ids_remove_padding=share_inputs["ids_remove_padding"],
rotary_embs=share_inputs["rope_emb"], rotary_embs=share_inputs["rope_emb"],
@@ -173,55 +190,81 @@ def xpu_pre_process(
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0], max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
is_speculative=use_speculate_method, is_speculative=use_speculate_method,
) )
xpu_forward_meta.init_inplace_tensor(seq_lens_encoder.shape[0], share_inputs["block_tables"].shape)
block_tables = xpu_forward_meta.block_tables
encoder_batch_map = xpu_forward_meta.encoder_batch_map
decoder_batch_map = xpu_forward_meta.decoder_batch_map
encoder_batch_idx = xpu_forward_meta.encoder_batch_idx
decoder_batch_idx = xpu_forward_meta.decoder_batch_idx
encoder_seq_lod = xpu_forward_meta.encoder_seq_lod
decoder_seq_lod = xpu_forward_meta.decoder_seq_lod
encoder_kv_lod = xpu_forward_meta.encoder_kv_lod
prefix_len = xpu_forward_meta.prefix_len
decoder_context_len = xpu_forward_meta.decoder_context_len
decoder_context_len_cache = xpu_forward_meta.decoder_context_len_cache
prefix_block_tables = xpu_forward_meta.prefix_block_tables
encoder_batch_map_cpu = xpu_forward_meta.encoder_batch_map_cpu
decoder_batch_map_cpu = xpu_forward_meta.decoder_batch_map_cpu
encoder_batch_idx_cpu = xpu_forward_meta.encoder_batch_idx_cpu
decoder_batch_idx_cpu = xpu_forward_meta.decoder_batch_idx_cpu
encoder_seq_lod_cpu = xpu_forward_meta.encoder_seq_lod_cpu
decoder_seq_lod_cpu = xpu_forward_meta.decoder_seq_lod_cpu
encoder_kv_lod_cpu = xpu_forward_meta.encoder_kv_lod_cpu
prefix_len_cpu = xpu_forward_meta.prefix_len_cpu
decoder_context_len_cpu = xpu_forward_meta.decoder_context_len_cpu
decoder_context_len_cache_cpu = xpu_forward_meta.decoder_context_len_cache_cpu
len_info_cpu = xpu_forward_meta.len_info_cpu
( (
xpu_forward_meta.encoder_batch_map, slot_mapping_enc,
xpu_forward_meta.decoder_batch_map, slot_mapping_dec,
xpu_forward_meta.encoder_batch_idx,
xpu_forward_meta.decoder_batch_idx,
xpu_forward_meta.encoder_seq_lod,
xpu_forward_meta.decoder_seq_lod,
xpu_forward_meta.encoder_kv_lod,
xpu_forward_meta.prefix_len,
xpu_forward_meta.decoder_context_len,
xpu_forward_meta.decoder_context_len_cache,
xpu_forward_meta.prefix_block_tables,
xpu_forward_meta.encoder_batch_map_cpu,
xpu_forward_meta.decoder_batch_map_cpu,
xpu_forward_meta.encoder_batch_idx_cpu,
xpu_forward_meta.decoder_batch_idx_cpu,
xpu_forward_meta.encoder_seq_lod_cpu,
xpu_forward_meta.decoder_seq_lod_cpu,
xpu_forward_meta.encoder_kv_lod_cpu,
xpu_forward_meta.prefix_len_cpu,
xpu_forward_meta.decoder_context_len_cpu,
xpu_forward_meta.decoder_context_len_cache_cpu,
xpu_forward_meta.len_info_cpu,
xpu_forward_meta.slot_mapping_enc,
xpu_forward_meta.slot_mapping_dec,
) = get_infer_param( ) = get_infer_param(
seq_lens_encoder, seq_lens_encoder,
seq_lens_decoder, seq_lens_decoder,
seq_lens_this_time, seq_lens_this_time,
xpu_forward_meta.block_tables, block_tables,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
block_size, block_size,
num_speculative_tokens, num_speculative_tokens,
) )
xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0]
xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1]
xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2]
adjusted_input = adjust_batch( adjusted_input = adjust_batch(
ids_remove_padding.reshape([-1, 1]), ids_remove_padding.reshape([-1, 1]),
xpu_forward_meta.encoder_seq_lod, encoder_seq_lod,
xpu_forward_meta.decoder_seq_lod, decoder_seq_lod,
xpu_forward_meta.encoder_batch_idx, encoder_batch_idx,
xpu_forward_meta.decoder_batch_idx, decoder_batch_idx,
xpu_forward_meta.encoder_seq_lod_cpu, encoder_seq_lod_cpu,
xpu_forward_meta.decoder_seq_lod_cpu, decoder_seq_lod_cpu,
xpu_forward_meta.encoder_batch_idx_cpu, encoder_batch_idx_cpu,
xpu_forward_meta.decoder_batch_idx_cpu, decoder_batch_idx_cpu,
xpu_forward_meta.len_info_cpu, len_info_cpu,
None, # output_padding_offset None, # output_padding_offset
-1, # max bs -1, # max bs
) )
@@ -229,16 +272,21 @@ def xpu_pre_process(
adjusted_input = adjusted_input.squeeze(1) adjusted_input = adjusted_input.squeeze(1)
share_inputs["ids_remove_padding"].copy_(adjusted_input, False) share_inputs["ids_remove_padding"].copy_(adjusted_input, False)
xpu_forward_meta.enc_batch = len_info_cpu[0]
xpu_forward_meta.dec_batch = len_info_cpu[1]
xpu_forward_meta.total_enc_len = len_info_cpu[2]
xpu_forward_meta.ids_remove_padding = adjusted_input xpu_forward_meta.ids_remove_padding = adjusted_input
# Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends # Set xpu_forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
xpu_forward_meta.is_profiling = is_profiling xpu_forward_meta.is_profiling = is_profiling
if use_cudagraph:
if forward_meta is None: # prefill does not use cudagraph, inplace copy is not needed
return xpu_forward_meta xpu_forward_meta.slot_mapping_enc = slot_mapping_enc
else: if use_cudagraph and forward_meta is not None:
forward_meta.copy_from(xpu_forward_meta) xpu_forward_meta.slot_mapping_dec.copy_(slot_mapping_dec, False)
return forward_meta
else: else:
xpu_forward_meta.slot_mapping_dec = slot_mapping_dec
return xpu_forward_meta return xpu_forward_meta