c++ code format (#4527)

This commit is contained in:
zhupengyang
2025-10-22 17:59:50 +08:00
committed by GitHub
parent d7bcedf421
commit 3a6883ac1a
97 changed files with 8760 additions and 7382 deletions
+1 -1
View File
@@ -16,7 +16,7 @@
---
Language: Cpp
BasedOnStyle: Google
IndentWidth: 4
IndentWidth: 2
TabWidth: 2
ContinuationIndentWidth: 4
AccessModifierOffset: -1 # The private/protected/public has no indent in class
+13
View File
@@ -1,3 +1,7 @@
exclude: |
(?x)^(
dockerfiles/.+
)$
default_install_hook_types:
- pre-commit
- commit-msg
@@ -27,6 +31,15 @@ repos:
hooks:
- id: ruff
args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml]
# For C++ files
- repo: local
hooks:
- id: clang-format
name: clang-format
description: Format files with ClangFormat.
entry: clang-format -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|cuh|hpp|hxx|xpu|kps)$
# # 拼写检查
# - repo: https://github.com/codespell-project/codespell
# rev: v2.4.1
+10 -10
View File
@@ -19,28 +19,28 @@ std::vector<paddle::Tensor> InvokeAvxWeightOnly(const paddle::Tensor &x,
const paddle::Tensor &w_bias,
const std::string &alog,
bool trans) {
auto out_shape = x.shape();
out_shape[out_shape.size() - 1] = weight.shape()[1];
auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace());
return {out};
auto out_shape = x.shape();
out_shape[out_shape.size() - 1] = weight.shape()[1];
auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace());
return {out};
}
std::vector<std::vector<int64_t>> AvxWeightOnlyInferShape(
std::vector<int64_t> x_shape,
std::vector<int64_t> weigh_shape,
std::vector<int64_t> weigh_bias_shape) {
int m = 1;
for (int i = 0; i < x_shape.size() - 1; i++) {
m = m * x_shape[i];
}
return {std::vector<int64_t>{m, weigh_shape[1]}};
int m = 1;
for (int i = 0; i < x_shape.size() - 1; i++) {
m = m * x_shape[i];
}
return {std::vector<int64_t>{m, weigh_shape[1]}};
}
std::vector<paddle::DataType> AvxWeightOnlyInferDtype(
paddle::DataType x_dtype,
paddle::DataType weight_dtype,
paddle::DataType weight_bias_dtype) {
return {x_dtype};
return {x_dtype};
}
PD_BUILD_STATIC_OP(avx_weight_only)
+51 -59
View File
@@ -20,13 +20,13 @@ void remove_padding(int64_t *output_data,
const int *cum_offsets,
const int sequence_length,
const int bsz) {
for (int bi = 0; bi < bsz; ++bi) {
for (int i = 0; i < seq_lens[bi]; ++i) {
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
const int src_seq_id = bi * sequence_length + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
}
for (int bi = 0; bi < bsz; ++bi) {
for (int i = 0; i < seq_lens[bi]; ++i) {
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
const int src_seq_id = bi * sequence_length + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
}
}
}
void get_padding_offset_kernel(int *padding_offset,
@@ -37,56 +37,53 @@ void get_padding_offset_kernel(int *padding_offset,
const int *seq_lens,
const int max_seq_len,
const int bsz) {
for (int bi = 0; bi < bsz; ++bi) {
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
auto seq_len_now = seq_lens[bi];
for (int i = 0; i < seq_len_now; ++i) {
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
}
cum_offsets_out[bi] = cum_offset;
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
cu_seqlens_q[bi + 1] = cum_seq_len;
cu_seqlens_k[bi + 1] = cum_seq_len;
for (int bi = 0; bi < bsz; ++bi) {
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
auto seq_len_now = seq_lens[bi];
for (int i = 0; i < seq_len_now; ++i) {
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
}
cum_offsets_out[bi] = cum_offset;
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
cu_seqlens_q[bi + 1] = cum_seq_len;
cu_seqlens_k[bi + 1] = cum_seq_len;
}
}
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &token_num,
const paddle::Tensor &seq_len) {
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
get_padding_offset_kernel(padding_offset.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length,
bsz);
remove_padding(x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
seq_len.data<int>(),
cum_offsets_out.data<int>(),
seq_length,
bsz);
return {x_remove_padding,
padding_offset,
cu_seqlens_q,
cu_seqlens_k};
const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
get_padding_offset_kernel(padding_offset.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length,
bsz);
remove_padding(x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
seq_len.data<int>(),
cum_offsets_out.data<int>(),
seq_length,
bsz);
return {x_remove_padding, padding_offset, cu_seqlens_q, cu_seqlens_k};
}
std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
@@ -94,9 +91,9 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
const std::vector<int64_t> &cum_offsets_shape,
const std::vector<int64_t> &token_num_shape,
const std::vector<int64_t> &seq_len_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
}
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
@@ -104,18 +101,13 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &token_num_dtype,
const paddle::DataType &seq_len_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
}
PD_BUILD_STATIC_OP(get_padding_offset_cpu)
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
.Outputs({"x_remove_padding",
"padding_offset",
"cu_seqlens_q",
"cu_seqlens_k"})
.Outputs(
{"x_remove_padding", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"})
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype));
+173 -177
View File
@@ -19,7 +19,6 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
template <typename T>
void RebuildPaddingCPUImpl(T *output_data,
const T *input_data,
@@ -30,27 +29,27 @@ void RebuildPaddingCPUImpl(T *output_data,
int max_input_length,
int dim_embed,
const int elem_nums) {
for (int i = 0; i < elem_nums; ++i) {
const int bi = i / dim_embed;
const int bias_idx = i % dim_embed;
int seq_id = 0;
for (int i = 0; i < elem_nums; ++i) {
const int bi = i / dim_embed;
const int bias_idx = i % dim_embed;
int seq_id = 0;
if (seq_len_this_time_data[bi] == 0) {
continue;
}
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
continue;
}
if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1;
}
const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
const int src_offset = ori_token_idx * dim_embed + bias_idx;
output_data[i] = input_data[src_offset];
if (seq_len_this_time_data[bi] == 0) {
continue;
}
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
continue;
}
if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1;
}
const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
const int src_offset = ori_token_idx * dim_embed + bias_idx;
output_data[i] = input_data[src_offset];
}
}
template <typename T>
@@ -64,27 +63,25 @@ void RebuildAppendPaddingCPUImpl(T *output_data,
const int max_input_length,
const int dim_embed,
const int64_t output_elem_nums) {
for (int i = 0; i < output_elem_nums; ++i) {
int out_token_id = i / dim_embed;
int ori_token_id =
out_token_id + output_padding_offset_data[out_token_id];
int bi = ori_token_id / max_input_length;
if (seq_len_this_time_data[bi] == 0 ||
(seq_lens_decoder_data[bi] == 0 &&
seq_lens_encoder_data[bi] == 0)) {
continue;
}
int seq_id = 0;
if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1;
}
int input_token_id = cu_seqlens_q_data[bi] + seq_id;
int bias_idx = i % dim_embed;
int src_offset = input_token_id * dim_embed + bias_idx;
output_data[i] = input_data[src_offset];
for (int i = 0; i < output_elem_nums; ++i) {
int out_token_id = i / dim_embed;
int ori_token_id = out_token_id + output_padding_offset_data[out_token_id];
int bi = ori_token_id / max_input_length;
if (seq_len_this_time_data[bi] == 0 ||
(seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0)) {
continue;
}
int seq_id = 0;
if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1;
}
int input_token_id = cu_seqlens_q_data[bi] + seq_id;
int bias_idx = i % dim_embed;
int src_offset = input_token_id * dim_embed + bias_idx;
output_data[i] = input_data[src_offset];
}
}
std::vector<paddle::Tensor> RebuildPaddingCPU(
@@ -95,140 +92,139 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
auto seq_len_this_time_cpu =
seq_len_this_time.copy_to(paddle::CPUPlace(), true);
auto seq_lens_decoder_cpu =
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
auto seq_lens_encoder_cpu =
seq_lens_encoder.copy_to(paddle::CPUPlace(), true);
paddle::optional<paddle::Tensor> output_padding_offset_cpu;
if (output_padding_offset) {
output_padding_offset_cpu =
output_padding_offset->copy_to(paddle::CPUPlace(), true);
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
auto seq_len_this_time_cpu =
seq_len_this_time.copy_to(paddle::CPUPlace(), true);
auto seq_lens_decoder_cpu =
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
auto seq_lens_encoder_cpu =
seq_lens_encoder.copy_to(paddle::CPUPlace(), true);
paddle::optional<paddle::Tensor> output_padding_offset_cpu;
if (output_padding_offset) {
output_padding_offset_cpu =
output_padding_offset->copy_to(paddle::CPUPlace(), true);
}
int token_num = tmp_out_cpu.shape()[0];
int dim_embed = tmp_out_cpu.shape()[1];
int bsz = cu_seqlens_q_cpu.shape()[0] - 1;
paddle::Tensor out;
if (output_padding_offset_cpu) {
int need_delete_token_num = 0;
for (int i = 0; i < bsz; ++i) {
if (seq_lens_encoder_cpu.data<int>()[i] > 0) {
need_delete_token_num += seq_lens_encoder_cpu.data<int>()[i] - 1;
}
}
int output_token_num = token_num - need_delete_token_num;
out = paddle::full({output_token_num, dim_embed},
0,
tmp_out_cpu.dtype(),
paddle::CPUPlace());
} else {
out = paddle::full(
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
}
int token_num = tmp_out_cpu.shape()[0];
int dim_embed = tmp_out_cpu.shape()[1];
int bsz = cu_seqlens_q_cpu.shape()[0] - 1;
const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
int elem_nums = out.numel();
paddle::Tensor out;
if (output_padding_offset_cpu) {
int need_delete_token_num = 0;
for (int i = 0; i < bsz; ++i) {
if (seq_lens_encoder_cpu.data<int>()[i] > 0) {
need_delete_token_num +=
seq_lens_encoder_cpu.data<int>()[i] - 1;
}
}
int output_token_num = token_num - need_delete_token_num;
out = paddle::full({output_token_num, dim_embed},
0,
tmp_out_cpu.dtype(),
paddle::CPUPlace());
} else {
out = paddle::full(
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
if (output_padding_offset_cpu) {
const int *output_padding_offset_data =
output_padding_offset_cpu->data<int>();
switch (tmp_out_cpu.dtype()) {
case paddle::DataType::FLOAT32:
RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(),
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
output_padding_offset_data,
max_input_length,
dim_embed,
elem_nums);
break;
case paddle::DataType::FLOAT16:
RebuildAppendPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(),
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
output_padding_offset_data,
max_input_length,
dim_embed,
elem_nums);
break;
case paddle::DataType::BFLOAT16:
RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(),
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
output_padding_offset_data,
max_input_length,
dim_embed,
elem_nums);
break;
default:
PD_THROW(
"Unsupported data type for rebuild_padding_cpu. "
"Only float32, float16, and bfloat16 are supported.");
}
const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
int elem_nums = out.numel();
if (output_padding_offset_cpu) {
const int *output_padding_offset_data =
output_padding_offset_cpu->data<int>();
switch (tmp_out_cpu.dtype()) {
case paddle::DataType::FLOAT32:
RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(),
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
output_padding_offset_data,
max_input_length,
dim_embed,
elem_nums);
break;
case paddle::DataType::FLOAT16:
RebuildAppendPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(),
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
output_padding_offset_data,
max_input_length,
dim_embed,
elem_nums);
break;
case paddle::DataType::BFLOAT16:
RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(),
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
output_padding_offset_data,
max_input_length,
dim_embed,
elem_nums);
break;
default:
PD_THROW(
"Unsupported data type for rebuild_padding_cpu. "
"Only float32, float16, and bfloat16 are supported.");
}
} else {
switch (tmp_out_cpu.dtype()) {
case paddle::DataType::FLOAT32:
RebuildPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(),
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
max_input_length,
dim_embed,
elem_nums);
break;
case paddle::DataType::FLOAT16:
RebuildPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(),
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
max_input_length,
dim_embed,
elem_nums);
break;
case paddle::DataType::BFLOAT16:
RebuildPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(),
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
max_input_length,
dim_embed,
elem_nums);
break;
default:
PD_THROW(
"Unsupported data type for rebuild_padding_cpu. "
"Only float32, float16, and bfloat16 are supported.");
}
} else {
switch (tmp_out_cpu.dtype()) {
case paddle::DataType::FLOAT32:
RebuildPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(),
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
max_input_length,
dim_embed,
elem_nums);
break;
case paddle::DataType::FLOAT16:
RebuildPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(),
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
max_input_length,
dim_embed,
elem_nums);
break;
case paddle::DataType::BFLOAT16:
RebuildPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(),
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
max_input_length,
dim_embed,
elem_nums);
break;
default:
PD_THROW(
"Unsupported data type for rebuild_padding_cpu. "
"Only float32, float16, and bfloat16 are supported.");
}
return {out};
}
return {out};
}
std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
@@ -238,13 +234,13 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_encoder_shape,
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
int64_t dim_embed = tmp_out_shape[1];
if (output_padding_offset_shape) {
return {{-1, dim_embed}};
} else {
int64_t bsz = cu_seqlens_q_shape[0] - 1;
return {{bsz, dim_embed}};
}
int64_t dim_embed = tmp_out_shape[1];
if (output_padding_offset_shape) {
return {{-1, dim_embed}};
} else {
int64_t bsz = cu_seqlens_q_shape[0] - 1;
return {{bsz, dim_embed}};
}
}
std::vector<paddle::DataType> RebuildPaddingInferDtype(
@@ -254,7 +250,7 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_encoder_dtype,
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
return {tmp_out_dtype};
return {tmp_out_dtype};
}
PD_BUILD_STATIC_OP(rebuild_padding_cpu)
+25 -25
View File
@@ -15,27 +15,27 @@
#include "paddle/extension.h"
void set_value_by_flags_and_idx(const bool *stop_flags,
int64_t *pre_ids_all,
const int64_t *input_ids,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int length_input_ids) {
for (int bi = 0; bi < bs; bi++) {
if (!stop_flags[bi]) {
const int seq_len_dec = seq_lens_decoder[bi];
const int seq_len_enc = seq_lens_encoder[bi];
int64_t *pre_ids_all_now = pre_ids_all + bi * length;
const int64_t *input_ids_now = input_ids + bi * length_input_ids;
if (seq_len_dec == 0) {
pre_ids_all_now[step_idx[bi]] = input_ids_now[seq_len_enc - 1];
} else {
pre_ids_all_now[step_idx[bi]] = input_ids_now[0];
}
}
int64_t *pre_ids_all,
const int64_t *input_ids,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int length_input_ids) {
for (int bi = 0; bi < bs; bi++) {
if (!stop_flags[bi]) {
const int seq_len_dec = seq_lens_decoder[bi];
const int seq_len_enc = seq_lens_encoder[bi];
int64_t *pre_ids_all_now = pre_ids_all + bi * length;
const int64_t *input_ids_now = input_ids + bi * length_input_ids;
if (seq_len_dec == 0) {
pre_ids_all_now[step_idx[bi]] = input_ids_now[seq_len_enc - 1];
} else {
pre_ids_all_now[step_idx[bi]] = input_ids_now[0];
}
}
}
}
void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
@@ -45,12 +45,12 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags) {
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
int bs = seq_lens_this_time.shape()[0];
int length = pre_ids_all_shape[1];
int length_input_ids = input_ids.shape()[1];
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
int bs = seq_lens_this_time.shape()[0];
int length = pre_ids_all_shape[1];
int length_input_ids = input_ids.shape()[1];
set_value_by_flags_and_idx(stop_flags.data<bool>(),
set_value_by_flags_and_idx(stop_flags.data<bool>(),
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
input_ids.data<int64_t>(),
seq_lens_encoder.data<int>(),
+29 -29
View File
@@ -21,45 +21,45 @@ void probs_sort(const float *probs,
float *ProbsVals,
int vocab_size,
int bsz) {
float cursum = 0;
std::vector<int64_t> elementsIds(vocab_size);
std::vector<float> elementsProbs(vocab_size);
float cursum = 0;
std::vector<int64_t> elementsIds(vocab_size);
std::vector<float> elementsProbs(vocab_size);
#pragma omp parallel for
for (int j = 0; j < vocab_size; j++) {
elementsIds[j] = j;
elementsProbs[j] = probs[j];
}
x86simdsortStatic::keyvalue_qsort(
elementsProbs.data(), elementsIds.data(), vocab_size, false, true);
for (int j = 0; j < vocab_size; j++) {
elementsIds[j] = j;
elementsProbs[j] = probs[j];
}
x86simdsortStatic::keyvalue_qsort(
elementsProbs.data(), elementsIds.data(), vocab_size, false, true);
#pragma omp parallel for
for (int j = 0; j < vocab_size; ++j) {
ProbsVals[j] = elementsProbs[j];
ProbsIds[j] = elementsIds[j];
}
for (int j = 0; j < vocab_size; ++j) {
ProbsVals[j] = elementsProbs[j];
ProbsIds[j] = elementsIds[j];
}
}
std::vector<paddle::Tensor> SimdSort(const paddle::Tensor &probs) {
const int bsz = probs.shape()[0];
const int vocab_size = probs.shape()[1];
auto sorted_indices = paddle::empty(
{bsz, vocab_size}, paddle::DataType::INT64, probs.place());
auto sorted_probs = paddle::empty(
{bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place());
probs_sort(probs.data<float>(),
const_cast<int64_t *>(sorted_indices.data<int64_t>()),
const_cast<float *>(sorted_probs.data<float>()),
vocab_size,
bsz);
return {sorted_indices, sorted_probs};
const int bsz = probs.shape()[0];
const int vocab_size = probs.shape()[1];
auto sorted_indices =
paddle::empty({bsz, vocab_size}, paddle::DataType::INT64, probs.place());
auto sorted_probs = paddle::empty(
{bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place());
probs_sort(probs.data<float>(),
const_cast<int64_t *>(sorted_indices.data<int64_t>()),
const_cast<float *>(sorted_probs.data<float>()),
vocab_size,
bsz);
return {sorted_indices, sorted_probs};
}
std::vector<std::vector<int64_t>> SimdSortInferShape(
const std::vector<int64_t> &probs_shape) {
int64_t bsz = probs_shape[0];
int64_t vocab_size = probs_shape[1];
return {{bsz, vocab_size}, {bsz, vocab_size}};
int64_t bsz = probs_shape[0];
int64_t vocab_size = probs_shape[1];
return {{bsz, vocab_size}, {bsz, vocab_size}};
}
std::vector<paddle::DataType> SimdSortInferDtype(
const paddle::DataType &probs_dtype) {
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
}
PD_BUILD_STATIC_OP(simd_sort)
.Inputs({"probs"})
+11 -11
View File
@@ -16,23 +16,23 @@
#include "paddle/extension.h"
std::vector<paddle::Tensor> SimdSort(const paddle::Tensor &probs) {
const int bsz = probs.shape()[0];
const int vocab_size = probs.shape()[1];
auto sorted_indices = paddle::empty(
{bsz, vocab_size}, paddle::DataType::INT64, probs.place());
auto sorted_probs = paddle::empty(
{bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place());
return {sorted_indices, sorted_probs};
const int bsz = probs.shape()[0];
const int vocab_size = probs.shape()[1];
auto sorted_indices =
paddle::empty({bsz, vocab_size}, paddle::DataType::INT64, probs.place());
auto sorted_probs = paddle::empty(
{bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place());
return {sorted_indices, sorted_probs};
}
std::vector<std::vector<int64_t>> SimdSortInferShape(
const std::vector<int64_t> &probs_shape) {
int64_t bsz = probs_shape[0];
int64_t vocab_size = probs_shape[1];
return {{bsz, vocab_size}, {bsz, vocab_size}};
int64_t bsz = probs_shape[0];
int64_t vocab_size = probs_shape[1];
return {{bsz, vocab_size}, {bsz, vocab_size}};
}
std::vector<paddle::DataType> SimdSortInferDtype(
const paddle::DataType &probs_dtype) {
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
}
PD_BUILD_STATIC_OP(simd_sort)
.Inputs({"probs"})
@@ -23,13 +23,13 @@
#endif
bool is_in_end(const int64_t id, const int64_t *end_ids, int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
return flag;
}
return flag;
}
void set_value_by_flags(bool *stop_flags,
@@ -40,23 +40,23 @@ void set_value_by_flags(bool *stop_flags,
const int bs,
const int end_length,
bool beam_search) {
for (int bi = 0; bi < bs; bi++) {
if (stop_flags[bi]) {
if ((seq_lens[bi] == 0)) {
topk_ids[bi] = -1;
} else {
topk_ids[bi] = end_ids[0];
next_tokens[bi] = end_ids[0];
}
} else {
next_tokens[bi] = topk_ids[bi];
}
if (!beam_search && is_in_end(topk_ids[bi], end_ids, end_length)) {
stop_flags[bi] = true;
topk_ids[bi] = end_ids[0];
next_tokens[bi] = end_ids[0];
}
for (int bi = 0; bi < bs; bi++) {
if (stop_flags[bi]) {
if ((seq_lens[bi] == 0)) {
topk_ids[bi] = -1;
} else {
topk_ids[bi] = end_ids[0];
next_tokens[bi] = end_ids[0];
}
} else {
next_tokens[bi] = topk_ids[bi];
}
if (!beam_search && is_in_end(topk_ids[bi], end_ids, end_length)) {
stop_flags[bi] = true;
topk_ids[bi] = end_ids[0];
next_tokens[bi] = end_ids[0];
}
}
}
void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
@@ -65,17 +65,17 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &end_ids,
const paddle::Tensor &next_tokens,
const bool beam_search) {
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
set_value_by_flags(const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(),
seq_lens.data<int>(),
bs_now,
end_length,
false);
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
set_value_by_flags(const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(),
seq_lens.data<int>(),
bs_now,
end_length,
false);
}
PD_BUILD_STATIC_OP(set_stop_value_multi_ends_cpu)
@@ -23,16 +23,16 @@ void min_length_logits_process(float *logits,
const int64_t bs,
const int64_t length,
const int64_t end_length) {
for (int bi = 0; bi < bs; ++bi) {
if (cur_len[bi] < 0) {
continue;
}
if (cur_len[bi] < min_len[bi]) {
for (int i = 0; i < end_length; ++i) {
logits[bi * length + eos_token_id[i]] = -1e10;
}
}
for (int bi = 0; bi < bs; ++bi) {
if (cur_len[bi] < 0) {
continue;
}
if (cur_len[bi] < min_len[bi]) {
for (int i = 0; i < end_length; ++i) {
logits[bi * length + eos_token_id[i]] = -1e10;
}
}
}
}
void update_repeat_times(const int64_t *pre_ids,
@@ -41,20 +41,20 @@ void update_repeat_times(const int64_t *pre_ids,
const int64_t bs,
const int64_t length,
const int64_t length_id) {
for (int bi = 0; bi < bs; ++bi) {
if (cur_len[bi] < 0) {
continue;
}
const int64_t *pre_ids_now = pre_ids + bi * length_id;
int *repeat_times_now = repeat_times + bi * length;
for (int i = 0; i < length_id; i++) {
int64_t id = pre_ids_now[i];
if (id < 0) {
break;
}
repeat_times_now[id] += 1;
}
for (int bi = 0; bi < bs; ++bi) {
if (cur_len[bi] < 0) {
continue;
}
const int64_t *pre_ids_now = pre_ids + bi * length_id;
int *repeat_times_now = repeat_times + bi * length;
for (int i = 0; i < length_id; i++) {
int64_t id = pre_ids_now[i];
if (id < 0) {
break;
}
repeat_times_now[id] += 1;
}
}
}
void update_value_by_repeat_times(const int *repeat_times,
@@ -65,24 +65,22 @@ void update_value_by_repeat_times(const int *repeat_times,
float *logits,
const int64_t bs,
const int64_t length) {
for (int bi = 0; bi < bs; ++bi) {
float *logits_now = logits + bi * length;
const int *repeat_times_now = repeat_times + bi * length;
float alpha = static_cast<float>(penalty_scores[bi]);
float beta = static_cast<float>(frequency_score[bi]);
float gamma = static_cast<float>(presence_score[bi]);
for (int i = 0; i < length; ++i) {
int times = repeat_times_now[i];
float logit_now = static_cast<float>(logits_now[i]);
if (times == 0) {
logits_now[i] =
static_cast<float>(logit_now / temperatures[bi]);
}
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
logits_now[i] =
static_cast<float>(logit_now - times * beta - gamma);
}
for (int bi = 0; bi < bs; ++bi) {
float *logits_now = logits + bi * length;
const int *repeat_times_now = repeat_times + bi * length;
float alpha = static_cast<float>(penalty_scores[bi]);
float beta = static_cast<float>(frequency_score[bi]);
float gamma = static_cast<float>(presence_score[bi]);
for (int i = 0; i < length; ++i) {
int times = repeat_times_now[i];
float logit_now = static_cast<float>(logits_now[i]);
if (times == 0) {
logits_now[i] = static_cast<float>(logit_now / temperatures[bi]);
}
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
logits_now[i] = static_cast<float>(logit_now - times * beta - gamma);
}
}
}
void ban_bad_words(float *logits,
@@ -90,15 +88,14 @@ void ban_bad_words(float *logits,
const int64_t bs,
const int64_t length,
const int64_t bad_words_length) {
for (int bi = 0; bi < bs; ++bi) {
float *logits_now = logits + bi * length;
for (int bwid = 0; bwid < bad_words_length; ++bwid) {
const int64_t bad_words_token_id = bad_words_list[bwid];
if (bad_words_token_id >= length || bad_words_token_id < 0)
continue;
logits_now[bad_words_token_id] = -1e10;
}
for (int bi = 0; bi < bs; ++bi) {
float *logits_now = logits + bi * length;
for (int bwid = 0; bwid < bad_words_length; ++bwid) {
const int64_t bad_words_token_id = bad_words_list[bwid];
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
logits_now[bad_words_token_id] = -1e10;
}
}
}
template <paddle::DataType D>
@@ -112,44 +109,44 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id) {
std::vector<int64_t> shape = logits.shape();
auto repeat_times =
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
int64_t bs = shape[0];
int64_t length = shape[1];
int64_t length_id = pre_ids.shape()[1];
int64_t end_length = eos_token_id.shape()[0];
int64_t length_bad_words = bad_tokens.shape()[0];
std::vector<int64_t> shape = logits.shape();
auto repeat_times =
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
int64_t bs = shape[0];
int64_t length = shape[1];
int64_t length_id = pre_ids.shape()[1];
int64_t end_length = eos_token_id.shape()[0];
int64_t length_bad_words = bad_tokens.shape()[0];
min_length_logits_process(const_cast<float *>(logits.data<float>()),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bs,
length,
end_length);
min_length_logits_process(const_cast<float *>(logits.data<float>()),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bs,
length,
end_length);
update_repeat_times(pre_ids.data<int64_t>(),
cur_len.data<int64_t>(),
repeat_times.data<int>(),
bs,
length,
length_id);
update_repeat_times(pre_ids.data<int64_t>(),
cur_len.data<int64_t>(),
repeat_times.data<int>(),
bs,
length,
length_id);
update_value_by_repeat_times(repeat_times.data<int>(),
penalty_scores.data<float>(),
frequency_score.data<float>(),
presence_score.data<float>(),
temperatures.data<float>(),
const_cast<float *>(logits.data<float>()),
bs,
length);
update_value_by_repeat_times(repeat_times.data<int>(),
penalty_scores.data<float>(),
frequency_score.data<float>(),
presence_score.data<float>(),
temperatures.data<float>(),
const_cast<float *>(logits.data<float>()),
bs,
length);
ban_bad_words(const_cast<float *>(logits.data<float>()),
bad_tokens.data<int64_t>(),
bs,
length,
length_bad_words);
ban_bad_words(const_cast<float *>(logits.data<float>()),
bad_tokens.data<int64_t>(),
bs,
length,
length_bad_words);
}
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
@@ -162,17 +159,17 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id) {
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
}
PD_BUILD_STATIC_OP(get_token_penalty_multi_scores_cpu)
+41 -41
View File
@@ -24,50 +24,50 @@ void update_inputs_kernel(bool *not_need_stop,
const int64_t *next_tokens,
const int bsz,
const int input_ids_stride) {
int64_t stop_sum = 0;
for (int bi = 0; bi < bsz; ++bi) {
bool stop_flag_now = false;
int64_t stop_flag_now_int = 0;
stop_flag_now = stop_flags[bi];
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
auto seq_len_this_time = seq_lens_this_time[bi];
auto seq_len_encoder = seq_lens_encoder[bi];
auto seq_len_decoder = seq_lens_decoder[bi];
seq_lens_decoder[bi] =
stop_flag_now ? 0
: (seq_len_decoder == 0 ? seq_len_encoder
: seq_len_decoder + 1);
seq_lens_this_time[bi] = stop_flag_now ? 0 : 1;
seq_lens_encoder[bi] = 0;
int64_t *input_ids_now = input_ids + bi * input_ids_stride;
input_ids_now[0] = next_tokens[bi];
stop_sum += stop_flag_now_int;
}
not_need_stop[0] = stop_sum < stop_nums[0];
int64_t stop_sum = 0;
for (int bi = 0; bi < bsz; ++bi) {
bool stop_flag_now = false;
int64_t stop_flag_now_int = 0;
stop_flag_now = stop_flags[bi];
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
auto seq_len_this_time = seq_lens_this_time[bi];
auto seq_len_encoder = seq_lens_encoder[bi];
auto seq_len_decoder = seq_lens_decoder[bi];
seq_lens_decoder[bi] =
stop_flag_now
? 0
: (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1);
seq_lens_this_time[bi] = stop_flag_now ? 0 : 1;
seq_lens_encoder[bi] = 0;
int64_t *input_ids_now = input_ids + bi * input_ids_stride;
input_ids_now[0] = next_tokens[bi];
stop_sum += stop_flag_now_int;
}
not_need_stop[0] = stop_sum < stop_nums[0];
}
void UpdateInputs(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &input_ids,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step) {
const int bsz = input_ids.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
update_inputs_kernel(const_cast<bool *>(not_need_stop.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
stop_nums.data<int64_t>(),
stop_flags.data<bool>(),
is_block_step.data<bool>(),
next_tokens.data<int64_t>(),
bsz,
input_ids_stride);
const paddle::Tensor &not_need_stop,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &input_ids,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step) {
const int bsz = input_ids.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
update_inputs_kernel(const_cast<bool *>(not_need_stop.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
stop_nums.data<int64_t>(),
stop_flags.data<bool>(),
is_block_step.data<bool>(),
next_tokens.data<int64_t>(),
bsz,
input_ids_stride);
}
PD_BUILD_STATIC_OP(update_inputs_cpu)
+4 -4
View File
@@ -45,18 +45,18 @@ std::vector<paddle::Tensor> InvokeAllLLaMALayer(
int maxPositions,
int maxPosEmbed,
int intermediateSize) {
auto out = paddle::empty_like(input);
return {out};
auto out = paddle::empty_like(input);
return {out};
}
std::vector<std::vector<int64_t>> AllLLaMALayerInferShape(
std::vector<int64_t> x_shape) {
return {x_shape};
return {x_shape};
}
std::vector<paddle::DataType> AllLLaMALayerInferDtype(
paddle::DataType x_dtype) {
return {x_dtype};
return {x_dtype};
}
PD_BUILD_STATIC_OP(xft_llama_all_layer)
+8 -8
View File
@@ -16,20 +16,20 @@
#include "paddle/extension.h"
std::vector<paddle::Tensor> XftGreedySearch(const paddle::Tensor &probs) {
const int bsz = probs.shape()[0];
const int vocab_size = probs.shape()[1];
auto next_tokens =
paddle::empty({bsz, 1}, paddle::DataType::INT64, probs.place());
return {next_tokens};
const int bsz = probs.shape()[0];
const int vocab_size = probs.shape()[1];
auto next_tokens =
paddle::empty({bsz, 1}, paddle::DataType::INT64, probs.place());
return {next_tokens};
}
std::vector<std::vector<int64_t>> XftGreedySearchInferShape(
const std::vector<int64_t> &probs_shape) {
int64_t bsz = probs_shape[0];
return {{bsz, 1}};
int64_t bsz = probs_shape[0];
return {{bsz, 1}};
}
std::vector<paddle::DataType> XftGreedySearchInferDtype(
const paddle::DataType &probs_dtype) {
return {paddle::DataType::INT64};
return {paddle::DataType::INT64};
}
PD_BUILD_STATIC_OP(xft_greedy_search)
.Inputs({"probs"})
+1 -1
View File
@@ -16,8 +16,8 @@
*/
#pragma once
#include <string>
#include <sstream>
#include <string>
#include "cub/cub.cuh"
namespace phi {
+80 -61
View File
@@ -19,8 +19,8 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include "fused_moe_imp_op.h"
#include "fused_moe_helper.h"
#include "fused_moe_imp_op.h"
// Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
@@ -34,8 +34,8 @@
namespace phi {
struct GpuLaunchConfig {
dim3 block_per_grid;
dim3 thread_per_block;
dim3 block_per_grid;
dim3 thread_per_block;
};
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
@@ -81,7 +81,6 @@ __launch_bounds__(TPB) __global__
cub::Sum sum;
float threadData(-FLT_MAX);
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
@@ -275,7 +274,8 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert]
: inputs_after_softmax[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
@@ -292,7 +292,9 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
output[idx] =
bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]
: result_kvp.value;
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
}
@@ -301,14 +303,15 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
__launch_bounds__(TPB) __global__
void moe_softmax_top_k_fused(const T* input,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
// softmax
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@@ -321,11 +324,12 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
return;
}
const int64_t thread_row_offset = globalIdx * num_experts;
const int64_t idx = thread_row_offset+threadIdx.x;
const int64_t idx = thread_row_offset + threadIdx.x;
cub::Sum sum;
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
float threadData =
(threadIdx.x < num_experts) ? static_cast<float>(input[idx]) : (-FLT_MAX);
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
@@ -377,7 +381,8 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int cur_idx = k * globalIdx + k_idx;
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
output[cur_idx] =
bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
indices[cur_idx] = result_kvp.key;
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
}
@@ -386,14 +391,15 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
__launch_bounds__(TPB) __global__
void moe_top_k_normed(const T* inputs_after_softmax,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@@ -422,7 +428,8 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert]
: inputs_after_softmax[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const int prior_winning_expert = indices[k * block_row + prior_k];
@@ -439,11 +446,14 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
// output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
// output[idx] = bias ? inputs_after_softmax[thread_read_offset +
// result_kvp.key]: result_kvp.value;
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
T row_out =
bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]
: result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
}
@@ -458,16 +468,16 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so
}
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
__launch_bounds__(TPB) __global__
void moe_softmax_top_k_normed_fused(const T* input,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
// softmax
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@@ -480,11 +490,12 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
return;
}
const int64_t thread_row_offset = globalIdx * num_experts;
const int64_t idx = thread_row_offset+threadIdx.x;
const int64_t idx = thread_row_offset + threadIdx.x;
cub::Sum sum;
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
float threadData =
(threadIdx.x < num_experts) ? static_cast<float>(input[idx]) : (-FLT_MAX);
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
@@ -542,7 +553,8 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
if (threadIdx.x == 0) {
const int cur_idx = k * globalIdx + k_idx;
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
T row_out =
bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
@@ -595,29 +607,36 @@ void topk_gating_softmax_kernelLauncher(const T* input,
if (topk_only_mode) {
static constexpr int TPB = 256;
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, gating_correction_bias, output, indices, source_row, num_experts, k, num_rows);
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
gating_correction_bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
return;
}
static constexpr int WARPS_PER_TB = 4;
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
}
int64_t tem_num_experts = num_experts;
if(gating_correction_bias != nullptr) tem_num_experts = 0;
if (gating_correction_bias != nullptr) tem_num_experts = 0;
switch (tem_num_experts) {
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
default: {
static constexpr int TPB = 256;
@@ -646,15 +665,15 @@ void topk_gating_softmax_kernelLauncher(const T* input,
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, softmax, num_experts, num_rows);
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
gating_correction_bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
softmax,
gating_correction_bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
}
}
+279 -243
View File
@@ -23,8 +23,8 @@ void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv,
const paddle::Tensor& decode_block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::Tensor& seq_lens,
const paddle::optional<paddle::Tensor> &rope_sin,
const paddle::optional<paddle::Tensor> &rope_cos,
const paddle::optional<paddle::Tensor>& rope_sin,
const paddle::optional<paddle::Tensor>& rope_cos,
int prefill_num_tokens,
int num_heads,
int head_dim,
@@ -42,318 +42,354 @@ void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv,
bool enable_cuda_graph,
bool use_sqrt_alibi,
paddle::Tensor& out) {
typedef PDTraits<T> traits_;
typedef typename traits_::data_t data_t;
typedef PDTraits<T> traits_;
typedef typename traits_::data_t data_t;
const auto& dtype = qkv.dtype();
cuinferDataType_t cuinfer_data_type;
cudaDataType_t cu_data_type;
if (dtype == paddle::DataType::FLOAT16) {
cuinfer_data_type = CUINFER_DATA_HALF;
cu_data_type = CUDA_R_16F;
} else {
cuinfer_data_type = CUINFER_DATA_BFLOAT16;
cu_data_type = CUDA_R_16BF;
}
const auto& dtype = qkv.dtype();
cuinferDataType_t cuinfer_data_type;
cudaDataType_t cu_data_type;
if (dtype == paddle::DataType::FLOAT16) {
cuinfer_data_type = CUINFER_DATA_HALF;
cu_data_type = CUDA_R_16F;
} else {
cuinfer_data_type = CUINFER_DATA_BFLOAT16;
cu_data_type = CUDA_R_16BF;
}
const auto& qkv_dims = qkv.dims();
const auto& kv_cache_dims = k_cache.dims();
const auto& prefill_block_table_dims = prefill_block_table.dims();
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
const auto& qkv_dims = qkv.dims();
const auto& kv_cache_dims = k_cache.dims();
const auto& prefill_block_table_dims = prefill_block_table.dims();
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
int prefill_batch_size = prefill_block_table_dims[0];
int num_tokens = qkv_dims[0];
int decode_num_tokens = num_tokens - prefill_num_tokens;
int num_total_heads = num_heads + 2 * num_kv_heads;
int max_num_blocks_per_seq = prefill_block_table_dims[1];
int qkv_stride = qkv.strides()[0];
int num_blocks = kv_cache_dims[0];
int prefill_batch_size = prefill_block_table_dims[0];
int num_tokens = qkv_dims[0];
int decode_num_tokens = num_tokens - prefill_num_tokens;
int num_total_heads = num_heads + 2 * num_kv_heads;
int max_num_blocks_per_seq = prefill_block_table_dims[1];
int qkv_stride = qkv.strides()[0];
int num_blocks = kv_cache_dims[0];
int kv_block_stride = k_cache.strides()[0];
int kv_head_stride = k_cache.strides()[1];
int block_table_stride = prefill_block_table.strides()[0];
const float* rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
const float* rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
int kv_block_stride = k_cache.strides()[0];
int kv_head_stride = k_cache.strides()[1];
int block_table_stride = prefill_block_table.strides()[0];
const float *rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
const float *rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
cuinferTensorDescriptor_t qkv_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t qkv_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
qkv_desc,
cuinfer_data_type,
3,
std::vector<int>({prefill_num_tokens, num_total_heads, head_dim}).data(),
std::vector<int>({num_total_heads * head_dim, head_dim, 1}).data()));
cuinferTensorDescriptor_t qkv_seqlens_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t qkv_seqlens_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
qkv_seqlens_desc,
CUINFER_DATA_INT32,
1,
std::vector<int>({prefill_batch_size + 1}).data(),
std::vector<int>({1}).data()));
cuinferTensorDescriptor_t block_table_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t block_table_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
block_table_desc,
CUINFER_DATA_INT32,
2,
std::vector<int>({prefill_batch_size, block_table_stride}).data(),
std::vector<int>({block_table_stride, 1}).data()));
cuinferTensorDescriptor_t o_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t o_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
o_desc,
cuinfer_data_type,
3,
std::vector<int>({prefill_num_tokens, num_heads, head_dim}).data(),
std::vector<int>({num_heads * head_dim, head_dim, 1}).data()));
cuinferTensorDescriptor_t k_cache_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t k_cache_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
k_cache_desc,
cuinfer_data_type,
4,
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
std::vector<int>({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data()));
std::vector<int>({num_kv_heads * block_size * head_dim,
block_size * head_dim,
head_dim,
1})
.data()));
cuinferTensorDescriptor_t v_cache_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t v_cache_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
v_cache_desc,
cuinfer_data_type,
4,
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
std::vector<int>({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data()));
std::vector<int>({num_kv_heads * block_size * head_dim,
block_size * head_dim,
head_dim,
1})
.data()));
cuinferTensorDescriptor_t cos_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t cos_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cos_desc,
CUINFER_DATA_FLOAT,
2,
std::vector<int>({max_seq_len, head_dim}).data(),
std::vector<int>({head_dim, 1}).data()));
cuinferTensorDescriptor_t sin_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t sin_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
sin_desc,
CUINFER_DATA_FLOAT,
2,
std::vector<int>({max_seq_len, head_dim}).data(),
std::vector<int>({head_dim, 1}).data()));
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
cuinferHandle_t cuinfer_handle =
iluvatar::getContextInstance()->getIxInferHandle();
size_t prefill_workspace_size = 0;
CUINFER_CHECK(cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(prefill_num_tokens,
num_heads,
num_kv_heads,
head_dim,
q_rope,
k_rope,
v_rope,
cuinfer_data_type,
cuinfer_data_type,
cuinfer_data_type,
&prefill_workspace_size));
auto* allocator = paddle::GetAllocator(qkv.place());
phi::Allocator::AllocationPtr prefill_tmp_workspace = allocator->Allocate(prefill_workspace_size);
void* prefill_workspace_ptr = prefill_tmp_workspace->ptr();
CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
qkv_desc,
qkv.data(),
qkv_seqlens_desc,
cu_seqlens_qkv.data<int32_t>(),
block_table_desc,
prefill_block_table.data<int32_t>(),
o_desc,
out.data(),
k_cache_desc,
k_cache.data(),
v_cache_desc,
v_cache.data(),
prefill_workspace_ptr,
prefill_workspace_size,
cos_desc,
rope_cos_ptr,
sin_desc,
rope_sin_ptr,
prefill_batch_size,
size_t prefill_workspace_size = 0;
CUINFER_CHECK(
cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(prefill_num_tokens,
num_heads,
num_kv_heads,
head_dim,
causal,
scale,
q_rope,
k_rope,
v_rope));
v_rope,
cuinfer_data_type,
cuinfer_data_type,
cuinfer_data_type,
&prefill_workspace_size));
size_t decode_workspace_size = 0;
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(decode_num_tokens,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_seq_len,
&decode_workspace_size));
auto* allocator = paddle::GetAllocator(qkv.place());
phi::Allocator::AllocationPtr decode_tmp_workspace = allocator->Allocate(decode_workspace_size);
void* decode_workspace_ptr = decode_tmp_workspace->ptr();
phi::Allocator::AllocationPtr prefill_tmp_workspace =
allocator->Allocate(prefill_workspace_size);
void* prefill_workspace_ptr = prefill_tmp_workspace->ptr();
void* decode_qkv_ptr = (void*)(qkv.data<data_t>() + prefill_num_tokens * qkv_stride);
void* decode_out_ptr = (void*)(out.data<data_t>() + prefill_num_tokens * out.strides()[0]);
CUINFER_CHECK(
cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
qkv_desc,
qkv.data(),
qkv_seqlens_desc,
cu_seqlens_qkv.data<int32_t>(),
block_table_desc,
prefill_block_table.data<int32_t>(),
o_desc,
out.data(),
k_cache_desc,
k_cache.data(),
v_cache_desc,
v_cache.data(),
prefill_workspace_ptr,
prefill_workspace_size,
cos_desc,
rope_cos_ptr,
sin_desc,
rope_sin_ptr,
prefill_batch_size,
num_heads,
num_kv_heads,
head_dim,
causal,
scale,
q_rope,
k_rope,
v_rope));
PageAttentionWithKVCacheArguments args{
static_cast<float>(scale), 1.0, 1.0, static_cast<float>(softcap), window_left, window_right,
causal, use_sqrt_alibi, enable_cuda_graph, false, nullptr, decode_qkv_ptr, decode_qkv_ptr,
decode_workspace_ptr, true, rope_sin_ptr, rope_cos_ptr};
size_t decode_workspace_size = 0;
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(decode_num_tokens,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_seq_len,
&decode_workspace_size));
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
decode_out_ptr,
cu_data_type,
phi::Allocator::AllocationPtr decode_tmp_workspace =
allocator->Allocate(decode_workspace_size);
void* decode_workspace_ptr = decode_tmp_workspace->ptr();
void* decode_qkv_ptr =
(void*)(qkv.data<data_t>() + prefill_num_tokens * qkv_stride);
void* decode_out_ptr =
(void*)(out.data<data_t>() + prefill_num_tokens * out.strides()[0]);
PageAttentionWithKVCacheArguments args{static_cast<float>(scale),
1.0,
1.0,
static_cast<float>(softcap),
window_left,
window_right,
causal,
use_sqrt_alibi,
enable_cuda_graph,
false,
nullptr,
decode_qkv_ptr,
cu_data_type,
decode_num_tokens,
num_heads,
num_kv_heads,
head_dim,
qkv_stride,
kv_block_stride,
kv_head_stride,
k_cache.data(),
cu_data_type,
v_cache.data(),
cu_data_type,
block_size,
max_num_blocks_per_seq,
max_seq_len,
decode_block_table.data<int32_t>(),
seq_lens.data<int32_t>(),
args));
decode_qkv_ptr,
decode_workspace_ptr,
true,
rope_sin_ptr,
rope_cos_ptr};
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
decode_out_ptr,
cu_data_type,
decode_qkv_ptr,
cu_data_type,
decode_num_tokens,
num_heads,
num_kv_heads,
head_dim,
qkv_stride,
kv_block_stride,
kv_head_stride,
k_cache.data(),
cu_data_type,
v_cache.data(),
cu_data_type,
block_size,
max_num_blocks_per_seq,
max_seq_len,
decode_block_table.data<int32_t>(),
seq_lens.data<int32_t>(),
args));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
}
std::vector<paddle::Tensor> MixedFusedPagedAttn(const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& prefill_block_table,
const paddle::Tensor& decode_block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::Tensor& seq_lens,
const paddle::optional<paddle::Tensor> &rope_sin,
const paddle::optional<paddle::Tensor> &rope_cos,
int prefill_num_tokens,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi) {
const auto dtype = qkv.dtype();
auto out = paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
std::vector<paddle::Tensor> MixedFusedPagedAttn(
const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& prefill_block_table,
const paddle::Tensor& decode_block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::Tensor& seq_lens,
const paddle::optional<paddle::Tensor>& rope_sin,
const paddle::optional<paddle::Tensor>& rope_cos,
int prefill_num_tokens,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi) {
const auto dtype = qkv.dtype();
auto out =
paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
switch (dtype) {
case paddle::DataType::BFLOAT16:
MixedFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
k_cache,
v_cache,
prefill_block_table,
decode_block_table,
cu_seqlens_qkv,
seq_lens,
rope_sin,
rope_cos,
prefill_num_tokens,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
out);
break;
case paddle::DataType::FLOAT16:
MixedFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
k_cache,
v_cache,
prefill_block_table,
decode_block_table,
cu_seqlens_qkv,
seq_lens,
rope_sin,
rope_cos,
prefill_num_tokens,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
out);
break;
default:
PD_THROW("Unsupported data type for mixed paged attn");
}
return {out};
switch (dtype) {
case paddle::DataType::BFLOAT16:
MixedFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
k_cache,
v_cache,
prefill_block_table,
decode_block_table,
cu_seqlens_qkv,
seq_lens,
rope_sin,
rope_cos,
prefill_num_tokens,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
out);
break;
case paddle::DataType::FLOAT16:
MixedFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
k_cache,
v_cache,
prefill_block_table,
decode_block_table,
cu_seqlens_qkv,
seq_lens,
rope_sin,
rope_cos,
prefill_num_tokens,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
out);
break;
default:
PD_THROW("Unsupported data type for mixed paged attn");
}
return {out};
}
std::vector<std::vector<int64_t>> MixedFusedPagedAttnInferShape(const std::vector<int64_t>& qkv_shape,
int num_heads,
int head_dim) {
return {{qkv_shape[0], num_heads * head_dim}};
std::vector<std::vector<int64_t>> MixedFusedPagedAttnInferShape(
const std::vector<int64_t>& qkv_shape, int num_heads, int head_dim) {
return {{qkv_shape[0], num_heads * head_dim}};
}
std::vector<paddle::DataType> MixedFusedPagedAttnInferDtype(const paddle::DataType& qkv_dtype) {
return {qkv_dtype};
std::vector<paddle::DataType> MixedFusedPagedAttnInferDtype(
const paddle::DataType& qkv_dtype) {
return {qkv_dtype};
}
PD_BUILD_STATIC_OP(mixed_fused_paged_attn)
.Inputs({"qkv", "k_cache", "v_cache", "prefill_block_table", "decode_block_table",
"cu_seqlens_qkv", "seq_lens", paddle::Optional("rope_sin"), paddle::Optional("rope_cos")})
.Inputs({"qkv",
"k_cache",
"v_cache",
"prefill_block_table",
"decode_block_table",
"cu_seqlens_qkv",
"seq_lens",
paddle::Optional("rope_sin"),
paddle::Optional("rope_cos")})
.Outputs({"out"})
.Attrs({"prefill_num_tokens:int",
"num_heads: int",
@@ -362,14 +398,14 @@ PD_BUILD_STATIC_OP(mixed_fused_paged_attn)
"block_size:int",
"max_seq_len:int",
"scale:float",
"causal:bool",
"q_rope:bool",
"causal:bool",
"q_rope:bool",
"k_rope:bool",
"v_rope:bool",
"window_left:int",
"window_right:int",
"softcap:float",
"enable_cuda_graph:bool",
"enable_cuda_graph:bool",
"use_sqrt_alibi:bool"})
.SetKernelFn(PD_KERNEL(MixedFusedPagedAttn))
.SetInferShapeFn(PD_INFER_SHAPE(MixedFusedPagedAttnInferShape))
+55 -54
View File
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
@@ -29,10 +28,10 @@ __global__ void compute_total_rows_before_expert_kernel(
const int64_t sorted_experts_len,
const int64_t num_experts,
int64_t* total_rows_before_expert) {
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
if (expert >= num_experts) return;
total_rows_before_expert[expert] =
phi::find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
if (expert >= num_experts) return;
total_rows_before_expert[expert] = phi::find_total_elts_leq_target(
sorted_experts, sorted_experts_len, expert);
}
void compute_total_rows_before_expert(int* sorted_indices,
@@ -40,36 +39,38 @@ void compute_total_rows_before_expert(int* sorted_indices,
const int64_t num_experts,
int64_t* total_rows_before_expert,
cudaStream_t stream) {
const int threads = std::min(int64_t(1024), num_experts);
const int blocks = (num_experts + threads - 1) / threads;
const int threads = std::min(int64_t(1024), num_experts);
const int blocks = (num_experts + threads - 1) / threads;
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, total_rows_before_expert);
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, total_rows_before_expert);
}
template <paddle::DataType T>
void MoeDispatchKernel(const paddle::Tensor& input,
const paddle::Tensor& gating_output,
const paddle::optional<paddle::Tensor>& gating_correction_bias,
const int moe_topk,
const bool group_moe,
const std::string &moe_quant_type,
const bool topk_only_mode,
const int num_rows,
const int hidden_size,
const int expert_num,
paddle::Tensor* permute_input,
paddle::Tensor* tokens_expert_prefix_sum,
paddle::Tensor* permute_indices_per_token,
paddle::Tensor* top_k_weight,
paddle::Tensor* top_k_indices) {
void MoeDispatchKernel(
const paddle::Tensor& input,
const paddle::Tensor& gating_output,
const paddle::optional<paddle::Tensor>& gating_correction_bias,
const int moe_topk,
const bool group_moe,
const std::string& moe_quant_type,
const bool topk_only_mode,
const int num_rows,
const int hidden_size,
const int expert_num,
paddle::Tensor* permute_input,
paddle::Tensor* tokens_expert_prefix_sum,
paddle::Tensor* permute_indices_per_token,
paddle::Tensor* top_k_weight,
paddle::Tensor* top_k_indices) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto place = input.place();
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input.place()));
auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(input.place()));
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
if (group_moe) {
// Check if expert_num is divisible by moe_topk, else throw an error
@@ -131,19 +132,21 @@ void MoeDispatchKernel(const paddle::Tensor& input,
softmax_out_ = nullptr;
}
topk_gating_softmax_kernelLauncher<float>(gating_output.data<float>(),
gating_correction_bias ? gating_correction_bias.get().data<float>() : nullptr,
top_k_weight->data<float>(),
softmax_out_,
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
expert_num,
moe_topk,
group_moe,
stream,
topk_only_mode);
topk_gating_softmax_kernelLauncher<float>(
gating_output.data<float>(),
gating_correction_bias ? gating_correction_bias.get().data<float>()
: nullptr,
top_k_weight->data<float>(),
softmax_out_,
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
expert_num,
moe_topk,
group_moe,
stream,
topk_only_mode);
sorter_.run(reinterpret_cast<void*>(sorter_ws_ptr),
sorter_ws_size_bytes,
@@ -155,7 +158,6 @@ void MoeDispatchKernel(const paddle::Tensor& input,
false,
stream);
initialize_moe_routing_kernelLauncher(
input.data<data_t>(),
permute_input->data<data_t>(),
@@ -167,16 +169,13 @@ void MoeDispatchKernel(const paddle::Tensor& input,
moe_topk,
stream);
compute_total_rows_before_expert(
permuted_experts_,
moe_topk * num_rows,
expert_num,
tokens_expert_prefix_sum->data<int64_t>(),
stream);
compute_total_rows_before_expert(permuted_experts_,
moe_topk * num_rows,
expert_num,
tokens_expert_prefix_sum->data<int64_t>(),
stream);
}
std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor& input,
const paddle::Tensor& gating_output,
@@ -184,7 +183,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::optional<paddle::Tensor>& w4a8_in_scale,
const int moe_topk,
const bool group_moe,
const std::string &moe_quant_type,
const std::string& moe_quant_type,
const bool topk_only_mode) {
const auto input_type = input.dtype();
auto place = input.place();
@@ -214,7 +213,6 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
auto permute_indices_per_token =
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
@@ -261,7 +259,6 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
top_k_indices};
}
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& gating_output_shape,
@@ -299,17 +296,21 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
paddle::DataType::INT32};
}
PD_BUILD_STATIC_OP(moe_expert_dispatch)
.Inputs({"input", "gating_output", paddle::Optional("gating_correction_bias"),
paddle::Optional("w4a8_in_scale")})
.Inputs({"input",
"gating_output",
paddle::Optional("gating_correction_bias"),
paddle::Optional("w4a8_in_scale")})
.Outputs({"permute_input",
"tokens_expert_prefix_sum",
"permute_indices_per_token",
"top_k_weight",
"top_k_indices",
"expert_idx_per_token"})
.Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"})
.Attrs({"moe_topk:int",
"group_moe:bool",
"moe_quant_type:std::string",
"topk_only_mode:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
+68 -69
View File
@@ -16,9 +16,9 @@
#pragma once
#include "helper.h"
#include "fused_moe_helper.h"
#include "fused_moe_op.h"
#include "helper.h"
template <paddle::DataType T>
void MoeReduceKernel(const paddle::Tensor& ffn_out,
@@ -32,27 +32,28 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
const int hidden_size,
const int topk,
paddle::Tensor* output) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(ffn_out.place()));
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(ffn_out.place()));
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
finalize_moe_routing_kernelLauncher(
ffn_out.data<data_t>(),
output->data<data_t>(),
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
top_k_weight.data<float>(),
permute_indices_per_token.data<int32_t>(),
top_k_indices.data<int>(),
num_rows,
hidden_size,
topk,
static_cast<int>(1),
norm_topk_prob,
routed_scaling_factor,
stream);
finalize_moe_routing_kernelLauncher(
ffn_out.data<data_t>(),
output->data<data_t>(),
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
top_k_weight.data<float>(),
permute_indices_per_token.data<int32_t>(),
top_k_indices.data<int>(),
num_rows,
hidden_size,
topk,
static_cast<int>(1),
norm_topk_prob,
routed_scaling_factor,
stream);
}
paddle::Tensor MoeExpertReduceFunc(
@@ -63,48 +64,46 @@ paddle::Tensor MoeExpertReduceFunc(
const paddle::optional<paddle::Tensor>& down_proj_bias,
const bool norm_topk_prob,
const float routed_scaling_factor) {
const auto input_type = ffn_out.dtype();
auto place = ffn_out.place();
const auto input_type = ffn_out.dtype();
auto place = ffn_out.place();
const int topk = top_k_indices.dims()[1];
const int num_rows = ffn_out.dims()[0] / topk;
const int hidden_size = ffn_out.dims()[1];
const int topk = top_k_indices.dims()[1];
const int num_rows = ffn_out.dims()[0] / topk;
const int hidden_size = ffn_out.dims()[1];
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(
ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
hidden_size,
topk,
&output);
break;
case paddle::DataType::FLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(
ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
hidden_size,
topk,
&output);
break;
default:
PD_THROW("Unsupported data type for MoeDispatchKernel");
}
return output;
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
hidden_size,
topk,
&output);
break;
case paddle::DataType::FLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
hidden_size,
topk,
&output);
break;
default:
PD_THROW("Unsupported data type for MoeDispatchKernel");
}
return output;
}
std::vector<paddle::Tensor> MoeExpertReduce(
@@ -115,13 +114,13 @@ std::vector<paddle::Tensor> MoeExpertReduce(
const paddle::optional<paddle::Tensor>& down_proj_bias,
const bool norm_topk_prob,
const float routed_scaling_factor) {
return {MoeExpertReduceFunc(ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor)};
return {MoeExpertReduceFunc(ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor)};
}
std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
@@ -130,7 +129,7 @@ std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
const std::vector<int64_t>& permute_indices_per_token_shape,
const std::vector<int64_t>& top_k_indices_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape) {
return {ffn_out_shape};
return {ffn_out_shape};
}
std::vector<paddle::DataType> MoeExpertReduceInferDtype(
@@ -139,7 +138,7 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
const paddle::DataType& permute_indices_per_token_dtype,
const paddle::DataType& top_k_indices_dtype,
const paddle::optional<paddle::DataType>& down_proj_bias_dtype) {
return {ffn_out_dtype};
return {ffn_out_dtype};
}
PD_BUILD_STATIC_OP(moe_expert_reduce)
+305 -278
View File
@@ -15,18 +15,17 @@
#include "helper.h"
#include "iluvatar_context.h"
template <paddle::DataType T>
void PagedAttnKernel(const paddle::Tensor& q,
const paddle::Tensor& k_cache,
const paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& seq_lens,
const paddle::optional<paddle::Tensor> &alibi_slopes,
const paddle::optional<paddle::Tensor> &k,
const paddle::optional<paddle::Tensor> &v,
const paddle::optional<paddle::Tensor> &rope_sin,
const paddle::optional<paddle::Tensor> &rope_cos,
const paddle::optional<paddle::Tensor>& alibi_slopes,
const paddle::optional<paddle::Tensor>& k,
const paddle::optional<paddle::Tensor>& v,
const paddle::optional<paddle::Tensor>& rope_sin,
const paddle::optional<paddle::Tensor>& rope_cos,
int num_heads,
int head_dim,
int num_kv_heads,
@@ -41,298 +40,326 @@ void PagedAttnKernel(const paddle::Tensor& q,
bool use_sqrt_alibi,
bool merged_qkv,
paddle::Tensor& out) {
if (alibi_slopes) {
PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(),
paddle::DataType::FLOAT32,
common::errors::InvalidArgument(
"paged_attention expects alibi_slopes float tensor"));
PADDLE_ENFORCE_EQ(alibi_slopes.get().is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects alibi_slopes is contiguous"));
}
if (alibi_slopes) {
PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(),
paddle::DataType::FLOAT32,
common::errors::InvalidArgument(
"paged_attention expects alibi_slopes float tensor"));
PADDLE_ENFORCE_EQ(
alibi_slopes.get().is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects alibi_slopes is contiguous"));
}
// check dtype and contiguous
const auto& dtype = q.dtype();
cudaDataType_t data_type;
if (dtype == paddle::DataType::FLOAT16) {
data_type = CUDA_R_16F;
} else if (dtype == paddle::DataType::BFLOAT16) {
data_type = CUDA_R_16BF;
} else {
common::errors::InvalidArgument("paged_attention support half and bfloat16 now");
}
// check dtype and contiguous
const auto& dtype = q.dtype();
cudaDataType_t data_type;
if (dtype == paddle::DataType::FLOAT16) {
data_type = CUDA_R_16F;
} else if (dtype == paddle::DataType::BFLOAT16) {
data_type = CUDA_R_16BF;
} else {
common::errors::InvalidArgument(
"paged_attention support half and bfloat16 now");
}
PADDLE_ENFORCE_EQ(k_cache.dtype(),
dtype,
common::errors::InvalidArgument(
"k_cache dtype must be the same as query dtype"));
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects k_cache is contiguous"));
PADDLE_ENFORCE_EQ(block_table.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument(
"block_table dtype must be int32"));
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects block_table is contiguous"));
PADDLE_ENFORCE_EQ(seq_lens.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument(
"seq_lens dtype must be int32"));
PADDLE_ENFORCE_EQ(seq_lens.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects seq_lens is contiguous"));
// check dim and shape
// k_cache: [num_blocks, kv_num_heads, block_size, head_dim]
// v_cache: [num_blocks, kv_num_heads, block_size, head_dim]
// block_table: [num_seqs, max_num_blocks_per_seq]
// seq_lens: [num_seqs]
// q and out:
// if merged_qkv = false:
// q:[num_seqs, hidden_size]
// out:[num_seqs, hidden_size]
// if merged_qkv = true:
// q: [num_seqs, (num_heads+2*num_kv_heads)*head_dim]
// out: [num_seqs, hidden_size]
PADDLE_ENFORCE_EQ(k_cache.dtype(),
dtype,
common::errors::InvalidArgument(
"k_cache dtype must be the same as query dtype"));
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects k_cache is contiguous"));
PADDLE_ENFORCE_EQ(
block_table.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument("block_table dtype must be int32"));
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects block_table is contiguous"));
PADDLE_ENFORCE_EQ(
seq_lens.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument("seq_lens dtype must be int32"));
PADDLE_ENFORCE_EQ(seq_lens.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects seq_lens is contiguous"));
// check dim and shape
// k_cache: [num_blocks, kv_num_heads, block_size, head_dim]
// v_cache: [num_blocks, kv_num_heads, block_size, head_dim]
// block_table: [num_seqs, max_num_blocks_per_seq]
// seq_lens: [num_seqs]
// q and out:
// if merged_qkv = false:
// q:[num_seqs, hidden_size]
// out:[num_seqs, hidden_size]
// if merged_qkv = true:
// q: [num_seqs, (num_heads+2*num_kv_heads)*head_dim]
// out: [num_seqs, hidden_size]
const auto& q_dims = q.dims();
PADDLE_ENFORCE_EQ(q_dims.size(),
2,
common::errors::InvalidArgument(
"paged_attn receive query dims is "
"[num_seqs, (num_heads+2*num_kv_heads)*head_dim]"));
PADDLE_ENFORCE_EQ(out.dims().size(),
2,
common::errors::InvalidArgument(
"paged_attn receive out dims is "
"[num_seqs, hidden_size]"));
const auto& q_dims = q.dims();
PADDLE_ENFORCE_EQ(q_dims.size(),
2,
common::errors::InvalidArgument(
"paged_attn receive query dims is "
"[num_seqs, (num_heads+2*num_kv_heads)*head_dim]"));
PADDLE_ENFORCE_EQ(
out.dims().size(),
2,
common::errors::InvalidArgument("paged_attn receive out dims is "
"[num_seqs, hidden_size]"));
const auto& kv_cache_dims = k_cache.dims();
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
4,
common::errors::InvalidArgument(
"paged_attn receive kv cache dims is "
"[num_blocks, kv_num_heads, block_size, head_dim]"));
const auto& kv_cache_dims = k_cache.dims();
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
4,
common::errors::InvalidArgument(
"paged_attn receive kv cache dims is "
"[num_blocks, kv_num_heads, block_size, head_dim]"));
const auto& block_table_dims = block_table.dims();
PADDLE_ENFORCE_EQ(block_table_dims.size(),
2,
common::errors::InvalidArgument(
"paged_attn receive block_table dims is "
"[num_seqs, max_num_blocks_per_seq]"));
const auto& block_table_dims = block_table.dims();
PADDLE_ENFORCE_EQ(
block_table_dims.size(),
2,
common::errors::InvalidArgument("paged_attn receive block_table dims is "
"[num_seqs, max_num_blocks_per_seq]"));
const auto& seq_lens_dims = seq_lens.dims();
PADDLE_ENFORCE_EQ(seq_lens_dims.size(),
1,
common::errors::InvalidArgument(
"paged_attn receive seq_lens dims is [num_seqs]"));
const auto& seq_lens_dims = seq_lens.dims();
PADDLE_ENFORCE_EQ(seq_lens_dims.size(),
1,
common::errors::InvalidArgument(
"paged_attn receive seq_lens dims is [num_seqs]"));
int num_seqs = q_dims[0];
int max_num_blocks_per_seq = block_table_dims[1];
int q_stride = q.strides()[0];
int num_blocks = kv_cache_dims[0];
int num_seqs = q_dims[0];
int max_num_blocks_per_seq = block_table_dims[1];
int q_stride = q.strides()[0];
int num_blocks = kv_cache_dims[0];
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
num_kv_heads,
common::errors::InvalidArgument(
"kv_cache_dims[1] must be equal to num_kv_head"));
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
block_size,
common::errors::InvalidArgument(
"kv_cache_dims[2] must be equal to block_size"));
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
head_dim,
common::errors::InvalidArgument(
"kv_cache_dims[3] must be equal to head_dim"));
PADDLE_ENFORCE_EQ(block_table_dims[0],
num_seqs,
common::errors::InvalidArgument(
"block_table_dims[0] must be equal to num_seqs"));
PADDLE_ENFORCE_EQ(seq_lens_dims[0],
num_seqs,
common::errors::InvalidArgument(
"seq_lens_dims[0] must be equal to num_seqs"));
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
num_kv_heads,
common::errors::InvalidArgument(
"kv_cache_dims[1] must be equal to num_kv_head"));
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
block_size,
common::errors::InvalidArgument(
"kv_cache_dims[2] must be equal to block_size"));
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
head_dim,
common::errors::InvalidArgument(
"kv_cache_dims[3] must be equal to head_dim"));
PADDLE_ENFORCE_EQ(block_table_dims[0],
num_seqs,
common::errors::InvalidArgument(
"block_table_dims[0] must be equal to num_seqs"));
PADDLE_ENFORCE_EQ(seq_lens_dims[0],
num_seqs,
common::errors::InvalidArgument(
"seq_lens_dims[0] must be equal to num_seqs"));
int kv_block_stride = k_cache.strides()[0];
int kv_head_stride = k_cache.strides()[1];
const float *alibi_slopes_ptr = alibi_slopes ? alibi_slopes.get().data<float>() : nullptr;
const void *key_ptr = k ? k.get().data() : nullptr;
const void *value_ptr = v ? v.get().data() : nullptr;
const float *rope_sin_ptr = merged_qkv ? rope_sin.get().data<float>() : nullptr;
const float *rope_cos_ptr = merged_qkv ? rope_cos.get().data<float>() : nullptr;
int kv_block_stride = k_cache.strides()[0];
int kv_head_stride = k_cache.strides()[1];
const float* alibi_slopes_ptr =
alibi_slopes ? alibi_slopes.get().data<float>() : nullptr;
const void* key_ptr = k ? k.get().data() : nullptr;
const void* value_ptr = v ? v.get().data() : nullptr;
const float* rope_sin_ptr =
merged_qkv ? rope_sin.get().data<float>() : nullptr;
const float* rope_cos_ptr =
merged_qkv ? rope_cos.get().data<float>() : nullptr;
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
cuinferHandle_t cuinfer_handle =
iluvatar::getContextInstance()->getIxInferHandle();
size_t workspace_size = 0;
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(num_seqs,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_context_len,
&workspace_size));
auto* allocator = paddle::GetAllocator(q.place());
phi::Allocator::AllocationPtr tmp_workspace = allocator->Allocate(workspace_size);
void* workspace_ptr = tmp_workspace->ptr();
size_t workspace_size = 0;
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(num_seqs,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_context_len,
&workspace_size));
auto* allocator = paddle::GetAllocator(q.place());
phi::Allocator::AllocationPtr tmp_workspace =
allocator->Allocate(workspace_size);
void* workspace_ptr = tmp_workspace->ptr();
PageAttentionWithKVCacheArguments args{
static_cast<float>(scale), 1.0, 1.0, static_cast<float>(softcap), window_left, window_right,
causal, use_sqrt_alibi, enable_cuda_graph, false, alibi_slopes_ptr, key_ptr, value_ptr,
workspace_ptr, merged_qkv, rope_sin_ptr, rope_cos_ptr};
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
out.data(),
data_type,
q.data(),
data_type,
num_seqs,
num_heads,
num_kv_heads,
head_dim,
q_stride,
kv_block_stride,
kv_head_stride,
k_cache.data(),
data_type,
v_cache.data(),
data_type,
block_size,
max_num_blocks_per_seq,
max_context_len,
block_table.data<int32_t>(),
seq_lens.data<int32_t>(),
args));
PageAttentionWithKVCacheArguments args{static_cast<float>(scale),
1.0,
1.0,
static_cast<float>(softcap),
window_left,
window_right,
causal,
use_sqrt_alibi,
enable_cuda_graph,
false,
alibi_slopes_ptr,
key_ptr,
value_ptr,
workspace_ptr,
merged_qkv,
rope_sin_ptr,
rope_cos_ptr};
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
out.data(),
data_type,
q.data(),
data_type,
num_seqs,
num_heads,
num_kv_heads,
head_dim,
q_stride,
kv_block_stride,
kv_head_stride,
k_cache.data(),
data_type,
v_cache.data(),
data_type,
block_size,
max_num_blocks_per_seq,
max_context_len,
block_table.data<int32_t>(),
seq_lens.data<int32_t>(),
args));
}
std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
const paddle::Tensor& k_cache,
const paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& seq_lens,
const paddle::optional<paddle::Tensor> &alibi_slopes,
const paddle::optional<paddle::Tensor> &k,
const paddle::optional<paddle::Tensor> &v,
const paddle::optional<paddle::Tensor> &rope_sin,
const paddle::optional<paddle::Tensor> &rope_cos,
int num_heads,
int head_dim,
int num_kv_heads,
float scale,
int block_size,
int max_context_len,
bool causal,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
bool merged_qkv) {
std::vector<paddle::Tensor> PagedAttn(
const paddle::Tensor& q,
const paddle::Tensor& k_cache,
const paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& seq_lens,
const paddle::optional<paddle::Tensor>& alibi_slopes,
const paddle::optional<paddle::Tensor>& k,
const paddle::optional<paddle::Tensor>& v,
const paddle::optional<paddle::Tensor>& rope_sin,
const paddle::optional<paddle::Tensor>& rope_cos,
int num_heads,
int head_dim,
int num_kv_heads,
float scale,
int block_size,
int max_context_len,
bool causal,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
bool merged_qkv) {
const auto dtype = q.dtype();
auto out =
paddle::empty({q.shape()[0], num_heads * head_dim}, dtype, q.place());
const auto dtype = q.dtype();
auto out = paddle::empty({q.shape()[0], num_heads * head_dim}, dtype, q.place());
switch (dtype) {
case paddle::DataType::BFLOAT16:
PagedAttnKernel<paddle::DataType::BFLOAT16>(q,
k_cache,
v_cache,
block_table,
seq_lens,
alibi_slopes,
k,
v,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
scale,
block_size,
max_context_len,
causal,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
merged_qkv,
out);
break;
case paddle::DataType::FLOAT16:
PagedAttnKernel<paddle::DataType::FLOAT16>(q,
k_cache,
v_cache,
block_table,
seq_lens,
alibi_slopes,
k,
v,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
scale,
block_size,
max_context_len,
causal,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
merged_qkv,
out);
break;
default:
PD_THROW("Unsupported data type for Paged attn");
}
return {out};
switch (dtype) {
case paddle::DataType::BFLOAT16:
PagedAttnKernel<paddle::DataType::BFLOAT16>(q,
k_cache,
v_cache,
block_table,
seq_lens,
alibi_slopes,
k,
v,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
scale,
block_size,
max_context_len,
causal,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
merged_qkv,
out);
break;
case paddle::DataType::FLOAT16:
PagedAttnKernel<paddle::DataType::FLOAT16>(q,
k_cache,
v_cache,
block_table,
seq_lens,
alibi_slopes,
k,
v,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
scale,
block_size,
max_context_len,
causal,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
merged_qkv,
out);
break;
default:
PD_THROW("Unsupported data type for Paged attn");
}
return {out};
}
std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>& q_shape,
const std::vector<int64_t>& k_cache_shape,
const std::vector<int64_t>& v_cache_shape,
const std::vector<int64_t>& block_table_shape,
const std::vector<int64_t>& seq_lens_shape,
const std::vector<int64_t>& alibi_slopes_shape,
const std::vector<int64_t>& k_shape,
const std::vector<int64_t>& v_shape,
const std::vector<int64_t>& rope_sin_shape,
const std::vector<int64_t>& rope_cos_shape,
int num_heads,
int head_dim,
int num_kv_heads,
float scale,
int block_size,
int max_context_len,
bool causal,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
bool merged_qkv) {
if (merged_qkv) {
return {{q_shape[0], num_heads * head_dim}};
} else {
return {q_shape};
}
std::vector<std::vector<int64_t>> PagedAttnInferShape(
const std::vector<int64_t>& q_shape,
const std::vector<int64_t>& k_cache_shape,
const std::vector<int64_t>& v_cache_shape,
const std::vector<int64_t>& block_table_shape,
const std::vector<int64_t>& seq_lens_shape,
const std::vector<int64_t>& alibi_slopes_shape,
const std::vector<int64_t>& k_shape,
const std::vector<int64_t>& v_shape,
const std::vector<int64_t>& rope_sin_shape,
const std::vector<int64_t>& rope_cos_shape,
int num_heads,
int head_dim,
int num_kv_heads,
float scale,
int block_size,
int max_context_len,
bool causal,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
bool merged_qkv) {
if (merged_qkv) {
return {{q_shape[0], num_heads * head_dim}};
} else {
return {q_shape};
}
}
std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtype) {
return {q_dtype};
std::vector<paddle::DataType> PagedAttnInferDtype(
const paddle::DataType& q_dtype) {
return {q_dtype};
}
PD_BUILD_STATIC_OP(paged_attn)
.Inputs({"q", "k_cache", "v_cache", "block_table", "seq_lens",
paddle::Optional("alibi_slopes"), paddle::Optional("k"),
paddle::Optional("v"), paddle::Optional("rope_sin"),
.Inputs({"q",
"k_cache",
"v_cache",
"block_table",
"seq_lens",
paddle::Optional("alibi_slopes"),
paddle::Optional("k"),
paddle::Optional("v"),
paddle::Optional("rope_sin"),
paddle::Optional("rope_cos")})
.Outputs({"out"})
.Attrs({"num_heads:int",
@@ -341,11 +368,11 @@ PD_BUILD_STATIC_OP(paged_attn)
"scale:float",
"block_size:int",
"max_context_len:int",
"causal:bool",
"causal:bool",
"window_left:int",
"window_right:int",
"softcap:float",
"enable_cuda_graph:bool",
"enable_cuda_graph:bool",
"use_sqrt_alibi:bool",
"merged_qkv:bool"})
.SetKernelFn(PD_KERNEL(PagedAttn))
+306 -284
View File
@@ -16,352 +16,374 @@
#include "iluvatar_context.h"
template <paddle::DataType T>
void PrefillFusedPagedAttnKernel(const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::optional<paddle::Tensor> &rope_sin,
const paddle::optional<paddle::Tensor> &rope_cos,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope,
paddle::Tensor& out) {
void PrefillFusedPagedAttnKernel(
const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::optional<paddle::Tensor>& rope_sin,
const paddle::optional<paddle::Tensor>& rope_cos,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope,
paddle::Tensor& out) {
// check dtype and contiguous
const auto& dtype = qkv.dtype();
cuinferDataType_t data_type;
if (dtype == paddle::DataType::FLOAT16) {
data_type = CUINFER_DATA_HALF;
// check dtype and contiguous
const auto& dtype = qkv.dtype();
cuinferDataType_t data_type;
if (dtype == paddle::DataType::FLOAT16) {
data_type = CUINFER_DATA_HALF;
} else if (dtype == paddle::DataType::BFLOAT16) {
data_type = CUINFER_DATA_BFLOAT16;
} else {
common::errors::InvalidArgument(
"paged_attention support half and bfloat16 now");
}
} else if (dtype == paddle::DataType::BFLOAT16) {
data_type = CUINFER_DATA_BFLOAT16;
} else {
common::errors::InvalidArgument("paged_attention support half and bfloat16 now");
}
PADDLE_ENFORCE_EQ(k_cache.dtype(),
dtype,
common::errors::InvalidArgument(
"k_cache dtype must be the same as query dtype"));
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects k_cache is contiguous"));
PADDLE_ENFORCE_EQ(
block_table.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument("block_table dtype must be int32"));
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects block_table is contiguous"));
PADDLE_ENFORCE_EQ(
cu_seqlens_qkv.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument("cu_seqlens_qkv dtype must be int32"));
PADDLE_ENFORCE_EQ(
cu_seqlens_qkv.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects cu_seqlens_qkv is contiguous"));
// check dim and shape
// k_cache: [num_blocks, kv_num_heads, block_size, head_dim]
// v_cache: [num_blocks, kv_num_heads, block_size, head_dim]
// block_table: [batch_size, max_num_blocks_per_seq]
// seq_lens: [batch_size]
// qkv: [num_tokens, (num_heads+2*num_kv_heads)*head_dim]
// out: [num_tokens, hidden_size]
PADDLE_ENFORCE_EQ(k_cache.dtype(),
dtype,
common::errors::InvalidArgument(
"k_cache dtype must be the same as query dtype"));
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects k_cache is contiguous"));
PADDLE_ENFORCE_EQ(block_table.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument(
"block_table dtype must be int32"));
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects block_table is contiguous"));
PADDLE_ENFORCE_EQ(cu_seqlens_qkv.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument(
"cu_seqlens_qkv dtype must be int32"));
PADDLE_ENFORCE_EQ(cu_seqlens_qkv.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects cu_seqlens_qkv is contiguous"));
// check dim and shape
// k_cache: [num_blocks, kv_num_heads, block_size, head_dim]
// v_cache: [num_blocks, kv_num_heads, block_size, head_dim]
// block_table: [batch_size, max_num_blocks_per_seq]
// seq_lens: [batch_size]
// qkv: [num_tokens, (num_heads+2*num_kv_heads)*head_dim]
// out: [num_tokens, hidden_size]
const auto& qkv_dims = qkv.dims();
PADDLE_ENFORCE_EQ(qkv_dims.size(),
2,
common::errors::InvalidArgument(
"paged_attn receive query dims is "
"[num_tokens, (num_heads+2*num_kv_heads)*head_dim]"));
PADDLE_ENFORCE_EQ(
out.dims().size(),
2,
common::errors::InvalidArgument("paged_attn receive out dims is "
"[num_tokens, hidden_size]"));
const auto& qkv_dims = qkv.dims();
PADDLE_ENFORCE_EQ(qkv_dims.size(),
2,
common::errors::InvalidArgument(
"paged_attn receive query dims is "
"[num_tokens, (num_heads+2*num_kv_heads)*head_dim]"));
PADDLE_ENFORCE_EQ(out.dims().size(),
2,
common::errors::InvalidArgument(
"paged_attn receive out dims is "
"[num_tokens, hidden_size]"));
const auto& kv_cache_dims = k_cache.dims();
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
4,
common::errors::InvalidArgument(
"paged_attn receive kv cache dims is "
"[num_blocks, kv_num_heads, block_size, head_dim]"));
const auto& kv_cache_dims = k_cache.dims();
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
4,
common::errors::InvalidArgument(
"paged_attn receive kv cache dims is "
"[num_blocks, kv_num_heads, block_size, head_dim]"));
const auto& block_table_dims = block_table.dims();
PADDLE_ENFORCE_EQ(
block_table_dims.size(),
2,
common::errors::InvalidArgument("paged_attn receive block_table dims is "
"[batch_size, max_num_blocks_per_seq]"));
const auto& block_table_dims = block_table.dims();
PADDLE_ENFORCE_EQ(block_table_dims.size(),
2,
common::errors::InvalidArgument(
"paged_attn receive block_table dims is "
"[batch_size, max_num_blocks_per_seq]"));
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
PADDLE_ENFORCE_EQ(
cu_seqlens_qkv_dims.size(),
1,
common::errors::InvalidArgument(
"paged_attn receive cu_seqlens_qkv dims is [batch_size]"));
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
PADDLE_ENFORCE_EQ(cu_seqlens_qkv_dims.size(),
1,
common::errors::InvalidArgument(
"paged_attn receive cu_seqlens_qkv dims is [batch_size]"));
int batch_size = block_table_dims[0];
int num_tokens = qkv_dims[0];
int num_total_heads = num_heads + 2 * num_kv_heads;
int qkv_stride = qkv.strides()[0];
int num_blocks = kv_cache_dims[0];
int batch_size = block_table_dims[0];
int num_tokens = qkv_dims[0];
int num_total_heads = num_heads + 2 * num_kv_heads;
int qkv_stride = qkv.strides()[0];
int num_blocks = kv_cache_dims[0];
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
num_kv_heads,
common::errors::InvalidArgument(
"kv_cache_dims[1] must be equal to num_kv_head"));
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
block_size,
common::errors::InvalidArgument(
"kv_cache_dims[2] must be equal to block_size"));
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
head_dim,
common::errors::InvalidArgument(
"kv_cache_dims[3] must be equal to head_dim"));
PADDLE_ENFORCE_EQ(
cu_seqlens_qkv_dims[0],
batch_size + 1,
common::errors::InvalidArgument(
"cu_seqlens_qkv_dims[0] must be equal to batch_size + 1"));
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
num_kv_heads,
common::errors::InvalidArgument(
"kv_cache_dims[1] must be equal to num_kv_head"));
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
block_size,
common::errors::InvalidArgument(
"kv_cache_dims[2] must be equal to block_size"));
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
head_dim,
common::errors::InvalidArgument(
"kv_cache_dims[3] must be equal to head_dim"));
PADDLE_ENFORCE_EQ(cu_seqlens_qkv_dims[0],
batch_size + 1,
common::errors::InvalidArgument(
"cu_seqlens_qkv_dims[0] must be equal to batch_size + 1"));
int block_table_stride = block_table.strides()[0];
const float* rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
const float* rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
int block_table_stride = block_table.strides()[0];
const float *rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
const float *rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
cuinferHandle_t cuinfer_handle =
iluvatar::getContextInstance()->getIxInferHandle();
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
size_t workspace_size = 0;
CUINFER_CHECK(cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(num_tokens,
num_heads,
num_kv_heads,
head_dim,
q_rope,
k_rope,
v_rope,
data_type,
data_type,
data_type,
&workspace_size));
auto* allocator = paddle::GetAllocator(qkv.place());
phi::Allocator::AllocationPtr tmp_workspace =
allocator->Allocate(workspace_size);
void* workspace_ptr = tmp_workspace->ptr();
size_t workspace_size = 0;
CUINFER_CHECK(cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(num_tokens,
num_heads,
num_kv_heads,
head_dim,
q_rope,
k_rope,
v_rope,
data_type,
data_type,
data_type,
&workspace_size));
auto* allocator = paddle::GetAllocator(qkv.place());
phi::Allocator::AllocationPtr tmp_workspace = allocator->Allocate(workspace_size);
void* workspace_ptr = tmp_workspace->ptr();
cuinferTensorDescriptor_t qkv_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t qkv_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
qkv_desc,
data_type,
3,
std::vector<int>({num_tokens, num_total_heads, head_dim}).data(),
std::vector<int>({num_total_heads * head_dim, head_dim, 1}).data()));
cuinferTensorDescriptor_t qkv_seqlens_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
qkv_seqlens_desc,
CUINFER_DATA_INT32,
1,
std::vector<int>({batch_size + 1}).data(),
std::vector<int>({1}).data()));
cuinferTensorDescriptor_t qkv_seqlens_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
CUINFER_CHECK(
cuinferSetTensorNdDescriptor(qkv_seqlens_desc,
CUINFER_DATA_INT32,
1,
std::vector<int>({batch_size + 1}).data(),
std::vector<int>({1}).data()));
cuinferTensorDescriptor_t block_table_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t block_table_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
block_table_desc,
CUINFER_DATA_INT32,
2,
std::vector<int>({batch_size, block_table_stride}).data(),
std::vector<int>({block_table_stride, 1}).data()));
cuinferTensorDescriptor_t o_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t o_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
o_desc,
data_type,
3,
std::vector<int>({num_tokens, num_heads, head_dim}).data(),
std::vector<int>({num_heads * head_dim, head_dim, 1}).data()));
cuinferTensorDescriptor_t k_cache_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t k_cache_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
k_cache_desc,
data_type,
4,
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
std::vector<int>({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data()));
std::vector<int>({num_kv_heads * block_size * head_dim,
block_size * head_dim,
head_dim,
1})
.data()));
cuinferTensorDescriptor_t v_cache_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t v_cache_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
v_cache_desc,
data_type,
4,
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
std::vector<int>({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data()));
std::vector<int>({num_kv_heads * block_size * head_dim,
block_size * head_dim,
head_dim,
1})
.data()));
cuinferTensorDescriptor_t cos_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t cos_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cos_desc,
CUINFER_DATA_FLOAT,
2,
std::vector<int>({max_seq_len, head_dim}).data(),
std::vector<int>({head_dim, 1}).data()));
cuinferTensorDescriptor_t sin_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cuinferTensorDescriptor_t sin_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
sin_desc,
CUINFER_DATA_FLOAT,
2,
std::vector<int>({max_seq_len, head_dim}).data(),
std::vector<int>({head_dim, 1}).data()));
CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
qkv_desc,
qkv.data(),
qkv_seqlens_desc,
cu_seqlens_qkv.data<int32_t>(),
block_table_desc,
block_table.data<int32_t>(),
o_desc,
out.data(),
k_cache_desc,
k_cache.data(),
v_cache_desc,
v_cache.data(),
workspace_ptr,
workspace_size,
cos_desc,
rope_cos_ptr,
sin_desc,
rope_sin_ptr,
batch_size,
num_heads,
num_kv_heads,
head_dim,
causal,
scale,
q_rope,
k_rope,
v_rope));
CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
qkv_desc,
qkv.data(),
qkv_seqlens_desc,
cu_seqlens_qkv.data<int32_t>(),
block_table_desc,
block_table.data<int32_t>(),
o_desc,
out.data(),
k_cache_desc,
k_cache.data(),
v_cache_desc,
v_cache.data(),
workspace_ptr,
workspace_size,
cos_desc,
rope_cos_ptr,
sin_desc,
rope_sin_ptr,
batch_size,
num_heads,
num_kv_heads,
head_dim,
causal,
scale,
q_rope,
k_rope,
v_rope));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
}
std::vector<paddle::Tensor> PrefillFusedPagedAttn(const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::optional<paddle::Tensor> &rope_sin,
const paddle::optional<paddle::Tensor> &rope_cos,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope) {
std::vector<paddle::Tensor> PrefillFusedPagedAttn(
const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::optional<paddle::Tensor>& rope_sin,
const paddle::optional<paddle::Tensor>& rope_cos,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope) {
const auto dtype = qkv.dtype();
auto out =
paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
const auto dtype = qkv.dtype();
auto out = paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
switch (dtype) {
case paddle::DataType::BFLOAT16:
PrefillFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
k_cache,
v_cache,
block_table,
cu_seqlens_qkv,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
out);
break;
case paddle::DataType::FLOAT16:
PrefillFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
k_cache,
v_cache,
block_table,
cu_seqlens_qkv,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
out);
break;
default:
PD_THROW("Unsupported data type for Paged attn");
}
return {out};
switch (dtype) {
case paddle::DataType::BFLOAT16:
PrefillFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
k_cache,
v_cache,
block_table,
cu_seqlens_qkv,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
out);
break;
case paddle::DataType::FLOAT16:
PrefillFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
k_cache,
v_cache,
block_table,
cu_seqlens_qkv,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
out);
break;
default:
PD_THROW("Unsupported data type for Paged attn");
}
return {out};
}
std::vector<std::vector<int64_t>> PrefillFusedPagedAttnInferShape(const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& k_cache_shape,
const std::vector<int64_t>& v_cache_shape,
const std::vector<int64_t>& block_table_shape,
const std::vector<int64_t>& cu_seqlens_qkv_shape,
const std::vector<int64_t>& rope_sin_shape,
const std::vector<int64_t>& rope_cos_shape,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope) {
return {{qkv_shape[0], num_heads * head_dim}};
std::vector<std::vector<int64_t>> PrefillFusedPagedAttnInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& k_cache_shape,
const std::vector<int64_t>& v_cache_shape,
const std::vector<int64_t>& block_table_shape,
const std::vector<int64_t>& cu_seqlens_qkv_shape,
const std::vector<int64_t>& rope_sin_shape,
const std::vector<int64_t>& rope_cos_shape,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope) {
return {{qkv_shape[0], num_heads * head_dim}};
}
std::vector<paddle::DataType> PrefillFusedPagedAttnInferDtype(const paddle::DataType& qkv_dtype) {
return {qkv_dtype};
std::vector<paddle::DataType> PrefillFusedPagedAttnInferDtype(
const paddle::DataType& qkv_dtype) {
return {qkv_dtype};
}
PD_BUILD_STATIC_OP(prefill_fused_paged_attn)
.Inputs({"qkv", "k_cache", "v_cache", "block_table", "cu_seqlens_qkv",
paddle::Optional("rope_sin"), paddle::Optional("rope_cos")})
.Inputs({"qkv",
"k_cache",
"v_cache",
"block_table",
"cu_seqlens_qkv",
paddle::Optional("rope_sin"),
paddle::Optional("rope_cos")})
.Outputs({"out"})
.Attrs({"num_heads:int",
"head_dim:int",
@@ -369,8 +391,8 @@ PD_BUILD_STATIC_OP(prefill_fused_paged_attn)
"block_size:int",
"max_seq_len:int",
"scale:float",
"causal:bool",
"q_rope:bool",
"causal:bool",
"q_rope:bool",
"k_rope:bool",
"v_rope:bool"})
.SetKernelFn(PD_KERNEL(PrefillFusedPagedAttn))
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iluvatar_context.h"
#include <memory>
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -33,27 +32,26 @@
#include <vector>
#define CUINFER_CHECK(func) \
do { \
cuinferStatus_t status = (func); \
if (status != CUINFER_STATUS_SUCCESS) { \
std::cerr << "Error in file " << __FILE__ << " on line " \
<< __LINE__ << ": " << cuinferGetErrorString(status) \
<< std::endl; \
throw std::runtime_error("CUINFER_CHECK ERROR"); \
} \
} while (0)
do { \
cuinferStatus_t status = (func); \
if (status != CUINFER_STATUS_SUCCESS) { \
std::cerr << "Error in file " << __FILE__ << " on line " << __LINE__ \
<< ": " << cuinferGetErrorString(status) << std::endl; \
throw std::runtime_error("CUINFER_CHECK ERROR"); \
} \
} while (0)
namespace iluvatar {
class IluvatarContext {
public:
IluvatarContext() = default;
~IluvatarContext();
public:
IluvatarContext() = default;
~IluvatarContext();
cuinferHandle_t getIxInferHandle();
cuinferHandle_t getIxInferHandle();
private:
cuinferHandle_t ixinfer_handle_{nullptr};
private:
cuinferHandle_t ixinfer_handle_{nullptr};
};
IluvatarContext* getContextInstance();
+149 -149
View File
@@ -20,157 +20,157 @@ std::vector<paddle::Tensor> GroupGemm(const paddle::Tensor& x,
const paddle::Tensor& weight_scale,
const paddle::Tensor& prefix_sum,
const int32_t group_size) {
auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
const auto& x_dims = x.dims();
const auto& w_dims = weight.dims();
const auto& ws_dims = weight_scale.dims();
const auto& prefix_sum_dims = prefix_sum.dims();
// [m, k]
PD_CHECK(x_dims.size() == 2, "x should be 2D");
// [n_experts, n, k]
PD_CHECK(w_dims.size() == 3, "weight should be 3D");
// [n_experts, n]
PD_CHECK(ws_dims.size() == 2, "weight_scale should be 2D");
// [n_experts]
PD_CHECK(prefix_sum_dims.size() == 1, "prefix_sum should be 1D");
PD_CHECK(group_size == -1);
auto m = x_dims[0];
auto k = x_dims[1];
auto n_experts = w_dims[0];
auto n = w_dims[1];
PD_CHECK(w_dims[2] == k);
PD_CHECK(ws_dims[0] == n_experts);
PD_CHECK(ws_dims[1] == n);
PD_CHECK(prefix_sum_dims[0] == n_experts);
auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
const auto& x_dims = x.dims();
const auto& w_dims = weight.dims();
const auto& ws_dims = weight_scale.dims();
const auto& prefix_sum_dims = prefix_sum.dims();
// [m, k]
PD_CHECK(x_dims.size() == 2, "x should be 2D");
// [n_experts, n, k]
PD_CHECK(w_dims.size() == 3, "weight should be 3D");
// [n_experts, n]
PD_CHECK(ws_dims.size() == 2, "weight_scale should be 2D");
// [n_experts]
PD_CHECK(prefix_sum_dims.size() == 1, "prefix_sum should be 1D");
PD_CHECK(group_size == -1);
auto m = x_dims[0];
auto k = x_dims[1];
auto n_experts = w_dims[0];
auto n = w_dims[1];
PD_CHECK(w_dims[2] == k);
PD_CHECK(ws_dims[0] == n_experts);
PD_CHECK(ws_dims[1] == n);
PD_CHECK(prefix_sum_dims[0] == n_experts);
PD_CHECK(prefix_sum.dtype() == paddle::DataType::INT64);
PD_CHECK(prefix_sum.is_cpu());
PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 ||
x.dtype() == paddle::DataType::FLOAT16);
PD_CHECK(weight.dtype() == paddle::DataType::INT8);
PD_CHECK(weight_scale.dtype() == x.dtype());
PD_CHECK(x.is_contiguous());
PD_CHECK(weight.is_contiguous());
PD_CHECK(weight_scale.is_contiguous());
PD_CHECK(prefix_sum.dtype() == paddle::DataType::INT64);
PD_CHECK(prefix_sum.is_cpu());
PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 ||
x.dtype() == paddle::DataType::FLOAT16);
PD_CHECK(weight.dtype() == paddle::DataType::INT8);
PD_CHECK(weight_scale.dtype() == x.dtype());
PD_CHECK(x.is_contiguous());
PD_CHECK(weight.is_contiguous());
PD_CHECK(weight_scale.is_contiguous());
const int64_t* prefix_sum_ptr = prefix_sum.data<int64_t>();
auto output = GetEmptyTensor({m, n}, x.dtype(), x.place());
int16_t* out_data = static_cast<int16_t*>(output.data());
const int16_t* x_data = static_cast<const int16_t*>(x.data());
const int8_t* weight_data = weight.data<int8_t>();
const int16_t* weight_scale_data =
static_cast<const int16_t*>(weight_scale.data());
const int64_t* prefix_sum_ptr = prefix_sum.data<int64_t>();
auto output = GetEmptyTensor({m, n}, x.dtype(), x.place());
int16_t* out_data = static_cast<int16_t*>(output.data());
const int16_t* x_data = static_cast<const int16_t*>(x.data());
const int8_t* weight_data = weight.data<int8_t>();
const int16_t* weight_scale_data =
static_cast<const int16_t*>(weight_scale.data());
cuinferHandle_t handle = iluvatar::getContextInstance()->getIxInferHandle();
cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST;
cuinferOperation_t transa = CUINFER_OP_T;
cuinferOperation_t transb = CUINFER_OP_N;
cudaDataType_t a_type = CUDA_R_8I;
cudaDataType_t b_type;
cudaDataType_t c_type;
if (x.dtype() == paddle::DataType::FLOAT16) {
b_type = CUDA_R_16F;
} else if (x.dtype() == paddle::DataType::BFLOAT16) {
b_type = CUDA_R_16BF;
} else {
PADDLE_THROW(common::errors::Unimplemented("Unsupported input dtype."));
cuinferHandle_t handle = iluvatar::getContextInstance()->getIxInferHandle();
cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST;
cuinferOperation_t transa = CUINFER_OP_T;
cuinferOperation_t transb = CUINFER_OP_N;
cudaDataType_t a_type = CUDA_R_8I;
cudaDataType_t b_type;
cudaDataType_t c_type;
if (x.dtype() == paddle::DataType::FLOAT16) {
b_type = CUDA_R_16F;
} else if (x.dtype() == paddle::DataType::BFLOAT16) {
b_type = CUDA_R_16BF;
} else {
PADDLE_THROW(common::errors::Unimplemented("Unsupported input dtype."));
}
c_type = b_type;
cudaDataType_t Atype = a_type;
cudaDataType_t Btype = b_type;
cudaDataType_t Ctype = c_type;
cudaDataType_t computeType = CUDA_R_32F;
cudaDataType_t scaleType = CUDA_R_32F;
cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE;
cuinferQuantGEMMHostParam cust_host_param;
cust_host_param.size = sizeof(cuinferQuantGEMMHostParam);
cust_host_param.persistent = 0;
cust_host_param.groupSize = group_size;
cuinferQuantGEMMDeviceParam cust_device_param;
cust_device_param.bias = nullptr;
cust_device_param.workspace = nullptr;
int lda = k;
int ldb = k;
int ldc = n;
float beta = 0.f;
float alpha = 1.f;
int batch_count = 1;
size_t pre = 0;
auto* allocator = paddle::GetAllocator(x.place());
phi::Allocator::AllocationPtr tmp_workspace;
for (int i = 0; i < n_experts; i++) {
size_t expert_i_end = prefix_sum_ptr[i];
size_t cur_len = expert_i_end - pre;
pre = expert_i_end;
if (cur_len != 0) {
cust_device_param.scale = weight_scale_data;
if (k % 64 != 0) {
size_t workspace_size;
CUINFER_CHECK(cuinferGetCustomGemmWorkspace(transa,
transb,
n,
cur_len,
k,
Atype,
lda,
lda,
Btype,
ldb,
ldb,
Ctype,
ldc,
ldc,
batch_count,
computeType,
scaleType,
&workspace_size));
tmp_workspace = allocator->Allocate(workspace_size);
cust_device_param.workspace = tmp_workspace->ptr();
} else {
cust_device_param.workspace = nullptr;
}
CUINFER_CHECK(cuinferCustomGemm(handle,
stream,
cuinfer_ptr_mode,
transa,
transb,
n,
cur_len,
k,
&alpha,
weight_data,
Atype,
lda,
lda,
x_data,
Btype,
ldb,
ldb,
&beta,
out_data,
Ctype,
ldc,
ldc,
batch_count,
computeType,
scaleType,
&cust_host_param,
&cust_device_param,
customOption));
}
c_type = b_type;
cudaDataType_t Atype = a_type;
cudaDataType_t Btype = b_type;
cudaDataType_t Ctype = c_type;
cudaDataType_t computeType = CUDA_R_32F;
cudaDataType_t scaleType = CUDA_R_32F;
cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE;
cuinferQuantGEMMHostParam cust_host_param;
cust_host_param.size = sizeof(cuinferQuantGEMMHostParam);
cust_host_param.persistent = 0;
cust_host_param.groupSize = group_size;
cuinferQuantGEMMDeviceParam cust_device_param;
cust_device_param.bias = nullptr;
cust_device_param.workspace = nullptr;
int lda = k;
int ldb = k;
int ldc = n;
float beta = 0.f;
float alpha = 1.f;
int batch_count = 1;
size_t pre = 0;
auto* allocator = paddle::GetAllocator(x.place());
phi::Allocator::AllocationPtr tmp_workspace;
for (int i = 0; i < n_experts; i++) {
size_t expert_i_end = prefix_sum_ptr[i];
size_t cur_len = expert_i_end - pre;
pre = expert_i_end;
if (cur_len != 0) {
cust_device_param.scale = weight_scale_data;
if (k % 64 != 0) {
size_t workspace_size;
CUINFER_CHECK(cuinferGetCustomGemmWorkspace(transa,
transb,
n,
cur_len,
k,
Atype,
lda,
lda,
Btype,
ldb,
ldb,
Ctype,
ldc,
ldc,
batch_count,
computeType,
scaleType,
&workspace_size));
tmp_workspace = allocator->Allocate(workspace_size);
cust_device_param.workspace = tmp_workspace->ptr();
} else {
cust_device_param.workspace = nullptr;
}
CUINFER_CHECK(cuinferCustomGemm(handle,
stream,
cuinfer_ptr_mode,
transa,
transb,
n,
cur_len,
k,
&alpha,
weight_data,
Atype,
lda,
lda,
x_data,
Btype,
ldb,
ldb,
&beta,
out_data,
Ctype,
ldc,
ldc,
batch_count,
computeType,
scaleType,
&cust_host_param,
&cust_device_param,
customOption));
}
x_data += cur_len * k;
weight_data += k * n;
weight_scale_data += n;
out_data += cur_len * n;
}
return {output};
x_data += cur_len * k;
weight_data += k * n;
weight_scale_data += n;
out_data += cur_len * n;
}
return {output};
}
std::vector<std::vector<int64_t>> GroupGemmInferShape(
@@ -178,7 +178,7 @@ std::vector<std::vector<int64_t>> GroupGemmInferShape(
const std::vector<int64_t>& weight_shape,
const std::vector<int64_t>& weight_scale_shape,
const std::vector<int64_t>& prefix_sum_shape) {
return {{x_shape[0], weight_shape[1]}};
return {{x_shape[0], weight_shape[1]}};
}
std::vector<paddle::DataType> GroupGemmInferDtype(
const paddle::DataType& input_dtype,
@@ -186,7 +186,7 @@ std::vector<paddle::DataType> GroupGemmInferDtype(
const paddle::DataType& weight_scale_dtype,
const paddle::DataType& prefix_sum_dtype,
const int moe_topk) {
return {input_dtype};
return {input_dtype};
}
PD_BUILD_STATIC_OP(w8a16_group_gemm)
+38 -35
View File
@@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fused_moe_op.h"
#include "helper.h"
#include "mc_fused_moe_helper.h"
#include "fused_moe_op.h"
__global__ void compute_total_rows_before_expert_kernel(
int* sorted_experts,
@@ -43,7 +42,10 @@ void compute_total_rows_before_expert(int* sorted_indices,
sorted_indices, total_indices, num_experts, total_rows_before_expert);
}
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
template <paddle::DataType T,
typename ElementA,
typename ElementB,
typename ElementC>
void FusedMoeKernel(const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& ffn1_weight,
@@ -63,27 +65,26 @@ void FusedMoeKernel(const paddle::Tensor& input,
auto* output_data = output->data<data_t>();
auto moe_compute = McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
auto moe_compute =
McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
moe_compute.computeFFN(
&input,
&gate_weight,
&ffn1_weight,
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
&ffn2_weight,
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
nullptr,
moe_topk,
group_moe,
norm_topk_prob,
1.0, // ComputeFFN
"ffn",
output);
moe_compute.computeFFN(&input,
&gate_weight,
&ffn1_weight,
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
&ffn2_weight,
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
nullptr,
moe_topk,
group_moe,
norm_topk_prob,
1.0, // ComputeFFN
"ffn",
output);
}
std::vector<paddle::Tensor> FusedExpertMoe(
const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
@@ -102,19 +103,22 @@ std::vector<paddle::Tensor> FusedExpertMoe(
switch (input_type) {
case paddle::DataType::BFLOAT16:
FusedMoeKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(input,
gate_weight,
ffn1_weight,
ffn1_scale,
ffn1_bias,
ffn2_weight,
ffn2_scale,
ffn2_bias,
quant_method,
moe_topk,
group_moe,
norm_topk_prob,
&output);
FusedMoeKernel<paddle::DataType::BFLOAT16,
maca_bfloat16,
int8_t,
maca_bfloat16>(input,
gate_weight,
ffn1_weight,
ffn1_scale,
ffn1_bias,
ffn2_weight,
ffn2_scale,
ffn2_bias,
quant_method,
moe_topk,
group_moe,
norm_topk_prob,
&output);
break;
// case paddle::DataType::FLOAT16:
// FusedMoeKernel<paddle::DataType::FLOAT16>(input,
@@ -161,7 +165,6 @@ std::vector<paddle::DataType> FusedExpertMoeInferDtype(
return {input_dtype};
}
PD_BUILD_OP(fused_expert_moe)
.Inputs({"input",
"gate_weight",
+1 -1
View File
@@ -16,8 +16,8 @@
*/
#pragma once
#include <string>
#include <sstream>
#include <string>
#include "cub/cub.cuh"
static const float HALF_FLT_MAX = 65504.F;
+7 -14
View File
@@ -19,9 +19,9 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include "fused_moe_imp_op.h"
#include "fused_moe_helper.h"
#include "mctlass/numeric_conversion.h" // BUILD_MARK
#include "fused_moe_imp_op.h"
#include "mctlass/numeric_conversion.h" // BUILD_MARK
// Ignore mctlass warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
@@ -35,8 +35,8 @@
#define WARP_SIZE 32
struct GpuLaunchConfig {
dim3 block_per_grid;
dim3 thread_per_block;
dim3 block_per_grid;
dim3 thread_per_block;
};
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
@@ -82,7 +82,6 @@ __launch_bounds__(TPB) __global__
cub::Sum sum;
float threadData(-FLT_MAX);
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
@@ -603,7 +602,7 @@ void topk_gating_softmax_kernelLauncher(const T* input,
}
static constexpr int WARPS_PER_TB = 4;
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
@@ -646,14 +645,8 @@ void topk_gating_softmax_kernelLauncher(const T* input,
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, softmax, num_experts, num_rows);
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
output,
indices,
source_row,
num_experts,
k,
num_rows);
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
softmax, output, indices, source_row, num_experts, k, num_rows);
}
}
}
+376 -320
View File
@@ -1,52 +1,71 @@
#include "fused_moe_helper.h"
#include "mctlass/numeric_conversion.h"
#include "mctlassEx/mctlassEx.h"
#include "fused_moe_helper.h"
template <typename ElementA, typename ElementB, typename ElementC>
void mc_grouped_gemm_basic_kernel(
const ElementA* ptrA,
mctlassExOrder_t majorA,
const ElementB* ptrB,
mctlassExOrder_t majorB,
const ElementA* ptrScale,
const ElementA* ptrBias,
ElementC* ptrC,
mctlassExOrder_t majorC,
const int *ptrSegInd,
int numExperts,
int m, // expanded_active_expert_rows
int n, // inter_dim
int k, // hidden_size
mcStream_t stream) {
void mc_grouped_gemm_basic_kernel(const ElementA *ptrA,
mctlassExOrder_t majorA,
const ElementB *ptrB,
mctlassExOrder_t majorB,
const ElementA *ptrScale,
const ElementA *ptrBias,
ElementC *ptrC,
mctlassExOrder_t majorC,
const int *ptrSegInd,
int numExperts,
int m, // expanded_active_expert_rows
int n, // inter_dim
int k, // hidden_size
mcStream_t stream) {
mctlassExHandle_t handle;
mctlassExHandleCreate(&handle);
int* ptrMNumTilesInd;
mcMallocAsync((void**)&ptrMNumTilesInd, sizeof(int) * numExperts, stream);
int *ptrMNumTilesInd;
mcMallocAsync((void **)&ptrMNumTilesInd, sizeof(int) * numExperts, stream);
mctlassExMatrixLayout_t matLayoutA;
mctlassExMatrixLayout_t matLayoutB;
mctlassExMatrixLayout_t matLayoutC;
// mat A: (m, k)
mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k);
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorA, sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts, sizeof(int));
mctlassExMatrixLayoutCreate(
&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k);
mctlassExMatrixLayoutSetAttribute(
matLayoutA,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorA,
sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(
matLayoutA,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts,
sizeof(int));
// mat B: (num_experts, n, k)
mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k);
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorB, sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts, sizeof(int));
mctlassExMatrixLayoutCreate(
&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k);
mctlassExMatrixLayoutSetAttribute(
matLayoutB,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorB,
sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(
matLayoutB,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts,
sizeof(int));
// mat C: (m, n)
mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n);
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorC, sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts, sizeof(int));
mctlassExMatrixLayoutCreate(
&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n);
mctlassExMatrixLayoutSetAttribute(
matLayoutC,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorC,
sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(
matLayoutC,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts,
sizeof(int));
// bias: (num_experts, n)
// scale: (num, n)
@@ -55,44 +74,81 @@ void mc_grouped_gemm_basic_kernel(
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_DATATYPE_BF16;
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_DATATYPE_INT8;
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_DATATYPE_FP32;
mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
mctlassExEpilogueType epilogue_type =
mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
if (ptrBias) {
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS;
}
// set scale
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
&ptrScale, sizeof(ptrScale));
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
&input_type, sizeof(mctlassExDataType));
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
&ptrScale,
sizeof(ptrScale));
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
&input_type,
sizeof(mctlassExDataType));
// set bias
if (ptrBias) {
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
&ptrBias, sizeof(ptrBias));
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
&ptrBias,
sizeof(ptrBias));
}
// set coumpute type
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
&compute_type, sizeof(mctlassExDataType));
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
&compute_type,
sizeof(mctlassExDataType));
// set epilogue type
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
&epilogue_type, sizeof(mctlassExEpilogueType));
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
&epilogue_type,
sizeof(mctlassExEpilogueType));
const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
const mctlassExContiguousGroupedGemmAlgo_t algo =
mctlassExContiguousGroupedGemmAlgo_t::
MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
mctlassExContiguousGroupedDesc_t contiguous_group_desc;
mctlassExContiguousGroupedDescCreate(&contiguous_group_desc,
ptrSegInd,
nullptr,
ptrMNumTilesInd,
1);
mctlassExContiguousGroupedDescCreate(
&contiguous_group_desc, ptrSegInd, nullptr, ptrMNumTilesInd, 1);
int blocksizeM;
mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, &blocksizeM);
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, contiguous_group_desc, numExperts, blocksizeM, stream);
mctlassExContiguousGroupedGemmGetBlocksizeM(handle,
mctlass_desc,
matLayoutA,
matLayoutB,
matLayoutC,
&algo,
&blocksizeM);
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle,
mctlass_desc,
matLayoutA,
matLayoutB,
matLayoutC,
&algo,
contiguous_group_desc,
numExperts,
blocksizeM,
stream);
mctlassExContiguousGroupedGemmBasic(handle, mctlass_desc,
ptrA, matLayoutA,
ptrB, matLayoutB,
ptrC, matLayoutC,
mctlassExContiguousGroupedGemmBasic(handle,
mctlass_desc,
ptrA,
matLayoutA,
ptrB,
matLayoutB,
ptrC,
matLayoutC,
contiguous_group_desc,
&algo, nullptr, 0, stream);
&algo,
nullptr,
0,
stream);
mctlassExHandleDestroy(handle);
mctlassExMatrixLayoutDestroy(matLayoutA);
@@ -103,312 +159,312 @@ void mc_grouped_gemm_basic_kernel(
mcFreeAsync(ptrMNumTilesInd, stream);
}
template<typename T, typename ElementA, typename ElementB, typename ElementC>
template <typename T, typename ElementA, typename ElementB, typename ElementC>
class McMoeHelper {
public:
McMoeHelper(const std::string gemm_method): gemm_method_(gemm_method) {}
public:
McMoeHelper(const std::string gemm_method) : gemm_method_(gemm_method) {}
// -------- getWorkspaceSize -------- //
template <typename KeyT>
size_t getWorkspaceSize(const int64_t num_rows,
const int64_t hidden_size,
const int64_t inter_size,
const int64_t num_experts,
const int64_t k) {
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const size_t padded_experts = AlignTo16(num_experts);
const size_t num_moe_inputs = AlignTo16(k * num_rows);
// softmax output, permuted_rows and permuted_experts have moved to outside
// of moe kernel, allocate them in Encoder or Decoder before invoking
// FfnLayer forward.
size_t total_ws_bytes =
5 * num_moe_inputs *
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
total_ws_bytes +=
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
// -------- getWorkspaceSize -------- //
template <typename KeyT>
size_t getWorkspaceSize(const int64_t num_rows,
const int64_t hidden_size,
const int64_t inter_size,
const int64_t num_experts,
const int64_t k) {
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const size_t padded_experts = AlignTo16(num_experts);
const size_t num_moe_inputs = AlignTo16(k * num_rows);
// softmax output, permuted_rows and permuted_experts have moved to outside
// of moe kernel, allocate them in Encoder or Decoder before invoking
// FfnLayer forward.
size_t total_ws_bytes =
5 * num_moe_inputs *
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
total_ws_bytes +=
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
const size_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(num_rows));
sorter_.update_num_experts(num_experts);
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
const size_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(num_rows));
sorter_.update_num_experts(num_experts);
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
int64_t remaining_bytes =
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
bytes_for_intermediate_and_sorting += remaining_bytes;
}
total_ws_bytes +=
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
// sorting workspace
int64_t num_softmax_outs = 0;
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
num_softmax_outs = AlignTo16(num_rows * num_experts);
}
total_ws_bytes += num_softmax_outs * sizeof(float);
return total_ws_bytes;
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
int64_t remaining_bytes =
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
bytes_for_intermediate_and_sorting += remaining_bytes;
}
void computeFFN(const paddle::Tensor *input,
const paddle::Tensor *gate_weight,
const paddle::Tensor *ffn1_weight,
const paddle::Tensor *ffn1_scale,
const paddle::Tensor *ffn1_bias,
const paddle::Tensor *ffn2_weight,
const paddle::Tensor *ffn2_scale,
const paddle::Tensor *ffn2_bias,
const paddle::Tensor *moe_token_type_ids,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
const float routed_scaling_factor,
const std::string moe_type,
paddle::Tensor *output) {
auto *input_activations = input->data<T>();
auto *gating_weights = gate_weight->data<float>();
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
total_ws_bytes +=
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
// sorting workspace
auto *output_ = output->data<T>();
auto stream = input->stream();
auto place = input->place();
auto input_type = input->dtype();
int64_t num_softmax_outs = 0;
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
num_softmax_outs = AlignTo16(num_rows * num_experts);
}
auto input_dims = input->dims();
auto ffn1_dims = ffn1_weight->dims();
int64_t token_num = 0;
if (input_dims.size() == 3) {
token_num = input_dims[0] * input_dims[1];
} else {
token_num = input_dims[0];
}
const int64_t num_rows = token_num;
total_ws_bytes += num_softmax_outs * sizeof(float);
const int64_t hidden_size = ffn1_dims[2];
int64_t inter_dim = 0;
if (moe_type == "qkv") {
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
} else {
inter_dim = ffn1_dims[1];
}
return total_ws_bytes;
}
// if (gemm_method == "weight_only_int4") {
// inter_dim = inter_dim * 2;
// }
void computeFFN(const paddle::Tensor *input,
const paddle::Tensor *gate_weight,
const paddle::Tensor *ffn1_weight,
const paddle::Tensor *ffn1_scale,
const paddle::Tensor *ffn1_bias,
const paddle::Tensor *ffn2_weight,
const paddle::Tensor *ffn2_scale,
const paddle::Tensor *ffn2_bias,
const paddle::Tensor *moe_token_type_ids,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
const float routed_scaling_factor,
const std::string moe_type,
paddle::Tensor *output) {
auto *input_activations = input->data<T>();
auto *gating_weights = gate_weight->data<float>();
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
const int64_t inter_size = inter_dim;
const int64_t num_experts = ffn1_dims[0];
const int64_t k = moe_topk;
auto *output_ = output->data<T>();
auto stream = input->stream();
auto place = input->place();
auto input_type = input->dtype();
auto input_dims = input->dims();
auto ffn1_dims = ffn1_weight->dims();
int64_t token_num = 0;
if (input_dims.size() == 3) {
token_num = input_dims[0] * input_dims[1];
} else {
token_num = input_dims[0];
}
const int64_t num_rows = token_num;
int64_t bytes =
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
const int64_t hidden_size = ffn1_dims[2];
int64_t inter_dim = 0;
if (moe_type == "qkv") {
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
} else {
inter_dim = ffn1_dims[1];
}
// Pointers
int *expert_for_source_row;
int *source_rows_;
int *permuted_rows_;
int *permuted_experts_;
int *expanded_source_row_to_expanded_dest_row;
// if (gemm_method == "weight_only_int4") {
// inter_dim = inter_dim * 2;
// }
T *permuted_data_;
int32_t *total_rows_before_expert_;
T *fc1_result_;
float *softmax_out_;
const int64_t inter_size = inter_dim;
const int64_t num_experts = ffn1_dims[0];
const int64_t k = moe_topk;
paddle::Tensor ws_ptr_tensor =
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
int64_t bytes =
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const int64_t padded_experts = AlignTo16(num_experts);
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
// Pointers
int *expert_for_source_row;
int *source_rows_;
int *permuted_rows_;
int *permuted_experts_;
int *expanded_source_row_to_expanded_dest_row;
expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
source_rows_ = expert_for_source_row + num_moe_inputs;
permuted_rows_ = source_rows_ + num_moe_inputs;
permuted_experts_ = permuted_rows_ + num_moe_inputs;
expanded_source_row_to_expanded_dest_row =
permuted_experts_ + num_moe_inputs;
permuted_data_ = reinterpret_cast<T *>(
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
total_rows_before_expert_ =
reinterpret_cast<int32_t *>(permuted_data_ + buf_size);
fc1_result_ =
reinterpret_cast<T *>(total_rows_before_expert_ + padded_experts);
T *permuted_data_;
int32_t *total_rows_before_expert_;
T *fc1_result_;
float *softmax_out_;
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
softmax_out_ = reinterpret_cast<float *>(fc1_result_ + interbuf_size);
} else {
softmax_out_ = nullptr;
}
paddle::Tensor ws_ptr_tensor =
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
paddle::Tensor expert_scales_float_tensor =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
float *expert_scales_float = expert_scales_float_tensor.data<float>();
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const int64_t padded_experts = AlignTo16(num_experts);
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
float *softmax_max_prob = nullptr;
if (group_moe) {
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
// (TODO: check fill success ?)
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
softmax_max_prob = softmax_max_prob_tensor.data<float>();
}
expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
source_rows_ = expert_for_source_row + num_moe_inputs;
permuted_rows_ = source_rows_ + num_moe_inputs;
permuted_experts_ = permuted_rows_ + num_moe_inputs;
expanded_source_row_to_expanded_dest_row =
permuted_experts_ + num_moe_inputs;
permuted_data_ = reinterpret_cast<T *>(
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
total_rows_before_expert_ =
reinterpret_cast<int32_t *>(permuted_data_ + buf_size);
fc1_result_ =
reinterpret_cast<T *>(total_rows_before_expert_ + padded_experts);
paddle::Tensor fc1_out_tensor =
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
T *fc1_out = fc1_out_tensor.data<T>();
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
softmax_out_ = reinterpret_cast<float *>(fc1_result_ + interbuf_size);
} else {
softmax_out_ = nullptr;
}
auto input_cast_tensor =
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
auto gate_tensor =
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
float *gating_output = gate_tensor.data<float>();
paddle::Tensor expert_scales_float_tensor =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
float *expert_scales_float = expert_scales_float_tensor.data<float>();
if (moe_token_type_ids) {
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
moe_token_type_ids_kernelLauncher<float>(gating_output,
moe_token_type_ids_out,
num_rows,
num_experts,
k,
stream);
}
float *softmax_max_prob = nullptr;
if (group_moe) {
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
// (TODO: check fill success ?)
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
softmax_max_prob = softmax_max_prob_tensor.data<float>();
}
topk_gating_softmax_kernelLauncher<float>(gating_output,
expert_scales_float,
softmax_out_,
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
num_experts,
k,
group_moe,
stream);
paddle::Tensor fc1_out_tensor =
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
T *fc1_out = fc1_out_tensor.data<T>();
const int64_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
auto input_cast_tensor =
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
auto gate_tensor =
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
float *gating_output = gate_tensor.data<float>();
sorter_.run(fc1_result_,
sorter_ws_size_bytes,
expert_for_source_row,
permuted_experts_,
source_rows_,
permuted_rows_,
k * num_rows,
false,
stream);
if (moe_token_type_ids) {
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
moe_token_type_ids_kernelLauncher<float>(gating_output,
moe_token_type_ids_out,
num_rows,
num_experts,
k,
stream);
}
initialize_moe_routing_kernelLauncher(
input_activations,
permuted_data_,
permuted_rows_,
expanded_source_row_to_expanded_dest_row,
num_rows,
num_rows,
hidden_size,
k,
stream);
topk_gating_softmax_kernelLauncher<float>(gating_output,
expert_scales_float,
softmax_out_,
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
num_experts,
k,
group_moe,
stream);
const int64_t expanded_active_expert_rows = k * num_rows;
const int64_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
compute_total_rows_before_expert(permuted_experts_,
expanded_active_expert_rows,
num_experts,
total_rows_before_expert_,
stream);
sorter_.run(fc1_result_,
sorter_ws_size_bytes,
expert_for_source_row,
permuted_experts_,
source_rows_,
permuted_rows_,
k * num_rows,
false,
stream);
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
initialize_moe_routing_kernelLauncher(
input_activations,
permuted_data_,
permuted_rows_,
expanded_source_row_to_expanded_dest_row,
num_rows,
num_rows,
hidden_size,
k,
stream);
const int64_t expanded_active_expert_rows = k * num_rows;
compute_total_rows_before_expert(permuted_experts_,
expanded_active_expert_rows,
num_experts,
total_rows_before_expert_,
stream);
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
mctlassExOrder_t column_major =
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(permuted_data_),
row_major,
reinterpret_cast<const ElementB *>(ffn1_weight->data<ElementB>()),
column_major,
reinterpret_cast<const ElementA *>(ffn1_scale->data<T>()),
reinterpret_cast<const ElementA *>(fc1_expert_biases),
reinterpret_cast<ElementC *>(fc1_out),
row_major,
total_rows_before_expert_,
num_experts,
expanded_active_expert_rows,
inter_size,
hidden_size,
stream);
if (moe_type == "ffn") {
auto act_out_tensor =
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
auto act_out = act_out_tensor.data<T>();
paddle::Tensor fc2_output_tensor =
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
T *fc2_result = fc2_output_tensor.data<T>();
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(permuted_data_),
reinterpret_cast<const ElementA *>(act_out),
row_major,
reinterpret_cast<const ElementB *>(ffn1_weight->data<ElementB>()),
reinterpret_cast<const ElementB *>(ffn2_weight->data<ElementB>()),
column_major,
reinterpret_cast<const ElementA *>(ffn1_scale->data<T>()),
reinterpret_cast<const ElementA *>(fc1_expert_biases),
reinterpret_cast<ElementC *>(fc1_out),
reinterpret_cast<const ElementA *>(ffn2_scale->data<T>()),
nullptr,
reinterpret_cast<ElementC *>(fc2_result),
row_major,
total_rows_before_expert_,
num_experts,
expanded_active_expert_rows,
inter_size,
hidden_size,
inter_size / 2,
stream);
if (moe_type == "ffn") {
auto act_out_tensor =
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
auto act_out = act_out_tensor.data<T>();
paddle::Tensor fc2_output_tensor =
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
T *fc2_result = fc2_output_tensor.data<T>();
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(act_out),
row_major,
reinterpret_cast<const ElementB *>(ffn2_weight->data<ElementB>()),
column_major,
reinterpret_cast<const ElementA *>(ffn2_scale->data<T>()),
nullptr,
reinterpret_cast<ElementC *>(fc2_result),
row_major,
total_rows_before_expert_,
num_experts,
expanded_active_expert_rows,
hidden_size,
inter_size / 2,
stream);
finalize_moe_routing_kernelLauncher(
fc2_result,
output_,
fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
hidden_size,
k,
static_cast<int>(1),
norm_topk_prob,
routed_scaling_factor,
stream);
} else {
finalize_moe_routing_kernelLauncher(
// fc2_result,
fc1_out,
output_,
fc1_expert_biases, // fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
inter_size,
k,
static_cast<int>(0),
norm_topk_prob,
routed_scaling_factor,
stream);
}
finalize_moe_routing_kernelLauncher(
fc2_result,
output_,
fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
hidden_size,
k,
static_cast<int>(1),
norm_topk_prob,
routed_scaling_factor,
stream);
} else {
finalize_moe_routing_kernelLauncher(
// fc2_result,
fc1_out,
output_,
fc1_expert_biases, // fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
inter_size,
k,
static_cast<int>(0),
norm_topk_prob,
routed_scaling_factor,
stream);
}
}
private:
private:
std::string gemm_method_;
CubKeyValueSorter sorter_;
};
+5 -14
View File
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wunused-function"
@@ -24,7 +23,6 @@
#include "helper.h"
template <paddle::DataType T>
void MoeDispatchKernel(const paddle::Tensor& input,
const paddle::Tensor& gating_output,
@@ -128,7 +126,6 @@ void MoeDispatchKernel(const paddle::Tensor& input,
false,
stream);
initialize_moe_routing_kernelLauncher(
input.data<data_t>(),
permute_input->data<data_t>(),
@@ -140,16 +137,13 @@ void MoeDispatchKernel(const paddle::Tensor& input,
moe_topk,
stream);
compute_total_rows_before_expert(
permuted_experts_,
moe_topk * num_rows,
expert_num,
tokens_expert_prefix_sum->data<int32_t>(),
stream);
compute_total_rows_before_expert(permuted_experts_,
moe_topk * num_rows,
expert_num,
tokens_expert_prefix_sum->data<int32_t>(),
stream);
}
std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor& input,
const paddle::Tensor& gating_output,
@@ -184,7 +178,6 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
auto permute_indices_per_token =
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
@@ -226,7 +219,6 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
top_k_indices};
}
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& gating_output_shape,
@@ -260,7 +252,6 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
paddle::DataType::INT32};
}
PD_BUILD_OP(moe_expert_dispatch)
.Inputs({"input", "gating_output"})
.Outputs({"permute_input",
+68 -58
View File
@@ -14,19 +14,22 @@
// BUILD_MARK
#pragma once
#include "mc_fused_moe_helper.h"
#include "helper.h"
#include "mc_fused_moe_helper.h"
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
template <paddle::DataType T,
typename ElementA,
typename ElementB,
typename ElementC>
void McMoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method,
paddle::Tensor ffn_out) {
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method,
paddle::Tensor ffn_out) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -37,61 +40,65 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
auto input_type = permute_input.dtype();
auto stream = permute_input.stream();
const int expanded_active_expert_rows = permute_input.dims()[0]; // permute_input.dims(): m, k
const int num_experts = ffn1_weight.dims()[0]; // batchsize
const int hidden_size = ffn1_weight.dims()[2]; // n
int inter_dim = ffn1_weight.dims()[1]; // k
const int expanded_active_expert_rows =
permute_input.dims()[0]; // permute_input.dims(): m, k
const int num_experts = ffn1_weight.dims()[0]; // batchsize
const int hidden_size = ffn1_weight.dims()[2]; // n
int inter_dim = ffn1_weight.dims()[1]; // k
const int64_t inter_size = inter_dim; // since weight_only_int_8
const int64_t inter_size = inter_dim; // since weight_only_int_8
paddle::Tensor fc1_out_tensor = GetEmptyTensor(
{expanded_active_expert_rows, inter_size}, input_type, place);
auto fc1_out_ptr = fc1_out_tensor.data<data_t>();
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
mctlassExOrder_t column_major =
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
// ffn1
auto fc1_expert_biases =
ffn1_bias
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
: nullptr;
auto fc1_expert_scales = const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())->data<data_t>();
ffn1_bias
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
: nullptr;
auto fc1_expert_scales =
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())->data<data_t>();
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(permuted_input_ptr),
row_major,
reinterpret_cast<const ElementB *>(ffn1_weight.data<ElementB>()),
column_major,
reinterpret_cast<const ElementA *>(fc1_expert_scales),
reinterpret_cast<const ElementA *>(fc1_expert_biases),
reinterpret_cast<ElementC *>(fc1_out_ptr),
row_major,
tokens_expert_prefix_sum.data<int>(),
num_experts,
expanded_active_expert_rows,
inter_dim,
hidden_size,
stream);
reinterpret_cast<const ElementA*>(permuted_input_ptr),
row_major,
reinterpret_cast<const ElementB*>(ffn1_weight.data<ElementB>()),
column_major,
reinterpret_cast<const ElementA*>(fc1_expert_scales),
reinterpret_cast<const ElementA*>(fc1_expert_biases),
reinterpret_cast<ElementC*>(fc1_out_ptr),
row_major,
tokens_expert_prefix_sum.data<int>(),
num_experts,
expanded_active_expert_rows,
inter_dim,
hidden_size,
stream);
// swiglu
auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
auto act_out = act_out_tensor.data<data_t>();
auto fc2_expert_scales = const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())->data<data_t>();
auto fc2_expert_scales =
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())->data<data_t>();
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(act_out),
row_major,
reinterpret_cast<const ElementB *>(ffn2_weight.data<ElementB>()),
column_major,
reinterpret_cast<const ElementA *>(fc2_expert_scales),
nullptr,
reinterpret_cast<ElementC *>(ffn_out_ptr),
row_major,
tokens_expert_prefix_sum.data<int>(),
num_experts,
expanded_active_expert_rows,
hidden_size,
inter_dim / 2,
stream);
reinterpret_cast<const ElementA*>(act_out),
row_major,
reinterpret_cast<const ElementB*>(ffn2_weight.data<ElementB>()),
column_major,
reinterpret_cast<const ElementA*>(fc2_expert_scales),
nullptr,
reinterpret_cast<ElementC*>(ffn_out_ptr),
row_major,
tokens_expert_prefix_sum.data<int>(),
num_experts,
expanded_active_expert_rows,
hidden_size,
inter_dim / 2,
stream);
}
std::vector<paddle::Tensor> MoeExpertFFN(
@@ -109,15 +116,18 @@ std::vector<paddle::Tensor> MoeExpertFFN(
switch (input_type) {
case paddle::DataType::BFLOAT16:
McMoeFFNKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
quant_method,
ffn_out);
McMoeFFNKernel<paddle::DataType::BFLOAT16,
maca_bfloat16,
int8_t,
maca_bfloat16>(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
quant_method,
ffn_out);
break;
// case paddle::DataType::FLOAT16:
// MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
+1 -5
View File
@@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "fused_moe_helper.h"
#include "fused_moe_op.h"
#include "helper.h"
template <paddle::DataType T>
void MoeReduceKernel(const paddle::Tensor& ffn_out,
@@ -52,7 +51,6 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
stream);
}
std::vector<paddle::Tensor> MoeExpertReduce(
const paddle::Tensor& ffn_out,
const paddle::Tensor& top_k_weight,
@@ -106,7 +104,6 @@ std::vector<paddle::Tensor> MoeExpertReduce(
return {output};
}
std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
const std::vector<int64_t>& ffn_out_shape,
const std::vector<int64_t>& top_k_weight_shape,
@@ -129,7 +126,6 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
return {ffn_out_dtype};
}
PD_BUILD_OP(moe_expert_reduce)
.Inputs({"ffn_out",
"top_k_weight",
+106 -86
View File
@@ -12,64 +12,69 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "utility/helper.h"
#include "xpu/plugin.h"
#include <paddle/phi/backends/xpu/xpu_context.h>
template <paddle::DataType T>
std::vector<paddle::Tensor>
AdjustBatchKernel(const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &encoder_batch_idx,
const paddle::Tensor &decoder_batch_idx,
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_idx_cpu,
const paddle::Tensor &decoder_batch_idx_cpu,
const paddle::Tensor &enc_batch_tensor,
const paddle::Tensor &dec_batch_tensor,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
PD_CHECK(x.dtype() == T);
PD_CHECK(x.dims().size() == 2);
std::vector<paddle::Tensor> AdjustBatchKernel(
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &encoder_batch_idx,
const paddle::Tensor &decoder_batch_idx,
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_idx_cpu,
const paddle::Tensor &decoder_batch_idx_cpu,
const paddle::Tensor &enc_batch_tensor,
const paddle::Tensor &dec_batch_tensor,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
PD_CHECK(x.dtype() == T);
PD_CHECK(x.dims().size() == 2);
using XPUType = typename XPUTypeTrait<typename PDTraits<T>::DataType>::Type;
using data_t = typename PDTraits<T>::data_t;
const int token_num = x.dims()[0];
const int dim = x.dims()[1];
const int bsz = cum_offsets.shape()[0];
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
using XPUType = typename XPUTypeTrait<typename PDTraits<T>::DataType>::Type;
using data_t = typename PDTraits<T>::data_t;
const int token_num = x.dims()[0];
const int dim = x.dims()[1];
const int bsz = cum_offsets.shape()[0];
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
enc_batch + 1, const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
const_cast<int32_t *>(encoder_batch_idx_cpu.data<int32_t>()), enc_batch,
const_cast<int32_t *>(encoder_batch_idx.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp{
const_cast<int32_t *>(decoder_batch_idx_cpu.data<int32_t>()), dec_batch,
const_cast<int32_t *>(decoder_batch_idx.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
enc_batch + 1,
const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
const_cast<int32_t *>(encoder_batch_idx_cpu.data<int32_t>()),
enc_batch,
const_cast<int32_t *>(encoder_batch_idx.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp{
const_cast<int32_t *>(decoder_batch_idx_cpu.data<int32_t>()),
dec_batch,
const_cast<int32_t *>(decoder_batch_idx.data<int32_t>())};
auto out = paddle::full({token_num, dim}, -2, x.type(), x.place());
auto out = paddle::full({token_num, dim}, -2, x.type(), x.place());
int r = baidu::xpu::api::plugin::eb_adjust_batch<XPUType, XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType *>(x.data<data_t>()),
reinterpret_cast<XPUType *>(out.data<data_t>()), encoder_seqs_lods_vp,
encoder_batch_map_vp, decoder_batch_map_vp, dim);
return {out};
int r = baidu::xpu::api::plugin::eb_adjust_batch<XPUType, XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType *>(x.data<data_t>()),
reinterpret_cast<XPUType *>(out.data<data_t>()),
encoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
return {out};
}
using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &encoder_batch_idx,
const paddle::Tensor &decoder_batch_idx,
@@ -81,42 +86,50 @@ using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length);
std::vector<paddle::Tensor>
AdjustBatch(const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &encoder_batch_idx,
const paddle::Tensor &decoder_batch_idx,
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_idx_cpu,
const paddle::Tensor &decoder_batch_idx_cpu,
const paddle::Tensor &enc_batch_tensor,
const paddle::Tensor &dec_batch_tensor,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
AdjustBatchKernelFuncPtr func = nullptr;
std::vector<paddle::Tensor> AdjustBatch(
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &encoder_batch_idx,
const paddle::Tensor &decoder_batch_idx,
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_idx_cpu,
const paddle::Tensor &decoder_batch_idx_cpu,
const paddle::Tensor &enc_batch_tensor,
const paddle::Tensor &dec_batch_tensor,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
AdjustBatchKernelFuncPtr func = nullptr;
switch (x.dtype()) {
switch (x.dtype()) {
case paddle::DataType::BFLOAT16:
func = &AdjustBatchKernel<paddle::DataType::BFLOAT16>;
break;
func = &AdjustBatchKernel<paddle::DataType::BFLOAT16>;
break;
case paddle::DataType::FLOAT16:
func = &AdjustBatchKernel<paddle::DataType::FLOAT16>;
break;
func = &AdjustBatchKernel<paddle::DataType::FLOAT16>;
break;
case paddle::DataType::FLOAT32:
func = &AdjustBatchKernel<paddle::DataType::FLOAT32>;
break;
func = &AdjustBatchKernel<paddle::DataType::FLOAT32>;
break;
case paddle::DataType::INT64:
func = &AdjustBatchKernel<paddle::DataType::INT64>;
break;
func = &AdjustBatchKernel<paddle::DataType::INT64>;
break;
default:
PD_THROW("Unsupported data type: ", x.dtype());
}
PD_THROW("Unsupported data type: ", x.dtype());
}
return func(x, cum_offsets, encoder_seq_lod, encoder_batch_idx,
decoder_batch_idx, encoder_seq_lod_cpu, encoder_batch_idx_cpu,
decoder_batch_idx_cpu, enc_batch_tensor, dec_batch_tensor,
output_padding_offset, max_input_length);
return func(x,
cum_offsets,
encoder_seq_lod,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
enc_batch_tensor,
dec_batch_tensor,
output_padding_offset,
max_input_length);
}
std::vector<std::vector<int64_t>> AdjustBatchInferShape(
@@ -131,16 +144,17 @@ std::vector<std::vector<int64_t>> AdjustBatchInferShape(
const std::vector<int64_t> &enc_batch_tensor_shape,
const std::vector<int64_t> &dec_batch_tensor_shape,
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
if (output_padding_offset_shape) {
PD_THROW("speculative decoding is not supported in XPU.");
}
int64_t token_num = x_shape[0];
int64_t dim_embed = x_shape[1];
return {{token_num, dim_embed}};
if (output_padding_offset_shape) {
PD_THROW("speculative decoding is not supported in XPU.");
}
int64_t token_num = x_shape[0];
int64_t dim_embed = x_shape[1];
return {{token_num, dim_embed}};
}
std::vector<paddle::DataType> AdjustBatchInferDtype(
const paddle::DataType &x_dtype, const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &x_dtype,
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &encoder_seq_lod_dtype,
const paddle::DataType &encoder_batch_idx_dtype,
const paddle::DataType &decoder_batch_idx_dtype,
@@ -150,14 +164,20 @@ std::vector<paddle::DataType> AdjustBatchInferDtype(
const paddle::DataType &enc_batch_tensor_dtype,
const paddle::DataType &dec_batch_tensor_dtype,
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
return {x_dtype};
return {x_dtype};
}
PD_BUILD_OP(adjust_batch)
.Inputs({"x", "cum_offsets", "encoder_seq_lod", "encoder_batch_idx",
"decoder_batch_idx", "encoder_seq_lod_cpu",
"encoder_batch_idx_cpu", "decoder_batch_idx_cpu",
"enc_batch_tensor", "dec_batch_tensor",
.Inputs({"x",
"cum_offsets",
"encoder_seq_lod",
"encoder_batch_idx",
"decoder_batch_idx",
"encoder_seq_lod_cpu",
"encoder_batch_idx_cpu",
"decoder_batch_idx_cpu",
"enc_batch_tensor",
"dec_batch_tensor",
paddle::Optional("output_padding_offset")})
.Outputs({"out"})
.Attrs({"max_input_length: int"})
+63 -61
View File
@@ -89,7 +89,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const std::string &pos_emb_type,
const std::string& pos_emb_type,
bool rope_3d) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
@@ -215,8 +215,8 @@ std::vector<paddle::Tensor> BlockAttnKernel(
param.prefill_len = is_prefix_cache ? param.max_valid_seqlen : -1;
param.page_attn.block_size = block_size;
param.page_attn.max_num_blocks_per_seq = prefix_block_num_per_seq;
// prefix_block_tables is a subset of block_tables, which is used for prefix
// cache
// prefix_block_tables is a subset of block_tables, which is used for
// prefix cache
xftblock::Tensor prefix_block_tables_tensor(
is_prefix_cache ? reinterpret_cast<void*>(const_cast<int32_t*>(
prefix_block_tables.data<int32_t>()))
@@ -306,12 +306,12 @@ std::vector<paddle::Tensor> BlockAttnKernel(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
prefix_lens_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
prefix_lens_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
@@ -480,7 +480,8 @@ std::vector<paddle::Tensor> BlockAttnKernel(
api::VectorParam<int32_t> decoder_context_len_vp = {
const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()),
dec_batch,
nullptr}; // use for speculative_attention_decoder seq_len in MTP
nullptr}; // use for speculative_attention_decoder seq_len in
// MTP
api::VectorParam<int32_t> decoder_context_len_cache_vp = {
const_cast<int32_t*>(decoder_context_len_cache_cpu.data<int32_t>()),
dec_batch,
@@ -597,49 +598,49 @@ std::vector<paddle::Tensor> BlockAttnKernel(
tfloat32,
int8_wo_t>;
constexpr int quant_mode = std::is_same_v<XPU_CType, int8_t> ? 3 : 0;
ret = baidu::xpu::xfa::speculative_attention_decoder<XPU_XType,
XPU_CType,
XPU_XType,
TGEMM,
TGEMM,
float,
int32_t,
quant_mode>(
xpu_ctx->x_context(),
decode_output_ptr, // out
q_buf_ptr, // q
nullptr, // k
nullptr, // v
reinterpret_cast<const XPU_CType*>(
key_cache.data<cdata_t>()), // k_cache
reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>()), // v_cache
reinterpret_cast<const int32_t*>(
block_tables.data<int32_t>()), // block_tables
decoder_context_len_vp, // seq_lengths
decoder_batch_map_vp, // valid_batch
param.max_batch_size, // batch_num
q_len, // qlen
max_seq_len, // max_seq_len
param.head_num, // head_num
param.head_dim, // head_dim
param.kv_head_num, // kv_head_num
nullptr, // attn_mask
1.0f /
std::sqrt(static_cast<float>(param.head_dim)), // scale 【check】
block_size, // block_size
max_block_per_seq, // max_blocks_per_seq
-1, // max_window_size
nullptr, // q_maxptr
has_zp // k_cache_maxptr
? fake_perhead_scale
: quant_k_scale_inv,
has_zp // v_cache_maxptr
? fake_perhead_scale
: quant_v_scale_inv,
nullptr, // o_maxptr
param.head_dim); // vo_head_dim
PD_CHECK(0, "speculative_attention unimplemented");
ret = baidu::xpu::xfa::speculative_attention_decoder<XPU_XType,
XPU_CType,
XPU_XType,
TGEMM,
TGEMM,
float,
int32_t,
quant_mode>(
xpu_ctx->x_context(),
decode_output_ptr, // out
q_buf_ptr, // q
nullptr, // k
nullptr, // v
reinterpret_cast<const XPU_CType*>(
key_cache.data<cdata_t>()), // k_cache
reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>()), // v_cache
reinterpret_cast<const int32_t*>(
block_tables.data<int32_t>()), // block_tables
decoder_context_len_vp, // seq_lengths
decoder_batch_map_vp, // valid_batch
param.max_batch_size, // batch_num
q_len, // qlen
max_seq_len, // max_seq_len
param.head_num, // head_num
param.head_dim, // head_dim
param.kv_head_num, // kv_head_num
nullptr, // attn_mask
1.0f /
std::sqrt(static_cast<float>(param.head_dim)), // scale 【check】
block_size, // block_size
max_block_per_seq, // max_blocks_per_seq
-1, // max_window_size
nullptr, // q_maxptr
has_zp // k_cache_maxptr
? fake_perhead_scale
: quant_k_scale_inv,
has_zp // v_cache_maxptr
? fake_perhead_scale
: quant_v_scale_inv,
nullptr, // o_maxptr
param.head_dim); // vo_head_dim
PD_CHECK(0, "speculative_attention unimplemented");
PD_CHECK(ret == api::SUCCESS,
"xfa::speculative_attention_decoder failed.");
if (!Eq_len) {
@@ -702,11 +703,11 @@ std::vector<paddle::Tensor> BlockAttnKernel(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size = rotary_embs.dims()[1] = 1
rope_max_seqlen, // max_seqlen
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size = rotary_embs.dims()[1] = 1
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
@@ -777,7 +778,8 @@ std::vector<paddle::Tensor> BlockAttnKernel(
ret = xftblock::xft_decoder_core_attenion_block<
XPU_XType,
XPU_CType,
XPU_XType>( // TGEMM = XPU_XType TODOlizan03: used high precision
XPU_XType>( // TGEMM = XPU_XType TODOlizan03: used high
// precision
&xctx,
&q_buf,
&key_cache_tensor,
@@ -867,8 +869,8 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const std::string &pos_emb_type="NORMAL",
bool rope_3d=false) {
const std::string& pos_emb_type = "NORMAL",
bool rope_3d = false) {
#define APPLY_KERNEL(TX, TC, TS) \
return BlockAttnKernel<TX, TC, TS>(qkv, \
key_cache, \
@@ -12,13 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include "xpu/xpuml.h"
#include <cstdlib>
#include <fcntl.h>
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <random>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@@ -26,27 +21,31 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <cstdlib>
#include <random>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include "xpu/xpuml.h"
std::vector<paddle::Tensor> GetMaxMemDemand(int64_t device_id) {
if (device_id == -1) {
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
}
phi::XPUPlace place(device_id);
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
if (device_id == -1) {
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
}
phi::XPUPlace place(device_id);
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
paddle::Tensor max_mem_demand = paddle::zeros({1}, paddle::DataType::INT64);
paddle::Tensor max_mem_demand = paddle::zeros({1}, paddle::DataType::INT64);
max_mem_demand.data<int64_t>()[0] =
xpu_ctx->x_context()->_gm_mgr.get_max_mem_demand();
return {max_mem_demand};
max_mem_demand.data<int64_t>()[0] =
xpu_ctx->x_context()->_gm_mgr.get_max_mem_demand();
return {max_mem_demand};
}
std::vector<std::vector<int64_t>> GetMaxMemDemandInferShape() { return {{1}}; }
std::vector<paddle::DataType> GetMaxMemDemandInferDtype() {
return {paddle::DataType::INT64};
return {paddle::DataType::INT64};
}
PD_BUILD_OP(xpu_get_context_gm_max_mem_demand)
@@ -12,13 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include "xpu/xpuml.h"
#include <cstdlib>
#include <fcntl.h>
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <random>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@@ -26,30 +21,35 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <cstdlib>
#include <random>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include "xpu/xpuml.h"
std::vector<paddle::Tensor> GetFreeGlobalMemory(int64_t device_id) {
if (device_id == -1) {
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
}
if (device_id == -1) {
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
}
paddle::Tensor free_global_memory =
paddle::zeros({1}, paddle::DataType::INT64);
paddle::Tensor free_global_memory =
paddle::zeros({1}, paddle::DataType::INT64);
xpumlDevice_t device_handle;
xpumlInit();
xpumlDeviceGetHandleByIndex(device_id, &device_handle);
xpumlMemory_t device_memory;
xpumlDeviceGetMemoryInfo(device_handle, &device_memory);
free_global_memory.data<int64_t>()[0] = device_memory.freeGlobalMemory;
return {free_global_memory};
xpumlDevice_t device_handle;
xpumlInit();
xpumlDeviceGetHandleByIndex(device_id, &device_handle);
xpumlMemory_t device_memory;
xpumlDeviceGetMemoryInfo(device_handle, &device_memory);
free_global_memory.data<int64_t>()[0] = device_memory.freeGlobalMemory;
return {free_global_memory};
}
std::vector<std::vector<int64_t>> GetFreeGlobalMemoryInferShape() {
return {{1}};
return {{1}};
}
std::vector<paddle::DataType> GetFreeGlobalMemoryInferDtype() {
return {paddle::DataType::INT64};
return {paddle::DataType::INT64};
}
PD_BUILD_OP(xpu_get_free_global_memory)
@@ -12,13 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include "xpu/xpuml.h"
#include <cstdlib>
#include <fcntl.h>
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <random>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@@ -26,29 +21,34 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <cstdlib>
#include <random>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include "xpu/xpuml.h"
std::vector<paddle::Tensor> GetTotalGlobalMemory(int64_t device_id) {
if (device_id == -1) {
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
}
if (device_id == -1) {
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
}
paddle::Tensor total_global_memory =
paddle::zeros({1}, paddle::DataType::INT64);
xpumlDevice_t device_handle;
xpumlInit();
xpumlDeviceGetHandleByIndex(device_id, &device_handle);
xpumlMemory_t device_memory;
xpumlDeviceGetMemoryInfo(device_handle, &device_memory);
total_global_memory.data<int64_t>()[0] = device_memory.totalGlobalMemory;
return {total_global_memory};
paddle::Tensor total_global_memory =
paddle::zeros({1}, paddle::DataType::INT64);
xpumlDevice_t device_handle;
xpumlInit();
xpumlDeviceGetHandleByIndex(device_id, &device_handle);
xpumlMemory_t device_memory;
xpumlDeviceGetMemoryInfo(device_handle, &device_memory);
total_global_memory.data<int64_t>()[0] = device_memory.totalGlobalMemory;
return {total_global_memory};
}
std::vector<std::vector<int64_t>> GetTotalGlobalMemoryInferShape() {
return {{1}};
return {{1}};
}
std::vector<paddle::DataType> GetTotalGlobalMemoryInferDtype() {
return {paddle::DataType::INT64};
return {paddle::DataType::INT64};
}
PD_BUILD_OP(xpu_get_total_global_memory)
@@ -12,13 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include "xpu/xpuml.h"
#include <cstdlib>
#include <fcntl.h>
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <random>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@@ -26,29 +21,34 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <cstdlib>
#include <random>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include "xpu/xpuml.h"
std::vector<paddle::Tensor> GetUsedGlobalMemory(int64_t device_id) {
if (device_id == -1) {
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
}
if (device_id == -1) {
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
}
paddle::Tensor used_global_memory =
paddle::zeros({1}, paddle::DataType::INT64);
xpumlDevice_t device_handle;
xpumlInit();
xpumlDeviceGetHandleByIndex(device_id, &device_handle);
xpumlMemory_t device_memory;
xpumlDeviceGetMemoryInfo(device_handle, &device_memory);
used_global_memory.data<int64_t>()[0] = device_memory.usedGlobalMemory;
return {used_global_memory};
paddle::Tensor used_global_memory =
paddle::zeros({1}, paddle::DataType::INT64);
xpumlDevice_t device_handle;
xpumlInit();
xpumlDeviceGetHandleByIndex(device_id, &device_handle);
xpumlMemory_t device_memory;
xpumlDeviceGetMemoryInfo(device_handle, &device_memory);
used_global_memory.data<int64_t>()[0] = device_memory.usedGlobalMemory;
return {used_global_memory};
}
std::vector<std::vector<int64_t>> GetUsedGlobalMemoryInferShape() {
return {{1}};
return {{1}};
}
std::vector<paddle::DataType> GetUsedGlobalMemoryInferDtype() {
return {paddle::DataType::INT64};
return {paddle::DataType::INT64};
}
PD_BUILD_OP(xpu_get_used_global_memory)
+63 -52
View File
@@ -12,52 +12,57 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include <paddle/phi/backends/xpu/xpu_context.h>
std::vector<paddle::Tensor>
GatherNextToken(const paddle::Tensor &tmp_out, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &encoder_batch_map,
const paddle::Tensor &decoder_batch_map,
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_map_cpu,
const paddle::Tensor &decoder_batch_map_cpu,
const paddle::Tensor &enc_batch_tensor,
const paddle::Tensor &dec_batch_tensor,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
using XPUType =
typename XPUTypeTrait<bfloat16>::Type; // only support bfloat16
typedef paddle::bfloat16 data_t;
const int dim = tmp_out.dims()[1];
const int bsz = cum_offsets.shape()[0];
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor &tmp_out, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &encoder_batch_map,
const paddle::Tensor &decoder_batch_map,
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_map_cpu,
const paddle::Tensor &decoder_batch_map_cpu,
const paddle::Tensor &enc_batch_tensor,
const paddle::Tensor &dec_batch_tensor,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
using XPUType =
typename XPUTypeTrait<bfloat16>::Type; // only support bfloat16
typedef paddle::bfloat16 data_t;
const int dim = tmp_out.dims()[1];
const int bsz = cum_offsets.shape()[0];
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
enc_batch + 1, const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
const_cast<int32_t *>(encoder_batch_map_cpu.data<int32_t>()), enc_batch,
const_cast<int32_t *>(encoder_batch_map.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp{
const_cast<int32_t *>(decoder_batch_map_cpu.data<int32_t>()), dec_batch,
const_cast<int32_t *>(decoder_batch_map.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
enc_batch + 1,
const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
const_cast<int32_t *>(encoder_batch_map_cpu.data<int32_t>()),
enc_batch,
const_cast<int32_t *>(encoder_batch_map.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp{
const_cast<int32_t *>(decoder_batch_map_cpu.data<int32_t>()),
dec_batch,
const_cast<int32_t *>(decoder_batch_map.data<int32_t>())};
auto out = paddle::full({bsz, dim}, -2, tmp_out.type(), tmp_out.place());
auto out = paddle::full({bsz, dim}, -2, tmp_out.type(), tmp_out.place());
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType *>(tmp_out.data<data_t>()),
reinterpret_cast<XPUType *>(out.data<data_t>()), encoder_seqs_lods_vp,
encoder_batch_map_vp, decoder_batch_map_vp, dim);
return {out};
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType *>(tmp_out.data<data_t>()),
reinterpret_cast<XPUType *>(out.data<data_t>()),
encoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
return {out};
}
std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
@@ -72,12 +77,12 @@ std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
const std::vector<int64_t> &enc_batch_tensor_shape,
const std::vector<int64_t> &dec_batch_tensor_shape,
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
if (output_padding_offset_shape) {
PD_THROW("speculative decoding is not supported in XPU.");
}
int64_t bsz = cum_offsets_shape[0];
int64_t dim_embed = tmp_out_shape[1];
return {{bsz, dim_embed}};
if (output_padding_offset_shape) {
PD_THROW("speculative decoding is not supported in XPU.");
}
int64_t bsz = cum_offsets_shape[0];
int64_t dim_embed = tmp_out_shape[1];
return {{bsz, dim_embed}};
}
std::vector<paddle::DataType> GatherNextTokenInferDtype(
@@ -92,14 +97,20 @@ std::vector<paddle::DataType> GatherNextTokenInferDtype(
const paddle::DataType &enc_batch_tensor_dtype,
const paddle::DataType &dec_batch_tensor_dtype,
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
return {tmp_out_dtype};
return {tmp_out_dtype};
}
PD_BUILD_OP(gather_next_token)
.Inputs({"tmp_out", "cum_offsets", "encoder_seq_lod", "encoder_batch_map",
"decoder_batch_map", "encoder_seq_lod_cpu",
"encoder_batch_map_cpu", "decoder_batch_map_cpu",
"enc_batch_tensor", "dec_batch_tensor",
.Inputs({"tmp_out",
"cum_offsets",
"encoder_seq_lod",
"encoder_batch_map",
"decoder_batch_map",
"encoder_seq_lod_cpu",
"encoder_batch_map_cpu",
"decoder_batch_map_cpu",
"enc_batch_tensor",
"dec_batch_tensor",
paddle::Optional("output_padding_offset")})
.Outputs({"out"})
.Attrs({"max_input_length: int"})
@@ -14,43 +14,48 @@
#include "paddle/extension.h"
std::vector<paddle::Tensor> GetImgBoundaries(const paddle::Tensor& task_input_ids,
const paddle::Tensor& grid_thw,
const int64_t image_patch_id) {
// All tensor in cpu
auto input_ids_ptr = task_input_ids.data<int64_t>();
int64_t seq_lens_origin = task_input_ids.numel();
auto grid_thw_ptr = grid_thw.data<int64_t>();
std::vector<paddle::Tensor> GetImgBoundaries(
const paddle::Tensor& task_input_ids,
const paddle::Tensor& grid_thw,
const int64_t image_patch_id) {
// All tensor in cpu
auto input_ids_ptr = task_input_ids.data<int64_t>();
int64_t seq_lens_origin = task_input_ids.numel();
auto grid_thw_ptr = grid_thw.data<int64_t>();
int token_times = 4;
int token_idx = 0;
int image_idx = 0;
std::vector<int> img_boundaries, img_nums;
img_boundaries.emplace_back(0);
img_nums.emplace_back(0);
while (token_idx < seq_lens_origin) {
if (input_ids_ptr[token_idx] != image_patch_id) {
do {
token_idx++;
} while (token_idx < seq_lens_origin && input_ids_ptr[token_idx] != image_patch_id);
} else {
int cur_image_token_len = (grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) / token_times;
image_idx++;
token_idx += cur_image_token_len;
}
img_boundaries.emplace_back(token_idx);
img_nums.emplace_back(image_idx);
int token_times = 4;
int token_idx = 0;
int image_idx = 0;
std::vector<int> img_boundaries, img_nums;
img_boundaries.emplace_back(0);
img_nums.emplace_back(0);
while (token_idx < seq_lens_origin) {
if (input_ids_ptr[token_idx] != image_patch_id) {
do {
token_idx++;
} while (token_idx < seq_lens_origin &&
input_ids_ptr[token_idx] != image_patch_id);
} else {
int cur_image_token_len =
(grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) /
token_times;
image_idx++;
token_idx += cur_image_token_len;
}
img_boundaries.emplace_back(token_idx);
img_nums.emplace_back(image_idx);
}
int64_t num_img_boundaries = static_cast<int64_t>(img_boundaries.size());
auto out = paddle::full({2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace());
int64_t num_img_boundaries = static_cast<int64_t>(img_boundaries.size());
auto out = paddle::full(
{2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace());
for (int i = 0; i < num_img_boundaries; i++) {
out.data<int64_t>()[i] = img_boundaries[i];
out.data<int64_t>()[num_img_boundaries + i] = img_nums[i];
}
for (int i = 0; i < num_img_boundaries; i++) {
out.data<int64_t>()[i] = img_boundaries[i];
out.data<int64_t>()[num_img_boundaries + i] = img_nums[i];
}
return {out};
return {out};
}
PD_BUILD_OP(get_img_boundaries)
+61 -58
View File
@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "msg_utils.h"
#include "paddle/extension.h"
void GetOutputKVSignal(const paddle::Tensor& x,
void GetOutputKVSignal(const paddle::Tensor &x,
int64_t rank_id,
bool wait_flag) {
int msg_queue_id = 1024 + rank_id;
@@ -28,7 +28,7 @@ void GetOutputKVSignal(const paddle::Tensor& x,
static key_t key = ftok("/opt/", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
int* out_data = const_cast<int*>(x.data<int>());
int *out_data = const_cast<int *>(x.data<int>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
@@ -48,69 +48,72 @@ void GetOutputKVSignal(const paddle::Tensor& x,
return;
}
void GetOutput(const paddle::Tensor &x, int64_t rank_id, bool wait_flag,
void GetOutput(const paddle::Tensor &x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
if (rank_id > 0) {
return;
}
static struct msgdata msg_rcv;
if (const char *inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef GET_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("/dev/shm", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output msg_queue_id: " << msg_queue_id << std::endl;
std::cout << "get_output key: " << key << std::endl;
std::cout << "get_output msgid: " << msgid << std::endl;
std::cout << "get_output wait_flag: " << wait_flag << std::endl;
#endif
int64_t *out_data = const_cast<int64_t *>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output finish msgrcv" << std::endl;
#endif
if (ret == -1) {
out_data[0] = -2;
out_data[1] = 0;
return;
}
int bsz = msg_rcv.mtext[1];
for (int64_t i = 0; i < bsz + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output finished: " << msgid << std::endl;
#endif
if (rank_id > 0) {
return;
}
static struct msgdata msg_rcv;
if (const char *inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef GET_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("/dev/shm", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output msg_queue_id: " << msg_queue_id << std::endl;
std::cout << "get_output key: " << key << std::endl;
std::cout << "get_output msgid: " << msgid << std::endl;
std::cout << "get_output wait_flag: " << wait_flag << std::endl;
#endif
int64_t *out_data = const_cast<int64_t *>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output finish msgrcv" << std::endl;
#endif
if (ret == -1) {
out_data[0] = -2;
out_data[1] = 0;
return;
}
int bsz = msg_rcv.mtext[1];
for (int64_t i = 0; i < bsz + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output finished: " << msgid << std::endl;
#endif
return;
}
void GetOutputStatic(const paddle::Tensor &x, int64_t rank_id, bool wait_flag) {
GetOutput(x, rank_id, wait_flag, 1);
GetOutput(x, rank_id, wait_flag, 1);
}
void GetOutputDynamic(const paddle::Tensor &x, int64_t rank_id, bool wait_flag,
void GetOutputDynamic(const paddle::Tensor &x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
GetOutput(x, rank_id, wait_flag, msg_queue_id);
GetOutput(x, rank_id, wait_flag, msg_queue_id);
}
PD_BUILD_OP(get_output)
@@ -20,44 +20,43 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &token_num,
const paddle::Tensor &seq_len) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::full(
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::full(
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
int r = baidu::xpu::api::plugin::get_padding_offset(
xpu_ctx->x_context(),
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length,
bsz);
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
return {x_remove_padding,
cum_offsets_out,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k};
const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::full(
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::full(
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
int r = baidu::xpu::api::plugin::get_padding_offset(
xpu_ctx->x_context(),
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length,
bsz);
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
return {x_remove_padding,
cum_offsets_out,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k};
}
std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
@@ -65,9 +64,9 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
const std::vector<int64_t> &cum_offsets_shape,
const std::vector<int64_t> &token_num_shape,
const std::vector<int64_t> &seq_len_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
}
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
@@ -75,11 +74,11 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &token_num_dtype,
const paddle::DataType &seq_len_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
}
PD_BUILD_OP(get_padding_offset)
@@ -12,70 +12,97 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
#include <paddle/phi/backends/xpu/xpu_context.h>
void TokenPenaltyMultiScores(
const paddle::Tensor &pre_ids, const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
const paddle::Tensor &presence_scores, const paddle::Tensor &temperatures,
const paddle::Tensor &bad_tokens, const paddle::Tensor &cur_len,
const paddle::Tensor &min_len, const paddle::Tensor &eos_token_id) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
int64_t bs = logits.shape()[0];
PADDLE_ENFORCE_LE(
bs, 640,
phi::errors::InvalidArgument(
"Only support bsz <= 1024, but received bsz is %d", bs));
int64_t length = logits.shape()[1];
int64_t length_id = pre_ids.shape()[1];
int64_t length_bad_words = bad_tokens.shape()[0];
int64_t end_length = eos_token_id.shape()[0];
switch (logits.type()) {
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
const paddle::Tensor &presence_scores,
const paddle::Tensor &temperatures,
const paddle::Tensor &bad_tokens,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
int64_t bs = logits.shape()[0];
PADDLE_ENFORCE_LE(
bs,
640,
phi::errors::InvalidArgument(
"Only support bsz <= 1024, but received bsz is %d", bs));
int64_t length = logits.shape()[1];
int64_t length_id = pre_ids.shape()[1];
int64_t length_bad_words = bad_tokens.shape()[0];
int64_t end_length = eos_token_id.shape()[0];
switch (logits.type()) {
case paddle::DataType::FLOAT16: {
using XPUType = typename XPUTypeTrait<float16>::Type;
typedef paddle::float16 data_t;
int r = baidu::xpu::api::plugin::token_penalty_multi_scores(
xpu_ctx->x_context(), pre_ids.data<int64_t>(),
reinterpret_cast<XPUType *>(
const_cast<data_t *>(logits.data<data_t>())),
reinterpret_cast<const XPUType *>(penalty_scores.data<data_t>()),
reinterpret_cast<const XPUType *>(frequency_scores.data<data_t>()),
reinterpret_cast<const XPUType *>(presence_scores.data<data_t>()),
temperatures.data<float>(), cur_len.data<int64_t>(),
min_len.data<int64_t>(), eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(), bs, length, length_id, end_length,
length_bad_words);
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
using XPUType = typename XPUTypeTrait<float16>::Type;
typedef paddle::float16 data_t;
int r = baidu::xpu::api::plugin::token_penalty_multi_scores(
xpu_ctx->x_context(),
pre_ids.data<int64_t>(),
reinterpret_cast<XPUType *>(
const_cast<data_t *>(logits.data<data_t>())),
reinterpret_cast<const XPUType *>(penalty_scores.data<data_t>()),
reinterpret_cast<const XPUType *>(frequency_scores.data<data_t>()),
reinterpret_cast<const XPUType *>(presence_scores.data<data_t>()),
temperatures.data<float>(),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(),
bs,
length,
length_id,
end_length,
length_bad_words);
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
} break;
case paddle::DataType::FLOAT32: {
int r = baidu::xpu::api::plugin::token_penalty_multi_scores(
xpu_ctx->x_context(), pre_ids.data<int64_t>(),
const_cast<float *>(logits.data<float>()),
penalty_scores.data<float>(), frequency_scores.data<float>(),
presence_scores.data<float>(), temperatures.data<float>(),
cur_len.data<int64_t>(), min_len.data<int64_t>(),
eos_token_id.data<int64_t>(), bad_tokens.data<int64_t>(), bs,
length, length_id, end_length, length_bad_words);
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
int r = baidu::xpu::api::plugin::token_penalty_multi_scores(
xpu_ctx->x_context(),
pre_ids.data<int64_t>(),
const_cast<float *>(logits.data<float>()),
penalty_scores.data<float>(),
frequency_scores.data<float>(),
presence_scores.data<float>(),
temperatures.data<float>(),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(),
bs,
length,
length_id,
end_length,
length_bad_words);
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
} break;
default:
PD_THROW("NOT supported data type. "
"Only float16 and float32 are supported. ");
break;
}
PD_THROW(
"NOT supported data type. "
"Only float16 and float32 are supported. ");
break;
}
}
PD_BUILD_OP(get_token_penalty_multi_scores)
.Inputs({"pre_ids", "logits", "penalty_scores", "frequency_scores",
"presence_scores", "temperatures", "bad_tokens", "cur_len",
"min_len", "eos_token_id"})
.Inputs({"pre_ids",
"logits",
"penalty_scores",
"frequency_scores",
"presence_scores",
"temperatures",
"bad_tokens",
"cur_len",
"min_len",
"eos_token_id"})
.Outputs({"logits_out"})
.SetInplaceMap({{"logits", "logits_out"}})
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));
+1 -1
View File
@@ -72,7 +72,7 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
is_padding_input ? token_num_info : nullptr,
expert_num,
1, // moe_topk
0, // group_size
0, // group_size
ffn1_out_shape.size() == 2 ? xftblock::MoeFCInputMode::DENSE
: xftblock::MoeFCInputMode::SPARSE);
PD_CHECK(ret == 0);
+230 -188
View File
@@ -29,210 +29,246 @@
namespace xftblock = baidu::xpu::xftblock;
namespace api = baidu::xpu::api;
template <typename TX, typename TW> struct fused_moe_ffn_trait {
using GEMM_TYPE = TW;
template <typename TX, typename TW>
struct fused_moe_ffn_trait {
using GEMM_TYPE = TW;
};
template <> struct fused_moe_ffn_trait<bfloat16, bfloat16> {
using GEMM_TYPE = float;
template <>
struct fused_moe_ffn_trait<bfloat16, bfloat16> {
using GEMM_TYPE = float;
};
template <> struct fused_moe_ffn_trait<bfloat16, int8_t> {
using GEMM_TYPE = float;
template <>
struct fused_moe_ffn_trait<bfloat16, int8_t> {
using GEMM_TYPE = float;
};
template <> struct fused_moe_ffn_trait<bfloat16, int4_t> {
using GEMM_TYPE = int4_wo_int15;
template <>
struct fused_moe_ffn_trait<bfloat16, int4_t> {
using GEMM_TYPE = int4_wo_int15;
};
template <typename TX, typename TW>
std::vector<paddle::Tensor> MoeLayerKernel(
const paddle::Tensor &x, const paddle::Tensor &gate_weight,
const paddle::Tensor &x,
const paddle::Tensor &gate_weight,
const paddle::optional<paddle::Tensor> &gate_correction_bias,
const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight,
const paddle::Tensor &up_gate_proj_weight,
const paddle::Tensor &down_proj_weight,
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &up_gate_proj_weight_scale,
const paddle::optional<paddle::Tensor> &down_proj_weight_scale,
const paddle::optional<paddle::Tensor> &down_proj_in_scale, // not support
const std::string &quant_method, const int moe_top_k,
const paddle::optional<paddle::Tensor> &down_proj_in_scale, // not support
const std::string &quant_method,
const int moe_top_k,
const bool moe_group) {
// std::cout << "[Op Debug] enter moe layer" << std::endl;
using XPU_TX = typename XPUTypeTrait<TX>::Type;
using XPU_TW = typename XPUTypeTrait<TW>::Type;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
auto rt_guard = xctx.get_rt_guard();
// std::cout << "[Op Debug] enter moe layer" << std::endl;
using XPU_TX = typename XPUTypeTrait<TX>::Type;
using XPU_TW = typename XPUTypeTrait<TW>::Type;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
auto rt_guard = xctx.get_rt_guard();
const auto xtype = x.dtype();
auto x_dims = x.shape();
auto up_gate_proj_dims = up_gate_proj_weight.shape();
PD_CHECK(x_dims.size() == 2, "x_dims.size() should be 2.");
PD_CHECK(up_gate_proj_dims.size() == 3,
"up_gate_proj_dims.size() should be 3.");
PD_CHECK(down_proj_in_scale.get_ptr() == nullptr,
"down_proj_in_scale not support.");
if (quant_method == "weight_only_int4") {
PD_CHECK(x_dims[1] == up_gate_proj_dims[2] * 2,
"x_dims[1] should equal to up_gate_proj_dims[2], (weight must be "
"[e,n,k]).");
} else {
PD_CHECK(x_dims[1] == up_gate_proj_dims[2],
"x_dims[1] should equal to up_gate_proj_dims[2], (weight must be "
"[e,n,k]).");
}
int token_num = x_dims[0];
int hidden_dim = x_dims[1];
int expert_num = up_gate_proj_dims[0];
int inter_dim = up_gate_proj_dims[1];
int outer_dim = inter_dim / 2;
paddle::Tensor fused_moe_out = paddle::empty_like(x);
auto x_mpart_shape = x_dims;
int MPART_SIZE = 2048;
if (const char *env_val = std::getenv("XPU_MPART_SIZE")) {
MPART_SIZE = std::atoi(env_val);
}
int bsz = x_dims[0];
for (int m_part_start = 0; m_part_start < bsz; m_part_start += MPART_SIZE) {
auto m_part_end = std::min(m_part_start + MPART_SIZE, bsz);
auto x_offset = m_part_start * hidden_dim;
x_mpart_shape[0] = m_part_end - m_part_start;
int ret = -1;
auto xftblock_tx = xftblock::DataTypeToEnum<XPU_TX>::value;
auto xftblock_tw = xftblock::DataTypeToEnum<XPU_TW>::value;
// input + output
xftblock::Tensor xin(
const_cast<TX *>(x.data<TX>() + x_offset), xftblock_tx, x_mpart_shape);
xftblock::Tensor xout(fused_moe_out.mutable_data<TX>() + x_offset,
xftblock_tx,
x_mpart_shape);
// gate
xftblock::Tensor xgate_w(const_cast<float *>(gate_weight.data<float>()),
xftblock::DataType::DT_FLOAT,
gate_weight.shape());
std::shared_ptr<xftblock::Tensor> xgate_correct_bias;
if (gate_correction_bias.get_ptr()) {
xgate_correct_bias = std::make_shared<xftblock::Tensor>(
const_cast<float *>(gate_correction_bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT,
gate_correction_bias.get_ptr()->shape());
}
// up_gate_proj + down_proj
std::shared_ptr<xftblock::Tensor> xup_gate_proj_w, xdown_proj_w;
if (std::is_same<TW, int4_t>::value) {
xup_gate_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<int8_t *>(up_gate_proj_weight.data<int8_t>()),
nullptr,
const_cast<float *>(
up_gate_proj_weight_scale.get_ptr()
? up_gate_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
xdown_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<int8_t *>(down_proj_weight.data<int8_t>()),
nullptr,
const_cast<float *>(
down_proj_weight_scale.get_ptr()
? down_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
const auto xtype = x.dtype();
auto x_dims = x.shape();
auto up_gate_proj_dims = up_gate_proj_weight.shape();
PD_CHECK(x_dims.size() == 2, "x_dims.size() should be 2.");
PD_CHECK(up_gate_proj_dims.size() == 3, "up_gate_proj_dims.size() should be 3.");
PD_CHECK(down_proj_in_scale.get_ptr() == nullptr, "down_proj_in_scale not support.");
if (quant_method == "weight_only_int4") {
PD_CHECK(x_dims[1] == up_gate_proj_dims[2] * 2,
"x_dims[1] should equal to up_gate_proj_dims[2], (weight must be "
"[e,n,k]).");
} else {
PD_CHECK(x_dims[1] == up_gate_proj_dims[2],
"x_dims[1] should equal to up_gate_proj_dims[2], (weight must be "
"[e,n,k]).");
xup_gate_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<TW *>(up_gate_proj_weight.data<TW>()),
nullptr,
const_cast<float *>(
up_gate_proj_weight_scale.get_ptr()
? up_gate_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
xdown_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<TW *>(down_proj_weight.data<TW>()),
nullptr,
const_cast<float *>(
down_proj_weight_scale.get_ptr()
? down_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
}
int token_num = x_dims[0];
int hidden_dim = x_dims[1];
int expert_num = up_gate_proj_dims[0];
int inter_dim = up_gate_proj_dims[1];
int outer_dim = inter_dim / 2;
paddle::Tensor fused_moe_out = paddle::empty_like(x);
auto x_mpart_shape = x_dims;
int MPART_SIZE = 2048;
if (const char* env_val = std::getenv("XPU_MPART_SIZE")) {
MPART_SIZE = std::atoi(env_val);
std::shared_ptr<xftblock::Tensor> xup_gate_proj_bias;
std::shared_ptr<xftblock::Tensor> xdown_proj_bias;
if (up_gate_proj_bias.get_ptr()) {
xup_gate_proj_bias = std::make_shared<xftblock::Tensor>(
const_cast<float *>(up_gate_proj_bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT,
up_gate_proj_bias.get_ptr()->shape());
}
int bsz = x_dims[0];
for (int m_part_start = 0; m_part_start < bsz; m_part_start += MPART_SIZE) {
auto m_part_end = std::min(m_part_start + MPART_SIZE, bsz);
auto x_offset = m_part_start * hidden_dim;
x_mpart_shape[0] = m_part_end - m_part_start;
int ret = -1;
auto xftblock_tx = xftblock::DataTypeToEnum<XPU_TX>::value;
auto xftblock_tw = xftblock::DataTypeToEnum<XPU_TW>::value;
// input + output
xftblock::Tensor xin(const_cast<TX *>(x.data<TX>() + x_offset), xftblock_tx,
x_mpart_shape);
xftblock::Tensor xout(fused_moe_out.mutable_data<TX>() + x_offset, xftblock_tx,
x_mpart_shape);
// gate
xftblock::Tensor xgate_w(const_cast<float *>(gate_weight.data<float>()),
xftblock::DataType::DT_FLOAT, gate_weight.shape());
std::shared_ptr<xftblock::Tensor> xgate_correct_bias;
if (gate_correction_bias.get_ptr()) {
xgate_correct_bias = std::make_shared<xftblock::Tensor>(
const_cast<float *>(gate_correction_bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT,
gate_correction_bias.get_ptr()->shape());
}
// up_gate_proj + down_proj
std::shared_ptr<xftblock::Tensor> xup_gate_proj_w, xdown_proj_w;
if (std::is_same<TW, int4_t>::value) {
xup_gate_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<int8_t *>(up_gate_proj_weight.data<int8_t>()), nullptr,
const_cast<float *>(up_gate_proj_weight_scale.get_ptr()
? up_gate_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
xdown_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<int8_t *>(down_proj_weight.data<int8_t>()), nullptr,
const_cast<float *>(down_proj_weight_scale.get_ptr()
? down_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
} else {
xup_gate_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<TW *>(up_gate_proj_weight.data<TW>()), nullptr,
const_cast<float *>(up_gate_proj_weight_scale.get_ptr()
? up_gate_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, inter_dim, hidden_dim}
);
xdown_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<TW *>(down_proj_weight.data<TW>()), nullptr,
const_cast<float *>(down_proj_weight_scale.get_ptr()
? down_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, hidden_dim, outer_dim}
);
}
std::shared_ptr<xftblock::Tensor> xup_gate_proj_bias;
std::shared_ptr<xftblock::Tensor> xdown_proj_bias;
if (up_gate_proj_bias.get_ptr()) {
xup_gate_proj_bias = std::make_shared<xftblock::Tensor>(
const_cast<float *>(up_gate_proj_bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT, up_gate_proj_bias.get_ptr()->shape());
}
if (down_proj_bias.get_ptr()) {
xdown_proj_bias = std::make_shared<xftblock::Tensor>(
const_cast<float *>(down_proj_bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT, down_proj_bias.get_ptr()->shape());
}
// std::cout << "[Op Debug] start init moe_ffn weight and bias" <<
// std::endl; MoeFFNWeight
xftblock::MoeFFNWeight moe_ffn_w_struct;
moe_ffn_w_struct.gate_weight = &xgate_w;
moe_ffn_w_struct.ffn_inter_weights = xup_gate_proj_w.get();
moe_ffn_w_struct.ffn_inter_bias = xup_gate_proj_bias.get();
moe_ffn_w_struct.ffn_outer_weights = xdown_proj_w.get();
moe_ffn_w_struct.ffn_outer_bias = xdown_proj_bias.get();
moe_ffn_w_struct.score_bias = xgate_correct_bias.get();
// MoeFFNParam
xftblock::MoeFFNParam moe_ffn_param;
moe_ffn_param.expert_num = expert_num;
moe_ffn_param.moe_top_k = moe_top_k;
moe_ffn_param.fast_swiglu = true;
// std::cout << "[Op Debug] pre in xvfblock moe_ffn" << std::endl;
using XPU_TGEMM = typename fused_moe_ffn_trait<XPU_TX, XPU_TW>::GEMM_TYPE;
ret = baidu::xpu::xftblock::moe_ffn_block_sorted_castte_per_token<
XPU_TX, XPU_TW, XPU_TX, XPU_TGEMM>(&xctx, &xin, &xout, moe_ffn_w_struct,
moe_ffn_param);
PD_CHECK(ret == 0,
"xftblock::moe_ffn_block_sorted_castte_per_token failed");
if (down_proj_bias.get_ptr()) {
xdown_proj_bias = std::make_shared<xftblock::Tensor>(
const_cast<float *>(down_proj_bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT,
down_proj_bias.get_ptr()->shape());
}
// std::cout << "[Op Debug] start init moe_ffn weight and bias" <<
// std::endl; MoeFFNWeight
xftblock::MoeFFNWeight moe_ffn_w_struct;
moe_ffn_w_struct.gate_weight = &xgate_w;
moe_ffn_w_struct.ffn_inter_weights = xup_gate_proj_w.get();
moe_ffn_w_struct.ffn_inter_bias = xup_gate_proj_bias.get();
moe_ffn_w_struct.ffn_outer_weights = xdown_proj_w.get();
moe_ffn_w_struct.ffn_outer_bias = xdown_proj_bias.get();
moe_ffn_w_struct.score_bias = xgate_correct_bias.get();
// MoeFFNParam
xftblock::MoeFFNParam moe_ffn_param;
moe_ffn_param.expert_num = expert_num;
moe_ffn_param.moe_top_k = moe_top_k;
moe_ffn_param.fast_swiglu = true;
return {fused_moe_out};
// std::cout << "[Op Debug] pre in xvfblock moe_ffn" << std::endl;
using XPU_TGEMM = typename fused_moe_ffn_trait<XPU_TX, XPU_TW>::GEMM_TYPE;
ret =
baidu::xpu::xftblock::moe_ffn_block_sorted_castte_per_token<XPU_TX,
XPU_TW,
XPU_TX,
XPU_TGEMM>(
&xctx, &xin, &xout, moe_ffn_w_struct, moe_ffn_param);
PD_CHECK(ret == 0,
"xftblock::moe_ffn_block_sorted_castte_per_token failed");
}
return {fused_moe_out};
}
std::vector<paddle::Tensor>
MoeLayer(const paddle::Tensor &x, const paddle::Tensor &gate_weight,
const paddle::optional<paddle::Tensor> &gate_correction_bias,
const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight,
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &up_gate_proj_weight_scale,
const paddle::optional<paddle::Tensor> &down_proj_weight_scale,
const paddle::optional<paddle::Tensor> &down_proj_in_scale,
const std::string &quant_method, const int moe_top_k,
const bool moe_group) {
const auto x_type = x.dtype();
const auto w_type = up_gate_proj_weight.dtype();
std::vector<paddle::Tensor> MoeLayer(
const paddle::Tensor &x,
const paddle::Tensor &gate_weight,
const paddle::optional<paddle::Tensor> &gate_correction_bias,
const paddle::Tensor &up_gate_proj_weight,
const paddle::Tensor &down_proj_weight,
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &up_gate_proj_weight_scale,
const paddle::optional<paddle::Tensor> &down_proj_weight_scale,
const paddle::optional<paddle::Tensor> &down_proj_in_scale,
const std::string &quant_method,
const int moe_top_k,
const bool moe_group) {
const auto x_type = x.dtype();
const auto w_type = up_gate_proj_weight.dtype();
#define APPLY_MOE_LAYER_KERNEL(TX, TW) \
return MoeLayerKernel<TX, TW>( \
x, gate_weight, gate_correction_bias, up_gate_proj_weight, down_proj_weight, \
up_gate_proj_bias, down_proj_bias, up_gate_proj_weight_scale, down_proj_weight_scale, \
down_proj_in_scale, quant_method, moe_top_k, moe_group);
#define APPLY_MOE_LAYER_KERNEL(TX, TW) \
return MoeLayerKernel<TX, TW>(x, \
gate_weight, \
gate_correction_bias, \
up_gate_proj_weight, \
down_proj_weight, \
up_gate_proj_bias, \
down_proj_bias, \
up_gate_proj_weight_scale, \
down_proj_weight_scale, \
down_proj_in_scale, \
quant_method, \
moe_top_k, \
moe_group);
// TODO(mayang02): how to use quant_method?
if (x_type == paddle::DataType::BFLOAT16 &&
w_type == paddle::DataType::BFLOAT16) {
APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, paddle::bfloat16);
} else if (x_type == paddle::DataType::BFLOAT16 &&
quant_method == "weight_only_int8") {
APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int8_t);
} else if (x_type == paddle::DataType::BFLOAT16 &&
quant_method == "weight_only_int4") {
APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int4_t);
} else {
PD_THROW("MoeLayer not support x_type=", static_cast<int>(x_type),
", w_type=", static_cast<int>(w_type),
", quant_method=", quant_method);
return {};
}
// TODO(mayang02): how to use quant_method?
if (x_type == paddle::DataType::BFLOAT16 &&
w_type == paddle::DataType::BFLOAT16) {
APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, paddle::bfloat16);
} else if (x_type == paddle::DataType::BFLOAT16 &&
quant_method == "weight_only_int8") {
APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int8_t);
} else if (x_type == paddle::DataType::BFLOAT16 &&
quant_method == "weight_only_int4") {
APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int4_t);
} else {
PD_THROW("MoeLayer not support x_type=",
static_cast<int>(x_type),
", w_type=",
static_cast<int>(w_type),
", quant_method=",
quant_method);
return {};
}
#undef APPLY_MOE_LAYER_KERNEL
}
@@ -244,14 +280,16 @@ std::vector<std::vector<int64_t>> MoeLayerInferShape(
const std::vector<int64_t> &down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>> &up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>> &down_proj_bias_shape,
const paddle::optional<std::vector<int64_t>> &up_gate_proj_weight_scale_shape,
const paddle::optional<std::vector<int64_t>>
&up_gate_proj_weight_scale_shape,
const paddle::optional<std::vector<int64_t>> &down_proj_weight_scale_shape,
const paddle::optional<std::vector<int64_t>> &down_proj_in_scale_shape) {
return {x_shape};
return {x_shape};
}
std::vector<paddle::DataType> MoeLayerInferDtype(
const paddle::DataType &x_dtype, const paddle::DataType &gate_weight_dtype,
const paddle::DataType &x_dtype,
const paddle::DataType &gate_weight_dtype,
const paddle::optional<paddle::DataType> &gate_correction_bias_dtype,
const paddle::DataType &up_gate_proj_weight_dtype,
const paddle::DataType &down_proj_weight_dtype,
@@ -260,12 +298,16 @@ std::vector<paddle::DataType> MoeLayerInferDtype(
const paddle::optional<paddle::DataType> &up_gate_proj_weight_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_weight_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype) {
return {x_dtype};
return {x_dtype};
}
PD_BUILD_OP(xpu_moe_layer) // fused_moe
.Inputs({"x", "gate_weight", paddle::Optional("gate_correction_bias"),
"up_gate_proj_weight", "down_proj_weight", paddle::Optional("up_gate_proj_bias"),
PD_BUILD_OP(xpu_moe_layer) // fused_moe
.Inputs({"x",
"gate_weight",
paddle::Optional("gate_correction_bias"),
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("down_proj_bias"),
paddle::Optional("up_gate_proj_weight_scale"),
paddle::Optional("down_proj_weight_scale"),
@@ -14,9 +14,9 @@
#include <sys/mman.h> // NOLINT
#include "cuda_runtime_api.h" // NOLINT
#include "ops/pybind/pybind.h"
#include "paddle/extension.h"
#include "xpu/runtime.h"
#include "ops/pybind/pybind.h"
void check_xpu_error(int error) {
if (error != XPU_SUCCESS) {
+7 -7
View File
@@ -33,13 +33,13 @@ void prof_start();
void prof_stop();
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
const paddle::Tensor &seq_lens_this_time_tensor,
const paddle::Tensor &seq_lens_decoder_tensor,
void InitKVSignalPerQuery(const paddle::Tensor& seq_lens_encoder_tensor,
const paddle::Tensor& seq_lens_this_time_tensor,
const paddle::Tensor& seq_lens_decoder_tensor,
const int rank,
const int num_layers);
void GetOutputKVSignal(const paddle::Tensor &x,
void GetOutputKVSignal(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag);
@@ -70,8 +70,8 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const std::string &pos_emb_type="NORMAL",
bool rope_3d=false);
const std::string& pos_emb_type = "NORMAL",
bool rope_3d = false);
std::vector<paddle::Tensor> MoERedundantTopKSelect(
const paddle::Tensor& gating_logits,
@@ -477,7 +477,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("bias"),
py::arg("weight_dtype"),
py::arg("arch"),
py::arg("group_size")=-1);
py::arg("group_size") = -1);
m.def("ep_moe_expert_combine",
&MoeEPCombine,
@@ -18,32 +18,31 @@
#include "xpu/plugin.h"
void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &block_tables,
const paddle::Tensor &is_block_step,
const int block_size) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
int r = baidu::xpu::api::plugin::recover_decode_task(
xpu_ctx->x_context(),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
bsz,
block_num_per_seq,
block_size);
PD_CHECK(r == 0, "baidu::xpu::api::plugin::recover_decode_task failed.");
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &block_tables,
const paddle::Tensor &is_block_step,
const int block_size) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
int r = baidu::xpu::api::plugin::recover_decode_task(
xpu_ctx->x_context(),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
bsz,
block_num_per_seq,
block_size);
PD_CHECK(r == 0, "baidu::xpu::api::plugin::recover_decode_task failed.");
}
PD_BUILD_OP(recover_decode_task)
@@ -74,8 +74,8 @@ RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
using type_meta_data =
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;
// std::printf("#### open_shm_and_get_complete_signal_meta_data layer idx:%d,
// to ptx:%p \n",
// std::printf("#### open_shm_and_get_complete_signal_meta_data layer
// idx:%d, to ptx:%p \n",
// -1, signal_ptr);
type_meta_data meta_data(-1, signal_ptr, signal_shm_fd);
@@ -102,8 +102,8 @@ void RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise(
int32_t layer_id = meta_data_ptr[0];
int32_t* ptr = reinterpret_cast<int32_t*>(meta_data_ptr[1]);
*ptr = layer_id;
// std::printf("#### save_cache_kv_complete_signal_layerwise layer idx:%d, to
// ptx:%p \n",
// std::printf("#### save_cache_kv_complete_signal_layerwise layer idx:%d,
// to ptx:%p \n",
// *ptr, meta_data_ptr[1]);
}
@@ -12,114 +12,118 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#define MAX_BSZ 256
// #define SAVE_WITH_OUTPUT_DEBUG
struct msgdata {
long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
};
// #define SAVE_WITH_OUTPUT_DEBUG
void SaveOutMmsg(const paddle::Tensor &x, const paddle::Tensor &not_need_stop,
int64_t rank_id, int msg_queue_id, bool save_each_rank) {
if (!save_each_rank && rank_id > 0) {
return;
}
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t *x_data = x_cpu.data<int64_t>();
static struct msgdata msg_sed;
if (const char *inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
msg_queue_id = inference_msg_queue_id_from_env;
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default."
<< std::endl;
#endif
}
int inference_msg_id_from_env = 1;
if (const char *inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
<< std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
<< std::endl;
#endif
}
static key_t key = ftok("/dev/shm", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output key: " << key << std::endl;
std::cout << "save_output msgid: " << msgid << std::endl;
#endif
msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
// printf("not_need_stop_data %d\n", (int)not_need_stop_data);
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 2; i < bsz + 2; i++) {
msg_sed.mtext[i] = (int)x_data[i - 2];
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output msg data: ";
for (int i = 0; i < bsz; i++) {
std::cout << " " << (int)x_data[i];
}
std::cout << std::endl;
#endif
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) {
printf("save_output full msg buffer\n");
}
void SaveOutMmsg(const paddle::Tensor &x,
const paddle::Tensor &not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
if (!save_each_rank && rank_id > 0) {
return;
}
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t *x_data = x_cpu.data<int64_t>();
static struct msgdata msg_sed;
if (const char *inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
msg_queue_id = inference_msg_queue_id_from_env;
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default."
<< std::endl;
#endif
}
int inference_msg_id_from_env = 1;
if (const char *inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
<< std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
<< std::endl;
#endif
}
static key_t key = ftok("/dev/shm", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output key: " << key << std::endl;
std::cout << "save_output msgid: " << msgid << std::endl;
#endif
msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
// printf("not_need_stop_data %d\n", (int)not_need_stop_data);
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 2; i < bsz + 2; i++) {
msg_sed.mtext[i] = (int)x_data[i - 2];
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output msg data: ";
for (int i = 0; i < bsz; i++) {
std::cout << " " << (int)x_data[i];
}
std::cout << std::endl;
#endif
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) {
printf("save_output full msg buffer\n");
}
return;
}
void SaveOutMmsgStatic(const paddle::Tensor &x,
const paddle::Tensor &not_need_stop, int64_t rank_id,
const paddle::Tensor &not_need_stop,
int64_t rank_id,
bool save_each_rank) {
SaveOutMmsg(x, not_need_stop, rank_id, 1, save_each_rank);
SaveOutMmsg(x, not_need_stop, rank_id, 1, save_each_rank);
}
void SaveOutMmsgDynamic(const paddle::Tensor &x,
const paddle::Tensor &not_need_stop, int64_t rank_id,
int msg_queue_id, bool save_each_rank) {
SaveOutMmsg(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
const paddle::Tensor &not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
SaveOutMmsg(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
}
PD_BUILD_OP(save_output)
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include <paddle/phi/backends/xpu/xpu_context.h>
void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &input_ids,
@@ -23,26 +23,35 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
int bs = seq_lens_this_time.shape()[0];
int length = pre_ids_all.shape()[1];
int length_input_ids = input_ids.shape()[1];
int r = baidu::xpu::api::plugin::set_value_by_flags_and_idx(
xpu_ctx->x_context(), stop_flags.data<bool>(),
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
input_ids.data<int64_t>(), seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(), step_idx.data<int64_t>(), bs, length,
length_input_ids);
PD_CHECK(r == 0, "xpu::plugin::set_value_by_flags_and_idx failed.");
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
int bs = seq_lens_this_time.shape()[0];
int length = pre_ids_all.shape()[1];
int length_input_ids = input_ids.shape()[1];
int r = baidu::xpu::api::plugin::set_value_by_flags_and_idx(
xpu_ctx->x_context(),
stop_flags.data<bool>(),
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
input_ids.data<int64_t>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
step_idx.data<int64_t>(),
bs,
length,
length_input_ids);
PD_CHECK(r == 0, "xpu::plugin::set_value_by_flags_and_idx failed.");
}
PD_BUILD_OP(set_value_by_flags_and_idx)
.Inputs({"pre_ids_all", "input_ids", "seq_lens_this_time",
"seq_lens_encoder", "seq_lens_decoder", "step_idx", "stop_flags"})
.Inputs({"pre_ids_all",
"input_ids",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"stop_flags"})
.Outputs({"pre_ids_all_out"})
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
.SetKernelFn(PD_KERNEL(SetValueByFlagsAndIdx));
@@ -30,8 +30,9 @@ std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor &input,
void *data_ptr_addr = nullptr;
if (use_ipc) {
#if XPURT_VERSION_MAJOR == 5
int ret = xpu_ipc_open_memhandle(
&data_ptr_addr, *(XPUIpcMemHandle *)&shm->memHandle, 0x01); // NOLINT
int ret = xpu_ipc_open_memhandle(&data_ptr_addr,
*(XPUIpcMemHandle *)&shm->memHandle,
0x01); // NOLINT
PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_open_memhandle failed");
#elif XPURT_VERSION_MAJOR == 4
PD_THROW("kl2 not support prefix cache");
+103 -74
View File
@@ -12,82 +12,100 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
#include <paddle/phi/backends/xpu/xpu_context.h>
void StepPaddle(
const paddle::Tensor &stop_flags, const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &ori_seq_lens_encoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &is_block_step, const paddle::Tensor &step_block_list,
const paddle::Tensor &step_lens, const paddle::Tensor &recover_block_list,
const paddle::Tensor &recover_lens, const paddle::Tensor &need_block_list,
const paddle::Tensor &need_block_len, const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list, const paddle::Tensor &free_list_len,
const paddle::Tensor &input_ids, const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx, const paddle::Tensor &next_tokens,
const paddle::Tensor &first_token_ids, const int block_size,
const int encoder_decoder_block_num) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
void StepPaddle(const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &ori_seq_lens_encoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &step_block_list,
const paddle::Tensor &step_lens,
const paddle::Tensor &recover_block_list,
const paddle::Tensor &recover_lens,
const paddle::Tensor &need_block_list,
const paddle::Tensor &need_block_len,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const paddle::Tensor &input_ids,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &next_tokens,
const paddle::Tensor &first_token_ids,
const int block_size,
const int encoder_decoder_block_num) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
const int bsz = seq_lens_this_time.shape()[0];
PADDLE_ENFORCE_LE(
bsz, 640,
phi::errors::InvalidArgument(
"Only support bsz <= 640, but received bsz is %d", bsz));
const int block_num_per_seq = block_tables.shape()[1];
const int length = input_ids.shape()[1];
const int pre_id_length = pre_ids.shape()[1];
const int max_decoder_block_num = pre_id_length / block_size;
int r = baidu::xpu::api::plugin::free_and_dispatch_block(
xpu_ctx->x_context(), const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(encoder_block_lens.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<int *>(step_block_list.data<int>()),
const_cast<int *>(step_lens.data<int>()),
const int bsz = seq_lens_this_time.shape()[0];
PADDLE_ENFORCE_LE(
bsz,
640,
phi::errors::InvalidArgument(
"Only support bsz <= 640, but received bsz is %d", bsz));
const int block_num_per_seq = block_tables.shape()[1];
const int length = input_ids.shape()[1];
const int pre_id_length = pre_ids.shape()[1];
const int max_decoder_block_num = pre_id_length / block_size;
int r = baidu::xpu::api::plugin::free_and_dispatch_block(
xpu_ctx->x_context(),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(encoder_block_lens.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<int *>(step_block_list.data<int>()),
const_cast<int *>(step_lens.data<int>()),
const_cast<int *>(recover_block_list.data<int>()),
const_cast<int *>(recover_lens.data<int>()),
const_cast<int *>(need_block_list.data<int>()),
const_cast<int *>(need_block_len.data<int>()),
const_cast<int *>(used_list_len.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num);
PD_CHECK(r == 0, "free_and_dispatch_block failed.");
auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false);
int recover_lens_cpu_data = recover_lens_cpu.data<int>()[0];
if (recover_lens_cpu_data > 0) {
r = baidu::xpu::api::plugin::recover_block(
xpu_ctx->x_context(),
const_cast<int *>(recover_block_list.data<int>()),
const_cast<int *>(recover_lens.data<int>()),
const_cast<int *>(need_block_list.data<int>()),
const_cast<int *>(need_block_len.data<int>()),
const_cast<int *>(used_list_len.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
ori_seq_lens_encoder.data<int>(),
const_cast<int *>(seq_lens_encoder.data<int>()),
seq_lens_decoder.data<int>(),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
const_cast<int64_t *>(first_token_ids.data<int64_t>()), bsz, block_size,
block_num_per_seq, max_decoder_block_num);
PD_CHECK(r == 0, "free_and_dispatch_block failed.");
auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false);
int recover_lens_cpu_data = recover_lens_cpu.data<int>()[0];
if (recover_lens_cpu_data > 0) {
r = baidu::xpu::api::plugin::recover_block(
xpu_ctx->x_context(),
const_cast<int *>(recover_block_list.data<int>()),
const_cast<int *>(recover_lens.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
ori_seq_lens_encoder.data<int>(),
const_cast<int *>(seq_lens_encoder.data<int>()),
seq_lens_decoder.data<int>(),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
pre_ids.data<int64_t>(), step_idx.data<int64_t>(),
encoder_block_lens.data<int>(), used_list_len.data<int>(),
next_tokens.data<int64_t>(), first_token_ids.data<int64_t>(), bsz,
block_num_per_seq, length, pre_id_length);
PD_CHECK(r == 0, "recover_block failed.");
}
const_cast<int64_t *>(input_ids.data<int64_t>()),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
encoder_block_lens.data<int>(),
used_list_len.data<int>(),
next_tokens.data<int64_t>(),
first_token_ids.data<int64_t>(),
bsz,
block_num_per_seq,
length,
pre_id_length);
PD_CHECK(r == 0, "recover_block failed.");
}
}
PD_BUILD_OP(step_paddle)
@@ -114,13 +132,24 @@ PD_BUILD_OP(step_paddle)
"next_tokens",
"first_token_ids"})
.Attrs({"block_size: int", "encoder_decoder_block_num: int"})
.Outputs({"stop_flags_out", "seq_lens_this_time_out",
"seq_lens_encoder_out", "seq_lens_decoder_out",
"block_tables_out", "encoder_block_lens_out", "is_block_step_out",
"step_block_list_out", "step_lens_out", "recover_block_list_out",
"recover_lens_out", "need_block_list_out", "need_block_len_out",
"used_list_len_out", "free_list_out", "free_list_len_out",
"input_ids_out", "first_token_ids_out"})
.Outputs({"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"block_tables_out",
"encoder_block_lens_out",
"is_block_step_out",
"step_block_list_out",
"step_lens_out",
"recover_block_list_out",
"recover_lens_out",
"need_block_list_out",
"need_block_len_out",
"used_list_len_out",
"free_list_out",
"free_list_len_out",
"input_ids_out",
"first_token_ids_out"})
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
@@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include <cstdlib>
#include <fcntl.h>
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <stdio.h>
@@ -24,6 +21,9 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <cstdlib>
#include "paddle/extension.h"
#include "xpu/plugin.h"
void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &stop_flags,
@@ -31,22 +31,25 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &end_ids,
const paddle::Tensor &next_tokens,
const bool beam_search) {
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
int r = baidu::xpu::api::plugin::set_stop_value_multi_ends<int64_t>(
xpu_ctx->x_context(), const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(), seq_lens.data<int>(), bs_now, end_length,
beam_search);
PD_CHECK(r == 0, "xpu::plugin::set_stop_value_multi_ends failed.");
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
int r = baidu::xpu::api::plugin::set_stop_value_multi_ends<int64_t>(
xpu_ctx->x_context(),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(),
seq_lens.data<int>(),
bs_now,
end_length,
beam_search);
PD_CHECK(r == 0, "xpu::plugin::set_stop_value_multi_ends failed.");
}
PD_BUILD_OP(set_stop_value_multi_ends)
@@ -17,53 +17,49 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
void TextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
void TextImageGatherScatter(paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int64_t token_num = input.dims()[0];
const int64_t hidden_size = input.dims()[1];
const int64_t text_token_num = text_input.dims()[0];
const int64_t image_token_num = image_input.dims()[0];
const int64_t token_num = input.dims()[0];
const int64_t hidden_size = input.dims()[1];
const int64_t text_token_num = text_input.dims()[0];
const int64_t image_token_num = image_input.dims()[0];
switch (input.type()) {
case paddle::DataType::BFLOAT16: {
using XPUType = typename XPUTypeTrait<bfloat16>::Type;
typedef paddle::bfloat16 data_t;
int r = baidu::xpu::api::plugin::text_image_gather_scatter<XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<XPUType*>(input.data<data_t>()),
reinterpret_cast<XPUType*>(text_input.data<data_t>()),
reinterpret_cast<XPUType*>(image_input.data<data_t>()),
reinterpret_cast<int*>(token_type_ids.data<int>()),
reinterpret_cast<int*>(text_index.data<int>()),
reinterpret_cast<int*>(image_index.data<int>()),
token_num,
text_token_num,
image_token_num,
hidden_size,
is_scatter
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter");
break;
}
default: {
PD_THROW(
"NOT supported data type. Only support BFLOAT16. ");
break;
}
switch (input.type()) {
case paddle::DataType::BFLOAT16: {
using XPUType = typename XPUTypeTrait<bfloat16>::Type;
typedef paddle::bfloat16 data_t;
int r = baidu::xpu::api::plugin::text_image_gather_scatter<XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<XPUType*>(input.data<data_t>()),
reinterpret_cast<XPUType*>(text_input.data<data_t>()),
reinterpret_cast<XPUType*>(image_input.data<data_t>()),
reinterpret_cast<int*>(token_type_ids.data<int>()),
reinterpret_cast<int*>(text_index.data<int>()),
reinterpret_cast<int*>(image_index.data<int>()),
token_num,
text_token_num,
image_token_num,
hidden_size,
is_scatter);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter");
break;
}
default: {
PD_THROW("NOT supported data type. Only support BFLOAT16. ");
break;
}
}
}
PD_BUILD_OP(text_image_gather_scatter)
.Inputs({"input",
"text_input",
@@ -16,33 +16,30 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
void TextImageIndexOut(
const paddle::Tensor& token_type_ids,
const paddle::Tensor& text_index,
const paddle::Tensor& image_index) {
if (token_type_ids.type() != paddle::DataType::INT32 || text_index.type()
!= paddle::DataType::INT32 || image_index.type() != paddle::DataType::INT32) {
PD_THROW("NOT supported data type. Only support BFLOAT16. ");
}
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int64_t token_num = token_type_ids.shape()[0];
int r = baidu::xpu::api::plugin::text_image_index_out(xpu_ctx->x_context(),
token_type_ids.data<int32_t>(),
const_cast<int32_t*>(text_index.data<int32_t>()),
const_cast<int32_t*>(image_index.data<int32_t>()),
token_num);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out");
void TextImageIndexOut(const paddle::Tensor& token_type_ids,
const paddle::Tensor& text_index,
const paddle::Tensor& image_index) {
if (token_type_ids.type() != paddle::DataType::INT32 ||
text_index.type() != paddle::DataType::INT32 ||
image_index.type() != paddle::DataType::INT32) {
PD_THROW("NOT supported data type. Only support BFLOAT16. ");
}
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int64_t token_num = token_type_ids.shape()[0];
int r = baidu::xpu::api::plugin::text_image_index_out(
xpu_ctx->x_context(),
token_type_ids.data<int32_t>(),
const_cast<int32_t*>(text_index.data<int32_t>()),
const_cast<int32_t*>(image_index.data<int32_t>()),
token_num);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out");
}
PD_BUILD_OP(text_image_index_out)
.Inputs({"token_type_ids",
"text_index",
"image_index"})
.Outputs({"text_index_out",
"image_index_out"})
.Inputs({"token_type_ids", "text_index", "image_index"})
.Outputs({"text_index_out", "image_index_out"})
.SetInplaceMap({{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageIndexOut));
+31 -32
View File
@@ -26,40 +26,39 @@ void UpdateInputes(const paddle::Tensor &stop_flags,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
const int max_bsz = stop_flags.shape()[0];
PADDLE_ENFORCE_LE(
max_bsz,
1024,
phi::errors::InvalidArgument(
"Only support max_bs <= 1024, but received max_bs is %d", max_bsz));
const int now_bsz = seq_lens_this_time.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
auto not_need_stop_xpu = not_need_stop.copy_to(stop_flags.place(), false);
const int max_bsz = stop_flags.shape()[0];
PADDLE_ENFORCE_LE(
max_bsz,
1024,
phi::errors::InvalidArgument(
"Only support max_bs <= 1024, but received max_bs is %d", max_bsz));
const int now_bsz = seq_lens_this_time.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
auto not_need_stop_xpu = not_need_stop.copy_to(stop_flags.place(), false);
int r = baidu::xpu::api::plugin::update_inputs(
xpu_ctx->x_context(),
const_cast<bool *>(not_need_stop_xpu.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
stop_nums.data<int64_t>(),
stop_flags.data<bool>(),
is_block_step.data<bool>(),
next_tokens.data<int64_t>(),
now_bsz,
max_bsz,
input_ids_stride);
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs failed.");
auto not_need_stop_cpu =
not_need_stop_xpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
int r = baidu::xpu::api::plugin::update_inputs(
xpu_ctx->x_context(),
const_cast<bool *>(not_need_stop_xpu.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
stop_nums.data<int64_t>(),
stop_flags.data<bool>(),
is_block_step.data<bool>(),
next_tokens.data<int64_t>(),
now_bsz,
max_bsz,
input_ids_stride);
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs failed.");
auto not_need_stop_cpu =
not_need_stop_xpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_OP(update_inputs)
+47 -48
View File
@@ -18,55 +18,54 @@
#include "xpu/plugin.h"
void UpdateInputesV1(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop, // only on cpu
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &topk_ids,
const paddle::Tensor &input_ids,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step,
const int block_size) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
const paddle::Tensor &not_need_stop, // only on cpu
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &topk_ids,
const paddle::Tensor &input_ids,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step,
const int block_size) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0];
// std::cout << "now_bsz: " << now_bsz << std::endl;
const int input_ids_stride = input_ids.shape()[1];
const int block_num_per_seq = block_tables.shape()[1];
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
int r = baidu::xpu::api::plugin::update_inputs_v1(
xpu_ctx->x_context(),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(prompt_lens.data<int64_t>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
const_cast<int *>(block_tables.data<int>()),
stop_nums.data<int64_t>(),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<bool *>(is_block_step.data<bool>()),
next_tokens.data<int64_t>(),
now_bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size);
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs_kernel_v1 failed.");
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0];
// std::cout << "now_bsz: " << now_bsz << std::endl;
const int input_ids_stride = input_ids.shape()[1];
const int block_num_per_seq = block_tables.shape()[1];
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
int r = baidu::xpu::api::plugin::update_inputs_v1(
xpu_ctx->x_context(),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(prompt_lens.data<int64_t>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
const_cast<int *>(block_tables.data<int>()),
stop_nums.data<int64_t>(),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<bool *>(is_block_step.data<bool>()),
next_tokens.data<int64_t>(),
now_bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size);
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs_kernel_v1 failed.");
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_OP(update_inputs_v1)
View File
+35 -28
View File
@@ -15,8 +15,6 @@
#pragma once
#include <fcntl.h>
#include <fstream>
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@@ -24,47 +22,56 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/dense_tensor.h"
#include "xpu/plugin.h"
#include <paddle/phi/backends/xpu/xpu_context.h>
template <paddle::DataType D> class PDTraits;
template <paddle::DataType D>
class PDTraits;
template <> class PDTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
template <>
class PDTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
};
template <> class PDTraits<paddle::DataType::FLOAT16> {
public:
typedef float16 DataType;
typedef paddle::float16 data_t;
template <>
class PDTraits<paddle::DataType::FLOAT16> {
public:
typedef float16 DataType;
typedef paddle::float16 data_t;
};
template <> class PDTraits<paddle::DataType::BFLOAT16> {
public:
typedef bfloat16 DataType;
typedef paddle::bfloat16 data_t;
template <>
class PDTraits<paddle::DataType::BFLOAT16> {
public:
typedef bfloat16 DataType;
typedef paddle::bfloat16 data_t;
};
template <> class PDTraits<paddle::DataType::INT8> {
public:
typedef int8_t DataType;
typedef int8_t data_t;
template <>
class PDTraits<paddle::DataType::INT8> {
public:
typedef int8_t DataType;
typedef int8_t data_t;
};
template <> class PDTraits<paddle::DataType::UINT8> {
public:
typedef uint8_t DataType;
typedef uint8_t data_t;
template <>
class PDTraits<paddle::DataType::UINT8> {
public:
typedef uint8_t DataType;
typedef uint8_t data_t;
};
template <> class PDTraits<paddle::DataType::INT64> {
public:
typedef int64_t DataType;
typedef int64_t data_t;
template <>
class PDTraits<paddle::DataType::INT64> {
public:
typedef int64_t DataType;
typedef int64_t data_t;
};
@@ -11,110 +11,124 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include <infer_ops.h>
#include <infer_ops_eb.h>
#include <paddle/extension.h>
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "xpu/plugin.h"
template <typename T>
std::vector<paddle::Tensor>
WeightQuantizeKernel(const paddle::Tensor &x, const std::string &algo,
const int32_t arch, const int32_t group_size) {
using XPUType = typename XPUTypeTrait<T>::Type;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
int64_t k = x.shape()[0];
int64_t n = x.shape()[1];
std::vector<paddle::Tensor> WeightQuantizeKernel(const paddle::Tensor &x,
const std::string &algo,
const int32_t arch,
const int32_t group_size) {
using XPUType = typename XPUTypeTrait<T>::Type;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
int64_t k = x.shape()[0];
int64_t n = x.shape()[1];
paddle::Tensor scale =
paddle::full({n}, 0, paddle::DataType::FLOAT32, x.place());
if (algo == "weight_only_int8") {
paddle::Tensor out =
paddle::full({k, n}, 0, paddle::DataType::INT8, x.place());
int ret = baidu::xpu::api::plugin::quant2d_per_channel<XPUType, float,
int8_t>(
paddle::Tensor scale =
paddle::full({n}, 0, paddle::DataType::FLOAT32, x.place());
if (algo == "weight_only_int8") {
paddle::Tensor out =
paddle::full({k, n}, 0, paddle::DataType::INT8, x.place());
int ret =
baidu::xpu::api::plugin::quant2d_per_channel<XPUType, float, int8_t>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType *>(x.template data<T>()), nullptr,
out.data<int8_t>(), scale.data<float>(), k, n);
PD_CHECK(ret == 0);
return {out, scale};
} else if (algo == "weight_only_int4") {
// TODO(mayang02): fix quant2d_per_channel int4 bugs, use transpose +
// quant2d_per_token + transpose at now
PD_CHECK(k % 2 == 0);
paddle::Tensor out = paddle::full({(k + 1) / 2, n}, 0,
paddle::DataType::INT8, x.place());
xpu::ctx_guard RAII_GUARD(xpu_ctx->x_context());
XPUType *x_trans = RAII_GUARD.alloc<XPUType>(k * n);
int8_t *out_trans = RAII_GUARD.alloc<int8_t>(k * n / 2);
PD_CHECK(x_trans != nullptr);
PD_CHECK(out_trans != nullptr);
int ret = baidu::xpu::api::transpose<XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType *>(x.data<T>()), x_trans, {k, n},
{1, 0});
PD_CHECK(ret == 0);
ret = infer_ops::quant2d_per_token<XPUType, float, int4_t>(
xpu_ctx->x_context(), x_trans, nullptr,
reinterpret_cast<int4_t *>(out_trans), scale.data<float>(), n, k);
PD_CHECK(ret == 0);
ret = baidu::xpu::api::transpose<int8_t>(xpu_ctx->x_context(),
out_trans, out.data<int8_t>(),
{n, k / 2}, {1, 0});
PD_CHECK(ret == 0);
return {out, scale};
} else {
PD_THROW("Weight quantize only supports weight_only_int8 on XPU now.");
return {};
}
reinterpret_cast<const XPUType *>(x.template data<T>()),
nullptr,
out.data<int8_t>(),
scale.data<float>(),
k,
n);
PD_CHECK(ret == 0);
return {out, scale};
} else if (algo == "weight_only_int4") {
// TODO(mayang02): fix quant2d_per_channel int4 bugs, use transpose +
// quant2d_per_token + transpose at now
PD_CHECK(k % 2 == 0);
paddle::Tensor out =
paddle::full({(k + 1) / 2, n}, 0, paddle::DataType::INT8, x.place());
xpu::ctx_guard RAII_GUARD(xpu_ctx->x_context());
XPUType *x_trans = RAII_GUARD.alloc<XPUType>(k * n);
int8_t *out_trans = RAII_GUARD.alloc<int8_t>(k * n / 2);
PD_CHECK(x_trans != nullptr);
PD_CHECK(out_trans != nullptr);
int ret = baidu::xpu::api::transpose<XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType *>(x.data<T>()),
x_trans,
{k, n},
{1, 0});
PD_CHECK(ret == 0);
ret = infer_ops::quant2d_per_token<XPUType, float, int4_t>(
xpu_ctx->x_context(),
x_trans,
nullptr,
reinterpret_cast<int4_t *>(out_trans),
scale.data<float>(),
n,
k);
PD_CHECK(ret == 0);
ret = baidu::xpu::api::transpose<int8_t>(xpu_ctx->x_context(),
out_trans,
out.data<int8_t>(),
{n, k / 2},
{1, 0});
PD_CHECK(ret == 0);
return {out, scale};
} else {
PD_THROW("Weight quantize only supports weight_only_int8 on XPU now.");
return {};
}
}
std::vector<paddle::Tensor> WeightQuantize(const paddle::Tensor &x,
const std::string &algo,
const int32_t arch,
const int32_t group_size) {
const auto x_type = x.dtype();
#define APPLY_WEIGHT_QUANTIZE_KERNEL(TX) \
return WeightQuantizeKernel<TX>(x, algo, arch, group_size);
const auto x_type = x.dtype();
#define APPLY_WEIGHT_QUANTIZE_KERNEL(TX) \
return WeightQuantizeKernel<TX>(x, algo, arch, group_size);
if (x_type == paddle::DataType::BFLOAT16) {
APPLY_WEIGHT_QUANTIZE_KERNEL(paddle::bfloat16);
} else if (x_type == paddle::DataType::FLOAT32) {
APPLY_WEIGHT_QUANTIZE_KERNEL(float);
} else {
PD_THROW("WeightQuantize not support x_type==%d",
static_cast<int>(x_type));
return {};
}
if (x_type == paddle::DataType::BFLOAT16) {
APPLY_WEIGHT_QUANTIZE_KERNEL(paddle::bfloat16);
} else if (x_type == paddle::DataType::FLOAT32) {
APPLY_WEIGHT_QUANTIZE_KERNEL(float);
} else {
PD_THROW("WeightQuantize not support x_type==%d", static_cast<int>(x_type));
return {};
}
}
std::vector<std::vector<int64_t>>
WeightQuantizeInferShape(const std::vector<int64_t> &x_shape,
const std::string &algo, const int32_t arch,
const int32_t group_size) {
if (algo == "weight_only_int8") {
return {x_shape, {x_shape[1]}};
} else if (algo == "weight_only_int4") {
return {{x_shape[0] / 2, x_shape[1]}, {x_shape[1]}};
} else {
PD_THROW("weight_quantize not support algo=%s", algo);
}
std::vector<std::vector<int64_t>> WeightQuantizeInferShape(
const std::vector<int64_t> &x_shape,
const std::string &algo,
const int32_t arch,
const int32_t group_size) {
if (algo == "weight_only_int8") {
return {x_shape, {x_shape[1]}};
} else if (algo == "weight_only_int4") {
return {{x_shape[0] / 2, x_shape[1]}, {x_shape[1]}};
} else {
PD_THROW("weight_quantize not support algo=%s", algo);
}
}
std::vector<paddle::DataType>
WeightQuantizeInferDtype(const paddle::DataType &x_dtype,
const std::string &algo, const int32_t arch,
const int32_t group_size) {
if (algo == "weight_only_int8") {
return {paddle::DataType::INT8, paddle::DataType::FLOAT32};
} else if (algo == "weight_only_int4") {
return {paddle::DataType::INT8, paddle::DataType::FLOAT32};
} else {
PD_THROW("weight_quantize not support algo=%s", algo);
}
std::vector<paddle::DataType> WeightQuantizeInferDtype(
const paddle::DataType &x_dtype,
const std::string &algo,
const int32_t arch,
const int32_t group_size) {
if (algo == "weight_only_int8") {
return {paddle::DataType::INT8, paddle::DataType::FLOAT32};
} else if (algo == "weight_only_int4") {
return {paddle::DataType::INT8, paddle::DataType::FLOAT32};
} else {
PD_THROW("weight_quantize not support algo=%s", algo);
}
}
PD_BUILD_OP(weight_quantize_xpu)
+31 -31
View File
@@ -24,60 +24,60 @@
#include <sys/un.h>
#include <sys/wait.h>
#include <unistd.h>
#include <vector>
#include <xpu/runtime.h>
#include <xpu/version.h>
#include <vector>
struct shmStruct {
size_t nprocesses;
size_t nprocesses;
#if XPURT_VERSION_MAJOR == 5
XPUIpcMemHandle memHandle;
XPUIpcMemHandle memHandle;
#endif
uint64_t data_ptr_addr;
uint64_t data_ptr_addr;
};
struct sharedMemoryInfo {
void *addr;
size_t size;
int shmFd;
void *addr;
size_t size;
int shmFd;
};
static int sharedMemoryCreate(const char *name, size_t sz,
static int sharedMemoryCreate(const char *name,
size_t sz,
sharedMemoryInfo *info) {
info->size = sz;
info->size = sz;
info->shmFd = shm_open(name, O_RDWR | O_CREAT, 0777);
PD_CHECK(info->shmFd >= 0, "shm_open failed");
info->shmFd = shm_open(name, O_RDWR | O_CREAT, 0777);
PD_CHECK(info->shmFd >= 0, "shm_open failed");
int status = ftruncate(info->shmFd, sz);
PD_CHECK(status == 0, "ftruncate failed");
int status = ftruncate(info->shmFd, sz);
PD_CHECK(status == 0, "ftruncate failed");
info->addr =
mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0);
PD_CHECK(info->addr != NULL, "mmap failed");
info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0);
PD_CHECK(info->addr != NULL, "mmap failed");
return 0;
return 0;
}
static int sharedMemoryOpen(const char *name, size_t sz,
static int sharedMemoryOpen(const char *name,
size_t sz,
sharedMemoryInfo *info) {
info->size = sz;
info->size = sz;
info->shmFd = shm_open(name, O_RDWR, 0777);
PD_CHECK(info->shmFd >= 0, "shm_open failed");
info->shmFd = shm_open(name, O_RDWR, 0777);
PD_CHECK(info->shmFd >= 0, "shm_open failed");
info->addr =
mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0);
PD_CHECK(info->addr != nullptr, "mmap failed");
info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0);
PD_CHECK(info->addr != nullptr, "mmap failed");
return 0;
return 0;
}
static void sharedMemoryClose(sharedMemoryInfo *info) {
if (info->addr) {
munmap(info->addr, info->size);
}
if (info->shmFd) {
close(info->shmFd);
}
if (info->addr) {
munmap(info->addr, info->size);
}
if (info->shmFd) {
close(info->shmFd);
}
}
+159 -103
View File
@@ -24,121 +24,176 @@ namespace api {
namespace plugin {
template <typename T>
DLL_EXPORT int set_stop_value_multi_ends(Context *ctx, bool *stop_flags,
T *topk_ids, T *next_tokens,
const T *end_ids, const int *seq_lens,
const int bs, const int end_length,
DLL_EXPORT int set_stop_value_multi_ends(Context* ctx,
bool* stop_flags,
T* topk_ids,
T* next_tokens,
const T* end_ids,
const int* seq_lens,
const int bs,
const int end_length,
const bool beam_search);
DLL_EXPORT int set_value_by_flags_and_idx(Context *ctx, const bool *stop_flags,
int64_t *pre_ids_all,
const int64_t *input_ids,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx, int bs,
int length, int length_input_ids);
DLL_EXPORT int set_value_by_flags_and_idx(Context* ctx,
const bool* stop_flags,
int64_t* pre_ids_all,
const int64_t* input_ids,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int64_t* step_idx,
int bs,
int length,
int length_input_ids);
template <typename T>
DLL_EXPORT int token_penalty_multi_scores(
Context *ctx, const int64_t *pre_ids, T *logits, const T *penalty_scores,
const T *frequency_scores, const T *presence_scores,
const float *temperatures, const int64_t *cur_len, const int64_t *min_len,
const int64_t *eos_token_id, const int64_t *bad_words, const int64_t bs,
const int64_t length, const int64_t length_id, const int64_t end_length,
const int64_t length_bad_words);
DLL_EXPORT int token_penalty_multi_scores(Context* ctx,
const int64_t* pre_ids,
T* logits,
const T* penalty_scores,
const T* frequency_scores,
const T* presence_scores,
const float* temperatures,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int64_t* bad_words,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words);
DLL_EXPORT int get_padding_offset(Context *ctx, int *padding_offset,
int *cum_offsets_out, int *cu_seqlens_q,
int *cu_seqlens_k, int64_t *x_remove_padding,
const int64_t *input_ids,
const int *cum_offsets, const int *seq_lens,
const int max_seq_len, const int bs);
DLL_EXPORT int get_padding_offset(Context* ctx,
int* padding_offset,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
int64_t* x_remove_padding,
const int64_t* input_ids,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len,
const int bs);
DLL_EXPORT int update_inputs(Context *ctx, bool *not_need_stop,
int *seq_lens_this_time, int *seq_lens_encoder,
int *seq_lens_decoder, int64_t *input_ids,
const int64_t *stop_nums, const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens, const int bsz,
const int max_bsz, const int input_ids_stride);
DLL_EXPORT int update_inputs(Context* ctx,
bool* not_need_stop,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* input_ids,
const int64_t* stop_nums,
const bool* stop_flags,
const bool* is_block_step,
const int64_t* next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride);
DLL_EXPORT int free_and_dispatch_block(
Context *ctx, bool *stop_flags, int *seq_lens_this_time,
int *seq_lens_decoder, int *block_tables, int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len, int *recover_block_list, int *recover_len,
int *need_block_list, int *need_block_len, int *used_list_len,
int *free_list, int *free_list_len, int64_t *first_token_ids, const int bsz,
const int block_size, const int block_num_per_seq,
const int max_decoder_block_num);
DLL_EXPORT int free_and_dispatch_block(Context* ctx,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_decoder,
int* block_tables,
int* encoder_block_lens,
bool* is_block_step,
int* step_block_list, // [bsz]
int* step_len,
int* recover_block_list,
int* recover_len,
int* need_block_list,
int* need_block_len,
int* used_list_len,
int* free_list,
int* free_list_len,
int64_t* first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num);
DLL_EXPORT int
recover_block(Context *ctx,
int *recover_block_list, // [bsz]
int *recover_len, bool *stop_flags, int *seq_lens_this_time,
const int *ori_seq_lens_encoder, int *seq_lens_encoder,
const int *seq_lens_decoder, int *block_tables, int *free_list,
int *free_list_len, int64_t *input_ids, const int64_t *pre_ids,
const int64_t *step_idx, const int *encoder_block_lens,
const int *used_list_len, const int64_t *next_tokens,
const int64_t *first_token_ids, const int bsz,
const int block_num_per_seq, const int length,
const int pre_id_length);
DLL_EXPORT int recover_block(Context* ctx,
int* recover_block_list, // [bsz]
int* recover_len,
bool* stop_flags,
int* seq_lens_this_time,
const int* ori_seq_lens_encoder,
int* seq_lens_encoder,
const int* seq_lens_decoder,
int* block_tables,
int* free_list,
int* free_list_len,
int64_t* input_ids,
const int64_t* pre_ids,
const int64_t* step_idx,
const int* encoder_block_lens,
const int* used_list_len,
const int64_t* next_tokens,
const int64_t* first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length);
DLL_EXPORT int
recover_decode_task(Context *ctx, bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
DLL_EXPORT int recover_decode_task(Context* ctx,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int* step_seq_lens_decoder,
int* block_tables,
bool* is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size);
DLL_EXPORT int
update_inputs_v1(Context *ctx, bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size);
DLL_EXPORT int update_inputs_v1(Context* ctx,
bool* not_need_stop,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int* step_seq_lens_decoder,
int64_t* prompt_lens,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size);
template <typename TX, typename TY>
DLL_EXPORT int
eb_adjust_batch(Context *ctx, const TX *x, TY *y,
VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
VectorParam<int32_t> &encoder_batch_map, // NOLINT
VectorParam<int32_t> &decoder_batch_map, // NOLINT
int64_t hidden_dim);
DLL_EXPORT int eb_adjust_batch(
Context* ctx,
const TX* x,
TY* y,
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
VectorParam<int32_t>& encoder_batch_map, // NOLINT
VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
template <typename TX, typename TY>
DLL_EXPORT int
eb_gather_next_token(Context *ctx, const TX *x, TY *y,
VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
VectorParam<int32_t> &encoder_batch_map, // NOLINT
VectorParam<int32_t> &decoder_batch_map, // NOLINT
int64_t hidden_dim);
DLL_EXPORT int eb_gather_next_token(
Context* ctx,
const TX* x,
TY* y,
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
VectorParam<int32_t>& encoder_batch_map, // NOLINT
VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
template <typename TX, typename TSCALE = float, typename TY = int8_t>
DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x,
const TSCALE *scale_in, TY *y,
TSCALE *scale_out, int64_t m, int64_t n);
DLL_EXPORT int quant2d_per_channel(api::Context* ctx,
const TX* x,
const TSCALE* scale_in,
TY* y,
TSCALE* scale_out,
int64_t m,
int64_t n);
DLL_EXPORT int text_image_index_out(Context* ctx,
const int* token_type_ids, // x
@@ -160,7 +215,8 @@ DLL_EXPORT int text_image_gather_scatter(api::Context* ctx,
int64_t hidden_size,
bool is_scatter);
/*--------------------------------------- MTP being --------------------------------------------*/
/*--------------------------------------- MTP being
* --------------------------------------------*/
template <typename T>
DLL_EXPORT int speculate_token_penalty_multi_scores(
@@ -200,7 +256,6 @@ DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx,
const int block_num_per_seq,
const int max_draft_tokens);
template <bool ENABLE_TOPP, bool USE_TOPK>
DLL_EXPORT int speculate_verify(Context* ctx,
int64_t* accept_tokens,
@@ -457,9 +512,10 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx,
T* output,
int dim_embed,
int elem_cnt);
/*--------------------------------------- MTP end --------------------------------------------*/
/*--------------------------------------- MTP end
* --------------------------------------------*/
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -36,8 +36,8 @@ __global__ void get_padding_offset(int *batch_id_per_token,
}
mfence_lm();
LM2GM(batch_id_per_token_lm,
batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j,
cur_len * sizeof(int));
batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j,
cur_len * sizeof(int));
}
if (cid == 0) {
int cum_seq_len = (i + 1) * max_seq_len - cum_offsets_lm[1];
@@ -15,72 +15,72 @@ __global__ void RebuildAppendPaddingKernel(const T *full_hidden_states,
int dim_embed,
int elem_nums,
T *out) {
int ncores = core_num();
int cid = core_id();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
int64_t mstart = -1;
int64_t mend = -1;
int64_t nstart = -1;
int64_t nend = -1;
partition2d<int64_t>(tid,
nthreads,
elem_nums / dim_embed,
dim_embed,
&mstart,
&mend,
&nstart,
&nend);
int ncores = core_num();
int cid = core_id();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
int64_t mstart = -1;
int64_t mend = -1;
int64_t nstart = -1;
int64_t nend = -1;
partition2d<int64_t>(tid,
nthreads,
elem_nums / dim_embed,
dim_embed,
&mstart,
&mend,
&nstart,
&nend);
const int64_t BUFFER_LEN = rounddown(6144 / sizeof(T), 64);
__simd__ T lm_full_hidden_states[BUFFER_LEN];
int output_padding_offset_val, cum_offset_val, seq_len_encoder_val,
seq_len_decoder_val;
const int64_t BUFFER_LEN = rounddown(6144 / sizeof(T), 64);
__simd__ T lm_full_hidden_states[BUFFER_LEN];
int output_padding_offset_val, cum_offset_val, seq_len_encoder_val,
seq_len_decoder_val;
for (int64_t _m = mstart; _m < mend; _m++) {
int out_token_id = _m;
GM2LM(output_padding_offset + out_token_id,
&output_padding_offset_val,
sizeof(int));
int ori_token_id = out_token_id + output_padding_offset_val;
int bi = ori_token_id / max_seq_len;
GM2LM_ASYNC(seq_len_encoder + bi, &seq_len_encoder_val, sizeof(int));
GM2LM(seq_len_decoder + bi, &seq_len_decoder_val, sizeof(int));
int seq_id = 0;
if (seq_len_encoder_val == 0 and seq_len_decoder_val == 0) {
continue;
} else if (seq_len_encoder_val != 0) {
seq_id = seq_len_encoder_val - 1;
}
GM2LM(cum_offset + bi, &cum_offset_val, sizeof(int));
int input_token_id = ori_token_id - cum_offset_val + seq_id;
for (int64_t _n = nstart; _n < nend; _n += BUFFER_LEN) {
int64_t read_size = min(BUFFER_LEN, nend - _n);
// out[i] = full_hidden_states[(i / dim_embed +
// output_padding_offset[i / dim_embed] - cum_offset[(i / dim_embed
// + output_padding_offset[i / dim_embed]) / max_seq_len] + seq_id)
// * dim_embed + i % dim_embed]
GM2LM(full_hidden_states + input_token_id * dim_embed + _n,
lm_full_hidden_states,
read_size * sizeof(T));
LM2GM(lm_full_hidden_states,
out + _m * dim_embed + _n,
read_size * sizeof(T));
}
for (int64_t _m = mstart; _m < mend; _m++) {
int out_token_id = _m;
GM2LM(output_padding_offset + out_token_id,
&output_padding_offset_val,
sizeof(int));
int ori_token_id = out_token_id + output_padding_offset_val;
int bi = ori_token_id / max_seq_len;
GM2LM_ASYNC(seq_len_encoder + bi, &seq_len_encoder_val, sizeof(int));
GM2LM(seq_len_decoder + bi, &seq_len_decoder_val, sizeof(int));
int seq_id = 0;
if (seq_len_encoder_val == 0 and seq_len_decoder_val == 0) {
continue;
} else if (seq_len_encoder_val != 0) {
seq_id = seq_len_encoder_val - 1;
}
GM2LM(cum_offset + bi, &cum_offset_val, sizeof(int));
int input_token_id = ori_token_id - cum_offset_val + seq_id;
for (int64_t _n = nstart; _n < nend; _n += BUFFER_LEN) {
int64_t read_size = min(BUFFER_LEN, nend - _n);
// out[i] = full_hidden_states[(i / dim_embed +
// output_padding_offset[i / dim_embed] - cum_offset[(i / dim_embed
// + output_padding_offset[i / dim_embed]) / max_seq_len] + seq_id)
// * dim_embed + i % dim_embed]
GM2LM(full_hidden_states + input_token_id * dim_embed + _n,
lm_full_hidden_states,
read_size * sizeof(T));
LM2GM(lm_full_hidden_states,
out + _m * dim_embed + _n,
read_size * sizeof(T));
}
}
}
#define _XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(T) \
template __global__ void RebuildAppendPaddingKernel<T>( \
const T *full_hidden_states, \
const int *cum_offset, \
const int *seq_len_encoder, \
const int *seq_len_decoder, \
const int *output_padding_offset, \
int max_seq_len, \
int dim_embed, \
int elem_nums, \
T *out);
#define _XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(T) \
template __global__ void RebuildAppendPaddingKernel<T>( \
const T *full_hidden_states, \
const int *cum_offset, \
const int *seq_len_encoder, \
const int *seq_len_decoder, \
const int *output_padding_offset, \
int max_seq_len, \
int dim_embed, \
int elem_nums, \
T *out);
_XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(bfloat16);
_XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(float16);
@@ -152,8 +152,8 @@ __device__ void speculate_update_repeat_times_optimized(
repeat_times_read_size_per_core * sizeof(int));
}
sync_all();
// each core loads pre_ids step by step and record the index of pre_ids
// which is less than zero, and store the index to boundary
// each core loads pre_ids step by step and record the index of
// pre_ids which is less than zero, and store the index to boundary
if (repeat_times_start == 0) {
bool do_prone = false;
int64_t j = cid * pre_ids_lm_len;
@@ -190,8 +190,8 @@ __device__ void speculate_update_repeat_times_optimized(
buffer_ptr_pre_ids.toggle();
}
}
// each core loads all the needed pre_ids into lm without mfence in between
// according to the index recorded by previous iteration
// each core loads all the needed pre_ids into lm without mfence in
// between according to the index recorded by previous iteration
else {
int cnt = -1;
int64_t pre_ids_read_size = 0;
@@ -240,18 +240,20 @@ __global__ void speculate_update_value_by_repeat_times_simd(
alpha,
logits_,
logits_,
(time_mask &
~logit_mask)); // when time != 0 && logit < 0, do alpha * logit
(time_mask & ~logit_mask)); // when time != 0 && logit < 0, do
// alpha * logit
logits_ = svmul_float32x16_mh(
1.0f / alpha,
logits_,
logits_,
(time_mask & logit_mask)); // when time != 0 && >=0, do logit / alpha
logits_ = vvsub_float32x16_mh(
logits_, time_, logits_, time_mask); // when time != 0, do logit =
// logit - time * beta - gamma;
logits_ =
svmul_float32x16(1.0f / temperature, logits_); // logit / temperature
logits_ = vvsub_float32x16_mh(logits_,
time_,
logits_,
time_mask); // when time != 0, do logit =
// logit - time * beta - gamma;
logits_ = svmul_float32x16(1.0f / temperature,
logits_); // logit / temperature
vstore_lm_float32x16(logits_lm + j, logits_);
}
mfence_lm();
@@ -6,15 +6,15 @@ namespace xpu3 {
namespace plugin {
__global__ void recover_decode_task(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size) {
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
@@ -23,15 +23,17 @@ __global__ void recover_decode_task(bool *stop_flags,
int nthreads = nclusters * ncores;
// if (clusterid != 0) return;
for (; thread_idx < bsz; thread_idx += nthreads) {
if(is_block_step[thread_idx] == true) {
// int *block_table_now = block_tables + thread_idx * block_num_per_seq;
if (block_tables[thread_idx * block_num_per_seq + step_seq_lens_decoder[thread_idx] / block_size] != -1) {
// can be recovered for decoding
is_block_step[thread_idx] = false;
seq_lens_this_time[thread_idx]= 1;
stop_flags[thread_idx] = false;
seq_lens_encoder[thread_idx] = 0;
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
if (is_block_step[thread_idx] == true) {
// int *block_table_now = block_tables + thread_idx *
// block_num_per_seq;
if (block_tables[thread_idx * block_num_per_seq +
step_seq_lens_decoder[thread_idx] / block_size] != -1) {
// can be recovered for decoding
is_block_step[thread_idx] = false;
seq_lens_this_time[thread_idx] = 1;
stop_flags[thread_idx] = false;
seq_lens_encoder[thread_idx] = 0;
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
}
}
}
@@ -30,8 +30,8 @@ __global__ void remove_padding(int64_t *x_remove_padding,
input_lm,
sizeof(int64_t) * cur_len);
LM2GM(input_lm,
x_remove_padding + i * sequence_length - cum_offset_lm + j,
sizeof(int64_t) * cur_len);
x_remove_padding + i * sequence_length - cum_offset_lm + j,
sizeof(int64_t) * cur_len);
}
}
}
@@ -54,14 +54,13 @@ __global__ void set_stop_value_multi_ends(bool* stop_flags,
GM2LM_ASYNC(seq_lens + i, seq_lens_lm, sizeof(int) * readlen);
mfence();
for (int j = 0; j < readlen; j++) {
if(prefill_one_step_stop){
if (prefill_one_step_stop) {
stop_flags_lm[j] = true;
if (seq_lens_lm[j] == 0) {
topk_ids_lm[j] = -1;
}
next_tokens_lm[j] = topk_ids_lm[j];
}
else{
} else {
if (stop_flags_lm[j]) {
if (seq_lens_lm[j] == 0) {
topk_ids_lm[j] = -1;
@@ -8,166 +8,206 @@ namespace plugin {
template <typename T>
static __device__ inline void text_image_gather(
__global_ptr__ T* input,
__global_ptr__ T* text_input,
__global_ptr__ T* image_input,
__global_ptr__ int* token_type_ids,
__global_ptr__ int* text_index,
__global_ptr__ int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
T* input_lm) {
int cid = core_id();
int clusterid = cluster_id();
int token_start_cluster;
int token_end_cluster;
int token_start_core;
int token_end_core;
__global_ptr__ T* input,
__global_ptr__ T* text_input,
__global_ptr__ T* image_input,
__global_ptr__ int* token_type_ids,
__global_ptr__ int* text_index,
__global_ptr__ int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
T* input_lm) {
int cid = core_id();
int clusterid = cluster_id();
int token_start_cluster;
int token_end_cluster;
int token_start_core;
int token_end_core;
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
// cluster partition
partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster);
if (token_start_cluster >= token_end_cluster) {
return;
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
// cluster partition
partition(cluster_id(),
cluster_num(),
(int)token_num,
1,
&token_start_cluster,
&token_end_cluster);
if (token_start_cluster >= token_end_cluster) {
return;
}
int rows_cluster =
token_end_cluster - token_start_cluster; // total rows for a cluster
// core partition
partition(core_id(),
core_num(),
rows_cluster,
1,
&token_start_core,
&token_end_core);
int rows_core = token_end_core - token_start_core; // total rows for a core
token_start_core += token_start_cluster;
token_end_core += token_start_cluster;
int read_len;
for (int i = token_start_core; i < token_end_core; i += 1) {
int token_type, text_image_token_idx;
__global_ptr__ T* text_image_input = nullptr;
__global_ptr__ int* text_image_index = nullptr;
GM2LM(token_type_ids + i, &token_type, sizeof(int));
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else {
text_image_input = image_input;
text_image_index = image_index;
}
int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster
// core partition
partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core);
int rows_core = token_end_core - token_start_core; // total rows for a core
token_start_core += token_start_cluster;
token_end_core += token_start_cluster;
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
int read_len;
for (int i = token_start_core; i < token_end_core; i += 1) {
int token_type, text_image_token_idx;
__global_ptr__ T* text_image_input = nullptr;
__global_ptr__ int* text_image_index = nullptr;
GM2LM(token_type_ids + i, &token_type, sizeof(int));
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else {
text_image_input = image_input;
text_image_index = image_index;
}
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int j = 0; j < hidden_size; j += BUFSIZE) {
read_len = min(hidden_size - j, BUFSIZE);
GM2LM(text_image_input + text_image_offset + j, input_lm, sizeof(T) * read_len);
LM2GM(input_lm, input + input_offset + j, sizeof(T) * read_len);
}
for (int j = 0; j < hidden_size; j += BUFSIZE) {
read_len = min(hidden_size - j, BUFSIZE);
GM2LM(text_image_input + text_image_offset + j,
input_lm,
sizeof(T) * read_len);
LM2GM(input_lm, input + input_offset + j, sizeof(T) * read_len);
}
}
}
template <typename T>
static __device__ inline void text_image_scatter(
__global_ptr__ T* input,
__global_ptr__ T* text_input,
__global_ptr__ T* image_input,
__global_ptr__ int* token_type_ids,
__global_ptr__ int* text_index,
__global_ptr__ int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
T* input_lm) {
int cid = core_id();
int clusterid = cluster_id();
int token_start_cluster;
int token_end_cluster;
int token_start_core;
int token_end_core;
__global_ptr__ T* input,
__global_ptr__ T* text_input,
__global_ptr__ T* image_input,
__global_ptr__ int* token_type_ids,
__global_ptr__ int* text_index,
__global_ptr__ int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
T* input_lm) {
int cid = core_id();
int clusterid = cluster_id();
int token_start_cluster;
int token_end_cluster;
int token_start_core;
int token_end_core;
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
// cluster partition
partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster);
if (token_start_cluster >= token_end_cluster) {
return;
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
// cluster partition
partition(cluster_id(),
cluster_num(),
(int)token_num,
1,
&token_start_cluster,
&token_end_cluster);
if (token_start_cluster >= token_end_cluster) {
return;
}
int rows_cluster =
token_end_cluster - token_start_cluster; // total rows for a cluster
// core partition
partition(core_id(),
core_num(),
rows_cluster,
1,
&token_start_core,
&token_end_core);
int rows_core = token_end_core - token_start_core; // total rows for a core
token_start_core += token_start_cluster;
token_end_core += token_start_cluster;
int read_len;
for (int i = token_start_core; i < token_end_core; i += 1) {
int token_type, text_image_token_idx;
__global_ptr__ T* text_image_input = nullptr;
__global_ptr__ int* text_image_index = nullptr;
GM2LM(token_type_ids + i, &token_type, sizeof(int));
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else {
text_image_input = image_input;
text_image_index = image_index;
}
int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster
// core partition
partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core);
int rows_core = token_end_core - token_start_core; // total rows for a core
token_start_core += token_start_cluster;
token_end_core += token_start_cluster;
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
int read_len;
for (int i = token_start_core; i < token_end_core; i += 1) {
int token_type, text_image_token_idx;
__global_ptr__ T* text_image_input = nullptr;
__global_ptr__ int* text_image_index = nullptr;
GM2LM(token_type_ids + i, &token_type, sizeof(int));
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else {
text_image_input = image_input;
text_image_index = image_index;
}
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int j = 0; j < hidden_size; j += BUFSIZE) {
read_len = min(hidden_size - j, BUFSIZE);
GM2LM(input + input_offset + j, input_lm, sizeof(T) * read_len);
LM2GM(input_lm, text_image_input + text_image_offset + j, sizeof(T) * read_len);
}
for (int j = 0; j < hidden_size; j += BUFSIZE) {
read_len = min(hidden_size - j, BUFSIZE);
GM2LM(input + input_offset + j, input_lm, sizeof(T) * read_len);
LM2GM(input_lm,
text_image_input + text_image_offset + j,
sizeof(T) * read_len);
}
}
}
template <typename T>
__global__ void text_image_gather_scatter(
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
__simd__ T input_lm[BUFSIZE]; // 2KB for bf16 and fp32
if (is_scatter) {
text_image_scatter(
input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, input_lm);
} else {
text_image_gather(
input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, input_lm);
}
__global__ void text_image_gather_scatter(T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
__simd__ T input_lm[BUFSIZE]; // 2KB for bf16 and fp32
if (is_scatter) {
text_image_scatter(input,
text_input,
image_input,
token_type_ids,
text_index,
image_index,
token_num,
text_token_num,
image_token_num,
hidden_size,
input_lm);
} else {
text_image_gather(input,
text_input,
image_input,
token_type_ids,
text_index,
image_index,
token_num,
text_token_num,
image_token_num,
hidden_size,
input_lm);
}
}
#define _XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(T) \
template __global__ void text_image_gather_scatter<T>( \
T* input, \
T* text_input, \
T* image_input, \
int* token_type_ids, \
int* text_index, \
int* image_index, \
int64_t token_num, \
int64_t text_token_num, \
int64_t image_token_num, \
int64_t hidden_size, \
bool is_scatter);
#define _XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(T) \
template __global__ void text_image_gather_scatter<T>( \
T * input, \
T * text_input, \
T * image_input, \
int* token_type_ids, \
int* text_index, \
int* image_index, \
int64_t token_num, \
int64_t text_token_num, \
int64_t image_token_num, \
int64_t hidden_size, \
bool is_scatter);
_XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(bfloat16);
@@ -23,75 +23,92 @@
namespace xpu3 {
namespace plugin {
static __device__ void do_calc(const _shared_ptr_ int* lm_x, int* lm_y1, int* lm_y2, int64_t size, int& text_count, int& images_count) {
for (int j = 0; j < size; j++) {
if (lm_x[j] == 0) {
lm_y1[j] = text_count;
text_count += 1;
} else {
lm_y2[j] = images_count;
images_count += 1;
}
static __device__ void do_calc(const _shared_ptr_ int* lm_x,
int* lm_y1,
int* lm_y2,
int64_t size,
int& text_count,
int& images_count) {
for (int j = 0; j < size; j++) {
if (lm_x[j] == 0) {
lm_y1[j] = text_count;
text_count += 1;
} else {
lm_y2[j] = images_count;
images_count += 1;
}
mfence_lm_sm();
}
mfence_lm_sm();
}
__global__ void text_image_index_out_kernel(
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
const int cid = core_id();
const int tid = core_id() * cluster_num() + cluster_id();
const int nthreads = core_num() * cluster_num();
if (tid >= 1) return;
constexpr int BUFSIZE = 1024;
constexpr int READ_MAX_SIZE = BUFSIZE / sizeof(int);
const int64_t len = token_num;
__global__ void text_image_index_out_kernel(const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
const int cid = core_id();
const int tid = core_id() * cluster_num() + cluster_id();
const int nthreads = core_num() * cluster_num();
if (tid >= 1) return;
constexpr int BUFSIZE = 1024;
constexpr int READ_MAX_SIZE = BUFSIZE / sizeof(int);
const int64_t len = token_num;
__simd__ char buffer0[BUFSIZE * 3];
__simd__ char buffer1[BUFSIZE * 3];
__simd__ __shared__ char buffer2[64][BUFSIZE * 2];
__simd__ char buffer0[BUFSIZE * 3];
__simd__ char buffer1[BUFSIZE * 3];
__simd__ __shared__ char buffer2[64][BUFSIZE * 2];
DoublePtr<READ_MAX_SIZE, SmPtr<int>> buffer_ptr_x((SmPtr<int>((_shared_ptr_ int*)buffer2[cid])));
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y1((LmPtr<int>((int*)buffer0)));
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y2((LmPtr<int>((int*)buffer1)));
int64_t buflen = get_1d_buflen(len, nthreads, READ_MAX_SIZE, 64);
int64_t i = tid * buflen;
int read_size = 0;
int offset = nthreads * buflen;
DoublePtr<READ_MAX_SIZE, SmPtr<int>> buffer_ptr_x(
(SmPtr<int>((_shared_ptr_ int*)buffer2[cid])));
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y1(
(LmPtr<int>((int*)buffer0)));
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y2(
(LmPtr<int>((int*)buffer1)));
int64_t buflen = get_1d_buflen(len, nthreads, READ_MAX_SIZE, 64);
int64_t i = tid * buflen;
int read_size = 0;
int offset = nthreads * buflen;
int text_count = 0;
int images_count = 0;
int text_count = 0;
int images_count = 0;
if (i < len) {
read_size = min<int64_t>(buflen, len - i);
buffer_ptr_y1.gm_load_async(text_index + tid * buflen, read_size);
buffer_ptr_y2.gm_load_async(image_index + tid * buflen, read_size);
buffer_ptr_x.gm_load_async(token_type_ids + tid * buflen, read_size);
mfence();
}
while (i < len && i + offset < len) {
i = i + offset;
int read_size_next = min<int64_t>(buflen, len - i);
buffer_ptr_x.next().gm_load_async(token_type_ids + i, read_size_next);
buffer_ptr_y1.next().gm_load_async(text_index + i, read_size_next);
buffer_ptr_y2.next().gm_load_async(image_index + i, read_size_next);
if (i < len) {
read_size = min<int64_t>(buflen, len - i);
buffer_ptr_y1.gm_load_async(text_index + tid * buflen, read_size);
buffer_ptr_y2.gm_load_async(image_index + tid * buflen, read_size);
buffer_ptr_x.gm_load_async(token_type_ids + tid * buflen, read_size);
mfence();
}
while (i < len && i + offset < len) {
i = i + offset;
int read_size_next = min<int64_t>(buflen, len - i);
buffer_ptr_x.next().gm_load_async(token_type_ids + i, read_size_next);
buffer_ptr_y1.next().gm_load_async(text_index + i, read_size_next);
buffer_ptr_y2.next().gm_load_async(image_index + i, read_size_next);
do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count);
do_calc(buffer_ptr_x.ptr,
buffer_ptr_y1.ptr,
buffer_ptr_y2.ptr,
read_size,
text_count,
images_count);
buffer_ptr_y1.gm_store_async(text_index + i - offset, read_size);
buffer_ptr_y2.gm_store_async(image_index + i - offset, read_size);
buffer_ptr_x.toggle();
buffer_ptr_y1.toggle();
buffer_ptr_y2.toggle();
read_size = read_size_next;
}
if (i < len) {
do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count);
buffer_ptr_y1.gm_store_async(text_index + i, read_size);
buffer_ptr_y2.gm_store(image_index + i, read_size);
}
buffer_ptr_y1.gm_store_async(text_index + i - offset, read_size);
buffer_ptr_y2.gm_store_async(image_index + i - offset, read_size);
buffer_ptr_x.toggle();
buffer_ptr_y1.toggle();
buffer_ptr_y2.toggle();
read_size = read_size_next;
}
if (i < len) {
do_calc(buffer_ptr_x.ptr,
buffer_ptr_y1.ptr,
buffer_ptr_y2.ptr,
read_size,
text_count,
images_count);
buffer_ptr_y1.gm_store_async(text_index + i, read_size);
buffer_ptr_y2.gm_store(image_index + i, read_size);
}
}
} // namespace plugin
} // namespace xpu3
@@ -46,7 +46,8 @@ __global__ void update_inputs(bool *not_need_stop,
int seq_len_decoder_update =
stop_flag_now
? 0
: (seq_len_encoder > 0 ? (seq_len_encoder + seq_len_decoder) : seq_len_decoder + 1);
: (seq_len_encoder > 0 ? (seq_len_encoder + seq_len_decoder)
: seq_len_decoder + 1);
int seq_len_this_time_update = !stop_flag_now;
int seq_len_encoder_update = 0;
mfence_lm();
@@ -4,32 +4,30 @@
// #include <stdio.h>
// using namespace std;
#include "xpu/kernel/xtdk_io.h"
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu3 {
namespace plugin {
__global__ void update_inputs_v1(bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size) {
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size) {
// std::cout << "seq_lens_this_time " << seq_lens_this_time[0] << std::endl;
int cid = core_id();
int ncores = core_num();
@@ -41,74 +39,83 @@ __global__ void update_inputs_v1(bool *not_need_stop,
const int max_bs = 1024;
__shared__ bool stop_flags_sm[max_bs];
__shared__ int stop_flags_int_sm[max_bs];
if(cid == 0){
if (cid == 0) {
GM2SM(stop_flags, stop_flags_sm, sizeof(bool) * bsz);
}
sync_all();
for(int i = cid; i < bsz; i+= ncores){
if(i < bsz){
stop_flags_sm[i] = stop_flags[i];
stop_flags_int_sm[i] = static_cast<int64_t>(stop_flags_sm[i]);
}else{
stop_flags_sm[i] = true;
stop_flags_int_sm[i] = 1;
for (int i = cid; i < bsz; i += ncores) {
if (i < bsz) {
stop_flags_sm[i] = stop_flags[i];
stop_flags_int_sm[i] = static_cast<int64_t>(stop_flags_sm[i]);
} else {
stop_flags_sm[i] = true;
stop_flags_int_sm[i] = 1;
}
if(i<bsz){
int seq_len_this_time_update = 0;
int seq_len_decoder_update = 0;
int seq_lens_encoder_update = 0;
if(stop_flags_sm[i]){
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
if (i < bsz) {
int seq_len_this_time_update = 0;
int seq_len_decoder_update = 0;
int seq_lens_encoder_update = 0;
if (stop_flags_sm[i]) {
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
} else {
GM2LM(seq_lens_this_time + i, &seq_len_this_time_update, sizeof(int));
GM2LM(seq_lens_decoder + i, &seq_len_decoder_update, sizeof(int));
GM2LM(seq_lens_encoder + i, &seq_lens_encoder_update, sizeof(int));
int sum_of_seq_lens_this_time_and_seq_lens_decoder =
seq_len_this_time_update + seq_len_decoder_update;
int prompt_lens_update = 0;
GM2LM(prompt_lens + i, &prompt_lens_update, sizeof(int64_t));
// decoding
if (sum_of_seq_lens_this_time_and_seq_lens_decoder >=
prompt_lens_update) {
seq_len_decoder_update =
seq_len_this_time_update + seq_len_decoder_update;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
seq_len_this_time_update = 1;
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
seq_lens_encoder_update = 0;
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
int64_t input_ids_update;
GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t));
LM2GM(&input_ids_update,
input_ids + i * input_ids_stride,
sizeof(int64_t));
// to judge whether block is not enough
if (seq_len_this_time_update != 0 &&
block_tables[i * block_num_per_seq +
seq_len_decoder_update / block_size] == -1) {
is_block_step[i] = true;
seq_len_this_time_update = 0;
LM2GM(
&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
stop_flags_sm[i] = true;
SM2GM(stop_flags_sm + i, stop_flags + i, sizeof(bool));
LM2GM(&seq_len_decoder_update,
step_seq_lens_decoder + i,
sizeof(int));
seq_len_decoder_update = 0;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
}else{
GM2LM(seq_lens_this_time+i, &seq_len_this_time_update, sizeof(int));
GM2LM(seq_lens_decoder+i, &seq_len_decoder_update, sizeof(int));
GM2LM(seq_lens_encoder+i, &seq_lens_encoder_update, sizeof(int));
int sum_of_seq_lens_this_time_and_seq_lens_decoder = seq_len_this_time_update + seq_len_decoder_update;
int prompt_lens_update = 0;
GM2LM(prompt_lens+i, &prompt_lens_update, sizeof(int64_t));
// decoding
if(sum_of_seq_lens_this_time_and_seq_lens_decoder >= prompt_lens_update){
seq_len_decoder_update = seq_len_this_time_update + seq_len_decoder_update;
LM2GM(&seq_len_decoder_update, seq_lens_decoder+i, sizeof(int));
seq_len_this_time_update = 1;
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
seq_lens_encoder_update = 0;
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
int64_t input_ids_update;
GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t));
LM2GM(&input_ids_update, input_ids + i * input_ids_stride, sizeof(int64_t));
// to judge whether block is not enough
if(seq_len_this_time_update != 0 && block_tables[i * block_num_per_seq + seq_len_decoder_update/block_size] == -1){
is_block_step[i] = true;
seq_len_this_time_update = 0;
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
stop_flags_sm[i] = true;
SM2GM(stop_flags_sm+i, stop_flags+i, sizeof(bool));
LM2GM(&seq_len_decoder_update, step_seq_lens_decoder+i, sizeof(int));
seq_len_decoder_update = 0;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
seq_len_decoder_update = 0;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
stop_flags_int_sm[i] = 1;
}
}else{
stop_flags_sm[i] = true;
SM2GM(stop_flags_sm+i, stop_flags+i, sizeof(bool));
seq_len_this_time_update = 0;
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
seq_len_decoder_update = 0;
seq_lens_encoder_update = 0;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
int64_t topk_ids_update = -1;
LM2GM(&topk_ids_update, topk_ids + i, sizeof(int64_t));
stop_flags_int_sm[i] = 1;
}
seq_len_decoder_update = 0;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
stop_flags_int_sm[i] = 1;
}
} else {
stop_flags_sm[i] = true;
SM2GM(stop_flags_sm + i, stop_flags + i, sizeof(bool));
seq_len_this_time_update = 0;
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
seq_len_decoder_update = 0;
seq_lens_encoder_update = 0;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
int64_t topk_ids_update = -1;
LM2GM(&topk_ids_update, topk_ids + i, sizeof(int64_t));
stop_flags_int_sm[i] = 1;
}
}
}
}
sync_all();
@@ -6,16 +6,16 @@
namespace xpu3 {
namespace plugin {
__device__ void do_cast(const int* xlm, float* ylm, int64_t len) {
for (int64_t i = 0; i < len; i += 32) {
int32x16_t xl = vload_lm_int32x16(xlm + i);
int32x16_t xh = vload_lm_int32x16(xlm + i + 16);
float32x16_t yl = vfix2float(xl);
float32x16_t yh = vfix2float(xh);
vstore_lm_float32x16(ylm + i, yl);
vstore_lm_float32x16(ylm + i + 16, yh);
}
mfence_lm();
__device__ void do_cast(const int *xlm, float *ylm, int64_t len) {
for (int64_t i = 0; i < len; i += 32) {
int32x16_t xl = vload_lm_int32x16(xlm + i);
int32x16_t xh = vload_lm_int32x16(xlm + i + 16);
float32x16_t yl = vfix2float(xl);
float32x16_t yh = vfix2float(xh);
vstore_lm_float32x16(ylm + i, yl);
vstore_lm_float32x16(ylm + i + 16, yh);
}
mfence_lm();
}
template <typename T>
@@ -124,7 +124,8 @@ __global__ void update_value_by_repeat_times_simd(
int nthreads = cluster_num() * ncores;
int start = -1;
int end = -1;
partition(thread_id, nthreads, static_cast<int>(bs * length), 16, &start, &end);
partition(
thread_id, nthreads, static_cast<int>(bs * length), 16, &start, &end);
const int param_len = 256;
// ncores = 64 for xpu3
@@ -178,14 +179,28 @@ __global__ void update_value_by_repeat_times_simd(
alpha = alpha_buf[param_idx];
beta = beta_buf[param_idx];
gamma = gamma_buf[param_idx];
time_mask = svneq_float32x16(0.f, time_); // time != 0 mask
logit_mask = svle_float32x16(0.f, logits_); // logit >= 0 mask
time_ = svmul_float32x16(beta, time_); // time * beta
time_ = svadd_float32x16(gamma, time_); // time * beta + gamma
logits_ = svmul_float32x16_mh(alpha, logits_, logits_, (time_mask & ~logit_mask)); // when time != 0 && logit < 0, do alpha * logit
logits_ = svmul_float32x16_mh(1.0f / alpha, logits_, logits_, (time_mask & logit_mask)); // when time != 0 && >=0, do logit / alpha
logits_ = vvsub_float32x16_mh(logits_, time_, logits_, time_mask); // when time != 0, do logit = logit - time * beta - gamma;
logits_ = svmul_float32x16(1.0f / temperature, logits_); // logit / temperature
time_mask = svneq_float32x16(0.f, time_); // time != 0 mask
logit_mask = svle_float32x16(0.f, logits_); // logit >= 0 mask
time_ = svmul_float32x16(beta, time_); // time * beta
time_ = svadd_float32x16(gamma, time_); // time * beta + gamma
logits_ = svmul_float32x16_mh(
alpha,
logits_,
logits_,
(time_mask & ~logit_mask)); // when time != 0 && logit < 0, do
// alpha * logit
logits_ = svmul_float32x16_mh(
1.0f / alpha,
logits_,
logits_,
(time_mask & logit_mask)); // when time != 0 && >=0, do logit / alpha
logits_ = vvsub_float32x16_mh(logits_,
time_,
logits_,
time_mask); // when time != 0, do logit =
// logit - time * beta - gamma;
logits_ = svmul_float32x16(1.0f / temperature,
logits_); // logit / temperature
vstore_lm_float32x16(logits_lm + j, logits_);
}
mfence_lm();
@@ -195,14 +210,14 @@ __global__ void update_value_by_repeat_times_simd(
}
#define _XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(DATA_TYPE) \
template __global__ void update_value_by_repeat_times_simd( \
const int *repeat_times, \
const DATA_TYPE *penalty_scores, \
const DATA_TYPE *frequency_score, \
const DATA_TYPE *presence_score, \
const float *temperatures, \
DATA_TYPE *logits, \
const int64_t bs, \
template __global__ void update_value_by_repeat_times_simd( \
const int *repeat_times, \
const DATA_TYPE *penalty_scores, \
const DATA_TYPE *frequency_score, \
const DATA_TYPE *presence_score, \
const float *temperatures, \
DATA_TYPE *logits, \
const int64_t bs, \
const int64_t length);
_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(float);
_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(float16);
@@ -20,12 +20,16 @@
namespace xpu3 {
namespace plugin {
template <typename TX, typename TY>
__attribute__((global)) void
eb_adjust_batch(TX *src, TY *dst, int *encoder_seqs_lods,
int *encoder_batch_map, int *decoder_batch_map, int en_batch,
int de_batch, int64_t copy_size);
} // namespace plugin
} // namespace xpu3
__attribute__((global)) void eb_adjust_batch(TX *src,
TY *dst,
int *encoder_seqs_lods,
int *encoder_batch_map,
int *decoder_batch_map,
int en_batch,
int de_batch,
int64_t copy_size);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
@@ -33,10 +37,15 @@ namespace api {
namespace plugin {
template <typename TX, typename TY>
static int
cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods,
const int *encoder_batch_map, const int *decoder_batch_map,
int en_batch, int de_batch, int64_t hidden_dim) {
static int cpu_wrapper(api::Context *ctx,
const TX *x,
TY *y,
const int *encoder_seqs_lods,
const int *encoder_batch_map,
const int *decoder_batch_map,
int en_batch,
int de_batch,
int64_t hidden_dim) {
int ret = 0;
int cur_offset = 0;
int en_idx = 0;
@@ -48,7 +57,8 @@ cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods,
int cpy_m = 0;
if (de_batch > 0 && decoder_batch_map[de_idx] == i) {
cpy_m = 1;
ret = api::cast<TX, TY>(ctx, x + cur_offset * hidden_dim,
ret = api::cast<TX, TY>(ctx,
x + cur_offset * hidden_dim,
y + (encoder_len_total + de_idx) * hidden_dim,
cpy_m * hidden_dim);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
@@ -56,7 +66,8 @@ cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods,
}
if (en_batch > 0 && encoder_batch_map[en_idx] == i) {
cpy_m = encoder_seqs_lods[en_idx + 1] - encoder_seqs_lods[en_idx];
ret = api::cast<TX, TY>(ctx, x + cur_offset * hidden_dim,
ret = api::cast<TX, TY>(ctx,
x + cur_offset * hidden_dim,
y + encoder_seqs_lods[en_idx] * hidden_dim,
cpy_m * hidden_dim);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
@@ -69,11 +80,15 @@ cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods,
}
template <typename TX, typename TY>
static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y,
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
int en_batch, int de_batch, int64_t hidden_dim) {
static int xpu3_wrapper(api::Context *ctx,
const TX *x,
TY *y,
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
int en_batch,
int de_batch,
int64_t hidden_dim) {
using XPU_INDEX_TYPE_TX = typename XPUIndexType<TX>::type;
using XPU_INDEX_TYPE_TY = typename XPUIndexType<TY>::type;
auto eb_adjust_batch_kernel =
@@ -81,17 +96,23 @@ static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y,
// NOTE: Don't change 16 to 64, because kernel use gsm
eb_adjust_batch_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INDEX_TYPE_TX *>(const_cast<TX *>(x)),
reinterpret_cast<XPU_INDEX_TYPE_TY *>(y), encoder_seqs_lods.xpu,
encoder_batch_map.xpu, decoder_batch_map.xpu, en_batch, de_batch,
reinterpret_cast<XPU_INDEX_TYPE_TY *>(y),
encoder_seqs_lods.xpu,
encoder_batch_map.xpu,
decoder_batch_map.xpu,
en_batch,
de_batch,
hidden_dim);
return api::SUCCESS;
}
template <typename TX, typename TY>
int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y,
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
int eb_adjust_batch(api::Context *ctx,
const TX *x,
TY *y,
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
int64_t hidden_dim) {
// int dev_id = -1;
// xpu_current_device(&dev_id);
@@ -101,8 +122,13 @@ int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y,
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_adjust_batch", TX, TY);
WRAPPER_DUMP_PARAM6(ctx, x, y, encoder_seqs_lods, encoder_batch_map,
decoder_batch_map, hidden_dim);
WRAPPER_DUMP_PARAM6(ctx,
x,
y,
encoder_seqs_lods,
encoder_batch_map,
decoder_batch_map,
hidden_dim);
WRAPPER_DUMP(ctx);
int encoder_batch = encoder_batch_map.len;
int total_batch = encoder_batch + decoder_batch_map.len;
@@ -126,9 +152,14 @@ int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y,
WRAPPER_ASSERT_LT(ctx, decoder_batch_map.cpu[i], total_batch)
}
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<TX, TY>(ctx, x, y, encoder_seqs_lods.cpu,
encoder_batch_map.cpu, decoder_batch_map.cpu,
encoder_batch_map.len, decoder_batch_map.len,
return cpu_wrapper<TX, TY>(ctx,
x,
y,
encoder_seqs_lods.cpu,
encoder_batch_map.cpu,
decoder_batch_map.cpu,
encoder_batch_map.len,
decoder_batch_map.len,
hidden_dim);
}
if (ctx->dev().type() == api::kXPU3) {
@@ -139,18 +170,27 @@ int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y,
encoder_batch_map.to_xpu(RAII_GUARD);
api::VectorParam<int32_t> decoder_batch_map_xpu =
decoder_batch_map.to_xpu(RAII_GUARD);
return xpu3_wrapper<TX, TY>(ctx, x, y, encoder_seqs_lods_xpu,
encoder_batch_map_xpu, decoder_batch_map_xpu,
encoder_batch_map.len, decoder_batch_map.len,
return xpu3_wrapper<TX, TY>(ctx,
x,
y,
encoder_seqs_lods_xpu,
encoder_batch_map_xpu,
decoder_batch_map_xpu,
encoder_batch_map.len,
decoder_batch_map.len,
hidden_dim);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
#define INSTANTIATION_EB_ADJUST_BATCH(TX, TY) \
template int eb_adjust_batch<TX, TY>( \
api::Context *, const TX *, TY *, api::VectorParam<int32_t> &, \
api::VectorParam<int32_t> &, api::VectorParam<int32_t> &, int64_t);
#define INSTANTIATION_EB_ADJUST_BATCH(TX, TY) \
template int eb_adjust_batch<TX, TY>(api::Context *, \
const TX *, \
TY *, \
api::VectorParam<int32_t> &, \
api::VectorParam<int32_t> &, \
api::VectorParam<int32_t> &, \
int64_t);
INSTANTIATION_EB_ADJUST_BATCH(float16, float16);
INSTANTIATION_EB_ADJUST_BATCH(bfloat16, bfloat16);
@@ -163,7 +203,7 @@ INSTANTIATION_EB_ADJUST_BATCH(bfloat16, float);
INSTANTIATION_EB_ADJUST_BATCH(float, bfloat16);
INSTANTIATION_EB_ADJUST_BATCH(int32_t, int32_t);
INSTANTIATION_EB_ADJUST_BATCH(int64_t, int64_t);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -20,62 +20,92 @@
namespace xpu3 {
namespace plugin {
template <typename TX, typename TY>
__attribute__((global)) void
eb_gather_next_token(TX *src, TY *dst, int *encoder_seqs_lods,
int *encoder_batch_map, int *decoder_batch_map,
int en_batch, int de_batch, int64_t copy_size);
} // namespace plugin
} // namespace xpu3
__attribute__((global)) void eb_gather_next_token(TX *src,
TY *dst,
int *encoder_seqs_lods,
int *encoder_batch_map,
int *decoder_batch_map,
int en_batch,
int de_batch,
int64_t copy_size);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename TX, typename TY>
static int
cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods,
const int *encoder_batch_map, const int *decoder_batch_map,
int en_batch, int de_batch, int64_t hidden_dim) {
static int cpu_wrapper(api::Context *ctx,
const TX *x,
TY *y,
const int *encoder_seqs_lods,
const int *encoder_batch_map,
const int *decoder_batch_map,
int en_batch,
int de_batch,
int64_t hidden_dim) {
int ret = 0;
int encoder_len_total = encoder_seqs_lods[en_batch];
for (int i = 0; i < en_batch; i++) {
ret =
api::cast<TX, TY>(ctx, x + (encoder_seqs_lods[i + 1] - 1) * hidden_dim,
y + encoder_batch_map[i] * hidden_dim, hidden_dim);
ret = api::cast<TX, TY>(ctx,
x + (encoder_seqs_lods[i + 1] - 1) * hidden_dim,
y + encoder_batch_map[i] * hidden_dim,
hidden_dim);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
}
for (int i = 0; i < de_batch; i++) {
ret = api::cast<TX, TY>(ctx, x + (encoder_len_total + i) * hidden_dim,
y + decoder_batch_map[i] * hidden_dim, hidden_dim);
ret = api::cast<TX, TY>(ctx,
x + (encoder_len_total + i) * hidden_dim,
y + decoder_batch_map[i] * hidden_dim,
hidden_dim);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
}
return api::SUCCESS;
}
template <typename TX, typename TY>
static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y,
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
int en_batch, int de_batch, int64_t hidden_dim) {
static int xpu3_wrapper(api::Context *ctx,
const TX *x,
TY *y,
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
int en_batch,
int de_batch,
int64_t hidden_dim) {
auto eb_gather_next_token_kernel = xpu3::plugin::eb_gather_next_token<TX, TY>;
// NOTE: Don't change 16 to 64, because kernel use gsm
eb_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
const_cast<TX *>(x), y, encoder_seqs_lods.xpu, encoder_batch_map.xpu,
decoder_batch_map.xpu, en_batch, de_batch, hidden_dim);
const_cast<TX *>(x),
y,
encoder_seqs_lods.xpu,
encoder_batch_map.xpu,
decoder_batch_map.xpu,
en_batch,
de_batch,
hidden_dim);
return api::SUCCESS;
}
template <typename TX, typename TY>
int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
int64_t hidden_dim) {
int eb_gather_next_token(
api::Context *ctx,
const TX *x,
TY *y,
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
int64_t hidden_dim) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_gather_next_token", TX, TY);
WRAPPER_DUMP_PARAM6(ctx, x, y, encoder_seqs_lods, encoder_batch_map,
decoder_batch_map, hidden_dim);
WRAPPER_DUMP_PARAM6(ctx,
x,
y,
encoder_seqs_lods,
encoder_batch_map,
decoder_batch_map,
hidden_dim);
WRAPPER_DUMP(ctx);
int encoder_batch = encoder_batch_map.len;
int batch = encoder_batch + decoder_batch_map.len;
@@ -99,9 +129,14 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
WRAPPER_ASSERT_GE(ctx, decoder_batch_map.cpu[i], 0);
}
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<TX, TY>(ctx, x, y, encoder_seqs_lods.cpu,
encoder_batch_map.cpu, decoder_batch_map.cpu,
encoder_batch_map.len, decoder_batch_map.len,
return cpu_wrapper<TX, TY>(ctx,
x,
y,
encoder_seqs_lods.cpu,
encoder_batch_map.cpu,
decoder_batch_map.cpu,
encoder_batch_map.len,
decoder_batch_map.len,
hidden_dim);
}
if (ctx->dev().type() == api::kXPU3) {
@@ -112,17 +147,26 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
encoder_batch_map.to_xpu(RAII_GUARD);
api::VectorParam<int32_t> decoder_batch_map_xpu =
decoder_batch_map.to_xpu(RAII_GUARD);
return xpu3_wrapper<TX, TY>(ctx, x, y, encoder_seqs_lods_xpu,
encoder_batch_map_xpu, decoder_batch_map_xpu,
encoder_batch_map.len, decoder_batch_map.len,
return xpu3_wrapper<TX, TY>(ctx,
x,
y,
encoder_seqs_lods_xpu,
encoder_batch_map_xpu,
decoder_batch_map_xpu,
encoder_batch_map.len,
decoder_batch_map.len,
hidden_dim);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
#define INSTANTIATION_EB_GATHER_NEXT_TOKEN(TX, TY) \
template int eb_gather_next_token<TX, TY>( \
api::Context *, const TX *, TY *, api::VectorParam<int32_t> &, \
api::VectorParam<int32_t> &, api::VectorParam<int32_t> &, int64_t);
#define INSTANTIATION_EB_GATHER_NEXT_TOKEN(TX, TY) \
template int eb_gather_next_token<TX, TY>(api::Context *, \
const TX *, \
TY *, \
api::VectorParam<int32_t> &, \
api::VectorParam<int32_t> &, \
api::VectorParam<int32_t> &, \
int64_t);
INSTANTIATION_EB_GATHER_NEXT_TOKEN(float16, float16);
INSTANTIATION_EB_GATHER_NEXT_TOKEN(bfloat16, bfloat16);
@@ -133,7 +177,7 @@ INSTANTIATION_EB_GATHER_NEXT_TOKEN(bfloat16, float16);
INSTANTIATION_EB_GATHER_NEXT_TOKEN(float16, bfloat16);
INSTANTIATION_EB_GATHER_NEXT_TOKEN(bfloat16, float);
INSTANTIATION_EB_GATHER_NEXT_TOKEN(float, bfloat16);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -12,211 +12,304 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void free_and_dispatch_block(
bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder,
int *block_tables, int *encoder_block_lens, bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len, int *recover_block_list, int *recover_len,
int *need_block_list, int *need_block_len, int *used_list_len,
int *free_list, int *free_list_len, int64_t *first_token_ids, const int bsz,
const int block_size, const int block_num_per_seq,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num);
} // namespace plugin
} // namespace xpu3
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time,
int *seq_lens_decoder, int *block_tables,
int *encoder_block_lens, bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len, int *recover_block_list, int *recover_len,
int *need_block_list, int *need_block_len,
int *used_list_len, int *free_list, int *free_list_len,
int64_t *first_token_ids, const int bsz,
const int block_size, const int block_num_per_seq,
static int cpu_wrapper(Context *ctx,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num) {
for (int i = 0; i < bsz; i++) {
int *block_table_now = block_tables + i * block_num_per_seq;
if (stop_flags[i] && !is_block_step[i]) {
// 回收block块
const int encoder_block_len = encoder_block_lens[i];
const int decoder_used_len = used_list_len[i];
if (decoder_used_len > 0) {
const int ori_free_list_len = free_list_len[0];
free_list_len[0] += decoder_used_len;
for (int j = 0; j < decoder_used_len; j++) {
free_list[ori_free_list_len + j] =
block_table_now[encoder_block_len + j];
block_table_now[encoder_block_len + j] = -1;
}
encoder_block_lens[i] = 0;
used_list_len[i] = 0;
}
} else if (block_table_now[seq_lens_decoder[i] / block_size] == -1) {
// 统计需要分配block的位置和总数
const int ori_need_block_len = need_block_len[0];
need_block_len[0] += 1;
need_block_list[ori_need_block_len] = i;
}
}
while (need_block_len[0] > free_list_len[0]) {
// 调度block,根据used_list_len从大到小回收block,直到满足need_block_len
int max_used_list_len_id = 0;
int max_used_list_len = 0;
for (int i = 0; i < bsz; i++) {
int *block_table_now = block_tables + i * block_num_per_seq;
if (stop_flags[i] && !is_block_step[i]) {
// 回收block块
const int encoder_block_len = encoder_block_lens[i];
const int decoder_used_len = used_list_len[i];
if (decoder_used_len > 0) {
const int ori_free_list_len = free_list_len[0];
free_list_len[0] += decoder_used_len;
for (int j = 0; j < decoder_used_len; j++) {
free_list[ori_free_list_len + j] =
block_table_now[encoder_block_len + j];
block_table_now[encoder_block_len + j] = -1;
}
encoder_block_lens[i] = 0;
used_list_len[i] = 0;
}
} else if (block_table_now[seq_lens_decoder[i] / block_size] == -1) {
// 统计需要分配block的位置和总数
const int ori_need_block_len = need_block_len[0];
need_block_len[0] += 1;
need_block_list[ori_need_block_len] = i;
}
const int used_block_num = !is_block_step[i] ? used_list_len[i] : 0;
if (used_block_num > max_used_list_len) {
max_used_list_len_id = i;
max_used_list_len = used_block_num;
}
}
while (need_block_len[0] > free_list_len[0]) {
// 调度block,根据used_list_len从大到小回收block,直到满足need_block_len
int max_used_list_len_id = 0;
int max_used_list_len = 0;
for (int i = 0; i < bsz; i++) {
const int used_block_num = !is_block_step[i] ? used_list_len[i] : 0;
if (used_block_num > max_used_list_len) {
max_used_list_len_id = i;
max_used_list_len = used_block_num;
}
}
const int encoder_block_len = encoder_block_lens[max_used_list_len_id];
int *block_table_now =
block_tables + max_used_list_len_id * block_num_per_seq;
for (int i = 0; i < max_used_list_len; i++) {
free_list[free_list_len[0] + i] =
block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
step_block_list[step_len[0]] = max_used_list_len_id;
step_len[0] += 1;
free_list_len[0] += max_used_list_len;
stop_flags[max_used_list_len_id] = true;
is_block_step[max_used_list_len_id] = true;
seq_lens_this_time[max_used_list_len_id] = 0;
seq_lens_decoder[max_used_list_len_id] = 0;
const int encoder_block_len = encoder_block_lens[max_used_list_len_id];
int *block_table_now =
block_tables + max_used_list_len_id * block_num_per_seq;
for (int i = 0; i < max_used_list_len; i++) {
free_list[free_list_len[0] + i] = block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
step_block_list[step_len[0]] = max_used_list_len_id;
step_len[0] += 1;
free_list_len[0] += max_used_list_len;
stop_flags[max_used_list_len_id] = true;
is_block_step[max_used_list_len_id] = true;
seq_lens_this_time[max_used_list_len_id] = 0;
seq_lens_decoder[max_used_list_len_id] = 0;
}
// 为需要block的位置分配block,每个位置分配一个block
for (int i = 0; i < bsz; i++) {
if (i < need_block_len[0]) {
const int need_block_id = need_block_list[i];
if (!stop_flags[need_block_id]) {
// 如果需要的位置正好是上一步中被释放的位置,不做处理
used_list_len[need_block_id] += 1;
const int ori_free_list_len = free_list_len[0];
free_list_len[0]--;
int *block_table_now =
block_tables + need_block_id * block_num_per_seq;
block_table_now[seq_lens_decoder[need_block_id] / block_size] =
free_list[ori_free_list_len - 1];
}
need_block_list[i] = -1;
}
// 为需要block的位置分配block,每个位置分配一个block
for (int i = 0; i < bsz; i++) {
if (i < need_block_len[0]) {
const int need_block_id = need_block_list[i];
if (!stop_flags[need_block_id]) {
// 如果需要的位置正好是上一步中被释放的位置,不做处理
used_list_len[need_block_id] += 1;
const int ori_free_list_len = free_list_len[0];
free_list_len[0]--;
int *block_table_now = block_tables + need_block_id * block_num_per_seq;
block_table_now[seq_lens_decoder[need_block_id] / block_size] =
free_list[ori_free_list_len - 1];
}
need_block_list[i] = -1;
}
}
// 计算可以复原的query id
int ori_step_len = step_len[0];
if (ori_step_len > 0) {
int ori_free_list_len = free_list_len[0];
int ori_step_block_id = step_block_list[ori_step_len - 1];
int tmp_used_len = used_list_len[ori_step_block_id];
// 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中)
int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1
: tmp_used_len;
while (ori_step_len > 0 && ori_free_list_len >= used_len) {
recover_block_list[recover_len[0]] = ori_step_block_id;
is_block_step[ori_step_block_id] = false;
used_list_len[ori_step_block_id] = used_len;
ori_free_list_len -= used_len;
step_block_list[ori_step_len - 1] = -1;
step_len[0] -= 1;
recover_len[0] += 1;
ori_step_len = step_len[0];
if (ori_step_len > 0) {
ori_step_block_id = step_block_list[ori_step_len - 1];
tmp_used_len = used_list_len[ori_step_block_id];
used_len = tmp_used_len < max_decoder_block_num
? tmp_used_len + 1
: tmp_used_len;
}
}
need_block_len[0] = 0;
// 计算可以复原的query id
int ori_step_len = step_len[0];
if (ori_step_len > 0) {
int ori_free_list_len = free_list_len[0];
int ori_step_block_id = step_block_list[ori_step_len - 1];
int tmp_used_len = used_list_len[ori_step_block_id];
// 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中)
int used_len =
tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len;
while (ori_step_len > 0 && ori_free_list_len >= used_len) {
recover_block_list[recover_len[0]] = ori_step_block_id;
is_block_step[ori_step_block_id] = false;
used_list_len[ori_step_block_id] = used_len;
ori_free_list_len -= used_len;
step_block_list[ori_step_len - 1] = -1;
step_len[0] -= 1;
recover_len[0] += 1;
ori_step_len = step_len[0];
if (ori_step_len > 0) {
ori_step_block_id = step_block_list[ori_step_len - 1];
tmp_used_len = used_list_len[ori_step_block_id];
used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1
: tmp_used_len;
}
}
return api::SUCCESS;
need_block_len[0] = 0;
}
return api::SUCCESS;
}
static int xpu3_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time,
int *seq_lens_decoder, int *block_tables,
int *encoder_block_lens, bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len, int *recover_block_list,
int *recover_len, int *need_block_list,
int *need_block_len, int *used_list_len, int *free_list,
int *free_list_len, int64_t *first_token_ids,
const int bsz, const int block_size,
static int xpu3_wrapper(Context *ctx,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto free_and_dispatch_block_kernel = xpu3::plugin::free_and_dispatch_block;
free_and_dispatch_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables,
encoder_block_lens, is_block_step, step_block_list, step_len,
recover_block_list, recover_len, need_block_list, need_block_len,
used_list_len, free_list, free_list_len,
reinterpret_cast<XPU_INT64 *>(first_token_ids), bsz, block_size,
block_num_per_seq, max_decoder_block_num);
return api::SUCCESS;
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto free_and_dispatch_block_kernel = xpu3::plugin::free_and_dispatch_block;
free_and_dispatch_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(first_token_ids),
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num);
return api::SUCCESS;
}
int free_and_dispatch_block(Context *ctx, bool *stop_flags,
int *seq_lens_this_time, int *seq_lens_decoder,
int *block_tables, int *encoder_block_lens,
int free_and_dispatch_block(Context *ctx,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len, int *recover_block_list,
int *recover_len, int *need_block_list,
int *need_block_len, int *used_list_len,
int *free_list, int *free_list_len,
int64_t *first_token_ids, const int bsz,
const int block_size, const int block_num_per_seq,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "free_and_dispatch_block", float);
WRAPPER_DUMP_PARAM6(ctx, stop_flags, seq_lens_this_time, seq_lens_decoder,
block_tables, encoder_block_lens, is_block_step);
WRAPPER_DUMP_PARAM6(ctx, step_block_list, step_len, recover_block_list,
recover_len, need_block_list, need_block_len);
WRAPPER_DUMP_PARAM4(ctx, used_list_len, free_list, free_list_len,
first_token_ids);
WRAPPER_DUMP_PARAM4(ctx, bsz, block_size, block_num_per_seq,
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "free_and_dispatch_block", float);
WRAPPER_DUMP_PARAM6(ctx,
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step);
WRAPPER_DUMP_PARAM6(ctx,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len);
WRAPPER_DUMP_PARAM4(
ctx, used_list_len, free_list, free_list_len, first_token_ids);
WRAPPER_DUMP_PARAM4(
ctx, bsz, block_size, block_num_per_seq, max_decoder_block_num);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
first_token_ids,
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
first_token_ids,
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(
ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables,
encoder_block_lens, is_block_step, step_block_list, step_len,
recover_block_list, recover_len, need_block_list, need_block_len,
used_list_len, free_list, free_list_len, first_token_ids, bsz,
block_size, block_num_per_seq, max_decoder_block_num);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(
ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables,
encoder_block_lens, is_block_step, step_block_list, step_len,
recover_block_list, recover_len, need_block_list, need_block_len,
used_list_len, free_list, free_list_len, first_token_ids, bsz,
block_size, block_num_per_seq, max_decoder_block_num);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -12,120 +12,178 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void
get_padding_offset(int *padding_offset, int *cum_offsets_out, int *cu_seqlens_q,
int *cu_seqlens_k, const int *cum_offsets,
const int *seq_lens, const int max_seq_len, const int bs);
__attribute__((global)) void
remove_padding(int64_t *x_remove_padding, const int64_t *input_data,
const int *seq_lens, const int *cum_offsets,
const int sequence_length, const int bs);
__attribute__((global)) void get_padding_offset(int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
const int *cum_offsets,
const int *seq_lens,
const int max_seq_len,
const int bs);
__attribute__((global)) void remove_padding(int64_t *x_remove_padding,
const int64_t *input_data,
const int *seq_lens,
const int *cum_offsets,
const int sequence_length,
const int bs);
} // namespace plugin
} // namespace xpu3
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int get_padding_offset_cpu(int *padding_offset, int *cum_offsets_out,
int *cu_seqlens_q, int *cu_seqlens_k,
const int *cum_offsets, const int *seq_lens,
const int max_seq_len, const int bs) {
for (int i = 0; i < bs; i++) {
int cum_offset = i == 0 ? 0 : cum_offsets[i - 1];
for (int j = 0; j < seq_lens[i]; j++) {
padding_offset[i * max_seq_len - cum_offset + j] = cum_offset;
}
cum_offsets_out[i] = cum_offset;
int cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i];
cu_seqlens_q[i + 1] = cum_seq_len;
cu_seqlens_k[i + 1] = cum_seq_len;
static int get_padding_offset_cpu(int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
const int *cum_offsets,
const int *seq_lens,
const int max_seq_len,
const int bs) {
for (int i = 0; i < bs; i++) {
int cum_offset = i == 0 ? 0 : cum_offsets[i - 1];
for (int j = 0; j < seq_lens[i]; j++) {
padding_offset[i * max_seq_len - cum_offset + j] = cum_offset;
}
return api::SUCCESS;
cum_offsets_out[i] = cum_offset;
int cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i];
cu_seqlens_q[i + 1] = cum_seq_len;
cu_seqlens_k[i + 1] = cum_seq_len;
}
return api::SUCCESS;
}
static int remove_padding_cpu(int64_t *x_remove_padding,
const int64_t *input_data, const int *seq_lens,
const int *cum_offsets, const int sequence_length,
const int64_t *input_data,
const int *seq_lens,
const int *cum_offsets,
const int sequence_length,
const int bs) {
for (int i = 0; i < bs; i++) {
for (int j = 0; j < seq_lens[i]; j++) {
const int tgt_seq_id = i * sequence_length - cum_offsets[i] + j;
const int src_seq_id = i * sequence_length + j;
x_remove_padding[tgt_seq_id] = input_data[src_seq_id];
}
for (int i = 0; i < bs; i++) {
for (int j = 0; j < seq_lens[i]; j++) {
const int tgt_seq_id = i * sequence_length - cum_offsets[i] + j;
const int src_seq_id = i * sequence_length + j;
x_remove_padding[tgt_seq_id] = input_data[src_seq_id];
}
return api::SUCCESS;
}
return api::SUCCESS;
}
static int cpu_wrapper(Context *ctx, int *padding_offset, int *cum_offsets_out,
int *cu_seqlens_q, int *cu_seqlens_k,
int64_t *x_remove_padding, const int64_t *input_ids,
const int *cum_offsets, const int *seq_lens,
const int max_seq_len, const int bs) {
get_padding_offset_cpu(padding_offset, cum_offsets_out, cu_seqlens_q,
cu_seqlens_k, cum_offsets, seq_lens, max_seq_len,
bs);
remove_padding_cpu(x_remove_padding, input_ids, seq_lens, cum_offsets_out,
max_seq_len, bs);
return api::SUCCESS;
static int cpu_wrapper(Context *ctx,
int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
int64_t *x_remove_padding,
const int64_t *input_ids,
const int *cum_offsets,
const int *seq_lens,
const int max_seq_len,
const int bs) {
get_padding_offset_cpu(padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
cum_offsets,
seq_lens,
max_seq_len,
bs);
remove_padding_cpu(
x_remove_padding, input_ids, seq_lens, cum_offsets_out, max_seq_len, bs);
return api::SUCCESS;
}
static int xpu3_wrapper(Context *ctx, int *padding_offset, int *cum_offsets_out,
int *cu_seqlens_q, int *cu_seqlens_k,
int64_t *x_remove_padding, const int64_t *input_ids,
const int *cum_offsets, const int *seq_lens,
const int max_seq_len, const int bs) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto get_padding_offset = xpu3::plugin::get_padding_offset;
auto remove_padding = xpu3::plugin::remove_padding;
get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k,
cum_offsets, seq_lens, max_seq_len, bs);
remove_padding<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64 *>(x_remove_padding),
reinterpret_cast<const XPU_INT64 *>(input_ids), seq_lens,
cum_offsets_out, max_seq_len, bs);
return api::SUCCESS;
static int xpu3_wrapper(Context *ctx,
int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
int64_t *x_remove_padding,
const int64_t *input_ids,
const int *cum_offsets,
const int *seq_lens,
const int max_seq_len,
const int bs) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto get_padding_offset = xpu3::plugin::get_padding_offset;
auto remove_padding = xpu3::plugin::remove_padding;
get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
cum_offsets,
seq_lens,
max_seq_len,
bs);
remove_padding<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64 *>(x_remove_padding),
reinterpret_cast<const XPU_INT64 *>(input_ids),
seq_lens,
cum_offsets_out,
max_seq_len,
bs);
return api::SUCCESS;
}
int get_padding_offset(Context *ctx, int *padding_offset, int *cum_offsets_out,
int *cu_seqlens_q, int *cu_seqlens_k,
int64_t *x_remove_padding, const int64_t *input_ids,
const int *cum_offsets, const int *seq_lens,
const int max_seq_len, const int bs) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "get_padding_offset", int);
WRAPPER_DUMP_PARAM4(ctx, padding_offset, cum_offsets_out, cu_seqlens_q,
cu_seqlens_k);
WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, cum_offsets,
seq_lens);
WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bs);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx, padding_offset, cum_offsets_out, cu_seqlens_q,
cu_seqlens_k, x_remove_padding, input_ids,
cum_offsets, seq_lens, max_seq_len, bs);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx, padding_offset, cum_offsets_out, cu_seqlens_q,
cu_seqlens_k, x_remove_padding, input_ids,
cum_offsets, seq_lens, max_seq_len, bs);
}
WRAPPER_UNIMPLEMENTED(ctx);
int get_padding_offset(Context *ctx,
int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
int64_t *x_remove_padding,
const int64_t *input_ids,
const int *cum_offsets,
const int *seq_lens,
const int max_seq_len,
const int bs) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "get_padding_offset", int);
WRAPPER_DUMP_PARAM4(
ctx, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k);
WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, cum_offsets, seq_lens);
WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bs);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
x_remove_padding,
input_ids,
cum_offsets,
seq_lens,
max_seq_len,
bs);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
x_remove_padding,
input_ids,
cum_offsets,
seq_lens,
max_seq_len,
bs);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -76,19 +76,20 @@ static int cpu_wrapper(
}
static int xpu3_wrapper(Context* ctx,
const int64_t* base_model_draft_tokens,
int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const bool* base_model_stop_flags,
int bsz,
int base_model_draft_token_len) {
xpu3::plugin::draft_model_postprocess<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const xpu3::int64_t*>(base_model_draft_tokens),
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_stop_flags,
bsz,
base_model_draft_token_len);
const int64_t* base_model_draft_tokens,
int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const bool* base_model_stop_flags,
int bsz,
int base_model_draft_token_len) {
xpu3::plugin::
draft_model_postprocess<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const xpu3::int64_t*>(base_model_draft_tokens),
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_stop_flags,
bsz,
base_model_draft_token_len);
return api::SUCCESS;
}
@@ -124,12 +125,12 @@ int draft_model_postprocess(Context* ctx,
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
base_model_draft_tokens,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_stop_flags,
bsz,
base_model_draft_token_len);
base_model_draft_tokens,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_stop_flags,
bsz,
base_model_draft_token_len);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
@@ -21,13 +21,18 @@
namespace xpu3 {
namespace plugin {
template <typename T>
__attribute__((global)) void
set_stop_value_multi_ends(bool *stop_flags, T *topk_ids, T *next_tokens,
const T *end_ids, const int *seq_lens, const int bs,
const int end_length, const bool beam_search,
const bool prefill_one_step_stop);
} // namespace plugin
} // namespace xpu3
__attribute__((global)) void set_stop_value_multi_ends(
bool *stop_flags,
T *topk_ids,
T *next_tokens,
const T *end_ids,
const int *seq_lens,
const int bs,
const int end_length,
const bool beam_search,
const bool prefill_one_step_stop);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
@@ -36,104 +41,143 @@ namespace plugin {
template <typename T>
__inline__ bool is_in_end(const T id, const T *end_ids, int length) {
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
return false;
}
return false;
}
template <typename T>
static int cpu_wrapper(Context *ctx, bool *stop_flags, T *topk_ids,
T *next_tokens, const T *end_ids, const int *seq_lens,
const int bs, const int end_length,
static int cpu_wrapper(Context *ctx,
bool *stop_flags,
T *topk_ids,
T *next_tokens,
const T *end_ids,
const int *seq_lens,
const int bs,
const int end_length,
const bool beam_search,
const bool prefill_one_step_stop) {
for (int i = 0; i < bs; i++) {
if (prefill_one_step_stop) {
stop_flags[i] = true;
if (seq_lens[i] == 0) {
topk_ids[i] = -1;
}
next_tokens[i] = topk_ids[i];
for (int i = 0; i < bs; i++) {
if (prefill_one_step_stop) {
stop_flags[i] = true;
if (seq_lens[i] == 0) {
topk_ids[i] = -1;
}
next_tokens[i] = topk_ids[i];
} else {
if (stop_flags[i]) {
if (seq_lens[i] == 0) {
topk_ids[i] = -1;
} else {
if (stop_flags[i]) {
if (seq_lens[i] == 0) {
topk_ids[i] = -1;
} else {
topk_ids[i] = end_ids[0];
next_tokens[i] = end_ids[0];
}
} else {
next_tokens[i] = topk_ids[i];
}
if (!beam_search && is_in_end(topk_ids[i], end_ids, end_length)) {
stop_flags[i] = true;
}
topk_ids[i] = end_ids[0];
next_tokens[i] = end_ids[0];
}
} else {
next_tokens[i] = topk_ids[i];
}
if (!beam_search && is_in_end(topk_ids[i], end_ids, end_length)) {
stop_flags[i] = true;
}
}
return api::SUCCESS;
}
return api::SUCCESS;
}
template <typename T>
static int xpu3_wrapper(Context *ctx, bool *stop_flags, T *topk_ids,
T *next_tokens, const T *end_ids, const int *seq_lens,
const int bs, const int end_length,
static int xpu3_wrapper(Context *ctx,
bool *stop_flags,
T *topk_ids,
T *next_tokens,
const T *end_ids,
const int *seq_lens,
const int bs,
const int end_length,
const bool beam_search,
const bool prefill_one_step_stop) {
using XPU_TID = typename XPUIndexType<T>::type;
auto set_stop_value_multi_ends =
xpu3::plugin::set_stop_value_multi_ends<XPU_TID>;
set_stop_value_multi_ends<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags, reinterpret_cast<XPU_TID *>(topk_ids),
reinterpret_cast<XPU_TID *>(next_tokens),
reinterpret_cast<const XPU_TID *>(end_ids), seq_lens, bs, end_length,
beam_search, prefill_one_step_stop);
return api::SUCCESS;
using XPU_TID = typename XPUIndexType<T>::type;
auto set_stop_value_multi_ends =
xpu3::plugin::set_stop_value_multi_ends<XPU_TID>;
set_stop_value_multi_ends<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
reinterpret_cast<XPU_TID *>(topk_ids),
reinterpret_cast<XPU_TID *>(next_tokens),
reinterpret_cast<const XPU_TID *>(end_ids),
seq_lens,
bs,
end_length,
beam_search,
prefill_one_step_stop);
return api::SUCCESS;
}
template <typename T>
int set_stop_value_multi_ends(Context *ctx, bool *stop_flags, T *topk_ids,
T *next_tokens, const T *end_ids,
const int *seq_lens, const int bs,
const int end_length, const bool beam_search) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "set_stop_value_multi_ends", T);
WRAPPER_DUMP_PARAM5(ctx, stop_flags, topk_ids, next_tokens, end_ids,
seq_lens);
WRAPPER_DUMP_PARAM3(ctx, bs, end_length, beam_search);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, bool, bs, stop_flags);
WRAPPER_CHECK_PTR(ctx, T, bs, topk_ids);
WRAPPER_CHECK_PTR(ctx, T, end_length, end_ids);
WRAPPER_CHECK_PTR(ctx, T, bs, seq_lens);
WRAPPER_ASSERT_LE(ctx, end_length, 1024); // assume end_length <= 1024
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
int set_stop_value_multi_ends(Context *ctx,
bool *stop_flags,
T *topk_ids,
T *next_tokens,
const T *end_ids,
const int *seq_lens,
const int bs,
const int end_length,
const bool beam_search) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "set_stop_value_multi_ends", T);
WRAPPER_DUMP_PARAM5(
ctx, stop_flags, topk_ids, next_tokens, end_ids, seq_lens);
WRAPPER_DUMP_PARAM3(ctx, bs, end_length, beam_search);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, bool, bs, stop_flags);
WRAPPER_CHECK_PTR(ctx, T, bs, topk_ids);
WRAPPER_CHECK_PTR(ctx, T, end_length, end_ids);
WRAPPER_CHECK_PTR(ctx, T, bs, seq_lens);
WRAPPER_ASSERT_LE(ctx, end_length, 1024); // assume end_length <= 1024
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(ctx, stop_flags, topk_ids, next_tokens, end_ids,
seq_lens, bs, end_length, beam_search,
prefill_one_step_stop);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T>(ctx, stop_flags, topk_ids, next_tokens, end_ids,
seq_lens, bs, end_length, beam_search,
prefill_one_step_stop);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(ctx,
stop_flags,
topk_ids,
next_tokens,
end_ids,
seq_lens,
bs,
end_length,
beam_search,
prefill_one_step_stop);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T>(ctx,
stop_flags,
topk_ids,
next_tokens,
end_ids,
seq_lens,
bs,
end_length,
beam_search,
prefill_one_step_stop);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int set_stop_value_multi_ends<int64_t>(
Context *ctx, bool *stop_flags, int64_t *topk_ids, int64_t *next_tokens,
const int64_t *end_ids, const int *seq_lens, const int bs,
const int end_length, const bool beam_search);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
template int set_stop_value_multi_ends<int64_t>(Context *ctx,
bool *stop_flags,
int64_t *topk_ids,
int64_t *next_tokens,
const int64_t *end_ids,
const int *seq_lens,
const int bs,
const int end_length,
const bool beam_search);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -12,128 +12,173 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void set_value_by_flags_and_idx(
const bool *stop_flags, int64_t *pre_ids_all, const int64_t *input_ids,
const int *seq_lens_encoder, const int *seq_lens_decoder,
const int64_t *step_idx, int bs, int length, int length_input_ids);
const bool *stop_flags,
int64_t *pre_ids_all,
const int64_t *input_ids,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int length_input_ids);
} // namespace plugin
} // namespace xpu3
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context *ctx, const bool *stop_flags,
int64_t *pre_ids_all, const int64_t *pre_ids,
const int64_t *step_idx, const int bs,
static int cpu_wrapper(Context *ctx,
const bool *stop_flags,
int64_t *pre_ids_all,
const int64_t *pre_ids,
const int64_t *step_idx,
const int bs,
const int length) {
for (int i = 0; i < bs; i++) {
int64_t *pre_ids_all_now = pre_ids_all + i * length;
if (!stop_flags[i] && step_idx[i] >= 0) {
pre_ids_all_now[step_idx[i]] = pre_ids[i];
}
for (int i = 0; i < bs; i++) {
int64_t *pre_ids_all_now = pre_ids_all + i * length;
if (!stop_flags[i] && step_idx[i] >= 0) {
pre_ids_all_now[step_idx[i]] = pre_ids[i];
}
return api::SUCCESS;
}
return api::SUCCESS;
}
static int cpu_wrapper(Context *ctx, const bool *stop_flags,
int64_t *pre_ids_all, const int64_t *input_ids,
const int *seq_lens_encoder, const int *seq_lens_decoder,
const int64_t *step_idx, int bs, int length,
static int cpu_wrapper(Context *ctx,
const bool *stop_flags,
int64_t *pre_ids_all,
const int64_t *input_ids,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int length_input_ids) {
for (int i = 0; i < bs; i++) {
if (!stop_flags[i]) {
int64_t *pre_ids_all_now = pre_ids_all + i * length;
const int64_t *input_ids_now = input_ids + i * length_input_ids;
const int seq_len_dec = seq_lens_decoder[i];
const int seq_len_enc = seq_lens_encoder[i];
if (seq_len_dec == 0 && seq_len_enc == 0)
continue;
if (step_idx[i] >= 0) {
if (seq_len_enc > 0) {
// encoder, get last token accord to seq_lens_encoder
pre_ids_all_now[step_idx[i]] =
input_ids_now[seq_len_enc - 1];
} else {
// decoder, get first token
pre_ids_all_now[step_idx[i]] = input_ids_now[0];
}
}
for (int i = 0; i < bs; i++) {
if (!stop_flags[i]) {
int64_t *pre_ids_all_now = pre_ids_all + i * length;
const int64_t *input_ids_now = input_ids + i * length_input_ids;
const int seq_len_dec = seq_lens_decoder[i];
const int seq_len_enc = seq_lens_encoder[i];
if (seq_len_dec == 0 && seq_len_enc == 0) continue;
if (step_idx[i] >= 0) {
if (seq_len_enc > 0) {
// encoder, get last token accord to seq_lens_encoder
pre_ids_all_now[step_idx[i]] = input_ids_now[seq_len_enc - 1];
} else {
// decoder, get first token
pre_ids_all_now[step_idx[i]] = input_ids_now[0];
}
}
}
return api::SUCCESS;
}
return api::SUCCESS;
}
static int xpu3_wrapper(Context *ctx, const bool *stop_flags,
int64_t *pre_ids_all, const int64_t *input_ids,
static int xpu3_wrapper(Context *ctx,
const bool *stop_flags,
int64_t *pre_ids_all,
const int64_t *input_ids,
const int *seq_lens_encoder,
const int *seq_lens_decoder, const int64_t *step_idx,
int bs, int length, int length_input_ids) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto set_value_by_flags_and_idx_kernel =
xpu3::plugin::set_value_by_flags_and_idx;
set_value_by_flags_and_idx_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags, reinterpret_cast<XPU_INT64 *>(pre_ids_all),
reinterpret_cast<const XPU_INT64 *>(input_ids), seq_lens_encoder,
seq_lens_decoder, reinterpret_cast<const XPU_INT64 *>(step_idx), bs,
length, length_input_ids);
return api::SUCCESS;
const int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int length_input_ids) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto set_value_by_flags_and_idx_kernel =
xpu3::plugin::set_value_by_flags_and_idx;
set_value_by_flags_and_idx_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
reinterpret_cast<XPU_INT64 *>(pre_ids_all),
reinterpret_cast<const XPU_INT64 *>(input_ids),
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<const XPU_INT64 *>(step_idx),
bs,
length,
length_input_ids);
return api::SUCCESS;
}
int set_value_by_flags_and_idx(Context *ctx, const bool *stop_flags,
int64_t *pre_ids_all, const int64_t *input_ids,
int set_value_by_flags_and_idx(Context *ctx,
const bool *stop_flags,
int64_t *pre_ids_all,
const int64_t *input_ids,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx, int bs, int length,
const int64_t *step_idx,
int bs,
int length,
int length_input_ids) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "set_value_by_flags_and_idx", int64_t);
WRAPPER_DUMP_PARAM6(ctx, stop_flags, pre_ids_all, input_ids,
seq_lens_encoder, seq_lens_decoder, step_idx);
WRAPPER_DUMP_PARAM3(ctx, bs, length, length_input_ids);
WRAPPER_DUMP(ctx);
int64_t stop_flags_len = -1;
int64_t pre_ids_all_len = -1;
int64_t input_ids_len = -1;
int64_t seq_lens_encoder_len = -1;
int64_t seq_lens_decoder_len = -1;
int64_t step_idx_len = -1;
WRAPPER_CHECK_SHAPE(ctx, &stop_flags_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &pre_ids_all_len, {bs, length});
WRAPPER_CHECK_SHAPE(ctx, &input_ids_len, {bs, length_input_ids});
WRAPPER_CHECK_SHAPE(ctx, &seq_lens_encoder_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &seq_lens_decoder_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &step_idx_len, {bs});
WRAPPER_CHECK_PTR(ctx, int64_t, stop_flags_len, stop_flags);
WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_all_len, pre_ids_all);
WRAPPER_CHECK_PTR(ctx, int64_t, input_ids_len, input_ids);
WRAPPER_CHECK_PTR(ctx, int, seq_lens_encoder_len, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, seq_lens_decoder_len, seq_lens_decoder);
WRAPPER_CHECK_PTR(ctx, int64_t, step_idx_len, step_idx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx, stop_flags, pre_ids_all, input_ids,
seq_lens_encoder, seq_lens_decoder, step_idx, bs,
length, length_input_ids);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx, stop_flags, pre_ids_all, input_ids,
seq_lens_encoder, seq_lens_decoder, step_idx, bs,
length, length_input_ids);
}
WRAPPER_UNIMPLEMENTED(ctx);
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "set_value_by_flags_and_idx", int64_t);
WRAPPER_DUMP_PARAM6(ctx,
stop_flags,
pre_ids_all,
input_ids,
seq_lens_encoder,
seq_lens_decoder,
step_idx);
WRAPPER_DUMP_PARAM3(ctx, bs, length, length_input_ids);
WRAPPER_DUMP(ctx);
int64_t stop_flags_len = -1;
int64_t pre_ids_all_len = -1;
int64_t input_ids_len = -1;
int64_t seq_lens_encoder_len = -1;
int64_t seq_lens_decoder_len = -1;
int64_t step_idx_len = -1;
WRAPPER_CHECK_SHAPE(ctx, &stop_flags_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &pre_ids_all_len, {bs, length});
WRAPPER_CHECK_SHAPE(ctx, &input_ids_len, {bs, length_input_ids});
WRAPPER_CHECK_SHAPE(ctx, &seq_lens_encoder_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &seq_lens_decoder_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &step_idx_len, {bs});
WRAPPER_CHECK_PTR(ctx, int64_t, stop_flags_len, stop_flags);
WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_all_len, pre_ids_all);
WRAPPER_CHECK_PTR(ctx, int64_t, input_ids_len, input_ids);
WRAPPER_CHECK_PTR(ctx, int, seq_lens_encoder_len, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, seq_lens_decoder_len, seq_lens_decoder);
WRAPPER_CHECK_PTR(ctx, int64_t, step_idx_len, step_idx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
stop_flags,
pre_ids_all,
input_ids,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
bs,
length,
length_input_ids);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
stop_flags,
pre_ids_all,
input_ids,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
bs,
length,
length_input_ids);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -12,263 +12,367 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__attribute__((global)) void
min_length_logits_process(T *logits, const int64_t *cur_len,
const int64_t *min_len, const int64_t *eos_token_id,
const int64_t bs, const int64_t length,
const int64_t length_id, const int64_t end_length);
__attribute__((global)) void
update_repeat_times(const int64_t *pre_ids, const int64_t *cur_len,
int *repeat_times, const int64_t bs, const int64_t length,
const int64_t length_id);
__attribute__((global)) void min_length_logits_process(
T *logits,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length);
__attribute__((global)) void update_repeat_times(const int64_t *pre_ids,
const int64_t *cur_len,
int *repeat_times,
const int64_t bs,
const int64_t length,
const int64_t length_id);
template <typename T>
__attribute__((global)) void
update_value_by_repeat_times(const int *repeat_times, const T *penalty_scores,
const T *frequency_score, const T *presence_score,
const float *temperatures, T *logits,
const int64_t bs, const int64_t length);
__attribute__((global)) void update_value_by_repeat_times(
const int *repeat_times,
const T *penalty_scores,
const T *frequency_score,
const T *presence_score,
const float *temperatures,
T *logits,
const int64_t bs,
const int64_t length);
template <typename T>
__attribute__((global)) void update_value_by_repeat_times_simd(
const int *repeat_times, const T *penalty_scores, const T *frequency_score,
const T *presence_score, const float *temperatures, T *logits,
const int64_t bs, const int64_t length);
const int *repeat_times,
const T *penalty_scores,
const T *frequency_score,
const T *presence_score,
const float *temperatures,
T *logits,
const int64_t bs,
const int64_t length);
template <typename T>
__attribute__((global)) void
ban_bad_words(T *logits, const int64_t *bad_words_list, const int64_t bs,
const int64_t length, const int64_t bad_words_length);
__attribute__((global)) void ban_bad_words(T *logits,
const int64_t *bad_words_list,
const int64_t bs,
const int64_t length,
const int64_t bad_words_length);
} // namespace plugin
} // namespace xpu3
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
void update_repeat_times_cpu(const int64_t *pre_ids, const int64_t *cur_len,
int *repeat_times, const int64_t bs,
const int64_t length, const int64_t length_id) {
for (int64_t i = 0; i < bs; i++) {
if (cur_len[i] >= 0) {
for (int64_t j = 0; j < length_id; j++) {
int64_t id = pre_ids[i * length_id + j];
if (id < 0 || id >= length)
continue;
repeat_times[i * length + id] += 1;
}
}
void update_repeat_times_cpu(const int64_t *pre_ids,
const int64_t *cur_len,
int *repeat_times,
const int64_t bs,
const int64_t length,
const int64_t length_id) {
for (int64_t i = 0; i < bs; i++) {
if (cur_len[i] >= 0) {
for (int64_t j = 0; j < length_id; j++) {
int64_t id = pre_ids[i * length_id + j];
if (id < 0 || id >= length) continue;
repeat_times[i * length + id] += 1;
}
}
}
}
void ban_bad_words_cpu(float *logits, const int64_t *bad_words_list,
const int64_t bs, const int64_t length,
void ban_bad_words_cpu(float *logits,
const int64_t *bad_words_list,
const int64_t bs,
const int64_t length,
const int64_t bad_words_length) {
for (int64_t i = 0; i < bs; i++) {
float *logits_now = logits + i * length;
for (int64_t j = 0; j < bad_words_length; j++) {
int64_t bad_words_token_id = bad_words_list[j];
if (bad_words_token_id >= length || bad_words_token_id < 0)
continue;
logits_now[bad_words_token_id] = -1e10;
}
for (int64_t i = 0; i < bs; i++) {
float *logits_now = logits + i * length;
for (int64_t j = 0; j < bad_words_length; j++) {
int64_t bad_words_token_id = bad_words_list[j];
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
logits_now[bad_words_token_id] = -1e10;
}
}
}
template <typename T>
static int cpu_wrapper(Context *ctx, const int64_t *pre_ids, T *logits,
const T *penalty_scores, const T *frequency_scores,
const T *presence_scores, const float *temperatures,
const int64_t *cur_len, const int64_t *min_len,
const int64_t *eos_token_id, const int64_t *bad_words,
const int64_t bs, const int64_t length,
const int64_t length_id, const int64_t end_length,
static int cpu_wrapper(Context *ctx,
const int64_t *pre_ids,
T *logits,
const T *penalty_scores,
const T *frequency_scores,
const T *presence_scores,
const float *temperatures,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int64_t *bad_words,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words) {
std::vector<float> logitsfp32(bs * length);
std::vector<float> penalty_scoresfp32(bs);
std::vector<float> frequency_scoresfp32(bs);
std::vector<float> presence_scoresfp32(bs);
std::vector<int> repeat_times_buffer(bs * length, 0);
int ret = api::cast<T, float>(ctx, logits, logitsfp32.data(), bs * length);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
ret =
api::cast<T, float>(ctx, penalty_scores, penalty_scoresfp32.data(), bs);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
ret = api::cast<T, float>(ctx, frequency_scores,
frequency_scoresfp32.data(), bs);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
ret = api::cast<T, float>(ctx, presence_scores, presence_scoresfp32.data(),
bs);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
for (int64_t i = 0; i < bs; i++) {
if (cur_len[i] >= 0 && cur_len[i] < min_len[i]) {
for (int64_t j = 0; j < end_length; j++) {
logitsfp32[i * length + eos_token_id[j]] = -1e4;
}
}
std::vector<float> logitsfp32(bs * length);
std::vector<float> penalty_scoresfp32(bs);
std::vector<float> frequency_scoresfp32(bs);
std::vector<float> presence_scoresfp32(bs);
std::vector<int> repeat_times_buffer(bs * length, 0);
int ret = api::cast<T, float>(ctx, logits, logitsfp32.data(), bs * length);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
ret = api::cast<T, float>(ctx, penalty_scores, penalty_scoresfp32.data(), bs);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
ret = api::cast<T, float>(
ctx, frequency_scores, frequency_scoresfp32.data(), bs);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
ret =
api::cast<T, float>(ctx, presence_scores, presence_scoresfp32.data(), bs);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
for (int64_t i = 0; i < bs; i++) {
if (cur_len[i] >= 0 && cur_len[i] < min_len[i]) {
for (int64_t j = 0; j < end_length; j++) {
logitsfp32[i * length + eos_token_id[j]] = -1e4;
}
}
int *repeat_times = &(repeat_times_buffer[0]);
update_repeat_times_cpu(pre_ids, cur_len, repeat_times, bs, length,
length_id);
for (int64_t i = 0; i < bs; i++) {
float alpha = penalty_scoresfp32[i];
float beta = frequency_scoresfp32[i];
float gamma = presence_scoresfp32[i];
float temperature = temperatures[i];
for (int64_t j = 0; j < length; j++) {
int times = repeat_times[i * length + j];
float logit_now = logitsfp32[i * length + j];
if (times != 0) {
logit_now =
logit_now < 0 ? logit_now * alpha : logit_now / alpha;
logit_now = logit_now - times * beta - gamma;
}
logitsfp32[i * length + j] = logit_now / temperature;
}
}
int *repeat_times = &(repeat_times_buffer[0]);
update_repeat_times_cpu(
pre_ids, cur_len, repeat_times, bs, length, length_id);
for (int64_t i = 0; i < bs; i++) {
float alpha = penalty_scoresfp32[i];
float beta = frequency_scoresfp32[i];
float gamma = presence_scoresfp32[i];
float temperature = temperatures[i];
for (int64_t j = 0; j < length; j++) {
int times = repeat_times[i * length + j];
float logit_now = logitsfp32[i * length + j];
if (times != 0) {
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
logit_now = logit_now - times * beta - gamma;
}
logitsfp32[i * length + j] = logit_now / temperature;
}
if (bad_words && length_bad_words > 0) {
ban_bad_words_cpu(logitsfp32.data(), bad_words, bs, length,
length_bad_words);
}
ret = api::cast<float, T>(ctx, logitsfp32.data(), logits, bs * length);
return ret;
}
if (bad_words && length_bad_words > 0) {
ban_bad_words_cpu(
logitsfp32.data(), bad_words, bs, length, length_bad_words);
}
ret = api::cast<float, T>(ctx, logitsfp32.data(), logits, bs * length);
return ret;
}
template <typename T>
static int xpu3_wrapper(Context *ctx, const int64_t *pre_ids, T *logits,
const T *penalty_scores, const T *frequency_scores,
const T *presence_scores, const float *temperatures,
const int64_t *cur_len, const int64_t *min_len,
const int64_t *eos_token_id, const int64_t *bad_words,
const int64_t bs, const int64_t length,
const int64_t length_id, const int64_t end_length,
static int xpu3_wrapper(Context *ctx,
const int64_t *pre_ids,
T *logits,
const T *penalty_scores,
const T *frequency_scores,
const T *presence_scores,
const float *temperatures,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int64_t *bad_words,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words) {
api::ctx_guard RAII_GUARD(ctx);
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto min_length_logits_process_kernel =
xpu3::plugin::min_length_logits_process<T>;
auto update_repeat_times_kernel = xpu3::plugin::update_repeat_times;
auto update_value_by_repeat_times_kernel =
xpu3::plugin::update_value_by_repeat_times<T>;
if (length % 16 == 0) {
update_value_by_repeat_times_kernel =
xpu3::plugin::update_value_by_repeat_times_simd<T>;
}
auto ban_bad_words_kernel = xpu3::plugin::ban_bad_words<T>;
api::ctx_guard RAII_GUARD(ctx);
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto min_length_logits_process_kernel =
xpu3::plugin::min_length_logits_process<T>;
auto update_repeat_times_kernel = xpu3::plugin::update_repeat_times;
auto update_value_by_repeat_times_kernel =
xpu3::plugin::update_value_by_repeat_times<T>;
if (length % 16 == 0) {
update_value_by_repeat_times_kernel =
xpu3::plugin::update_value_by_repeat_times_simd<T>;
}
auto ban_bad_words_kernel = xpu3::plugin::ban_bad_words<T>;
int *repeat_times = RAII_GUARD.alloc_l3_or_gm<int>(bs * length);
WRAPPER_ASSERT_WORKSPACE(ctx, repeat_times);
int ret = api::constant<int>(ctx, repeat_times, bs * length, 0);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
int *repeat_times = RAII_GUARD.alloc_l3_or_gm<int>(bs * length);
WRAPPER_ASSERT_WORKSPACE(ctx, repeat_times);
int ret = api::constant<int>(ctx, repeat_times, bs * length, 0);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64 *>(pre_ids),
reinterpret_cast<const XPU_INT64 *>(cur_len), repeat_times, bs, length,
length_id);
min_length_logits_process_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
logits, reinterpret_cast<const XPU_INT64 *>(cur_len),
reinterpret_cast<const XPU_INT64 *>(min_len),
reinterpret_cast<const XPU_INT64 *>(eos_token_id), bs, length,
length_id, end_length);
update_value_by_repeat_times_kernel<<<ctx->ncluster(), 64,
ctx->xpu_stream>>>(
repeat_times, penalty_scores, frequency_scores, presence_scores,
temperatures, logits, bs, length);
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64 *>(pre_ids),
reinterpret_cast<const XPU_INT64 *>(cur_len),
repeat_times,
bs,
length,
length_id);
min_length_logits_process_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
logits,
reinterpret_cast<const XPU_INT64 *>(cur_len),
reinterpret_cast<const XPU_INT64 *>(min_len),
reinterpret_cast<const XPU_INT64 *>(eos_token_id),
bs,
length,
length_id,
end_length);
update_value_by_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
repeat_times,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
logits,
bs,
length);
if (bad_words && length_bad_words > 0) {
ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
logits, reinterpret_cast<const XPU_INT64 *>(bad_words), bs, length,
length_bad_words);
}
return api::SUCCESS;
if (bad_words && length_bad_words > 0) {
ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
logits,
reinterpret_cast<const XPU_INT64 *>(bad_words),
bs,
length,
length_bad_words);
}
return api::SUCCESS;
}
template <typename T>
int token_penalty_multi_scores(
Context *ctx, const int64_t *pre_ids, T *logits, const T *penalty_scores,
const T *frequency_scores, const T *presence_scores,
const float *temperatures, const int64_t *cur_len, const int64_t *min_len,
const int64_t *eos_token_id, const int64_t *bad_words, const int64_t bs,
const int64_t length, const int64_t length_id, const int64_t end_length,
const int64_t length_bad_words) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "token_penalty_multi_scores", T);
WRAPPER_DUMP_PARAM6(ctx, pre_ids, logits, penalty_scores, frequency_scores,
presence_scores, temperatures);
WRAPPER_DUMP_PARAM3(ctx, cur_len, min_len, eos_token_id);
WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length);
WRAPPER_DUMP(ctx);
// TODO(mayang02) shape check
int64_t pre_ids_len = -1;
int64_t logits_len = -1;
int64_t penalty_scores_len = -1;
int64_t frequency_scores_len = -1;
int64_t presence_scores_len = -1;
int64_t temperatures_len = -1;
int64_t cur_len_len = -1;
int64_t min_len_len = -1;
int64_t eos_token_id_len = -1;
int64_t bad_words_len = -1;
WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id});
WRAPPER_CHECK_SHAPE(ctx, &logits_len, {bs, length});
WRAPPER_CHECK_SHAPE(ctx, &penalty_scores_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &frequency_scores_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &presence_scores_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &temperatures_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &cur_len_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length});
WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words});
WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids);
WRAPPER_CHECK_PTR(ctx, T, logits_len, logits);
WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores);
WRAPPER_CHECK_PTR(ctx, T, frequency_scores_len, frequency_scores);
WRAPPER_CHECK_PTR(ctx, T, presence_scores_len, presence_scores);
WRAPPER_CHECK_PTR(ctx, float, temperatures_len, temperatures);
WRAPPER_CHECK_PTR(ctx, int64_t, cur_len_len, cur_len);
WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len);
WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id);
WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(ctx, pre_ids, logits, penalty_scores,
frequency_scores, presence_scores, temperatures,
cur_len, min_len, eos_token_id, bad_words, bs,
length, length_id, end_length, length_bad_words);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T>(ctx, pre_ids, logits, penalty_scores,
frequency_scores, presence_scores, temperatures,
cur_len, min_len, eos_token_id, bad_words, bs,
length, length_id, end_length, length_bad_words);
}
WRAPPER_UNIMPLEMENTED(ctx);
int token_penalty_multi_scores(Context *ctx,
const int64_t *pre_ids,
T *logits,
const T *penalty_scores,
const T *frequency_scores,
const T *presence_scores,
const float *temperatures,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int64_t *bad_words,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "token_penalty_multi_scores", T);
WRAPPER_DUMP_PARAM6(ctx,
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures);
WRAPPER_DUMP_PARAM3(ctx, cur_len, min_len, eos_token_id);
WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length);
WRAPPER_DUMP(ctx);
// TODO(mayang02) shape check
int64_t pre_ids_len = -1;
int64_t logits_len = -1;
int64_t penalty_scores_len = -1;
int64_t frequency_scores_len = -1;
int64_t presence_scores_len = -1;
int64_t temperatures_len = -1;
int64_t cur_len_len = -1;
int64_t min_len_len = -1;
int64_t eos_token_id_len = -1;
int64_t bad_words_len = -1;
WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id});
WRAPPER_CHECK_SHAPE(ctx, &logits_len, {bs, length});
WRAPPER_CHECK_SHAPE(ctx, &penalty_scores_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &frequency_scores_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &presence_scores_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &temperatures_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &cur_len_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length});
WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words});
WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids);
WRAPPER_CHECK_PTR(ctx, T, logits_len, logits);
WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores);
WRAPPER_CHECK_PTR(ctx, T, frequency_scores_len, frequency_scores);
WRAPPER_CHECK_PTR(ctx, T, presence_scores_len, presence_scores);
WRAPPER_CHECK_PTR(ctx, float, temperatures_len, temperatures);
WRAPPER_CHECK_PTR(ctx, int64_t, cur_len_len, cur_len);
WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len);
WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id);
WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(ctx,
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
cur_len,
min_len,
eos_token_id,
bad_words,
bs,
length,
length_id,
end_length,
length_bad_words);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T>(ctx,
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
cur_len,
min_len,
eos_token_id,
bad_words,
bs,
length,
length_id,
end_length,
length_bad_words);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int token_penalty_multi_scores<float>(
Context *ctx, const int64_t *pre_ids, float *logits,
const float *penalty_scores, const float *frequency_scores,
const float *presence_scores, const float *temperatures,
const int64_t *cur_len, const int64_t *min_len, const int64_t *eos_token_id,
const int64_t *bad_words, const int64_t bs, const int64_t length,
const int64_t length_id, const int64_t end_length,
const int64_t length_bad_words);
template int token_penalty_multi_scores<float>(Context *ctx,
const int64_t *pre_ids,
float *logits,
const float *penalty_scores,
const float *frequency_scores,
const float *presence_scores,
const float *temperatures,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int64_t *bad_words,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words);
template int token_penalty_multi_scores<float16>(
Context *ctx, const int64_t *pre_ids, float16 *logits,
const float16 *penalty_scores, const float16 *frequency_scores,
const float16 *presence_scores, const float *temperatures,
const int64_t *cur_len, const int64_t *min_len, const int64_t *eos_token_id,
const int64_t *bad_words, const int64_t bs, const int64_t length,
const int64_t length_id, const int64_t end_length,
Context *ctx,
const int64_t *pre_ids,
float16 *logits,
const float16 *penalty_scores,
const float16 *frequency_scores,
const float16 *presence_scores,
const float *temperatures,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int64_t *bad_words,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -20,21 +20,18 @@
namespace xpu3 {
namespace plugin {
template <typename TX, typename TSCALE, typename TY>
__attribute__((global)) void
quant2d_per_channel_cluster(const TX *x, const TSCALE *scale, TY *y, int64_t m,
int64_t n);
__attribute__((global)) void quant2d_per_channel_cluster(
const TX *x, const TSCALE *scale, TY *y, int64_t m, int64_t n);
template <typename TX, typename TSCALE, typename TY, int MAX_N>
__attribute__((global)) void
quant2d_per_channel_cached(const TX *input, TY *output, TSCALE *scale,
int64_t m, int64_t n);
__attribute__((global)) void quant2d_per_channel_cached(
const TX *input, TY *output, TSCALE *scale, int64_t m, int64_t n);
template <typename TX, typename TSCALE, typename TY>
__attribute__((global)) void quant2d_per_channel_bign(const TX *input,
TY *output, TSCALE *scale,
int64_t m, int64_t n);
} // namespace plugin
} // namespace xpu3
__attribute__((global)) void quant2d_per_channel_bign(
const TX *input, TY *output, TSCALE *scale, int64_t m, int64_t n);
} // namespace plugin
} // namespace xpu3
namespace api = baidu::xpu::api;
@@ -43,11 +40,17 @@ namespace xpu {
namespace api {
namespace plugin {
template <typename TX, typename TSCALE, typename TY,
template <typename TX,
typename TSCALE,
typename TY,
typename std::enable_if<std::is_same<TY, int8_t>::value, TY>::type
*ptr = nullptr>
int cpu_wrapper_input_scale(api::Context *ctx, const TX *x, const TSCALE *scale,
TY *y, int64_t m, int64_t n) {
int cpu_wrapper_input_scale(api::Context *ctx,
const TX *x,
const TSCALE *scale,
TY *y,
int64_t m,
int64_t n) {
float absmax = 1e-30f;
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
@@ -78,11 +81,13 @@ static float16 quant_int4(float x, float scale) {
return (float16)std::min(static_cast<float>(r), 7.f);
}
template <typename TX, typename TSCALE, typename TY,
template <typename TX,
typename TSCALE,
typename TY,
typename std::enable_if<std::is_same<TY, int4_t>::value, TY>::type
*ptr = nullptr>
int cpu_wrapper_input_scale(api::Context *ctx, const TX *x, const TSCALE *scale,
TY *y, int m, int n) {
int cpu_wrapper_input_scale(
api::Context *ctx, const TX *x, const TSCALE *scale, TY *y, int m, int n) {
int8_t *y_ptr = reinterpret_cast<int8_t *>(y);
float t1, t2;
for (int i = 0; i < m; ++i) {
@@ -109,11 +114,17 @@ int cpu_wrapper_input_scale(api::Context *ctx, const TX *x, const TSCALE *scale,
return api::SUCCESS;
}
template <typename TX, typename TSCALE, typename TY,
template <typename TX,
typename TSCALE,
typename TY,
typename std::enable_if<!std::is_same<TY, int4_t>::value, TY>::type
*ptr = nullptr>
int cpu_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
TY *y, int64_t m, int64_t n) {
int cpu_wrapper_output_scale(api::Context *ctx,
const TX *x,
TSCALE *scale,
TY *y,
int64_t m,
int64_t n) {
int64_t i, j;
for (j = 0; j < n; ++j) {
float absmax = 1e-30f;
@@ -129,11 +140,13 @@ int cpu_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
return api::SUCCESS;
}
template <typename TX, typename TSCALE, typename TY,
template <typename TX,
typename TSCALE,
typename TY,
typename std::enable_if<std::is_same<TY, int4_t>::value, TY>::type
*ptr = nullptr>
int cpu_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
TY *y, int m, int n) {
int cpu_wrapper_output_scale(
api::Context *ctx, const TX *x, TSCALE *scale, TY *y, int m, int n) {
int8_t *y_ptr = reinterpret_cast<int8_t *>(y);
float t1, t2, absmax_1, absmax_2, act_scale_1, act_scale_2;
for (int j = 0; j < n; j += 2) {
@@ -173,18 +186,28 @@ int cpu_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
}
template <typename TX, typename TSCALE, typename TY>
int xpu3_wrapper_input_scale(api::Context *ctx, const TX *x,
const TSCALE *scale, TY *y, int64_t m, int64_t n) {
int xpu3_wrapper_input_scale(api::Context *ctx,
const TX *x,
const TSCALE *scale,
TY *y,
int64_t m,
int64_t n) {
auto func = xpu3::plugin::quant2d_per_channel_cluster<TX, TSCALE, TY>;
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, scale, y, m, n);
return api::SUCCESS;
}
template <typename TX, typename TSCALE, typename TY,
template <typename TX,
typename TSCALE,
typename TY,
typename std::enable_if<!std::is_same<TY, int4_t>::value, TY>::type
* = nullptr>
int xpu3_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
TY *y, int64_t m, int64_t n) {
int xpu3_wrapper_output_scale(api::Context *ctx,
const TX *x,
TSCALE *scale,
TY *y,
int64_t m,
int64_t n) {
int64_t channel_size = m * sizeof(TX);
int64_t cluster_n = (n + ctx->ncluster() - 1) / ctx->ncluster();
auto func = xpu3::plugin::quant2d_per_channel_bign<TX, TSCALE, TY>;
@@ -210,19 +233,30 @@ int xpu3_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n);
return api::SUCCESS;
}
template <typename TX, typename TSCALE, typename TY,
template <typename TX,
typename TSCALE,
typename TY,
typename std::enable_if<std::is_same<TY, int4_t>::value, TY>::type * =
nullptr>
int xpu3_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
TY *y, int64_t m, int64_t n) {
int xpu3_wrapper_output_scale(api::Context *ctx,
const TX *x,
TSCALE *scale,
TY *y,
int64_t m,
int64_t n) {
auto func = xpu3::plugin::quant2d_per_channel_bign<TX, TSCALE, TY>;
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n);
return api::SUCCESS;
}
template <typename TX, typename TSCALE, typename TY>
int quant2d_per_channel(api::Context *ctx, const TX *x, const TSCALE *scale_in,
TY *y, TSCALE *scale_out, int64_t m, int64_t n) {
int quant2d_per_channel(api::Context *ctx,
const TX *x,
const TSCALE *scale_in,
TY *y,
TSCALE *scale_out,
int64_t m,
int64_t n) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T3(ctx, "quant2d_per_channel", TX, TSCALE, TY);
WRAPPER_DUMP_PARAM4(ctx, x, scale_in, y, scale_out);
@@ -251,20 +285,24 @@ int quant2d_per_channel(api::Context *ctx, const TX *x, const TSCALE *scale_in,
}
if (ctx->dev().type() == api::kXPU3) {
if (scale_in != nullptr) {
return xpu3_wrapper_input_scale<TX, TSCALE, TY>(ctx, x, scale_in, y, m,
n);
return xpu3_wrapper_input_scale<TX, TSCALE, TY>(
ctx, x, scale_in, y, m, n);
}
return xpu3_wrapper_output_scale<TX, TSCALE, TY>(ctx, x, scale_out, y, m,
n);
return xpu3_wrapper_output_scale<TX, TSCALE, TY>(
ctx, x, scale_out, y, m, n);
}
WRAPPER_UNIMPLEMENTED(ctx);
return 0;
}
#define INSTANTIATION_QUANT2D_PER_CHANNEL(TX, TSCALE, TY) \
template int quant2d_per_channel<TX, TSCALE, TY>( \
api::Context *, const TX *, const TSCALE *, TY *, TSCALE *, int64_t, \
int64_t);
#define INSTANTIATION_QUANT2D_PER_CHANNEL(TX, TSCALE, TY) \
template int quant2d_per_channel<TX, TSCALE, TY>(api::Context *, \
const TX *, \
const TSCALE *, \
TY *, \
TSCALE *, \
int64_t, \
int64_t);
INSTANTIATION_QUANT2D_PER_CHANNEL(float16, float, int8_t);
INSTANTIATION_QUANT2D_PER_CHANNEL(bfloat16, float, int8_t);
@@ -274,7 +312,7 @@ INSTANTIATION_QUANT2D_PER_CHANNEL(float16, float16, int4_t);
INSTANTIATION_QUANT2D_PER_CHANNEL(float16, float, int4_t);
INSTANTIATION_QUANT2D_PER_CHANNEL(float, float, int4_t);
INSTANTIATION_QUANT2D_PER_CHANNEL(bfloat16, float, int4_t);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -12,28 +12,38 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void
recover_block(int *recover_block_list, // [bsz]
int *recover_len, bool *stop_flags, int *seq_lens_this_time,
const int *ori_seq_lens_encoder, int *seq_lens_encoder,
const int *seq_lens_decoder, int *block_tables, int *free_list,
int *free_list_len, int64_t *input_ids, const int64_t *pre_ids,
const int64_t *step_idx, const int *encoder_block_lens,
const int *used_list_len, const int64_t *next_tokens,
const int64_t *first_token_ids, const int bsz,
const int block_num_per_seq, const int length,
const int pre_id_length);
__attribute__((global)) void recover_block(int *recover_block_list, // [bsz]
int *recover_len,
bool *stop_flags,
int *seq_lens_this_time,
const int *ori_seq_lens_encoder,
int *seq_lens_encoder,
const int *seq_lens_decoder,
int *block_tables,
int *free_list,
int *free_list_len,
int64_t *input_ids,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *encoder_block_lens,
const int *used_list_len,
const int64_t *next_tokens,
const int64_t *first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length);
} // namespace plugin
} // namespace xpu3
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
@@ -41,125 +51,207 @@ namespace api {
namespace plugin {
static int cpu_wrapper(Context *ctx,
int *recover_block_list, // [bsz]
int *recover_len, bool *stop_flags,
int *seq_lens_this_time, const int *ori_seq_lens_encoder,
int *seq_lens_encoder, const int *seq_lens_decoder,
int *block_tables, int *free_list, int *free_list_len,
int64_t *input_ids, const int64_t *pre_ids,
const int64_t *step_idx, const int *encoder_block_lens,
const int *used_list_len, const int64_t *next_tokens,
const int64_t *first_token_ids, const int bsz,
const int block_num_per_seq, const int length,
int *recover_block_list, // [bsz]
int *recover_len,
bool *stop_flags,
int *seq_lens_this_time,
const int *ori_seq_lens_encoder,
int *seq_lens_encoder,
const int *seq_lens_decoder,
int *block_tables,
int *free_list,
int *free_list_len,
int64_t *input_ids,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *encoder_block_lens,
const int *used_list_len,
const int64_t *next_tokens,
const int64_t *first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length) {
for (int bid = 0; bid < recover_len[0]; bid++) {
const int recover_id = recover_block_list[bid];
const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id];
const int step_idx_now = step_idx[recover_id];
const int seq_len = ori_seq_len_encoder + step_idx_now;
const int encoder_block_len = encoder_block_lens[recover_id];
const int decoder_used_len = used_list_len[recover_id];
int *block_table_now = block_tables + recover_id * block_num_per_seq;
int64_t *input_ids_now = input_ids + recover_id * length;
const int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length;
for (int bid = 0; bid < recover_len[0]; bid++) {
const int recover_id = recover_block_list[bid];
const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id];
const int step_idx_now = step_idx[recover_id];
const int seq_len = ori_seq_len_encoder + step_idx_now;
const int encoder_block_len = encoder_block_lens[recover_id];
const int decoder_used_len = used_list_len[recover_id];
int *block_table_now = block_tables + recover_id * block_num_per_seq;
int64_t *input_ids_now = input_ids + recover_id * length;
const int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length;
seq_lens_this_time[recover_id] = seq_len;
seq_lens_encoder[recover_id] = seq_len;
stop_flags[recover_id] = false;
input_ids_now[seq_len - 1] = next_tokens[recover_id]; // next tokens
input_ids_now[0] =
first_token_ids[recover_id]; // set first prompt token
int ori_free_list_len = free_list_len[0];
free_list_len[0] -= decoder_used_len;
seq_lens_this_time[recover_id] = seq_len;
seq_lens_encoder[recover_id] = seq_len;
stop_flags[recover_id] = false;
input_ids_now[seq_len - 1] = next_tokens[recover_id]; // next tokens
input_ids_now[0] = first_token_ids[recover_id]; // set first prompt token
int ori_free_list_len = free_list_len[0];
free_list_len[0] -= decoder_used_len;
// 恢复block table
for (int i = 0; i < decoder_used_len; i++) {
block_table_now[encoder_block_len + i] =
free_list[ori_free_list_len - i - 1];
}
// 恢复input_ids
for (int i = 0; i < step_idx_now - 1; i++) {
input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1];
}
// 恢复block table
for (int i = 0; i < decoder_used_len; i++) {
block_table_now[encoder_block_len + i] =
free_list[ori_free_list_len - i - 1];
}
recover_len[0] = 0;
return api::SUCCESS;
// 恢复input_ids
for (int i = 0; i < step_idx_now - 1; i++) {
input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1];
}
}
recover_len[0] = 0;
return api::SUCCESS;
}
static int xpu3_wrapper(Context *ctx,
int *recover_block_list, // [bsz]
int *recover_len, bool *stop_flags,
int *recover_block_list, // [bsz]
int *recover_len,
bool *stop_flags,
int *seq_lens_this_time,
const int *ori_seq_lens_encoder, int *seq_lens_encoder,
const int *seq_lens_decoder, int *block_tables,
int *free_list, int *free_list_len, int64_t *input_ids,
const int64_t *pre_ids, const int64_t *step_idx,
const int *encoder_block_lens, const int *used_list_len,
const int *ori_seq_lens_encoder,
int *seq_lens_encoder,
const int *seq_lens_decoder,
int *block_tables,
int *free_list,
int *free_list_len,
int64_t *input_ids,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *encoder_block_lens,
const int *used_list_len,
const int64_t *next_tokens,
const int64_t *first_token_ids, const int bsz,
const int block_num_per_seq, const int length,
const int64_t *first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto recover_block_kernel = xpu3::plugin::recover_block;
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
recover_block_list, // [bsz]
recover_len, stop_flags, seq_lens_this_time, ori_seq_lens_encoder,
seq_lens_encoder, seq_lens_decoder, block_tables, free_list,
free_list_len, reinterpret_cast<XPU_INT64 *>(input_ids),
reinterpret_cast<const XPU_INT64 *>(pre_ids),
reinterpret_cast<const XPU_INT64 *>(step_idx), encoder_block_lens,
used_list_len, reinterpret_cast<const XPU_INT64 *>(next_tokens),
reinterpret_cast<const XPU_INT64 *>(first_token_ids), bsz,
block_num_per_seq, length, pre_id_length);
return api::SUCCESS;
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto recover_block_kernel = xpu3::plugin::recover_block;
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
recover_block_list, // [bsz]
recover_len,
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(input_ids),
reinterpret_cast<const XPU_INT64 *>(pre_ids),
reinterpret_cast<const XPU_INT64 *>(step_idx),
encoder_block_lens,
used_list_len,
reinterpret_cast<const XPU_INT64 *>(next_tokens),
reinterpret_cast<const XPU_INT64 *>(first_token_ids),
bsz,
block_num_per_seq,
length,
pre_id_length);
return api::SUCCESS;
}
int recover_block(Context *ctx,
int *recover_block_list, // [bsz]
int *recover_len, bool *stop_flags, int *seq_lens_this_time,
const int *ori_seq_lens_encoder, int *seq_lens_encoder,
const int *seq_lens_decoder, int *block_tables,
int *free_list, int *free_list_len, int64_t *input_ids,
const int64_t *pre_ids, const int64_t *step_idx,
const int *encoder_block_lens, const int *used_list_len,
const int64_t *next_tokens, const int64_t *first_token_ids,
const int bsz, const int block_num_per_seq, const int length,
int *recover_block_list, // [bsz]
int *recover_len,
bool *stop_flags,
int *seq_lens_this_time,
const int *ori_seq_lens_encoder,
int *seq_lens_encoder,
const int *seq_lens_decoder,
int *block_tables,
int *free_list,
int *free_list_len,
int64_t *input_ids,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *encoder_block_lens,
const int *used_list_len,
const int64_t *next_tokens,
const int64_t *first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_block", float);
WRAPPER_DUMP_PARAM6(ctx, recover_block_list, recover_len, stop_flags,
seq_lens_this_time, ori_seq_lens_encoder,
seq_lens_encoder);
WRAPPER_DUMP_PARAM6(ctx, seq_lens_decoder, block_tables, free_list,
free_list_len, input_ids, pre_ids);
WRAPPER_DUMP_PARAM5(ctx, step_idx, encoder_block_lens, used_list_len,
next_tokens, first_token_ids);
WRAPPER_DUMP_PARAM4(ctx, bsz, block_num_per_seq, length, pre_id_length);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(
ctx,
recover_block_list, // [bsz]
recover_len, stop_flags, seq_lens_this_time, ori_seq_lens_encoder,
seq_lens_encoder, seq_lens_decoder, block_tables, free_list,
free_list_len, input_ids, pre_ids, step_idx, encoder_block_lens,
used_list_len, next_tokens, first_token_ids, bsz, block_num_per_seq,
length, pre_id_length);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(
ctx,
recover_block_list, // [bsz]
recover_len, stop_flags, seq_lens_this_time, ori_seq_lens_encoder,
seq_lens_encoder, seq_lens_decoder, block_tables, free_list,
free_list_len, input_ids, pre_ids, step_idx, encoder_block_lens,
used_list_len, next_tokens, first_token_ids, bsz, block_num_per_seq,
length, pre_id_length);
}
WRAPPER_UNIMPLEMENTED(ctx);
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_block", float);
WRAPPER_DUMP_PARAM6(ctx,
recover_block_list,
recover_len,
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder);
WRAPPER_DUMP_PARAM6(ctx,
seq_lens_decoder,
block_tables,
free_list,
free_list_len,
input_ids,
pre_ids);
WRAPPER_DUMP_PARAM5(ctx,
step_idx,
encoder_block_lens,
used_list_len,
next_tokens,
first_token_ids);
WRAPPER_DUMP_PARAM4(ctx, bsz, block_num_per_seq, length, pre_id_length);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
recover_block_list, // [bsz]
recover_len,
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
free_list,
free_list_len,
input_ids,
pre_ids,
step_idx,
encoder_block_lens,
used_list_len,
next_tokens,
first_token_ids,
bsz,
block_num_per_seq,
length,
pre_id_length);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
recover_block_list, // [bsz]
recover_len,
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
free_list,
free_list_len,
input_ids,
pre_ids,
step_idx,
encoder_block_lens,
used_list_len,
next_tokens,
first_token_ids,
bsz,
block_num_per_seq,
length,
pre_id_length);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -12,96 +12,102 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void
recover_decode_task(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size);
__attribute__((global)) void recover_decode_task(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size);
} // namespace plugin
} // namespace xpu3
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int xpu3_wrapper(Context *ctx, bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto recover_decode_task = xpu3::plugin::recover_decode_task;
recover_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
block_tables,
is_block_step,
bsz,
block_num_per_seq,
block_size);
return api::SUCCESS;
static int xpu3_wrapper(Context *ctx,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto recover_decode_task = xpu3::plugin::recover_decode_task;
recover_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
block_tables,
is_block_step,
bsz,
block_num_per_seq,
block_size);
return api::SUCCESS;
}
int recover_decode_task(Context *ctx, bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_decode_task", int);
WRAPPER_DUMP_PARAM5(ctx, stop_flags, seq_lens_this_time,
seq_lens_encoder, seq_lens_decoder, step_seq_lens_decoder);
WRAPPER_DUMP_PARAM2(ctx, block_tables, is_block_step);
WRAPPER_DUMP_PARAM3(ctx, bsz, block_num_per_seq, block_size);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
assert(false);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx, stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
block_tables,
is_block_step,
bsz,
block_num_per_seq,
block_size);
}
WRAPPER_UNIMPLEMENTED(ctx);
int recover_decode_task(Context *ctx,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_decode_task", int);
WRAPPER_DUMP_PARAM5(ctx,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder);
WRAPPER_DUMP_PARAM2(ctx, block_tables, is_block_step);
WRAPPER_DUMP_PARAM3(ctx, bsz, block_num_per_seq, block_size);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
assert(false);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
block_tables,
is_block_step,
bsz,
block_num_per_seq,
block_size);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -18,18 +18,17 @@
namespace xpu3 {
namespace plugin {
template <typename T>
__attribute__((global)) void text_image_gather_scatter(
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter);
__attribute__((global)) void text_image_gather_scatter(T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter);
} // namespace plugin
} // namespace xpu3
@@ -41,18 +40,17 @@ namespace plugin {
template <typename T>
static int cpu_wrapper(
Context* ctx,
T* input, // shape [token_num, hidden_size]
T* text_input, // shape [text_token_num, hidden_size]
T* image_input, // shape [image_token_num, hidden_size]
int* token_type_ids,// shape [token_num], 0 for text, 1 for image
int* text_index, // shape [token_num], mapping from input to text_input
int* image_index, // shape [token_num], mapping from input to image_input
T* input, // shape [token_num, hidden_size]
T* text_input, // shape [text_token_num, hidden_size]
T* image_input, // shape [image_token_num, hidden_size]
int* token_type_ids, // shape [token_num], 0 for text, 1 for image
int* text_index, // shape [token_num], mapping from input to text_input
int* image_index, // shape [token_num], mapping from input to image_input
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
if (is_scatter) {
// Scatter mode: input -> text_input/image_input
for (int64_t i = 0; i < token_num; i++) {
@@ -106,36 +104,42 @@ static int cpu_wrapper(
}
template <typename T>
static int xpu3_wrapper(
Context* ctx,
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
xpu3::plugin::text_image_gather_scatter<T> <<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, is_scatter
);
static int xpu3_wrapper(Context* ctx,
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
xpu3::plugin::text_image_gather_scatter<T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(input,
text_input,
image_input,
token_type_ids,
text_index,
image_index,
token_num,
text_token_num,
image_token_num,
hidden_size,
is_scatter);
return api::SUCCESS;
}
template <typename T>
int text_image_gather_scatter(
Context* ctx,
T* input, // shape [token_num, hidden_size]
T* text_input, // shape [text_token_num, hidden_size]
T* image_input, // shape [image_token_num, hidden_size]
int* token_type_ids,// shape [token_num], 0 for text, 1 for image
int* text_index, // shape [token_num], mapping from input to text_input
int* image_index, // shape [token_num], mapping from input to image_input
T* input, // shape [token_num, hidden_size]
T* text_input, // shape [text_token_num, hidden_size]
T* image_input, // shape [image_token_num, hidden_size]
int* token_type_ids, // shape [token_num], 0 for text, 1 for image
int* text_index, // shape [token_num], mapping from input to text_input
int* image_index, // shape [token_num], mapping from input to image_input
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
@@ -143,14 +147,23 @@ int text_image_gather_scatter(
bool is_scatter) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_gather_scatter", T);
WRAPPER_DUMP_PARAM6(ctx, input, text_input, image_input, token_type_ids, text_index, image_index);
WRAPPER_DUMP_PARAM5(ctx, token_num, text_token_num, image_token_num, hidden_size, is_scatter);
WRAPPER_DUMP_PARAM6(ctx,
input,
text_input,
image_input,
token_type_ids,
text_index,
image_index);
WRAPPER_DUMP_PARAM5(
ctx, token_num, text_token_num, image_token_num, hidden_size, is_scatter);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, T, token_num * hidden_size, input);
if (text_token_num != 0) { // avoiding text_input tensor with shape [0, hidden_size]
if (text_token_num !=
0) { // avoiding text_input tensor with shape [0, hidden_size]
WRAPPER_CHECK_PTR(ctx, T, text_token_num * hidden_size, text_input);
}
if (image_token_num != 0) { // avoiding image_input tensor with shape [0, hidden_size]
if (image_token_num !=
0) { // avoiding image_input tensor with shape [0, hidden_size]
WRAPPER_CHECK_PTR(ctx, T, image_token_num * hidden_size, image_input);
}
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
@@ -159,23 +172,48 @@ int text_image_gather_scatter(
WRAPPER_ASSERT_EQ(ctx, token_num, text_token_num + image_token_num);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(
ctx, input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, is_scatter
);
return cpu_wrapper<T>(ctx,
input,
text_input,
image_input,
token_type_ids,
text_index,
image_index,
token_num,
text_token_num,
image_token_num,
hidden_size,
is_scatter);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T>(
ctx, input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, is_scatter
);
return xpu3_wrapper<T>(ctx,
input,
text_input,
image_input,
token_type_ids,
text_index,
image_index,
token_num,
text_token_num,
image_token_num,
hidden_size,
is_scatter);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int text_image_gather_scatter(
Context*, bfloat16*, bfloat16*, bfloat16*, int*, int*, int*, const int64_t, const int64_t, const int64_t, const int64_t, bool);
template int text_image_gather_scatter(Context*,
bfloat16*,
bfloat16*,
bfloat16*,
int*,
int*,
int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
bool);
} // namespace plugin
} // namespace api
} // namespace xpu
@@ -17,10 +17,11 @@
namespace xpu3 {
namespace plugin {
__attribute__((global)) void text_image_index_out_kernel(const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num);
__attribute__((global)) void text_image_index_out_kernel(
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num);
} // namespace plugin
} // namespace xpu3
@@ -30,69 +31,54 @@ namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
int text_count = 0;
int text_count = 0;
int image_count = 0;
for (int64_t i = 0; i < token_num; ++i) {
if (token_type_ids[i] == 0) {
text_index[i] = text_count;
++text_count;
} else {
image_index[i] = image_count;
++image_count;
}
if (token_type_ids[i] == 0) {
text_index[i] = text_count;
++text_count;
} else {
image_index[i] = image_count;
++image_count;
}
}
return api::SUCCESS;
}
static int xpu3_wrapper(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
xpu3::plugin::text_image_index_out_kernel<<<1, 1, ctx->xpu_stream>>>(
token_type_ids,
text_index,
image_index,
token_num);
token_type_ids, text_index, image_index, token_num);
return api::SUCCESS;
}
int text_image_index_out(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_index_out", int);
WRAPPER_DUMP_PARAM4(
ctx, token_type_ids, text_index, image_index, token_num);
WRAPPER_DUMP_PARAM4(ctx, token_type_ids, text_index, image_index, token_num);
WRAPPER_DUMP(ctx);
WRAPPER_ASSERT_GT(ctx, token_num, 0);
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
WRAPPER_CHECK_PTR(ctx, int, token_num, text_index);
WRAPPER_CHECK_PTR(ctx, int, token_num, image_index);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
token_type_ids,
text_index,
image_index,
token_num);
return cpu_wrapper(ctx, token_type_ids, text_index, image_index, token_num);
} else if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
token_type_ids,
text_index,
image_index,
token_num);
return xpu3_wrapper(
ctx, token_type_ids, text_index, image_index, token_num);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
@@ -12,108 +12,162 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void
update_inputs(bool *not_need_stop, int *seq_lens_this_time,
int *seq_lens_encoder, int *seq_lens_decoder, int64_t *input_ids,
const int64_t *stop_nums, const bool *stop_flags,
const bool *is_block_step, const int64_t *next_tokens,
const int bsz, const int max_bsz, const int input_ids_stride);
__attribute__((global)) void update_inputs(bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int64_t *input_ids,
const int64_t *stop_nums,
const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride);
} // namespace plugin
} // namespace xpu3
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context *ctx, bool *not_need_stop,
int *seq_lens_this_time, int *seq_lens_encoder,
int *seq_lens_decoder, int64_t *input_ids,
const int64_t *stop_nums, const bool *stop_flags,
const bool *is_block_step, const int64_t *next_tokens,
const int bsz, const int max_bsz,
static int cpu_wrapper(Context *ctx,
bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int64_t *input_ids,
const int64_t *stop_nums,
const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride) {
std::vector<int64_t> stop_flag_now_int(max_bsz, 1);
for (int i = 0; i < bsz; i++) {
bool stop_flags_now = stop_flags[i];
stop_flag_now_int[i] = is_block_step[i] ? 0 : stop_flags_now;
const int seq_len_encoder = seq_lens_encoder[i];
const int seq_len_decoder = seq_lens_decoder[i];
std::vector<int64_t> stop_flag_now_int(max_bsz, 1);
for (int i = 0; i < bsz; i++) {
bool stop_flags_now = stop_flags[i];
stop_flag_now_int[i] = is_block_step[i] ? 0 : stop_flags_now;
const int seq_len_encoder = seq_lens_encoder[i];
const int seq_len_decoder = seq_lens_decoder[i];
seq_lens_decoder[i] =
stop_flags[i] ? 0
: (seq_len_decoder == 0 ? seq_len_encoder
: seq_len_decoder + 1);
seq_lens_decoder[i] =
stop_flags[i]
? 0
: (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1);
seq_lens_this_time[i] = stop_flags[i] ? 0 : 1;
seq_lens_encoder[i] = 0;
int64_t *input_ids_now = input_ids + i * input_ids_stride;
input_ids_now[0] = next_tokens[i];
}
int64_t stop_sum = 0;
for (size_t i = 0; i < stop_flag_now_int.size(); i++) {
stop_sum += stop_flag_now_int[i];
}
not_need_stop[0] = stop_sum < stop_nums[0];
return api::SUCCESS;
seq_lens_this_time[i] = stop_flags[i] ? 0 : 1;
seq_lens_encoder[i] = 0;
int64_t *input_ids_now = input_ids + i * input_ids_stride;
input_ids_now[0] = next_tokens[i];
}
int64_t stop_sum = 0;
for (size_t i = 0; i < stop_flag_now_int.size(); i++) {
stop_sum += stop_flag_now_int[i];
}
not_need_stop[0] = stop_sum < stop_nums[0];
return api::SUCCESS;
}
static int xpu3_wrapper(Context *ctx, bool *not_need_stop,
int *seq_lens_this_time, int *seq_lens_encoder,
int *seq_lens_decoder, int64_t *input_ids,
const int64_t *stop_nums, const bool *stop_flags,
const bool *is_block_step, const int64_t *next_tokens,
const int bsz, const int max_bsz,
static int xpu3_wrapper(Context *ctx,
bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int64_t *input_ids,
const int64_t *stop_nums,
const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto update_inputs = xpu3::plugin::update_inputs;
update_inputs<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
not_need_stop, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder,
reinterpret_cast<XPU_INT64 *>(input_ids),
reinterpret_cast<const XPU_INT64 *>(stop_nums), stop_flags,
is_block_step, reinterpret_cast<const XPU_INT64 *>(next_tokens), bsz,
max_bsz, input_ids_stride);
return api::SUCCESS;
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto update_inputs = xpu3::plugin::update_inputs;
update_inputs<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<XPU_INT64 *>(input_ids),
reinterpret_cast<const XPU_INT64 *>(stop_nums),
stop_flags,
is_block_step,
reinterpret_cast<const XPU_INT64 *>(next_tokens),
bsz,
max_bsz,
input_ids_stride);
return api::SUCCESS;
}
int update_inputs(Context *ctx, bool *not_need_stop, int *seq_lens_this_time,
int *seq_lens_encoder, int *seq_lens_decoder,
int64_t *input_ids, const int64_t *stop_nums,
const bool *stop_flags, const bool *is_block_step,
const int64_t *next_tokens, const int bsz, const int max_bsz,
int update_inputs(Context *ctx,
bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int64_t *input_ids,
const int64_t *stop_nums,
const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs", int);
WRAPPER_DUMP_PARAM5(ctx, not_need_stop, seq_lens_this_time,
seq_lens_encoder, seq_lens_decoder, input_ids);
WRAPPER_DUMP_PARAM4(ctx, stop_nums, stop_flags, is_block_step, next_tokens);
WRAPPER_DUMP_PARAM3(ctx, bsz, max_bsz, input_ids_stride);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx, not_need_stop, seq_lens_this_time,
seq_lens_encoder, seq_lens_decoder, input_ids,
stop_nums, stop_flags, is_block_step, next_tokens,
bsz, max_bsz, input_ids_stride);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx, not_need_stop, seq_lens_this_time,
seq_lens_encoder, seq_lens_decoder, input_ids,
stop_nums, stop_flags, is_block_step, next_tokens,
bsz, max_bsz, input_ids_stride);
}
WRAPPER_UNIMPLEMENTED(ctx);
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs", int);
WRAPPER_DUMP_PARAM5(ctx,
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
input_ids);
WRAPPER_DUMP_PARAM4(ctx, stop_nums, stop_flags, is_block_step, next_tokens);
WRAPPER_DUMP_PARAM3(ctx, bsz, max_bsz, input_ids_stride);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
input_ids,
stop_nums,
stop_flags,
is_block_step,
next_tokens,
bsz,
max_bsz,
input_ids_stride);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
input_ids,
stop_nums,
stop_flags,
is_block_step,
next_tokens,
bsz,
max_bsz,
input_ids_stride);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -12,138 +12,146 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void
update_inputs_v1(bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size);
__attribute__((global)) void update_inputs_v1(bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size);
} // namespace plugin
} // namespace xpu3
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int xpu3_wrapper(Context *ctx, bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto update_inputs_v1 = xpu3::plugin::update_inputs_v1;
// kernel 内要做 reduce,只能用 1 个 cluster
update_inputs_v1<<<1, 64, ctx->xpu_stream>>>(
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
reinterpret_cast<XPU_INT64 *>(prompt_lens),
reinterpret_cast<XPU_INT64 *>(topk_ids),
reinterpret_cast<XPU_INT64 *>(input_ids),
block_tables,
reinterpret_cast<const XPU_INT64 *>(stop_nums),
stop_flags,
is_block_step,
reinterpret_cast<const XPU_INT64 *>(next_tokens),
bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size);
return api::SUCCESS;
static int xpu3_wrapper(Context *ctx,
bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto update_inputs_v1 = xpu3::plugin::update_inputs_v1;
// kernel 内要做 reduce,只能用 1 个 cluster
update_inputs_v1<<<1, 64, ctx->xpu_stream>>>(
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
reinterpret_cast<XPU_INT64 *>(prompt_lens),
reinterpret_cast<XPU_INT64 *>(topk_ids),
reinterpret_cast<XPU_INT64 *>(input_ids),
block_tables,
reinterpret_cast<const XPU_INT64 *>(stop_nums),
stop_flags,
is_block_step,
reinterpret_cast<const XPU_INT64 *>(next_tokens),
bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size);
return api::SUCCESS;
}
int update_inputs_v1(Context *ctx, bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs_v1", int);
WRAPPER_DUMP_PARAM5(ctx, not_need_stop, seq_lens_this_time,
seq_lens_encoder, seq_lens_decoder, step_seq_lens_decoder);
WRAPPER_DUMP_PARAM5(ctx, prompt_lens, topk_ids, input_ids, block_tables, stop_nums);
WRAPPER_DUMP_PARAM3(ctx, stop_flags, is_block_step, next_tokens);
WRAPPER_DUMP_PARAM5(ctx, bsz, max_bsz, input_ids_stride, block_num_per_seq, block_size);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
assert(false);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx, not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
prompt_lens,
topk_ids,
input_ids,
block_tables,
stop_nums,
stop_flags,
is_block_step,
next_tokens,
bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size);
}
WRAPPER_UNIMPLEMENTED(ctx);
int update_inputs_v1(Context *ctx,
bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs_v1", int);
WRAPPER_DUMP_PARAM5(ctx,
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder);
WRAPPER_DUMP_PARAM5(
ctx, prompt_lens, topk_ids, input_ids, block_tables, stop_nums);
WRAPPER_DUMP_PARAM3(ctx, stop_flags, is_block_step, next_tokens);
WRAPPER_DUMP_PARAM5(
ctx, bsz, max_bsz, input_ids_stride, block_num_per_seq, block_size);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
assert(false);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
prompt_lens,
topk_ids,
input_ids,
block_tables,
stop_nums,
stop_flags,
is_block_step,
next_tokens,
bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -22,32 +22,25 @@
#pragma once
#include <rdma/rdma_cma.h>
#include <rdma/rdma_verbs.h>
#include <sys/epoll.h>
#include <atomic>
#include <string>
#include <vector>
#include <netinet/tcp.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sstream>
#include <netdb.h>
#include <sstream>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <cstring>
#include <netdb.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <net/if.h>
#include <sys/ioctl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <rdma/rdma_cma.h>
#include <rdma/rdma_verbs.h>
#include <sys/epoll.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <unistd.h>
#include <atomic>
#include <cstring>
#include <memory>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "kvcache_rdma.h"
#include "util.h"
@@ -60,115 +53,115 @@
/// @brief IB device information structure
struct IbDeviceInfo {
int device;
uint64_t guid;
enum ibv_mtu mtu;
uint64_t busid;
uint8_t port;
uint8_t link;
uint8_t active_mtu;
int speed;
ibv_context* context;
char devName[64];
int realPort;
int maxQp;
int device;
uint64_t guid;
enum ibv_mtu mtu;
uint64_t busid;
uint8_t port;
uint8_t link;
uint8_t active_mtu;
int speed;
ibv_context* context;
char devName[64];
int realPort;
int maxQp;
};
/// @brief Queue Pair information for RDMA
struct QpInfo {
uint32_t lid;
uint32_t qpn;
uint32_t psn;
union ibv_gid gid;
enum ibv_mtu mtu;
uint32_t lid;
uint32_t qpn;
uint32_t psn;
union ibv_gid gid;
enum ibv_mtu mtu;
/// @brief Serialize QP info to buffer
void serialize(char* buffer) const {
uint32_t* intBuffer = reinterpret_cast<uint32_t*>(buffer);
intBuffer[0] = htonl(lid);
intBuffer[1] = htonl(qpn);
intBuffer[2] = htonl(psn);
memcpy(buffer + 12, gid.raw, sizeof(gid.raw));
intBuffer[7] = htonl(static_cast<uint32_t>(mtu));
}
/// @brief Serialize QP info to buffer
void serialize(char* buffer) const {
uint32_t* intBuffer = reinterpret_cast<uint32_t*>(buffer);
intBuffer[0] = htonl(lid);
intBuffer[1] = htonl(qpn);
intBuffer[2] = htonl(psn);
memcpy(buffer + 12, gid.raw, sizeof(gid.raw));
intBuffer[7] = htonl(static_cast<uint32_t>(mtu));
}
/// @brief Deserialize QP info from buffer
void deserialize(const char* buffer) {
const uint32_t* intBuffer = reinterpret_cast<const uint32_t*>(buffer);
lid = ntohl(intBuffer[0]);
qpn = ntohl(intBuffer[1]);
psn = ntohl(intBuffer[2]);
memcpy(gid.raw, buffer + 12, sizeof(gid.raw));
mtu = static_cast<ibv_mtu>(ntohl(intBuffer[7]));
}
/// @brief Deserialize QP info from buffer
void deserialize(const char* buffer) {
const uint32_t* intBuffer = reinterpret_cast<const uint32_t*>(buffer);
lid = ntohl(intBuffer[0]);
qpn = ntohl(intBuffer[1]);
psn = ntohl(intBuffer[2]);
memcpy(gid.raw, buffer + 12, sizeof(gid.raw));
mtu = static_cast<ibv_mtu>(ntohl(intBuffer[7]));
}
static const size_t size = 12 + sizeof(gid.raw) + 4;
static const size_t size = 12 + sizeof(gid.raw) + 4;
};
/// @brief RDMA connection context
struct Connection {
std::atomic<int> connected;
std::atomic<int> connected;
// Memory regions
struct ibv_mr *recv_mr;
struct ibv_mr *send_mr;
// Memory regions
struct ibv_mr* recv_mr;
struct ibv_mr* send_mr;
// Cache pointers
std::vector<std::vector<void*>> local_cache_key_ptr_per_layer;
std::vector<std::vector<void*>> local_cache_value_ptr_per_layer;
// Cache pointers
std::vector<std::vector<void*>> local_cache_key_ptr_per_layer;
std::vector<std::vector<void*>> local_cache_value_ptr_per_layer;
// Memory region lists
std::vector<ibv_mr*> write_cache_key_server_mr_list;
std::vector<ibv_mr*> write_cache_value_server_mr_list;
std::vector<std::vector<ibv_mr*>> write_mr_key_list;
std::vector<std::vector<ibv_mr*>> write_mr_value_list;
// Memory region lists
std::vector<ibv_mr*> write_cache_key_server_mr_list;
std::vector<ibv_mr*> write_cache_value_server_mr_list;
std::vector<std::vector<ibv_mr*>> write_mr_key_list;
std::vector<std::vector<ibv_mr*>> write_mr_value_list;
// Remote access information
std::vector<void*> write_cache_key_remote_ptr_list;
std::vector<uint32_t> write_cache_key_remote_rkey_list;
std::vector<void*> write_cache_value_remote_ptr_list;
std::vector<uint32_t> write_cache_value_remote_rkey_list;
// Remote access information
std::vector<void*> write_cache_key_remote_ptr_list;
std::vector<uint32_t> write_cache_key_remote_rkey_list;
std::vector<void*> write_cache_value_remote_ptr_list;
std::vector<uint32_t> write_cache_value_remote_rkey_list;
// Received remote memory information
std::vector<void*> receive_write_cache_key_remote_ptr_list;
std::vector<uint32_t> receive_write_cache_key_remote_rkey_list;
std::vector<void*> receive_write_cache_value_remote_ptr_list;
std::vector<uint32_t> receive_write_cache_value_remote_rkey_list;
// Received remote memory information
std::vector<void*> receive_write_cache_key_remote_ptr_list;
std::vector<uint32_t> receive_write_cache_key_remote_rkey_list;
std::vector<void*> receive_write_cache_value_remote_ptr_list;
std::vector<uint32_t> receive_write_cache_value_remote_rkey_list;
std::vector<void *> send_write_cache_key_remote_ptr_list;
std::vector<uint32_t> send_write_cache_key_remote_rkey_list;
std::vector<void *> send_write_cache_value_remote_ptr_list;
std::vector<uint32_t> send_write_cache_value_remote_rkey_list;
std::vector<void*> send_write_cache_key_remote_ptr_list;
std::vector<uint32_t> send_write_cache_key_remote_rkey_list;
std::vector<void*> send_write_cache_value_remote_ptr_list;
std::vector<uint32_t> send_write_cache_value_remote_rkey_list;
// For rdma read operations
std::vector<void*> read_bufs;
std::vector<ibv_mr*> read_mrs;
// For rdma read operations
std::vector<void*> read_bufs;
std::vector<ibv_mr*> read_mrs;
// Work completion tracking
int wc_count;
int wc_target_count;
// Work completion tracking
int wc_count;
int wc_target_count;
// Configuration
int layer_number;
int block_number;
int block_byte_size;
std::string url;
// Configuration
int layer_number;
int block_number;
int block_byte_size;
std::string url;
Connection() = default;
~Connection();
Connection() = default;
~Connection();
};
/// @brief RDMA context structure
struct RdmaContext {
int sock_fd;
struct ibv_context* context;
struct ibv_comp_channel* channel;
struct ibv_pd* pd;
struct ibv_mr* mr;
struct ibv_cq* cq;
struct ibv_qp* qp;
struct ibv_port_attr portinfo;
struct Connection conn;
int sock_fd;
struct ibv_context* context;
struct ibv_comp_channel* channel;
struct ibv_pd* pd;
struct ibv_mr* mr;
struct ibv_cq* cq;
struct ibv_qp* qp;
struct ibv_port_attr portinfo;
struct Connection conn;
};
// Global variables
@@ -176,36 +169,46 @@ extern std::vector<IbDeviceInfo> g_ib_all_devs;
static int g_kvcache_ib_dev_nums = -1;
// Connection management functions
bool client_exchange_destinations(
struct RdmaContext* ctx,
int ib_port,
unsigned int port,
int gidx,
const std::string& dst_ip);
bool client_exchange_destinations(struct RdmaContext* ctx,
int ib_port,
unsigned int port,
int gidx,
const std::string& dst_ip);
int server_exchange_qp_info(int connfd, QpInfo* local_dest, QpInfo* rem_dest);
struct RdmaContext* create_qp(struct IbDeviceInfo* ib_dev, struct ibv_pd** g_pd);
struct RdmaContext* create_qp(struct IbDeviceInfo* ib_dev,
struct ibv_pd** g_pd);
bool clear_qp_info(struct RdmaContext* ctx);
// QP modification functions
QpStatus modify_qp_to_rts(struct RdmaContext* ctx, int port, int my_psn,
struct QpInfo* dest, int sgid_id);
bool poll_cq_with_timeout(struct RdmaContext* ctx, int timeout_seconds, int cqe_count);
QpStatus modify_qp_to_rts(struct RdmaContext* ctx,
int port,
int my_psn,
struct QpInfo* dest,
int sgid_id);
bool poll_cq_with_timeout(struct RdmaContext* ctx,
int timeout_seconds,
int cqe_count);
// Utility functions
int get_port_info(struct ibv_context* Context, int port,
struct ibv_port_attr* attr);
int get_port_info(struct ibv_context* Context,
int port,
struct ibv_port_attr* attr);
int parse_port_ib_info();
// Memory region exchange
bool client_exchange_mr(struct RdmaContext* ctx);
bool server_exchange_mr(struct RdmaContext* ctx);
bool server_send_memory_region(struct RdmaContext *ctx, void *local_mr, int byte_num);
bool client_receive_memory_region(struct RdmaContext *ctx, void *remote_mr, int byte_num);
bool server_send_memory_region(struct RdmaContext* ctx,
void* local_mr,
int byte_num);
bool client_receive_memory_region(struct RdmaContext* ctx,
void* remote_mr,
int byte_num);
// Network setup
int setup_listening_socket(int port);
int configure_epoll(int sockfd);
std::vector<std::string> get_net_ifname();
#endif // FASTDEPLOY_KVCACHE_CONNECTION_H
#endif // FASTDEPLOY_KVCACHE_CONNECTION_H
@@ -4,77 +4,88 @@
#pragma once
#include <rdma/rdma_cma.h>
#include <vector>
#include <string>
#include <map>
#include <mutex>
#include "util.h" // Contains constant definitions
#include <string>
#include <vector>
#include "kvcache_connection.h"
#include "log.h"
#include "util.h" // Contains constant definitions
/**
* @brief RDMA communication handler for key-value cache
*/
class RDMACommunicator {
public:
// Construction/Destruction
RDMACommunicator(std::string &role, int gpu_idx, std::string &port,
std::vector<int64_t> local_key_cache,
std::vector<int64_t> local_value_cache,
int block_number, int block_bytes);
~RDMACommunicator();
public:
// Construction/Destruction
RDMACommunicator(std::string& role,
int gpu_idx,
std::string& port,
std::vector<int64_t> local_key_cache,
std::vector<int64_t> local_value_cache,
int block_number,
int block_bytes);
~RDMACommunicator();
// Connection management
int connect(const std::string &dst_ip, const std::string &dst_port);
bool is_connected(const std::string &dst_ip, const std::string &dst_port);
// Connection management
int connect(const std::string& dst_ip, const std::string& dst_port);
bool is_connected(const std::string& dst_ip, const std::string& dst_port);
// Core functionality
int write_cache(const std::string &ip, const std::string &port,
const std::vector<int64_t>& local_block_ids,
const std::vector<int64_t>& remote_block_ids,
int32_t layer_idx);
// Core functionality
int write_cache(const std::string& ip,
const std::string& port,
const std::vector<int64_t>& local_block_ids,
const std::vector<int64_t>& remote_block_ids,
int32_t layer_idx);
// Server Init
int init_server();
// Server Init
int init_server();
// get socket nic ip
std::string fetch_local_ip();
// get socket nic ip
std::string fetch_local_ip();
private:
// Server Core functions
int start_server(int sport, int sgid_idx, int gpu_index);
private:
// Server Core functions
int start_server(int sport, int sgid_idx, int gpu_index);
// Internal implementation methods
void resize_vectors();
void assign_pointers();
void validate_addr();
bool client_mr_register_per_layer(struct RdmaContext *ctx);
bool server_mr_register_per_layer(struct RdmaContext *ctx);
struct ibv_mr* register_memory_region(ibv_pd* pd, void* addr, size_t size,
const std::string& desc, uint32_t access_flags);
bool deregister_memory_regions(struct RdmaContext* ctx);
// Internal implementation methods
void resize_vectors();
void assign_pointers();
void validate_addr();
bool client_mr_register_per_layer(struct RdmaContext* ctx);
bool server_mr_register_per_layer(struct RdmaContext* ctx);
struct ibv_mr* register_memory_region(ibv_pd* pd,
void* addr,
size_t size,
const std::string& desc,
uint32_t access_flags);
bool deregister_memory_regions(struct RdmaContext* ctx);
bool post_block_send(struct RdmaContext* ctx, int layer_idx,
const std::vector<int64_t>& local_block_ids,
bool is_key, std::vector<uint64_t>& remote_addr,
uint32_t rkey, const std::string &ip,
const std::string &port);
bool post_block_send(struct RdmaContext* ctx,
int layer_idx,
const std::vector<int64_t>& local_block_ids,
bool is_key,
std::vector<uint64_t>& remote_addr,
uint32_t rkey,
const std::string& ip,
const std::string& port);
bool execute_rdma_writes(struct RdmaContext* ctx, int layer_idx,
bool execute_rdma_writes(struct RdmaContext* ctx,
int layer_idx,
const std::vector<int64_t>& local_block_ids,
bool is_key, std::vector<uint64_t>& remote_addr,
bool is_key,
std::vector<uint64_t>& remote_addr,
uint32_t rkey);
void prepare_write_requests(struct ibv_sge* sge_list,
struct ibv_send_wr* send_wr_list,
int layer_idx,
const std::vector<int64_t>& local_block_ids,
bool is_key,
std::vector<uint64_t>& remote_addr,
uint32_t rkey);
void prepare_write_requests(struct ibv_sge* sge_list,
struct ibv_send_wr* send_wr_list,
int layer_idx,
const std::vector<int64_t>& local_block_ids,
bool is_key,
std::vector<uint64_t>& remote_addr,
uint32_t rkey);
bool execute_read_verification(struct RdmaContext* ctx,
bool execute_read_verification(struct RdmaContext* ctx,
size_t block_idx,
uint64_t remote_addr,
uint32_t rkey,
@@ -82,46 +93,56 @@ private:
const std::string& ip,
const std::string& port);
bool post_send_with_retry(struct RdmaContext* ctx,
bool post_send_with_retry(struct RdmaContext* ctx,
struct ibv_send_wr* wr_list,
size_t inflight_wr,
bool need_poll);
// Connection management
int client_listener();
void close_server_connection(int fd, struct RdmaContext* ctx, int epollfd,
std::map<int, struct RdmaContext*>& connectionContexts);
void close_client_connection(int fd, struct RdmaContext* ctx, int epollfd);
// Connection management
int client_listener();
void close_server_connection(
int fd,
struct RdmaContext* ctx,
int epollfd,
std::map<int, struct RdmaContext*>& connectionContexts);
void close_client_connection(int fd, struct RdmaContext* ctx, int epollfd);
void remove_conn(const std::string& url);
struct RdmaContext *get_conn(const std::string &ip,
const std::string &port);
void remove_conn(const std::string& url);
struct RdmaContext* get_conn(const std::string& ip, const std::string& port);
// Member variables
std::string splitwise_role; // Role in distributed system ("decode" or other)
int gpu_idx; // GPU device index
std::string port; // Communication port
std::vector<int64_t> local_cache_key_ptr_layer_head_; // Key cache pointers
std::vector<int64_t> local_cache_value_ptr_layer_head_; // Value cache pointers
int block_number; // Number of blocks
int block_size_byte; // Size of each block in bytes
int layer_number; // Number of layers
// Member variables
std::string splitwise_role; // Role in distributed system ("decode" or other)
int gpu_idx; // GPU device index
std::string port; // Communication port
std::vector<int64_t> local_cache_key_ptr_layer_head_; // Key cache pointers
std::vector<int64_t>
local_cache_value_ptr_layer_head_; // Value cache pointers
int block_number; // Number of blocks
int block_size_byte; // Size of each block in bytes
int layer_number; // Number of layers
std::vector<std::vector<void*>> local_cache_key_ptr_per_layer; // Per-layer key pointers
std::vector<std::vector<void*>> local_cache_value_ptr_per_layer; // Per-layer value pointers
std::vector<std::vector<void*>>
local_cache_key_ptr_per_layer; // Per-layer key pointers
std::vector<std::vector<void*>>
local_cache_value_ptr_per_layer; // Per-layer value pointers
std::vector<struct ibv_mr*> write_mr_key_list; // Memory regions for key writes
std::vector<struct ibv_mr*> write_mr_value_list; // Memory regions for value writes
std::vector<struct ibv_mr*> write_cache_key_server_mr_list; // Server-side key memory regions
std::vector<struct ibv_mr*> write_cache_value_server_mr_list; // Server-side value memory regions
std::vector<struct ibv_mr*>
write_mr_key_list; // Memory regions for key writes
std::vector<struct ibv_mr*>
write_mr_value_list; // Memory regions for value writes
std::vector<struct ibv_mr*>
write_cache_key_server_mr_list; // Server-side key memory regions
std::vector<struct ibv_mr*>
write_cache_value_server_mr_list; // Server-side value memory regions
std::vector<std::string> main_ip_list; // List of local IP addresses
std::map<std::string, struct RdmaContext*> conn_map; // Active connections map
std::mutex mutex_; // Thread synchronization mutex
int rdma_event_channel_epoll_fd; // Epoll file descriptor
struct ibv_pd *g_pd = NULL; // fd
int RDMACommunicator_status; // Communicator status flag
bool start_client_listener = false; // Client listener flag
std::vector<std::string> main_ip_list; // List of local IP addresses
std::map<std::string, struct RdmaContext*>
conn_map; // Active connections map
std::mutex mutex_; // Thread synchronization mutex
int rdma_event_channel_epoll_fd; // Epoll file descriptor
struct ibv_pd* g_pd = NULL; // fd
int RDMACommunicator_status; // Communicator status flag
bool start_client_listener = false; // Client listener flag
};
#endif // KVCACHE_RDMA_H
#endif // KVCACHE_RDMA_H
@@ -19,99 +19,130 @@
* limitations under the License.
*/
#include <pthread.h>
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <sys/time.h>
#include <unistd.h> //for gethostname
#include <sys/syscall.h>
#include <pthread.h>
#include <string>
#include <ctime>
#include <sys/time.h>
#include <time.h>
#include <unistd.h> //for gethostname
#include <chrono>
#include <ctime>
#include <string>
#define KV_IS_DEBUG_ENABLED (std::getenv("KVCACHE_DEBUG"))
#define FILE_NAME(x) (strrchr(x,'/') ? strrchr(x,'/')+1 : x)
#define FILE_NAME(x) (strrchr(x, '/') ? strrchr(x, '/') + 1 : x)
static thread_local char __attribute__((__unused__)) str[64];
// for log levels (C++ enum class style in C)
typedef enum {
KV_LOG_LEVEL_INFO = 0,
KV_LOG_LEVEL_DEBUG = 1,
KV_LOG_LEVEL_WARN = 2,
KV_LOG_LEVEL_ERROR = 3
KV_LOG_LEVEL_INFO = 0,
KV_LOG_LEVEL_DEBUG = 1,
KV_LOG_LEVEL_WARN = 2,
KV_LOG_LEVEL_ERROR = 3
} KVLogLevel;
void debug_log(KVLogLevel level, bool enable_to_terminal, const char *filefunc,
int line, const char *fmt, ...) __attribute__ ((format (printf, 5, 6)));
void debug_log(KVLogLevel level,
bool enable_to_terminal,
const char *filefunc,
int line,
const char *fmt,
...) __attribute__((format(printf, 5, 6)));
/**
* @brief Unified logging macro to reduce duplication and improve maintainability.
* @brief Unified logging macro to reduce duplication and improve
* maintainability.
*
* @param level Log level (e.g., INFO, DEBUG, WARN, ERR).
* @param to_terminal If true, the log will be printed to terminal.
* @param ... Format string and arguments (like printf).
*/
#define KV_LOG(level, to_terminal, ...) \
debug_log(level, to_terminal, FILE_NAME(__FILE__), __LINE__, __VA_ARGS__)
debug_log(level, to_terminal, FILE_NAME(__FILE__), __LINE__, __VA_ARGS__)
// Public logging macros with terminal output
#define WARN(...) KV_LOG(KV_LOG_LEVEL_WARN, true, __VA_ARGS__)
#define ERR(...) KV_LOG(KV_LOG_LEVEL_ERROR, true, __VA_ARGS__)
#define DEBUG(...) KV_LOG(KV_LOG_LEVEL_DEBUG, true, __VA_ARGS__)
#define INFO(...) KV_LOG(KV_LOG_LEVEL_INFO, true, __VA_ARGS__)
#define WARN(...) KV_LOG(KV_LOG_LEVEL_WARN, true, __VA_ARGS__)
#define ERR(...) KV_LOG(KV_LOG_LEVEL_ERROR, true, __VA_ARGS__)
#define DEBUG(...) KV_LOG(KV_LOG_LEVEL_DEBUG, true, __VA_ARGS__)
#define INFO(...) KV_LOG(KV_LOG_LEVEL_INFO, true, __VA_ARGS__)
#define gettid() ((pid_t)syscall(SYS_gettid))
#define GET_CURRENT_TIME() do { \
time_t timer = time(0); \
struct tm* t = localtime(&timer); \
char hostname[32]; \
gethostname(hostname, 32); \
sprintf(str, "%02d:%02d:%02d][%.32s][%d", \
t->tm_hour, t->tm_min, t->tm_sec, hostname, gettid()); \
} while (0)
#define GET_CURRENT_TIME() \
do { \
time_t timer = time(0); \
struct tm *t = localtime(&timer); \
char hostname[32]; \
gethostname(hostname, 32); \
sprintf(str, \
"%02d:%02d:%02d][%.32s][%d", \
t->tm_hour, \
t->tm_min, \
t->tm_sec, \
hostname, \
gettid()); \
} while (0)
#define LOGE(fmt, arg...) do { \
GET_CURRENT_TIME(); \
fprintf(stderr, "[%s][ERR][KV_CACHE][%s:%d] " \
fmt "\n",str, \
FILE_NAME(__FILE__), __LINE__, ## arg); \
} while (0)
#define LOGE(fmt, arg...) \
do { \
GET_CURRENT_TIME(); \
fprintf(stderr, \
"[%s][ERR][KV_CACHE][%s:%d] " fmt "\n", \
str, \
FILE_NAME(__FILE__), \
__LINE__, \
##arg); \
} while (0)
#define LOGW(fmt, arg...) do { \
GET_CURRENT_TIME(); \
fprintf(stderr, "[%s][WARN][KV_CACHE][%s:%d] " \
fmt "\n",str, \
FILE_NAME(__FILE__), __LINE__, ## arg); \
} while (0)
#define LOGW(fmt, arg...) \
do { \
GET_CURRENT_TIME(); \
fprintf(stderr, \
"[%s][WARN][KV_CACHE][%s:%d] " fmt "\n", \
str, \
FILE_NAME(__FILE__), \
__LINE__, \
##arg); \
} while (0)
#define LOGI(fmt, arg...) do { \
GET_CURRENT_TIME(); \
fprintf(stdout, "[%s][INFO][KV_CACHE][%s:%d] " \
fmt "\n",str, \
FILE_NAME(__FILE__), __LINE__, ## arg); \
} while (0)
#define LOGI(fmt, arg...) \
do { \
GET_CURRENT_TIME(); \
fprintf(stdout, \
"[%s][INFO][KV_CACHE][%s:%d] " fmt "\n", \
str, \
FILE_NAME(__FILE__), \
__LINE__, \
##arg); \
} while (0)
#define LOGD(fmt, arg...) do { \
if (KV_IS_DEBUG_ENABLED) { \
GET_CURRENT_TIME(); \
fprintf(stdout, "[%s][DBG][KV_CACHE][%s:%d] " \
fmt "\n", str, \
FILE_NAME(__FILE__), __LINE__, ## arg); \
} \
} while (0)
#define LOGD(fmt, arg...) \
do { \
if (KV_IS_DEBUG_ENABLED) { \
GET_CURRENT_TIME(); \
fprintf(stdout, \
"[%s][DBG][KV_CACHE][%s:%d] " fmt "\n", \
str, \
FILE_NAME(__FILE__), \
__LINE__, \
##arg); \
} \
} while (0)
#define LOGD_IF(cond, fmt, ...) do { \
if ((cond)) \
LOGD(fmt, __VA_ARGS__); \
} while (0)
#define LOGD_IF(cond, fmt, ...) \
do { \
if ((cond)) LOGD(fmt, __VA_ARGS__); \
} while (0)
#define LOGD_RAW(fmt, arg...) do { \
if (ENV_ENABLE_RAW("KV_IS_DEBUG_ENABLED")) { \
GET_CURRENT_TIME(); \
fprintf(stdout, "[%s][DBG][KV_CACHE][%s:%d] " \
fmt "\n", str, \
FILE_NAME(__FILE__), __LINE__, ## arg); \
} \
} while (0)
#define LOGD_RAW(fmt, arg...) \
do { \
if (ENV_ENABLE_RAW("KV_IS_DEBUG_ENABLED")) { \
GET_CURRENT_TIME(); \
fprintf(stdout, \
"[%s][DBG][KV_CACHE][%s:%d] " fmt "\n", \
str, \
FILE_NAME(__FILE__), \
__LINE__, \
##arg); \
} \
} while (0)
@@ -1,21 +1,21 @@
#ifndef KVCACHE_UTILS_H
#define KVCACHE_UTILS_H
#include <ctime>
#include <chrono>
#include <iostream>
#include <string>
#include <cstdlib>
#include <algorithm>
#include <cctype>
#include <stdexcept>
#include <cstdio>
#include <arpa/inet.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <vector>
#include <algorithm>
#include <cctype>
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <iostream>
#include <stdexcept>
#include <string>
#include <vector>
#include "log.h"
#define PATH_MAX 4096 /* # chars in a path name including nul */
@@ -28,22 +28,22 @@
/// @brief Connection status enumeration
enum class ConnStatus {
kConnected, // Connection is active
kDisconnected, // Connection is not active
kError, // Connection error occurred
kTimeout, // Connection timed out
kInvalidParameters // Invalid connection parameters
kConnected, // Connection is active
kDisconnected, // Connection is not active
kError, // Connection error occurred
kTimeout, // Connection timed out
kInvalidParameters // Invalid connection parameters
};
/// @brief Queue Pair (QP) setup result status
enum class QpStatus {
kSuccess, // Successfully transitioned QP to RTS
kInvalidParameters, // ctx or dest is null
kDeviceQueryFailed, // ibv_query_device failed
kPortQueryFailed, // ibv_query_port failed
kMtuMismatch, // Requested MTU exceeds active MTU
kModifyToRTRFailed, // Failed to modify QP to RTR
kModifyToRTSFailed // Failed to modify QP to RTS
kSuccess, // Successfully transitioned QP to RTS
kInvalidParameters, // ctx or dest is null
kDeviceQueryFailed, // ibv_query_device failed
kPortQueryFailed, // ibv_query_port failed
kMtuMismatch, // Requested MTU exceeds active MTU
kModifyToRTRFailed, // Failed to modify QP to RTR
kModifyToRTSFailed // Failed to modify QP to RTS
};
/**
@@ -51,265 +51,281 @@ enum class QpStatus {
* @param busId PCI bus ID string (e.g. "0000:3b:00.0")
* @param[out] id Converted numeric ID
*/
inline void busid_to_int64(const char *busId, int64_t *id) {
char hexStr[17] = {0};
int hexOffset = 0;
inline void busid_to_int64(const char* busId, int64_t* id) {
char hexStr[17] = {0};
int hexOffset = 0;
// Filter valid hex characters
for (int i = 0; hexOffset < sizeof(hexStr) - 1 && busId[i] != '\0'; i++) {
char c = busId[i];
if (c == '.' || c == ':') continue;
// Filter valid hex characters
for (int i = 0; hexOffset < sizeof(hexStr) - 1 && busId[i] != '\0'; i++) {
char c = busId[i];
if (c == '.' || c == ':') continue;
if ((c >= '0' && c <= '9') ||
(c >= 'A' && c <= 'F') ||
(c >= 'a' && c <= 'f')) {
hexStr[hexOffset++] = c;
}
if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'F') ||
(c >= 'a' && c <= 'f')) {
hexStr[hexOffset++] = c;
}
}
*id = strtol(hexStr, NULL, 16);
*id = strtol(hexStr, NULL, 16);
}
class NetworkInterfaceManager {
public:
struct InterfaceInfo {
std::string name;
std::string ip;
bool is_up;
bool is_running;
bool is_loopback;
public:
struct InterfaceInfo {
std::string name;
std::string ip;
bool is_up;
bool is_running;
bool is_loopback;
bool isUsable() const {
return is_up && is_running && !is_loopback;
}
};
bool isUsable() const { return is_up && is_running && !is_loopback; }
};
static std::vector<InterfaceInfo> getAllInterfaces() {
std::vector<InterfaceInfo> interfaces;
struct ifaddrs *ifaddrs_ptr = nullptr;
static std::vector<InterfaceInfo> getAllInterfaces() {
std::vector<InterfaceInfo> interfaces;
struct ifaddrs* ifaddrs_ptr = nullptr;
if (getifaddrs(&ifaddrs_ptr) == -1) {
return interfaces;
}
for (struct ifaddrs *ifa = ifaddrs_ptr; ifa != nullptr; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == nullptr) continue;
if (ifa->ifa_addr->sa_family != AF_INET) continue;
InterfaceInfo info;
info.name = ifa->ifa_name;
info.is_up = (ifa->ifa_flags & IFF_UP) != 0;
info.is_running = (ifa->ifa_flags & IFF_RUNNING) != 0;
info.is_loopback = (ifa->ifa_flags & IFF_LOOPBACK) != 0;
struct sockaddr_in* sa = (struct sockaddr_in*)ifa->ifa_addr;
char ip_str[INET_ADDRSTRLEN];
inet_ntop(AF_INET, &sa->sin_addr, ip_str, INET_ADDRSTRLEN);
info.ip = ip_str;
interfaces.push_back(info);
}
freeifaddrs(ifaddrs_ptr);
return interfaces;
if (getifaddrs(&ifaddrs_ptr) == -1) {
return interfaces;
}
static std::string getFirstUsableInterface() {
auto interfaces = getAllInterfaces();
for (struct ifaddrs* ifa = ifaddrs_ptr; ifa != nullptr;
ifa = ifa->ifa_next) {
if (ifa->ifa_addr == nullptr) continue;
if (ifa->ifa_addr->sa_family != AF_INET) continue;
for (const auto& iface : interfaces) {
if (iface.isUsable()) {
return iface.name;
}
}
return "";
InterfaceInfo info;
info.name = ifa->ifa_name;
info.is_up = (ifa->ifa_flags & IFF_UP) != 0;
info.is_running = (ifa->ifa_flags & IFF_RUNNING) != 0;
info.is_loopback = (ifa->ifa_flags & IFF_LOOPBACK) != 0;
struct sockaddr_in* sa = (struct sockaddr_in*)ifa->ifa_addr;
char ip_str[INET_ADDRSTRLEN];
inet_ntop(AF_INET, &sa->sin_addr, ip_str, INET_ADDRSTRLEN);
info.ip = ip_str;
interfaces.push_back(info);
}
static void displayAllInterfaces() {
auto interfaces = getAllInterfaces();
freeifaddrs(ifaddrs_ptr);
return interfaces;
}
printf("Available network interfaces:\n");
for (const auto& iface : interfaces) {
printf(" %s: %s [%s%s%s]\n",
iface.name.c_str(),
iface.ip.c_str(),
iface.is_up ? "UP" : "DOWN",
iface.is_running ? ",RUNNING" : "",
iface.is_loopback ? ",LOOPBACK" : "");
}
static std::string getFirstUsableInterface() {
auto interfaces = getAllInterfaces();
for (const auto& iface : interfaces) {
if (iface.isUsable()) {
return iface.name;
}
}
return "";
}
static void displayAllInterfaces() {
auto interfaces = getAllInterfaces();
printf("Available network interfaces:\n");
for (const auto& iface : interfaces) {
printf(" %s: %s [%s%s%s]\n",
iface.name.c_str(),
iface.ip.c_str(),
iface.is_up ? "UP" : "DOWN",
iface.is_running ? ",RUNNING" : "",
iface.is_loopback ? ",LOOPBACK" : "");
}
}
};
class KVCacheConfig {
private:
// Configuration values
int rdma_gid_index_;
bool has_rdma_dest_port_override_; // 替代 std::optional
int rdma_dest_port_override_;
const char* socket_interface_;
char* socket_interface_buffer_;
bool gdrcopy_flush_enabled_;
bool verify_read_enabled_;
bool debug_mode_enabled_;
bool debug_output_enabled_;
const char* debug_file_path_;
const char* error_file_path_;
bool relax_ordering_enabled_;
int ib_timeout_;
const char* rdma_nics_;
private:
// Configuration values
int rdma_gid_index_;
bool has_rdma_dest_port_override_; // 替代 std::optional
int rdma_dest_port_override_;
const char* socket_interface_;
char* socket_interface_buffer_;
bool gdrcopy_flush_enabled_;
bool verify_read_enabled_;
bool debug_mode_enabled_;
bool debug_output_enabled_;
const char* debug_file_path_;
const char* error_file_path_;
bool relax_ordering_enabled_;
int ib_timeout_;
const char* rdma_nics_;
// Private constructor for singleton pattern
KVCacheConfig() {
// Initialize configuration from environment variables
rdma_gid_index_ = parse_int_value(
std::getenv("KVCACHE_RDMA_GID_INDEX"), 3, "KVCACHE_RDMA_GID_INDEX");
// Private constructor for singleton pattern
KVCacheConfig() {
// Initialize configuration from environment variables
rdma_gid_index_ = parse_int_value(
std::getenv("KVCACHE_RDMA_GID_INDEX"), 3, "KVCACHE_RDMA_GID_INDEX");
// Parse optional RDMA port override
const char* port_value = std::getenv("SET_RDMA_DEST_PORT");
has_rdma_dest_port_override_ = false; // 默认为false
if (port_value) {
try {
rdma_dest_port_override_ = std::stoi(std::string(port_value));
has_rdma_dest_port_override_ = true;
} catch (const std::exception& e) {
fprintf(stderr, "Invalid SET_RDMA_DEST_PORT value: '%s', ignoring\n", port_value);
}
}
const char* env_interface = std::getenv("KVCACHE_SOCKET_IFNAME");
if (env_interface && env_interface[0] != '\0') {
socket_interface_ = env_interface;
printf("Using specified interface: %s\n", socket_interface_);
} else {
std::string iface = NetworkInterfaceManager::getFirstUsableInterface();
if (!iface.empty()) {
socket_interface_buffer_ = new char[iface.size() + 1];
std::strcpy(socket_interface_buffer_, iface.c_str());
socket_interface_ = socket_interface_buffer_;
printf("Auto-detected interface: %s\n", socket_interface_);
} else {
fprintf(stderr, "Warning: No usable network interface found\n");
socket_interface_ = "";
}
NetworkInterfaceManager::displayAllInterfaces();
}
socket_interface_ = std::getenv("KVCACHE_SOCKET_IFNAME");
debug_file_path_ = std::getenv("KVCACHE_DEBUG_FILE");
error_file_path_ = std::getenv("KVCACHE_ERROR_FILE");
gdrcopy_flush_enabled_ = parse_bool_value(std::getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE"));
verify_read_enabled_ = parse_bool_value(std::getenv("KVCACHE_VERIFY_READ"));
debug_mode_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG")) ||
parse_bool_value(std::getenv("KV_IS_DEBUG_ENABLED"));
debug_output_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG_OUTPUT"));
relax_ordering_enabled_ = parse_bool_value(std::getenv("KVCACHE_RELAX_ORDERING"));
ib_timeout_ = parse_int_value(
std::getenv("KVCACHE_IB_TIMEOUT"),
18,
"KVCACHE_IB_TIMEOUT"
);
rdma_nics_ = std::getenv("KVCACHE_RDMA_NICS");
// Parse optional RDMA port override
const char* port_value = std::getenv("SET_RDMA_DEST_PORT");
has_rdma_dest_port_override_ = false; // 默认为false
if (port_value) {
try {
rdma_dest_port_override_ = std::stoi(std::string(port_value));
has_rdma_dest_port_override_ = true;
} catch (const std::exception& e) {
fprintf(stderr,
"Invalid SET_RDMA_DEST_PORT value: '%s', ignoring\n",
port_value);
}
}
// Helper methods
bool parse_bool_value(const char* value) {
if (!value) return false;
const char* env_interface = std::getenv("KVCACHE_SOCKET_IFNAME");
std::string str_value(value);
std::transform(str_value.begin(), str_value.end(), str_value.begin(), ::tolower);
return (str_value == "1" || str_value == "true" ||
str_value == "on" || str_value == "yes");
if (env_interface && env_interface[0] != '\0') {
socket_interface_ = env_interface;
printf("Using specified interface: %s\n", socket_interface_);
} else {
std::string iface = NetworkInterfaceManager::getFirstUsableInterface();
if (!iface.empty()) {
socket_interface_buffer_ = new char[iface.size() + 1];
std::strcpy(socket_interface_buffer_, iface.c_str());
socket_interface_ = socket_interface_buffer_;
printf("Auto-detected interface: %s\n", socket_interface_);
} else {
fprintf(stderr, "Warning: No usable network interface found\n");
socket_interface_ = "";
}
NetworkInterfaceManager::displayAllInterfaces();
}
int parse_int_value(const char* value, int default_value, const char* env_name) {
if (!value) return default_value;
socket_interface_ = std::getenv("KVCACHE_SOCKET_IFNAME");
debug_file_path_ = std::getenv("KVCACHE_DEBUG_FILE");
error_file_path_ = std::getenv("KVCACHE_ERROR_FILE");
try {
return std::stoi(std::string(value));
} catch (const std::invalid_argument& e) {
fprintf(stderr, "Invalid value for %s: '%s', using default: %d\n",
env_name, value, default_value);
return default_value;
} catch (const std::out_of_range& e) {
fprintf(stderr, "%s value out of range: '%s', using default: %d\n",
env_name, value, default_value);
return default_value;
}
gdrcopy_flush_enabled_ =
parse_bool_value(std::getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE"));
verify_read_enabled_ = parse_bool_value(std::getenv("KVCACHE_VERIFY_READ"));
debug_mode_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG")) ||
parse_bool_value(std::getenv("KV_IS_DEBUG_ENABLED"));
debug_output_enabled_ =
parse_bool_value(std::getenv("KVCACHE_DEBUG_OUTPUT"));
relax_ordering_enabled_ =
parse_bool_value(std::getenv("KVCACHE_RELAX_ORDERING"));
ib_timeout_ = parse_int_value(
std::getenv("KVCACHE_IB_TIMEOUT"), 18, "KVCACHE_IB_TIMEOUT");
rdma_nics_ = std::getenv("KVCACHE_RDMA_NICS");
}
// Helper methods
bool parse_bool_value(const char* value) {
if (!value) return false;
std::string str_value(value);
std::transform(
str_value.begin(), str_value.end(), str_value.begin(), ::tolower);
return (str_value == "1" || str_value == "true" || str_value == "on" ||
str_value == "yes");
}
int parse_int_value(const char* value,
int default_value,
const char* env_name) {
if (!value) return default_value;
try {
return std::stoi(std::string(value));
} catch (const std::invalid_argument& e) {
fprintf(stderr,
"Invalid value for %s: '%s', using default: %d\n",
env_name,
value,
default_value);
return default_value;
} catch (const std::out_of_range& e) {
fprintf(stderr,
"%s value out of range: '%s', using default: %d\n",
env_name,
value,
default_value);
return default_value;
}
}
public:
// Prevent copying and assignment
KVCacheConfig(const KVCacheConfig&) = delete;
KVCacheConfig& operator=(const KVCacheConfig&) = delete;
// Get singleton instance
static KVCacheConfig& getInstance() {
static KVCacheConfig instance;
return instance;
}
int get_ib_timeout() const { return ib_timeout_; }
// Configuration retrieval methods
int get_rdma_gid_index() const { return rdma_gid_index_; }
int resolve_rdma_dest_port(int default_port) const {
return has_rdma_dest_port_override_ ? rdma_dest_port_override_
: default_port;
}
int resolve_rdma_dest_port(const std::string& default_port) const {
try {
return resolve_rdma_dest_port(std::stoi(default_port));
} catch (const std::exception& e) {
fprintf(
stderr, "Invalid default port string: %s\n", default_port.c_str());
return 0;
}
}
const char* get_socket_interface() const { return socket_interface_; }
const char* get_debug_file_path() const { return debug_file_path_; }
const char* get_error_file_path() const { return error_file_path_; }
const char* get_rdma_nics() const { return rdma_nics_; }
// Feature check methods
bool is_gdrcopy_flush_enabled() const { return gdrcopy_flush_enabled_; }
bool is_verify_read_enabled() const { return verify_read_enabled_; }
bool is_debug_mode_enabled() const { return debug_mode_enabled_; }
bool is_debug_output_enabled() const { return debug_output_enabled_; }
bool is_relax_ordering_enabled() const { return relax_ordering_enabled_; }
// Display configuration
void displayConfiguration() const {
INFO("KVCache Configuration:\n");
INFO("Init KVCacheConfig RDMA GID Index: %d\n", rdma_gid_index_);
if (has_rdma_dest_port_override_) {
INFO("Init KVCacheConfig RDMA Destination Port Override: %d\n",
rdma_dest_port_override_);
}
public:
// Prevent copying and assignment
KVCacheConfig(const KVCacheConfig&) = delete;
KVCacheConfig& operator=(const KVCacheConfig&) = delete;
// Get singleton instance
static KVCacheConfig& getInstance() {
static KVCacheConfig instance;
return instance;
if (socket_interface_) {
INFO("Init KVCacheConfig Socket Interface: %s\n", socket_interface_);
}
int get_ib_timeout() const { return ib_timeout_; }
INFO("Init KVCacheConfig GDRCopy Flush: %s\n",
gdrcopy_flush_enabled_ ? "enabled" : "disabled");
INFO("Init KVCacheConfig Verify Read: %s\n",
verify_read_enabled_ ? "enabled" : "disabled");
INFO("Init KVCacheConfig Debug Mode: %s\n",
debug_mode_enabled_ ? "enabled" : "disabled");
INFO("Init KVCacheConfig Debug Output: %s\n",
debug_output_enabled_ ? "enabled" : "disabled");
// Configuration retrieval methods
int get_rdma_gid_index() const { return rdma_gid_index_; }
int resolve_rdma_dest_port(int default_port) const {
return has_rdma_dest_port_override_ ? rdma_dest_port_override_ : default_port;
if (debug_file_path_) {
INFO("Init KVCacheConfig Debug File: %s\n", debug_file_path_);
}
int resolve_rdma_dest_port(const std::string& default_port) const {
try {
return resolve_rdma_dest_port(std::stoi(default_port));
} catch (const std::exception& e) {
fprintf(stderr, "Invalid default port string: %s\n", default_port.c_str());
return 0;
}
}
const char* get_socket_interface() const { return socket_interface_; }
const char* get_debug_file_path() const { return debug_file_path_; }
const char* get_error_file_path() const { return error_file_path_; }
const char* get_rdma_nics() const { return rdma_nics_; }
// Feature check methods
bool is_gdrcopy_flush_enabled() const { return gdrcopy_flush_enabled_; }
bool is_verify_read_enabled() const { return verify_read_enabled_; }
bool is_debug_mode_enabled() const { return debug_mode_enabled_; }
bool is_debug_output_enabled() const { return debug_output_enabled_; }
bool is_relax_ordering_enabled() const { return relax_ordering_enabled_; }
// Display configuration
void displayConfiguration() const {
INFO("KVCache Configuration:\n");
INFO("Init KVCacheConfig RDMA GID Index: %d\n", rdma_gid_index_);
if (has_rdma_dest_port_override_) {
INFO("Init KVCacheConfig RDMA Destination Port Override: %d\n", rdma_dest_port_override_);
}
if (socket_interface_) {
INFO("Init KVCacheConfig Socket Interface: %s\n", socket_interface_);
}
INFO("Init KVCacheConfig GDRCopy Flush: %s\n", gdrcopy_flush_enabled_ ? "enabled" : "disabled");
INFO("Init KVCacheConfig Verify Read: %s\n", verify_read_enabled_ ? "enabled" : "disabled");
INFO("Init KVCacheConfig Debug Mode: %s\n", debug_mode_enabled_ ? "enabled" : "disabled");
INFO("Init KVCacheConfig Debug Output: %s\n", debug_output_enabled_ ? "enabled" : "disabled");
if (debug_file_path_) {
INFO("Init KVCacheConfig Debug File: %s\n", debug_file_path_);
}
if (error_file_path_) {
INFO("Init KVCacheConfig Error File: %s\n", error_file_path_);
}
if (error_file_path_) {
INFO("Init KVCacheConfig Error File: %s\n", error_file_path_);
}
}
};
#endif
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -17,14 +17,14 @@
* limitations under the License.
*/
#include <stdlib.h>
#include <stdarg.h>
#include <sys/syscall.h>
#include <sys/stat.h>
#include <libgen.h>
#include <errno.h>
#include <string.h>
#include "log.h"
#include <errno.h>
#include <libgen.h>
#include <stdarg.h>
#include <stdlib.h>
#include <string.h>
#include <sys/stat.h>
#include <sys/syscall.h>
#include "util.h"
static int pid = -1;
@@ -33,180 +33,237 @@ static char hostname[64];
char global_log_last_error[1024] = "";
FILE *global_debug_file = stdout;
FILE *global_error_file = stdout;
static char global_debug_file_name[PATH_MAX+1] = "";
static char global_err_file_name[PATH_MAX+1] = "";
static char global_debug_file_name[PATH_MAX + 1] = "";
static char global_err_file_name[PATH_MAX + 1] = "";
int global_debug_level = -1;
pthread_mutex_t global_debug_lock = PTHREAD_MUTEX_INITIALIZER;
pthread_mutex_t global_log_file_lock = PTHREAD_MUTEX_INITIALIZER;
void log_file_init(FILE **kv_cache_log_file, const char *kv_cache_log_file_env, char *logFileName) {
int c = 0;
char *dfn = logFileName;
while (c < PATH_MAX && kv_cache_log_file_env[c] != '\0') {
if (kv_cache_log_file_env[c++] != '%') {
*dfn++ = kv_cache_log_file_env[c - 1];
continue;
}
switch (kv_cache_log_file_env[c++]) {
case '%': // Double %
*dfn++ = '%';
break;
case 'h': // %h = hostname
dfn += snprintf(dfn, PATH_MAX, "%s", hostname);
break;
case 'p': // %p = pid
dfn += snprintf(dfn, PATH_MAX, "%d", pid);
break;
default: // Echo everything we don't understand
*dfn++ = '%';
*dfn++ = kv_cache_log_file_env[c - 1];
break;
}
void log_file_init(FILE **kv_cache_log_file,
const char *kv_cache_log_file_env,
char *logFileName) {
int c = 0;
char *dfn = logFileName;
while (c < PATH_MAX && kv_cache_log_file_env[c] != '\0') {
if (kv_cache_log_file_env[c++] != '%') {
*dfn++ = kv_cache_log_file_env[c - 1];
continue;
}
*dfn = '\0';
if (logFileName[0] != '\0') {
FILE *file = fopen(logFileName, "w");
if (file != nullptr) {
setbuf(file, nullptr); // disable buffering
*kv_cache_log_file = file;
}
switch (kv_cache_log_file_env[c++]) {
case '%': // Double %
*dfn++ = '%';
break;
case 'h': // %h = hostname
dfn += snprintf(dfn, PATH_MAX, "%s", hostname);
break;
case 'p': // %p = pid
dfn += snprintf(dfn, PATH_MAX, "%d", pid);
break;
default: // Echo everything we don't understand
*dfn++ = '%';
*dfn++ = kv_cache_log_file_env[c - 1];
break;
}
}
*dfn = '\0';
if (logFileName[0] != '\0') {
FILE *file = fopen(logFileName, "w");
if (file != nullptr) {
setbuf(file, nullptr); // disable buffering
*kv_cache_log_file = file;
}
}
}
void recreate_log_file(FILE **kv_cache_log_file, char *logFileName) {
if (logFileName[0] != '\0') {
pthread_mutex_lock(&global_log_file_lock);
FILE *file = fopen(logFileName, "a"); // Use "a" mode to append if file exists, otherwise create it
// close the previous log file if it exists
if (*kv_cache_log_file != NULL && *kv_cache_log_file != file) {
fclose(*kv_cache_log_file);
*kv_cache_log_file = NULL;
}
if (file != NULL) {
setbuf(file, NULL); // disable buffering
*kv_cache_log_file = file;
}
pthread_mutex_unlock(&global_log_file_lock);
if (logFileName[0] != '\0') {
pthread_mutex_lock(&global_log_file_lock);
FILE *file = fopen(
logFileName,
"a"); // Use "a" mode to append if file exists, otherwise create it
// close the previous log file if it exists
if (*kv_cache_log_file != NULL && *kv_cache_log_file != file) {
fclose(*kv_cache_log_file);
*kv_cache_log_file = NULL;
}
if (file != NULL) {
setbuf(file, NULL); // disable buffering
*kv_cache_log_file = file;
}
pthread_mutex_unlock(&global_log_file_lock);
}
}
void debug_init() {
pthread_mutex_lock(&global_debug_lock);
if (global_debug_level != -1) {
pthread_mutex_unlock(&global_debug_lock);
return;
}
const char* kv_cache_debug = std::getenv("KV_IS_DEBUG_ENABLED");
int tempg_kv_cache_debug_level = -1;
if (kv_cache_debug == NULL) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
} else if (strcasecmp(kv_cache_debug, "0") == 0) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
} else if (strcasecmp(kv_cache_debug, "1") == 0) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_DEBUG;
} else if (strcasecmp(kv_cache_debug, "2") == 0) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_WARN;
} else if (strcasecmp(kv_cache_debug, "3") == 0) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_ERROR;
} else {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
}
gethostname(hostname, 64);
pid = getpid();
const char* g_kv_cache_debug_fileEnv = KVCacheConfig::getInstance().get_debug_file_path();
if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO && g_kv_cache_debug_fileEnv != NULL) {
log_file_init(&global_debug_file, g_kv_cache_debug_fileEnv, global_debug_file_name);
}
const char* g_kv_cache_error_fileEnv = KVCacheConfig::getInstance().get_error_file_path();
if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO && g_kv_cache_error_fileEnv != NULL) {
log_file_init(&global_error_file, g_kv_cache_error_fileEnv, global_err_file_name);
char buffer[1024];
size_t len = 0;
char timeBuffer[80]; // Buffer to hold the formatted time
std::time_t absoluteTime = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", std::localtime(&absoluteTime));
len = snprintf(buffer, sizeof(buffer), "%s KV_CACHE START ", timeBuffer);
buffer[len++] = '\n';
if (global_error_file != NULL) {
fwrite(buffer, 1, len, global_error_file);
}
}
__atomic_store_n(&global_debug_level, tempg_kv_cache_debug_level, __ATOMIC_RELEASE);
pthread_mutex_lock(&global_debug_lock);
if (global_debug_level != -1) {
pthread_mutex_unlock(&global_debug_lock);
return;
}
const char *kv_cache_debug = std::getenv("KV_IS_DEBUG_ENABLED");
int tempg_kv_cache_debug_level = -1;
if (kv_cache_debug == NULL) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
} else if (strcasecmp(kv_cache_debug, "0") == 0) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
} else if (strcasecmp(kv_cache_debug, "1") == 0) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_DEBUG;
} else if (strcasecmp(kv_cache_debug, "2") == 0) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_WARN;
} else if (strcasecmp(kv_cache_debug, "3") == 0) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_ERROR;
} else {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
}
gethostname(hostname, 64);
pid = getpid();
const char *g_kv_cache_debug_fileEnv =
KVCacheConfig::getInstance().get_debug_file_path();
if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO &&
g_kv_cache_debug_fileEnv != NULL) {
log_file_init(
&global_debug_file, g_kv_cache_debug_fileEnv, global_debug_file_name);
}
const char *g_kv_cache_error_fileEnv =
KVCacheConfig::getInstance().get_error_file_path();
if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO &&
g_kv_cache_error_fileEnv != NULL) {
log_file_init(
&global_error_file, g_kv_cache_error_fileEnv, global_err_file_name);
char buffer[1024];
size_t len = 0;
char timeBuffer[80]; // Buffer to hold the formatted time
std::time_t absoluteTime =
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
std::strftime(timeBuffer,
sizeof(timeBuffer),
"%Y-%m-%d %H:%M:%S",
std::localtime(&absoluteTime));
len = snprintf(buffer, sizeof(buffer), "%s KV_CACHE START ", timeBuffer);
buffer[len++] = '\n';
if (global_error_file != NULL) {
fwrite(buffer, 1, len, global_error_file);
}
}
__atomic_store_n(
&global_debug_level, tempg_kv_cache_debug_level, __ATOMIC_RELEASE);
pthread_mutex_unlock(&global_debug_lock);
}
/* Common logging function used by the INFO, DEBUG and WARN macros
* Also exported to the dynamically loadable Net transport modules so
* they can share the debugging mechanisms and output files
*/
void debug_log(KVLogLevel level, bool enable_to_terminal, const char *filefunc, int line, const char *fmt, ...) {
if (__atomic_load_n(&global_debug_level, __ATOMIC_ACQUIRE) == -1) {
debug_init();
}
void debug_log(KVLogLevel level,
bool enable_to_terminal,
const char *filefunc,
int line,
const char *fmt,
...) {
if (__atomic_load_n(&global_debug_level, __ATOMIC_ACQUIRE) == -1) {
debug_init();
}
// Save the last error (WARN) as a human readable string
if (level == KV_LOG_LEVEL_WARN) {
pthread_mutex_lock(&global_debug_lock);
va_list vargs;
va_start(vargs, fmt);
(void) vsnprintf(global_log_last_error, sizeof(global_log_last_error), fmt, vargs);
va_end(vargs);
pthread_mutex_unlock(&global_debug_lock);
}
// Save the last error (WARN) as a human readable string
if (level == KV_LOG_LEVEL_WARN) {
pthread_mutex_lock(&global_debug_lock);
va_list vargs;
va_start(vargs, fmt);
(void)vsnprintf(
global_log_last_error, sizeof(global_log_last_error), fmt, vargs);
va_end(vargs);
pthread_mutex_unlock(&global_debug_lock);
}
if (tid == -1) {
tid = syscall(SYS_gettid);
}
if (tid == -1) {
tid = syscall(SYS_gettid);
}
char buffer[1024];
size_t len = 0;
// Convert timestamp to absolute time and directly use it in the snprintf function
std::time_t absoluteTime = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
char timeBuffer[80]; // Buffer to hold the formatted time
std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", std::localtime(&absoluteTime));
char buffer[1024];
size_t len = 0;
// Convert timestamp to absolute time and directly use it in the snprintf
// function
std::time_t absoluteTime =
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
char timeBuffer[80]; // Buffer to hold the formatted time
std::strftime(timeBuffer,
sizeof(timeBuffer),
"%Y-%m-%d %H:%M:%S",
std::localtime(&absoluteTime));
if (level == KV_LOG_LEVEL_WARN) {
len = snprintf(buffer, sizeof(buffer), "\n%s %s:%d:%d %s:%d KV_CACHE WARN ",
timeBuffer, hostname, pid, tid, filefunc, line);
} else if (level == KV_LOG_LEVEL_INFO) {
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE INFO ", timeBuffer, hostname, pid, tid);
} else if (level == KV_LOG_LEVEL_DEBUG) {
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE DEBUG ", timeBuffer, hostname, pid, tid);
} else if (level == KV_LOG_LEVEL_ERROR) {
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE ERROR ", timeBuffer, hostname, pid, tid);
} else {
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE ", timeBuffer, hostname, pid, tid);
}
if (level == KV_LOG_LEVEL_WARN) {
len = snprintf(buffer,
sizeof(buffer),
"\n%s %s:%d:%d %s:%d KV_CACHE WARN ",
timeBuffer,
hostname,
pid,
tid,
filefunc,
line);
} else if (level == KV_LOG_LEVEL_INFO) {
len = snprintf(buffer,
sizeof(buffer),
"%s %s:%d:%d KV_CACHE INFO ",
timeBuffer,
hostname,
pid,
tid);
} else if (level == KV_LOG_LEVEL_DEBUG) {
len = snprintf(buffer,
sizeof(buffer),
"%s %s:%d:%d KV_CACHE DEBUG ",
timeBuffer,
hostname,
pid,
tid);
} else if (level == KV_LOG_LEVEL_ERROR) {
len = snprintf(buffer,
sizeof(buffer),
"%s %s:%d:%d KV_CACHE ERROR ",
timeBuffer,
hostname,
pid,
tid);
} else {
len = snprintf(buffer,
sizeof(buffer),
"%s %s:%d:%d KV_CACHE ",
timeBuffer,
hostname,
pid,
tid);
}
if (len) {
va_list vargs;
va_start(vargs, fmt);
len += vsnprintf(buffer + len, sizeof(buffer) - len, fmt, vargs);
va_end(vargs);
// vsnprintf may return len > sizeof(buffer) in the case of a truncated output.
// Rewind len so that we can replace the final \0 by \n
if (len > sizeof(buffer)) {
len = sizeof(buffer) - 1;
}
buffer[len++] = '\n';
if (access(global_debug_file_name, F_OK) != 0) {
recreate_log_file(&global_debug_file, global_debug_file_name);
}
if (enable_to_terminal) {
fwrite(buffer, 1, len, global_debug_file);
}
if (level == KV_LOG_LEVEL_WARN && global_error_file != stdout) {
if (access(global_err_file_name, F_OK) != 0) {
recreate_log_file(&global_error_file, global_err_file_name);
}
if (global_error_file != NULL) {
fwrite(buffer, 1, len, global_error_file);
}
}
if (len) {
va_list vargs;
va_start(vargs, fmt);
len += vsnprintf(buffer + len, sizeof(buffer) - len, fmt, vargs);
va_end(vargs);
// vsnprintf may return len > sizeof(buffer) in the case of a truncated
// output. Rewind len so that we can replace the final \0 by \n
if (len > sizeof(buffer)) {
len = sizeof(buffer) - 1;
}
buffer[len++] = '\n';
if (access(global_debug_file_name, F_OK) != 0) {
recreate_log_file(&global_debug_file, global_debug_file_name);
}
if (enable_to_terminal) {
fwrite(buffer, 1, len, global_debug_file);
}
if (level == KV_LOG_LEVEL_WARN && global_error_file != stdout) {
if (access(global_err_file_name, F_OK) != 0) {
recreate_log_file(&global_error_file, global_err_file_name);
}
if (global_error_file != NULL) {
fwrite(buffer, 1, len, global_error_file);
}
}
}
}
@@ -6,17 +6,22 @@
namespace py = pybind11;
PYBIND11_MODULE(rdma_comm, m) {
m.doc() = R"pbdoc(kv cache messager)pbdoc";
py::class_<RDMACommunicator>(m, "RDMACommunicator")
.def(py::init<std::string &, int, std::string &, std::vector<int64_t>,
std::vector<int64_t>, int, int>())
.def("connect", &RDMACommunicator::connect)
.def("is_connected", &RDMACommunicator::is_connected)
.def("write_cache", &RDMACommunicator::write_cache);
m.doc() = R"pbdoc(kv cache messager)pbdoc";
py::class_<RDMACommunicator>(m, "RDMACommunicator")
.def(py::init<std::string &,
int,
std::string &,
std::vector<int64_t>,
std::vector<int64_t>,
int,
int>())
.def("connect", &RDMACommunicator::connect)
.def("is_connected", &RDMACommunicator::is_connected)
.def("write_cache", &RDMACommunicator::write_cache);
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
m.attr("__version__") = VERSION_INFO;
#else
m.attr("__version__") = "dev";
m.attr("__version__") = "dev";
#endif
}
+5
View File
@@ -28,6 +28,11 @@ if ! [[ $(python -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $1$2}') -ge 36
please change the default python to higher version."
exit 1
fi
if ! [[ $version == *"$VERSION"* ]]; then
# low version of pip may not have the source of clang-format whl
pip install --upgrade pip
pip install clang-format==13.0.0
fi
# Exclude any files under the 'test/ce/server/' directory from code style checks.
diff_files=$(git diff --name-only --diff-filter=ACMR ${BRANCH} | grep -v '^tests/ce/server/')