[XPU] Refactor get_padding_offset to single kernel. (#7029)

* [XPU] Refactor get_padding_offset to single kernel.

* add unittest.

* fix codestyle.

* remove cum_offsets_now.

* remove max_len.
This commit is contained in:
Jiajun Ji
2026-04-13 11:04:50 +08:00
committed by GitHub
parent 26d6a20c2f
commit cb03958b52
7 changed files with 199 additions and 182 deletions
@@ -16,29 +16,29 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &token_num,
const paddle::Tensor &seq_len) {
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
const paddle::Tensor& seq_len,
const int64_t cpu_token_num) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
const int token_num_data = static_cast<int>(cpu_token_num);
const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::full(
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::full(
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto cum_offsets_out =
paddle::full({bsz}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
if (token_num_data > 0) {
int r =
fastdeploy::plugin::get_padding_offset(xpu_ctx->x_context(),
@@ -48,7 +48,6 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
cu_seqlens_k.data<int>(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length,
bsz,
@@ -64,20 +63,15 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
}
std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
const std::vector<int64_t> &input_ids_shape,
const std::vector<int64_t> &cum_offsets_shape,
const std::vector<int64_t> &token_num_shape,
const std::vector<int64_t> &seq_len_shape) {
const std::vector<int64_t>& input_ids_shape,
const std::vector<int64_t>& seq_len_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
}
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
const paddle::DataType &input_ids_dtype,
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &token_num_dtype,
const paddle::DataType &seq_len_dtype) {
const paddle::DataType& input_ids_dtype,
const paddle::DataType& seq_len_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
@@ -86,12 +80,13 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
}
PD_BUILD_OP(get_padding_offset)
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
.Inputs({"input_ids", "seq_len"})
.Outputs({"x_remove_padding",
"cum_offsets_out",
"batch_id_per_token",
"cu_seqlens_q",
"cu_seqlens_k"})
.Attrs({"cpu_token_num: int64_t"})
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype));
+3 -5
View File
@@ -453,9 +453,8 @@ void GetOutputEPDynamic(const paddle::Tensor& x,
int msg_queue_id);
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len);
const paddle::Tensor& seq_len,
const int64_t cpu_token_num);
void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
const paddle::Tensor& stop_flags,
@@ -996,9 +995,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_padding_offset",
&GetPaddingOffset,
py::arg("input_ids"),
py::arg("cum_offsets"),
py::arg("token_num"),
py::arg("seq_len"),
py::arg("cpu_token_num"),
"get padding offset function");
m.def("init_kv_signal_per_query",
@@ -68,13 +68,12 @@ DLL_EXPORT int token_penalty_multi_scores(api::Context* ctx,
const int64_t length_bad_words);
DLL_EXPORT int get_padding_offset(api::Context* ctx,
int* padding_offset,
int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
int64_t* x_remove_padding,
const int64_t* input_ids,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len,
const int bs,
@@ -1,50 +1,123 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_simd.h"
namespace fd_xpu3 {
__global__ void get_padding_offset(int *batch_id_per_token,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
const int *cum_offsets,
const int *seq_lens,
#define MAX_BATCH_SIZE 1024
static inline __device__ int v_reduce_sum_int32(int32x16_t& v0) {
auto v1 = vsrlp_int32x16(1 << 8, v0);
v0 = vvadd_int32x16(v0, v1);
v1 = vsrlp_int32x16(1 << 7, v0);
v0 = vvadd_int32x16(v0, v1);
v1 = vsrlp_int32x16(1 << 6, v0);
v0 = vvadd_int32x16(v0, v1);
v1 = vsrlp_int32x16(1 << 5, v0);
v0 = vvadd_int32x16(v0, v1);
return vextract_int32x16(v0, 1);
}
inline __device__ int primitive_reduce_sum_sm(__shared_ptr__ const int* x,
int64_t len) {
int32x16_t x_l, x_h;
int32x16_t sum = vset_zero_int();
const auto rounddown_len = rounddown32(len);
for (int64_t i = 0; i < rounddown_len; i += 32) {
vload2_sm(x + i, x_l, x_h);
sum = vvadd_int32x16(sum, x_l);
sum = vvadd_int32x16(sum, x_h);
}
if (rounddown_len < len) {
const auto mask = ~(-1 << (len - rounddown_len));
vload2_sm_mz(x + rounddown_len, x_l, x_h, mask);
sum = vvadd_int32x16(sum, x_l);
sum = vvadd_int32x16(sum, x_h);
}
return v_reduce_sum_int32(sum);
}
__global__ void get_padding_offset(int64_t* ids_remove_padding,
int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
const int64_t* input_data,
const int* seq_lens,
const int max_seq_len,
const int bs) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
int tid = clusterid * ncores + cid;
int buf_len = 32;
__simd__ int batch_id_per_token_lm[buf_len];
__simd__ int cum_offsets_lm[16];
int seq_len_lm;
for (int i = clusterid; i < bs; i += nclusters) {
GM2LM_ASYNC(seq_lens + i, &seq_len_lm, sizeof(int));
GM2LM(cum_offsets + i - 1, cum_offsets_lm, 2 * sizeof(int));
if (i == 0) {
cum_offsets_lm[0] = 0;
}
for (int j = cid * buf_len; j < seq_len_lm; j += ncores * buf_len) {
int cur_len = min(seq_len_lm - j, buf_len);
for (int k = 0; k < cur_len; k++) {
batch_id_per_token_lm[k] = i;
}
mfence_lm();
LM2GM(batch_id_per_token_lm,
batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j,
cur_len * sizeof(int));
__shared__ int sm_seq_lens[MAX_BATCH_SIZE];
__shared__ int sm_cum_seq_len;
__simd__ __shared__ int buffer_cu_seqlens[64];
if (cid == 0) {
GM2SM(seq_lens, sm_seq_lens, sizeof(int) * bs);
}
sync_all();
for (int bi = clusterid; bi < bs; bi += nclusters) {
int cum_seq_len = 0;
for (int i = cid; i <= bi; i += ncores) {
cum_seq_len += sm_seq_lens[i];
}
buffer_cu_seqlens[cid] = cum_seq_len;
mfence();
sync_all();
if (cid == 0) {
int cum_seq_len = (i + 1) * max_seq_len - cum_offsets_lm[1];
mfence_lm();
LM2GM_ASYNC(cum_offsets_lm, cum_offsets_out + i, sizeof(int));
LM2GM_ASYNC(&cum_seq_len, cu_seqlens_q + i + 1, sizeof(int));
LM2GM(&cum_seq_len, cu_seqlens_k + i + 1, sizeof(int));
cum_seq_len =
primitive_reduce_sum_sm(buffer_cu_seqlens, min(bi + 1, ncores));
LM2GM_ASYNC(&cum_seq_len, cu_seqlens_q + bi + 1, sizeof(int));
LM2GM_ASYNC(&cum_seq_len, cu_seqlens_k + bi + 1, sizeof(int));
int cum_offset = bi * max_seq_len - (cum_seq_len - sm_seq_lens[bi]);
LM2GM(&cum_offset, cum_offsets_out + bi, sizeof(int));
sm_cum_seq_len = cum_seq_len;
}
mfence();
sync_all();
const int lm_seq_lens = sm_seq_lens[bi];
const int tgt_offset = sm_cum_seq_len - lm_seq_lens;
const int buf_len = 32;
__simd__ int64_t input_lm[buf_len];
__simd__ int batch_id_lm[buf_len];
for (int k = 0; k < buf_len; k++) {
batch_id_lm[k] = bi;
}
mfence_lm();
for (int j = cid * buf_len; j < lm_seq_lens; j += ncores * buf_len) {
int cur_len = min(lm_seq_lens - j, buf_len);
GM2LM(input_data + bi * max_seq_len + j,
input_lm,
sizeof(int64_t) * cur_len);
LM2GM(input_lm,
ids_remove_padding + tgt_offset + j,
sizeof(int64_t) * cur_len);
LM2GM(batch_id_lm,
batch_id_per_token + tgt_offset + j,
sizeof(int) * cur_len);
}
mfence();
sync_all();
}
if (cid == 0 && clusterid == 0) {
const int lm_zero = 0;
LM2GM_ASYNC(&lm_zero, cu_seqlens_q, sizeof(int));
LM2GM(&lm_zero, cu_seqlens_k, sizeof(int));
}
}
@@ -19,176 +19,121 @@
namespace fd_xpu3 {
__attribute__((global)) void get_padding_offset(int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
const int *cum_offsets,
const int *seq_lens,
__attribute__((global)) void get_padding_offset(int64_t* ids_remove_padding,
int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
const int64_t* input_data,
const int* seq_lens,
const int max_seq_len,
const int bs);
__attribute__((global)) void remove_padding(int64_t *x_remove_padding,
const int64_t *input_data,
const int *seq_lens,
const int *cum_offsets,
const int sequence_length,
const int bs);
} // namespace fd_xpu3
namespace fastdeploy {
namespace plugin {
static int get_padding_offset_cpu(int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
const int *cum_offsets,
const int *seq_lens,
const int max_seq_len,
const int bs) {
static int cpu_wrapper(api::Context* ctx,
int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
int64_t* x_remove_padding,
const int64_t* input_ids,
const int* seq_lens,
const int max_seq_len,
const int bs) {
int cum_seq_len = 0;
cu_seqlens_q[0] = 0;
cu_seqlens_k[0] = 0;
for (int i = 0; i < bs; i++) {
int cum_offset = i == 0 ? 0 : cum_offsets[i - 1];
cum_offsets_out[i] = i * max_seq_len - cum_seq_len;
for (int j = 0; j < seq_lens[i]; j++) {
// TODO(mayang02): check offset of padding_offset
padding_offset[i * max_seq_len - cum_offset + j] = cum_offset;
const int tgt = cum_seq_len + j;
x_remove_padding[tgt] = input_ids[i * max_seq_len + j];
batch_id_per_token[tgt] = i;
}
cum_offsets_out[i] = cum_offset;
int cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i];
cum_seq_len += seq_lens[i];
cu_seqlens_q[i + 1] = cum_seq_len;
cu_seqlens_k[i + 1] = cum_seq_len;
}
return api::SUCCESS;
}
static int remove_padding_cpu(int64_t *x_remove_padding,
const int64_t *input_data,
const int *seq_lens,
const int *cum_offsets,
const int sequence_length,
const int bs) {
for (int i = 0; i < bs; i++) {
for (int j = 0; j < seq_lens[i]; j++) {
const int tgt_seq_id = i * sequence_length - cum_offsets[i] + j;
const int src_seq_id = i * sequence_length + j;
// TODO(mayang02): check offset of x_remove_padding
x_remove_padding[tgt_seq_id] = input_data[src_seq_id];
}
}
return api::SUCCESS;
}
static int cpu_wrapper(api::Context *ctx,
int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
int64_t *x_remove_padding,
const int64_t *input_ids,
const int *cum_offsets,
const int *seq_lens,
const int max_seq_len,
const int bs) {
get_padding_offset_cpu(padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
cum_offsets,
seq_lens,
max_seq_len,
bs);
remove_padding_cpu(
x_remove_padding, input_ids, seq_lens, cum_offsets_out, max_seq_len, bs);
return api::SUCCESS;
}
static int xpu3_wrapper(api::Context *ctx,
int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
int64_t *x_remove_padding,
const int64_t *input_ids,
const int *cum_offsets,
const int *seq_lens,
static int xpu3_wrapper(api::Context* ctx,
int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
int64_t* x_remove_padding,
const int64_t* input_ids,
const int* seq_lens,
const int max_seq_len,
const int bs) {
using XPU_INT64 = typename api::XPUIndexType<int64_t>::type;
auto get_padding_offset = fd_xpu3::get_padding_offset;
auto remove_padding = fd_xpu3::remove_padding;
int32_t ret_xre =
get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
padding_offset,
fd_xpu3::get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64*>(x_remove_padding),
batch_id_per_token,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
cum_offsets,
reinterpret_cast<const XPU_INT64*>(input_ids),
seq_lens,
max_seq_len,
bs);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
ret_xre = remove_padding<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64 *>(x_remove_padding),
reinterpret_cast<const XPU_INT64 *>(input_ids),
seq_lens,
cum_offsets_out,
max_seq_len,
bs);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
int get_padding_offset(api::Context *ctx,
int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
int64_t *x_remove_padding,
const int64_t *input_ids,
const int *cum_offsets,
const int *seq_lens,
int get_padding_offset(api::Context* ctx,
int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
int64_t* x_remove_padding,
const int64_t* input_ids,
const int* seq_lens,
const int max_seq_len,
const int bs,
const int64_t token_num) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "get_padding_offset", int);
WRAPPER_DUMP_PARAM4(
ctx, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k);
WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, cum_offsets, seq_lens);
WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bs);
ctx, batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k);
WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, seq_lens, max_seq_len);
WRAPPER_DUMP_PARAM2(ctx, bs, token_num);
WRAPPER_DUMP(ctx);
WRAPPER_ASSERT_GT(ctx, bs, 0);
WRAPPER_ASSERT_GT(ctx, max_seq_len, 0);
WRAPPER_CHECK_PTR(ctx, int, token_num, padding_offset);
WRAPPER_CHECK_PTR(ctx, int64_t, token_num, x_remove_padding);
WRAPPER_CHECK_PTR(ctx, int, token_num, batch_id_per_token);
WRAPPER_CHECK_PTR(ctx, int, bs, cum_offsets_out);
WRAPPER_CHECK_PTR(ctx, int, bs + 1, cu_seqlens_q);
WRAPPER_CHECK_PTR(ctx, int, bs + 1, cu_seqlens_k);
WRAPPER_CHECK_PTR(ctx, int64_t, token_num, x_remove_padding);
WRAPPER_CHECK_PTR(ctx, int64_t, bs * max_seq_len, input_ids);
WRAPPER_CHECK_PTR(ctx, int, bs, cum_offsets);
WRAPPER_CHECK_PTR(ctx, int, bs, seq_lens);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
padding_offset,
batch_id_per_token,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
x_remove_padding,
input_ids,
cum_offsets,
seq_lens,
max_seq_len,
bs);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
padding_offset,
batch_id_per_token,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
x_remove_padding,
input_ids,
cum_offsets,
seq_lens,
max_seq_len,
bs);
@@ -21,8 +21,7 @@ np.random.seed(2023)
max_len = 10
seq_lens = np.array([4, 3, 6], "int32").reshape(-1, 1)
cum_offset = np.cumsum((max_len - seq_lens).flatten(), -1, "int32")
token_num = np.sum(seq_lens)
token_num = int(np.sum(seq_lens))
bs = seq_lens.shape[0]
input_ids = np.zeros([bs, max_len], "int64")
for i in range(bs):
@@ -32,34 +31,44 @@ for i in range(bs):
(
x_remove_padding,
cum_offsets_out,
padding_offset,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(
paddle.to_tensor(input_ids),
paddle.to_tensor(cum_offset),
paddle.to_tensor(token_num),
paddle.to_tensor(seq_lens),
paddle.to_tensor(seq_lens.flatten()),
token_num,
)
print("input_ids:\n", input_ids)
print("cum_offset:\n", cum_offset)
print("seq_lens:\n", seq_lens.flatten())
print("token_num:\n", token_num)
print("seq_lens:\n", seq_lens)
print("x_remove_padding:\n", x_remove_padding)
print("cum_offsets_out:\n", cum_offsets_out)
print("padding_offset:\n", padding_offset)
print("batch_id_per_token:\n", batch_id_per_token)
print("cu_seqlens_q:\n", cu_seqlens_q)
print("cu_seqlens_k:\n", cu_seqlens_k)
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64")
ref_cum_offsets_out = np.array([0, 6, 13], "int32")
ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13], "int32")
ref_batch_id_per_token = np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], "int32")
ref_cu_seqlens_q = np.array([0, 4, 7, 13], "int32")
ref_cu_seqlens_k = np.array([0, 4, 7, 13], "int32")
assert sum(ref_x_remove_padding - x_remove_padding) == 0, "Check x_remove_padding failed."
assert sum(ref_cum_offsets_out - cum_offsets_out) == 0, "Check cum_offsets_out failed."
assert sum(ref_padding_offset - padding_offset) == 0, "Check padding_offset failed."
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, "Check cu_seqlens_q failed."
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, "Check cu_seqlens_k failed."
assert (
np.sum(np.abs(ref_x_remove_padding - x_remove_padding.numpy())) == 0
), f"Check x_remove_padding failed.\nref: {ref_x_remove_padding}\ngot: {x_remove_padding.numpy()}"
assert (
np.sum(np.abs(ref_cum_offsets_out - cum_offsets_out.numpy())) == 0
), f"Check cum_offsets_out failed.\nref: {ref_cum_offsets_out}\ngot: {cum_offsets_out.numpy()}"
assert (
np.sum(np.abs(ref_batch_id_per_token - batch_id_per_token.numpy())) == 0
), f"Check batch_id_per_token failed.\nref: {ref_batch_id_per_token}\ngot: {batch_id_per_token.numpy()}"
assert (
np.sum(np.abs(ref_cu_seqlens_q - cu_seqlens_q.numpy())) == 0
), f"Check cu_seqlens_q failed.\nref: {ref_cu_seqlens_q}\ngot: {cu_seqlens_q.numpy()}"
assert (
np.sum(np.abs(ref_cu_seqlens_k - cu_seqlens_k.numpy())) == 0
), f"Check cu_seqlens_k failed.\nref: {ref_cu_seqlens_k}\ngot: {cu_seqlens_k.numpy()}"
print("\nAll checks passed!")