mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Refactor get_padding_offset to single kernel. (#7029)
* [XPU] Refactor get_padding_offset to single kernel. * add unittest. * fix codestyle. * remove cum_offsets_now. * remove max_len.
This commit is contained in:
@@ -16,29 +16,29 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &token_num,
|
||||
const paddle::Tensor &seq_len) {
|
||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& seq_len,
|
||||
const int64_t cpu_token_num) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int seq_length = input_ids_shape[1];
|
||||
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
|
||||
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
|
||||
const int token_num_data = static_cast<int>(cpu_token_num);
|
||||
|
||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||
auto x_remove_padding = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
|
||||
auto batch_id_per_token = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cum_offsets_out =
|
||||
paddle::full({bsz}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_q =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
|
||||
if (token_num_data > 0) {
|
||||
int r =
|
||||
fastdeploy::plugin::get_padding_offset(xpu_ctx->x_context(),
|
||||
@@ -48,7 +48,6 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
cu_seqlens_k.data<int>(),
|
||||
x_remove_padding.data<int64_t>(),
|
||||
input_ids.data<int64_t>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_len.data<int>(),
|
||||
seq_length,
|
||||
bsz,
|
||||
@@ -64,20 +63,15 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
|
||||
const std::vector<int64_t> &input_ids_shape,
|
||||
const std::vector<int64_t> &cum_offsets_shape,
|
||||
const std::vector<int64_t> &token_num_shape,
|
||||
const std::vector<int64_t> &seq_len_shape) {
|
||||
const std::vector<int64_t>& input_ids_shape,
|
||||
const std::vector<int64_t>& seq_len_shape) {
|
||||
int64_t bsz = seq_len_shape[0];
|
||||
int64_t seq_len = input_ids_shape[1];
|
||||
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
const paddle::DataType &input_ids_dtype,
|
||||
const paddle::DataType &cum_offsets_dtype,
|
||||
const paddle::DataType &token_num_dtype,
|
||||
const paddle::DataType &seq_len_dtype) {
|
||||
const paddle::DataType& input_ids_dtype,
|
||||
const paddle::DataType& seq_len_dtype) {
|
||||
return {input_ids_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
@@ -86,12 +80,13 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_padding_offset)
|
||||
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
||||
.Inputs({"input_ids", "seq_len"})
|
||||
.Outputs({"x_remove_padding",
|
||||
"cum_offsets_out",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"cu_seqlens_k"})
|
||||
.Attrs({"cpu_token_num: int64_t"})
|
||||
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype));
|
||||
|
||||
@@ -453,9 +453,8 @@ void GetOutputEPDynamic(const paddle::Tensor& x,
|
||||
int msg_queue_id);
|
||||
|
||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& token_num,
|
||||
const paddle::Tensor& seq_len);
|
||||
const paddle::Tensor& seq_len,
|
||||
const int64_t cpu_token_num);
|
||||
|
||||
void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
@@ -996,9 +995,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("get_padding_offset",
|
||||
&GetPaddingOffset,
|
||||
py::arg("input_ids"),
|
||||
py::arg("cum_offsets"),
|
||||
py::arg("token_num"),
|
||||
py::arg("seq_len"),
|
||||
py::arg("cpu_token_num"),
|
||||
"get padding offset function");
|
||||
|
||||
m.def("init_kv_signal_per_query",
|
||||
|
||||
@@ -68,13 +68,12 @@ DLL_EXPORT int token_penalty_multi_scores(api::Context* ctx,
|
||||
const int64_t length_bad_words);
|
||||
|
||||
DLL_EXPORT int get_padding_offset(api::Context* ctx,
|
||||
int* padding_offset,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
int64_t* x_remove_padding,
|
||||
const int64_t* input_ids,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs,
|
||||
|
||||
@@ -1,50 +1,123 @@
|
||||
#include "xpu/kernel/cluster.h"
|
||||
#include "xpu/kernel/cluster_partition.h"
|
||||
#include "xpu/kernel/cluster_primitive.h"
|
||||
#include "xpu/kernel/cluster_simd.h"
|
||||
|
||||
namespace fd_xpu3 {
|
||||
|
||||
__global__ void get_padding_offset(int *batch_id_per_token,
|
||||
int *cum_offsets_out,
|
||||
int *cu_seqlens_q,
|
||||
int *cu_seqlens_k,
|
||||
const int *cum_offsets,
|
||||
const int *seq_lens,
|
||||
#define MAX_BATCH_SIZE 1024
|
||||
|
||||
static inline __device__ int v_reduce_sum_int32(int32x16_t& v0) {
|
||||
auto v1 = vsrlp_int32x16(1 << 8, v0);
|
||||
v0 = vvadd_int32x16(v0, v1);
|
||||
v1 = vsrlp_int32x16(1 << 7, v0);
|
||||
v0 = vvadd_int32x16(v0, v1);
|
||||
v1 = vsrlp_int32x16(1 << 6, v0);
|
||||
v0 = vvadd_int32x16(v0, v1);
|
||||
v1 = vsrlp_int32x16(1 << 5, v0);
|
||||
v0 = vvadd_int32x16(v0, v1);
|
||||
return vextract_int32x16(v0, 1);
|
||||
}
|
||||
|
||||
inline __device__ int primitive_reduce_sum_sm(__shared_ptr__ const int* x,
|
||||
int64_t len) {
|
||||
int32x16_t x_l, x_h;
|
||||
int32x16_t sum = vset_zero_int();
|
||||
const auto rounddown_len = rounddown32(len);
|
||||
|
||||
for (int64_t i = 0; i < rounddown_len; i += 32) {
|
||||
vload2_sm(x + i, x_l, x_h);
|
||||
sum = vvadd_int32x16(sum, x_l);
|
||||
sum = vvadd_int32x16(sum, x_h);
|
||||
}
|
||||
|
||||
if (rounddown_len < len) {
|
||||
const auto mask = ~(-1 << (len - rounddown_len));
|
||||
vload2_sm_mz(x + rounddown_len, x_l, x_h, mask);
|
||||
sum = vvadd_int32x16(sum, x_l);
|
||||
sum = vvadd_int32x16(sum, x_h);
|
||||
}
|
||||
return v_reduce_sum_int32(sum);
|
||||
}
|
||||
|
||||
__global__ void get_padding_offset(int64_t* ids_remove_padding,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
const int64_t* input_data,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs) {
|
||||
int cid = core_id();
|
||||
int ncores = core_num();
|
||||
int clusterid = cluster_id();
|
||||
int nclusters = cluster_num();
|
||||
int tid = clusterid * ncores + cid;
|
||||
|
||||
int buf_len = 32;
|
||||
__simd__ int batch_id_per_token_lm[buf_len];
|
||||
__simd__ int cum_offsets_lm[16];
|
||||
int seq_len_lm;
|
||||
for (int i = clusterid; i < bs; i += nclusters) {
|
||||
GM2LM_ASYNC(seq_lens + i, &seq_len_lm, sizeof(int));
|
||||
GM2LM(cum_offsets + i - 1, cum_offsets_lm, 2 * sizeof(int));
|
||||
if (i == 0) {
|
||||
cum_offsets_lm[0] = 0;
|
||||
}
|
||||
for (int j = cid * buf_len; j < seq_len_lm; j += ncores * buf_len) {
|
||||
int cur_len = min(seq_len_lm - j, buf_len);
|
||||
for (int k = 0; k < cur_len; k++) {
|
||||
batch_id_per_token_lm[k] = i;
|
||||
}
|
||||
mfence_lm();
|
||||
LM2GM(batch_id_per_token_lm,
|
||||
batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j,
|
||||
cur_len * sizeof(int));
|
||||
__shared__ int sm_seq_lens[MAX_BATCH_SIZE];
|
||||
__shared__ int sm_cum_seq_len;
|
||||
__simd__ __shared__ int buffer_cu_seqlens[64];
|
||||
|
||||
if (cid == 0) {
|
||||
GM2SM(seq_lens, sm_seq_lens, sizeof(int) * bs);
|
||||
}
|
||||
sync_all();
|
||||
|
||||
for (int bi = clusterid; bi < bs; bi += nclusters) {
|
||||
int cum_seq_len = 0;
|
||||
for (int i = cid; i <= bi; i += ncores) {
|
||||
cum_seq_len += sm_seq_lens[i];
|
||||
}
|
||||
buffer_cu_seqlens[cid] = cum_seq_len;
|
||||
mfence();
|
||||
sync_all();
|
||||
|
||||
if (cid == 0) {
|
||||
int cum_seq_len = (i + 1) * max_seq_len - cum_offsets_lm[1];
|
||||
mfence_lm();
|
||||
LM2GM_ASYNC(cum_offsets_lm, cum_offsets_out + i, sizeof(int));
|
||||
LM2GM_ASYNC(&cum_seq_len, cu_seqlens_q + i + 1, sizeof(int));
|
||||
LM2GM(&cum_seq_len, cu_seqlens_k + i + 1, sizeof(int));
|
||||
cum_seq_len =
|
||||
primitive_reduce_sum_sm(buffer_cu_seqlens, min(bi + 1, ncores));
|
||||
|
||||
LM2GM_ASYNC(&cum_seq_len, cu_seqlens_q + bi + 1, sizeof(int));
|
||||
LM2GM_ASYNC(&cum_seq_len, cu_seqlens_k + bi + 1, sizeof(int));
|
||||
|
||||
int cum_offset = bi * max_seq_len - (cum_seq_len - sm_seq_lens[bi]);
|
||||
LM2GM(&cum_offset, cum_offsets_out + bi, sizeof(int));
|
||||
|
||||
sm_cum_seq_len = cum_seq_len;
|
||||
}
|
||||
mfence();
|
||||
sync_all();
|
||||
|
||||
const int lm_seq_lens = sm_seq_lens[bi];
|
||||
const int tgt_offset = sm_cum_seq_len - lm_seq_lens;
|
||||
const int buf_len = 32;
|
||||
__simd__ int64_t input_lm[buf_len];
|
||||
__simd__ int batch_id_lm[buf_len];
|
||||
|
||||
for (int k = 0; k < buf_len; k++) {
|
||||
batch_id_lm[k] = bi;
|
||||
}
|
||||
mfence_lm();
|
||||
|
||||
for (int j = cid * buf_len; j < lm_seq_lens; j += ncores * buf_len) {
|
||||
int cur_len = min(lm_seq_lens - j, buf_len);
|
||||
GM2LM(input_data + bi * max_seq_len + j,
|
||||
input_lm,
|
||||
sizeof(int64_t) * cur_len);
|
||||
LM2GM(input_lm,
|
||||
ids_remove_padding + tgt_offset + j,
|
||||
sizeof(int64_t) * cur_len);
|
||||
LM2GM(batch_id_lm,
|
||||
batch_id_per_token + tgt_offset + j,
|
||||
sizeof(int) * cur_len);
|
||||
}
|
||||
mfence();
|
||||
sync_all();
|
||||
}
|
||||
|
||||
if (cid == 0 && clusterid == 0) {
|
||||
const int lm_zero = 0;
|
||||
LM2GM_ASYNC(&lm_zero, cu_seqlens_q, sizeof(int));
|
||||
LM2GM(&lm_zero, cu_seqlens_k, sizeof(int));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,176 +19,121 @@
|
||||
|
||||
namespace fd_xpu3 {
|
||||
|
||||
__attribute__((global)) void get_padding_offset(int *padding_offset,
|
||||
int *cum_offsets_out,
|
||||
int *cu_seqlens_q,
|
||||
int *cu_seqlens_k,
|
||||
const int *cum_offsets,
|
||||
const int *seq_lens,
|
||||
__attribute__((global)) void get_padding_offset(int64_t* ids_remove_padding,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
const int64_t* input_data,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs);
|
||||
__attribute__((global)) void remove_padding(int64_t *x_remove_padding,
|
||||
const int64_t *input_data,
|
||||
const int *seq_lens,
|
||||
const int *cum_offsets,
|
||||
const int sequence_length,
|
||||
const int bs);
|
||||
|
||||
} // namespace fd_xpu3
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace plugin {
|
||||
|
||||
static int get_padding_offset_cpu(int *padding_offset,
|
||||
int *cum_offsets_out,
|
||||
int *cu_seqlens_q,
|
||||
int *cu_seqlens_k,
|
||||
const int *cum_offsets,
|
||||
const int *seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs) {
|
||||
static int cpu_wrapper(api::Context* ctx,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
int64_t* x_remove_padding,
|
||||
const int64_t* input_ids,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs) {
|
||||
int cum_seq_len = 0;
|
||||
cu_seqlens_q[0] = 0;
|
||||
cu_seqlens_k[0] = 0;
|
||||
for (int i = 0; i < bs; i++) {
|
||||
int cum_offset = i == 0 ? 0 : cum_offsets[i - 1];
|
||||
cum_offsets_out[i] = i * max_seq_len - cum_seq_len;
|
||||
for (int j = 0; j < seq_lens[i]; j++) {
|
||||
// TODO(mayang02): check offset of padding_offset
|
||||
padding_offset[i * max_seq_len - cum_offset + j] = cum_offset;
|
||||
const int tgt = cum_seq_len + j;
|
||||
x_remove_padding[tgt] = input_ids[i * max_seq_len + j];
|
||||
batch_id_per_token[tgt] = i;
|
||||
}
|
||||
cum_offsets_out[i] = cum_offset;
|
||||
int cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i];
|
||||
cum_seq_len += seq_lens[i];
|
||||
cu_seqlens_q[i + 1] = cum_seq_len;
|
||||
cu_seqlens_k[i + 1] = cum_seq_len;
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int remove_padding_cpu(int64_t *x_remove_padding,
|
||||
const int64_t *input_data,
|
||||
const int *seq_lens,
|
||||
const int *cum_offsets,
|
||||
const int sequence_length,
|
||||
const int bs) {
|
||||
for (int i = 0; i < bs; i++) {
|
||||
for (int j = 0; j < seq_lens[i]; j++) {
|
||||
const int tgt_seq_id = i * sequence_length - cum_offsets[i] + j;
|
||||
const int src_seq_id = i * sequence_length + j;
|
||||
// TODO(mayang02): check offset of x_remove_padding
|
||||
x_remove_padding[tgt_seq_id] = input_data[src_seq_id];
|
||||
}
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int cpu_wrapper(api::Context *ctx,
|
||||
int *padding_offset,
|
||||
int *cum_offsets_out,
|
||||
int *cu_seqlens_q,
|
||||
int *cu_seqlens_k,
|
||||
int64_t *x_remove_padding,
|
||||
const int64_t *input_ids,
|
||||
const int *cum_offsets,
|
||||
const int *seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs) {
|
||||
get_padding_offset_cpu(padding_offset,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
bs);
|
||||
remove_padding_cpu(
|
||||
x_remove_padding, input_ids, seq_lens, cum_offsets_out, max_seq_len, bs);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(api::Context *ctx,
|
||||
int *padding_offset,
|
||||
int *cum_offsets_out,
|
||||
int *cu_seqlens_q,
|
||||
int *cu_seqlens_k,
|
||||
int64_t *x_remove_padding,
|
||||
const int64_t *input_ids,
|
||||
const int *cum_offsets,
|
||||
const int *seq_lens,
|
||||
static int xpu3_wrapper(api::Context* ctx,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
int64_t* x_remove_padding,
|
||||
const int64_t* input_ids,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs) {
|
||||
using XPU_INT64 = typename api::XPUIndexType<int64_t>::type;
|
||||
auto get_padding_offset = fd_xpu3::get_padding_offset;
|
||||
auto remove_padding = fd_xpu3::remove_padding;
|
||||
int32_t ret_xre =
|
||||
get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
padding_offset,
|
||||
fd_xpu3::get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64*>(x_remove_padding),
|
||||
batch_id_per_token,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
cum_offsets,
|
||||
reinterpret_cast<const XPU_INT64*>(input_ids),
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
bs);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
ret_xre = remove_padding<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64 *>(x_remove_padding),
|
||||
reinterpret_cast<const XPU_INT64 *>(input_ids),
|
||||
seq_lens,
|
||||
cum_offsets_out,
|
||||
max_seq_len,
|
||||
bs);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int get_padding_offset(api::Context *ctx,
|
||||
int *padding_offset,
|
||||
int *cum_offsets_out,
|
||||
int *cu_seqlens_q,
|
||||
int *cu_seqlens_k,
|
||||
int64_t *x_remove_padding,
|
||||
const int64_t *input_ids,
|
||||
const int *cum_offsets,
|
||||
const int *seq_lens,
|
||||
int get_padding_offset(api::Context* ctx,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
int64_t* x_remove_padding,
|
||||
const int64_t* input_ids,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs,
|
||||
const int64_t token_num) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "get_padding_offset", int);
|
||||
WRAPPER_DUMP_PARAM4(
|
||||
ctx, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k);
|
||||
WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, cum_offsets, seq_lens);
|
||||
WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bs);
|
||||
ctx, batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k);
|
||||
WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, seq_lens, max_seq_len);
|
||||
WRAPPER_DUMP_PARAM2(ctx, bs, token_num);
|
||||
WRAPPER_DUMP(ctx);
|
||||
WRAPPER_ASSERT_GT(ctx, bs, 0);
|
||||
WRAPPER_ASSERT_GT(ctx, max_seq_len, 0);
|
||||
WRAPPER_CHECK_PTR(ctx, int, token_num, padding_offset);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, token_num, x_remove_padding);
|
||||
WRAPPER_CHECK_PTR(ctx, int, token_num, batch_id_per_token);
|
||||
WRAPPER_CHECK_PTR(ctx, int, bs, cum_offsets_out);
|
||||
WRAPPER_CHECK_PTR(ctx, int, bs + 1, cu_seqlens_q);
|
||||
WRAPPER_CHECK_PTR(ctx, int, bs + 1, cu_seqlens_k);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, token_num, x_remove_padding);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, bs * max_seq_len, input_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, int, bs, cum_offsets);
|
||||
WRAPPER_CHECK_PTR(ctx, int, bs, seq_lens);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx,
|
||||
padding_offset,
|
||||
batch_id_per_token,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
x_remove_padding,
|
||||
input_ids,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
bs);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
padding_offset,
|
||||
batch_id_per_token,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
x_remove_padding,
|
||||
input_ids,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
bs);
|
||||
|
||||
@@ -21,8 +21,7 @@ np.random.seed(2023)
|
||||
|
||||
max_len = 10
|
||||
seq_lens = np.array([4, 3, 6], "int32").reshape(-1, 1)
|
||||
cum_offset = np.cumsum((max_len - seq_lens).flatten(), -1, "int32")
|
||||
token_num = np.sum(seq_lens)
|
||||
token_num = int(np.sum(seq_lens))
|
||||
bs = seq_lens.shape[0]
|
||||
input_ids = np.zeros([bs, max_len], "int64")
|
||||
for i in range(bs):
|
||||
@@ -32,34 +31,44 @@ for i in range(bs):
|
||||
(
|
||||
x_remove_padding,
|
||||
cum_offsets_out,
|
||||
padding_offset,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(
|
||||
paddle.to_tensor(input_ids),
|
||||
paddle.to_tensor(cum_offset),
|
||||
paddle.to_tensor(token_num),
|
||||
paddle.to_tensor(seq_lens),
|
||||
paddle.to_tensor(seq_lens.flatten()),
|
||||
token_num,
|
||||
)
|
||||
|
||||
print("input_ids:\n", input_ids)
|
||||
print("cum_offset:\n", cum_offset)
|
||||
print("seq_lens:\n", seq_lens.flatten())
|
||||
print("token_num:\n", token_num)
|
||||
print("seq_lens:\n", seq_lens)
|
||||
print("x_remove_padding:\n", x_remove_padding)
|
||||
print("cum_offsets_out:\n", cum_offsets_out)
|
||||
print("padding_offset:\n", padding_offset)
|
||||
print("batch_id_per_token:\n", batch_id_per_token)
|
||||
print("cu_seqlens_q:\n", cu_seqlens_q)
|
||||
print("cu_seqlens_k:\n", cu_seqlens_k)
|
||||
|
||||
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64")
|
||||
ref_cum_offsets_out = np.array([0, 6, 13], "int32")
|
||||
ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13], "int32")
|
||||
ref_batch_id_per_token = np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], "int32")
|
||||
ref_cu_seqlens_q = np.array([0, 4, 7, 13], "int32")
|
||||
ref_cu_seqlens_k = np.array([0, 4, 7, 13], "int32")
|
||||
|
||||
assert sum(ref_x_remove_padding - x_remove_padding) == 0, "Check x_remove_padding failed."
|
||||
assert sum(ref_cum_offsets_out - cum_offsets_out) == 0, "Check cum_offsets_out failed."
|
||||
assert sum(ref_padding_offset - padding_offset) == 0, "Check padding_offset failed."
|
||||
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, "Check cu_seqlens_q failed."
|
||||
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, "Check cu_seqlens_k failed."
|
||||
assert (
|
||||
np.sum(np.abs(ref_x_remove_padding - x_remove_padding.numpy())) == 0
|
||||
), f"Check x_remove_padding failed.\nref: {ref_x_remove_padding}\ngot: {x_remove_padding.numpy()}"
|
||||
assert (
|
||||
np.sum(np.abs(ref_cum_offsets_out - cum_offsets_out.numpy())) == 0
|
||||
), f"Check cum_offsets_out failed.\nref: {ref_cum_offsets_out}\ngot: {cum_offsets_out.numpy()}"
|
||||
assert (
|
||||
np.sum(np.abs(ref_batch_id_per_token - batch_id_per_token.numpy())) == 0
|
||||
), f"Check batch_id_per_token failed.\nref: {ref_batch_id_per_token}\ngot: {batch_id_per_token.numpy()}"
|
||||
assert (
|
||||
np.sum(np.abs(ref_cu_seqlens_q - cu_seqlens_q.numpy())) == 0
|
||||
), f"Check cu_seqlens_q failed.\nref: {ref_cu_seqlens_q}\ngot: {cu_seqlens_q.numpy()}"
|
||||
assert (
|
||||
np.sum(np.abs(ref_cu_seqlens_k - cu_seqlens_k.numpy())) == 0
|
||||
), f"Check cu_seqlens_k failed.\nref: {ref_cu_seqlens_k}\ngot: {cu_seqlens_k.numpy()}"
|
||||
|
||||
print("\nAll checks passed!")
|
||||
|
||||
Reference in New Issue
Block a user