mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
c++ code format (#4527)
This commit is contained in:
+1
-1
@@ -16,7 +16,7 @@
|
||||
---
|
||||
Language: Cpp
|
||||
BasedOnStyle: Google
|
||||
IndentWidth: 4
|
||||
IndentWidth: 2
|
||||
TabWidth: 2
|
||||
ContinuationIndentWidth: 4
|
||||
AccessModifierOffset: -1 # The private/protected/public has no indent in class
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
exclude: |
|
||||
(?x)^(
|
||||
dockerfiles/.+
|
||||
)$
|
||||
default_install_hook_types:
|
||||
- pre-commit
|
||||
- commit-msg
|
||||
@@ -27,6 +31,15 @@ repos:
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml]
|
||||
# For C++ files
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: clang-format
|
||||
name: clang-format
|
||||
description: Format files with ClangFormat.
|
||||
entry: clang-format -i
|
||||
language: system
|
||||
files: \.(c|cc|cxx|cpp|cu|h|cuh|hpp|hxx|xpu|kps)$
|
||||
# # 拼写检查
|
||||
# - repo: https://github.com/codespell-project/codespell
|
||||
# rev: v2.4.1
|
||||
|
||||
@@ -19,28 +19,28 @@ std::vector<paddle::Tensor> InvokeAvxWeightOnly(const paddle::Tensor &x,
|
||||
const paddle::Tensor &w_bias,
|
||||
const std::string &alog,
|
||||
bool trans) {
|
||||
auto out_shape = x.shape();
|
||||
out_shape[out_shape.size() - 1] = weight.shape()[1];
|
||||
auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace());
|
||||
return {out};
|
||||
auto out_shape = x.shape();
|
||||
out_shape[out_shape.size() - 1] = weight.shape()[1];
|
||||
auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace());
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> AvxWeightOnlyInferShape(
|
||||
std::vector<int64_t> x_shape,
|
||||
std::vector<int64_t> weigh_shape,
|
||||
std::vector<int64_t> weigh_bias_shape) {
|
||||
int m = 1;
|
||||
for (int i = 0; i < x_shape.size() - 1; i++) {
|
||||
m = m * x_shape[i];
|
||||
}
|
||||
return {std::vector<int64_t>{m, weigh_shape[1]}};
|
||||
int m = 1;
|
||||
for (int i = 0; i < x_shape.size() - 1; i++) {
|
||||
m = m * x_shape[i];
|
||||
}
|
||||
return {std::vector<int64_t>{m, weigh_shape[1]}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> AvxWeightOnlyInferDtype(
|
||||
paddle::DataType x_dtype,
|
||||
paddle::DataType weight_dtype,
|
||||
paddle::DataType weight_bias_dtype) {
|
||||
return {x_dtype};
|
||||
return {x_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(avx_weight_only)
|
||||
|
||||
@@ -20,13 +20,13 @@ void remove_padding(int64_t *output_data,
|
||||
const int *cum_offsets,
|
||||
const int sequence_length,
|
||||
const int bsz) {
|
||||
for (int bi = 0; bi < bsz; ++bi) {
|
||||
for (int i = 0; i < seq_lens[bi]; ++i) {
|
||||
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
|
||||
const int src_seq_id = bi * sequence_length + i;
|
||||
output_data[tgt_seq_id] = input_data[src_seq_id];
|
||||
}
|
||||
for (int bi = 0; bi < bsz; ++bi) {
|
||||
for (int i = 0; i < seq_lens[bi]; ++i) {
|
||||
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
|
||||
const int src_seq_id = bi * sequence_length + i;
|
||||
output_data[tgt_seq_id] = input_data[src_seq_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void get_padding_offset_kernel(int *padding_offset,
|
||||
@@ -37,56 +37,53 @@ void get_padding_offset_kernel(int *padding_offset,
|
||||
const int *seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bsz) {
|
||||
for (int bi = 0; bi < bsz; ++bi) {
|
||||
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
||||
auto seq_len_now = seq_lens[bi];
|
||||
for (int i = 0; i < seq_len_now; ++i) {
|
||||
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
|
||||
}
|
||||
cum_offsets_out[bi] = cum_offset;
|
||||
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
|
||||
cu_seqlens_q[bi + 1] = cum_seq_len;
|
||||
cu_seqlens_k[bi + 1] = cum_seq_len;
|
||||
for (int bi = 0; bi < bsz; ++bi) {
|
||||
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
||||
auto seq_len_now = seq_lens[bi];
|
||||
for (int i = 0; i < seq_len_now; ++i) {
|
||||
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
|
||||
}
|
||||
cum_offsets_out[bi] = cum_offset;
|
||||
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
|
||||
cu_seqlens_q[bi + 1] = cum_seq_len;
|
||||
cu_seqlens_k[bi + 1] = cum_seq_len;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &token_num,
|
||||
const paddle::Tensor &seq_len) {
|
||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int seq_length = input_ids_shape[1];
|
||||
auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false);
|
||||
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
|
||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int seq_length = input_ids_shape[1];
|
||||
auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false);
|
||||
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||
auto x_remove_padding = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
||||
auto padding_offset = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_q =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
get_padding_offset_kernel(padding_offset.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_len.data<int>(),
|
||||
seq_length,
|
||||
bsz);
|
||||
remove_padding(x_remove_padding.data<int64_t>(),
|
||||
input_ids.data<int64_t>(),
|
||||
seq_len.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
seq_length,
|
||||
bsz);
|
||||
return {x_remove_padding,
|
||||
padding_offset,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k};
|
||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||
auto x_remove_padding = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
||||
auto padding_offset = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_q =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
get_padding_offset_kernel(padding_offset.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_len.data<int>(),
|
||||
seq_length,
|
||||
bsz);
|
||||
remove_padding(x_remove_padding.data<int64_t>(),
|
||||
input_ids.data<int64_t>(),
|
||||
seq_len.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
seq_length,
|
||||
bsz);
|
||||
return {x_remove_padding, padding_offset, cu_seqlens_q, cu_seqlens_k};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
|
||||
@@ -94,9 +91,9 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
|
||||
const std::vector<int64_t> &cum_offsets_shape,
|
||||
const std::vector<int64_t> &token_num_shape,
|
||||
const std::vector<int64_t> &seq_len_shape) {
|
||||
int64_t bsz = seq_len_shape[0];
|
||||
int64_t seq_len = input_ids_shape[1];
|
||||
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
int64_t bsz = seq_len_shape[0];
|
||||
int64_t seq_len = input_ids_shape[1];
|
||||
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
@@ -104,18 +101,13 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
const paddle::DataType &cum_offsets_dtype,
|
||||
const paddle::DataType &token_num_dtype,
|
||||
const paddle::DataType &seq_len_dtype) {
|
||||
return {input_ids_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype};
|
||||
return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_padding_offset_cpu)
|
||||
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
||||
.Outputs({"x_remove_padding",
|
||||
"padding_offset",
|
||||
"cu_seqlens_q",
|
||||
"cu_seqlens_k"})
|
||||
.Outputs(
|
||||
{"x_remove_padding", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"})
|
||||
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype));
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
|
||||
template <typename T>
|
||||
void RebuildPaddingCPUImpl(T *output_data,
|
||||
const T *input_data,
|
||||
@@ -30,27 +29,27 @@ void RebuildPaddingCPUImpl(T *output_data,
|
||||
int max_input_length,
|
||||
int dim_embed,
|
||||
const int elem_nums) {
|
||||
for (int i = 0; i < elem_nums; ++i) {
|
||||
const int bi = i / dim_embed;
|
||||
const int bias_idx = i % dim_embed;
|
||||
int seq_id = 0;
|
||||
for (int i = 0; i < elem_nums; ++i) {
|
||||
const int bi = i / dim_embed;
|
||||
const int bias_idx = i % dim_embed;
|
||||
int seq_id = 0;
|
||||
|
||||
if (seq_len_this_time_data[bi] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (seq_lens_encoder_data[bi] > 0) {
|
||||
seq_id = seq_lens_encoder_data[bi] - 1;
|
||||
}
|
||||
|
||||
const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
|
||||
const int src_offset = ori_token_idx * dim_embed + bias_idx;
|
||||
|
||||
output_data[i] = input_data[src_offset];
|
||||
if (seq_len_this_time_data[bi] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (seq_lens_encoder_data[bi] > 0) {
|
||||
seq_id = seq_lens_encoder_data[bi] - 1;
|
||||
}
|
||||
|
||||
const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
|
||||
const int src_offset = ori_token_idx * dim_embed + bias_idx;
|
||||
|
||||
output_data[i] = input_data[src_offset];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -64,27 +63,25 @@ void RebuildAppendPaddingCPUImpl(T *output_data,
|
||||
const int max_input_length,
|
||||
const int dim_embed,
|
||||
const int64_t output_elem_nums) {
|
||||
for (int i = 0; i < output_elem_nums; ++i) {
|
||||
int out_token_id = i / dim_embed;
|
||||
int ori_token_id =
|
||||
out_token_id + output_padding_offset_data[out_token_id];
|
||||
int bi = ori_token_id / max_input_length;
|
||||
if (seq_len_this_time_data[bi] == 0 ||
|
||||
(seq_lens_decoder_data[bi] == 0 &&
|
||||
seq_lens_encoder_data[bi] == 0)) {
|
||||
continue;
|
||||
}
|
||||
int seq_id = 0;
|
||||
|
||||
if (seq_lens_encoder_data[bi] > 0) {
|
||||
seq_id = seq_lens_encoder_data[bi] - 1;
|
||||
}
|
||||
int input_token_id = cu_seqlens_q_data[bi] + seq_id;
|
||||
int bias_idx = i % dim_embed;
|
||||
int src_offset = input_token_id * dim_embed + bias_idx;
|
||||
|
||||
output_data[i] = input_data[src_offset];
|
||||
for (int i = 0; i < output_elem_nums; ++i) {
|
||||
int out_token_id = i / dim_embed;
|
||||
int ori_token_id = out_token_id + output_padding_offset_data[out_token_id];
|
||||
int bi = ori_token_id / max_input_length;
|
||||
if (seq_len_this_time_data[bi] == 0 ||
|
||||
(seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0)) {
|
||||
continue;
|
||||
}
|
||||
int seq_id = 0;
|
||||
|
||||
if (seq_lens_encoder_data[bi] > 0) {
|
||||
seq_id = seq_lens_encoder_data[bi] - 1;
|
||||
}
|
||||
int input_token_id = cu_seqlens_q_data[bi] + seq_id;
|
||||
int bias_idx = i % dim_embed;
|
||||
int src_offset = input_token_id * dim_embed + bias_idx;
|
||||
|
||||
output_data[i] = input_data[src_offset];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
@@ -95,140 +92,139 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
|
||||
auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
|
||||
auto seq_len_this_time_cpu =
|
||||
seq_len_this_time.copy_to(paddle::CPUPlace(), true);
|
||||
auto seq_lens_decoder_cpu =
|
||||
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
|
||||
auto seq_lens_encoder_cpu =
|
||||
seq_lens_encoder.copy_to(paddle::CPUPlace(), true);
|
||||
paddle::optional<paddle::Tensor> output_padding_offset_cpu;
|
||||
if (output_padding_offset) {
|
||||
output_padding_offset_cpu =
|
||||
output_padding_offset->copy_to(paddle::CPUPlace(), true);
|
||||
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
|
||||
auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
|
||||
auto seq_len_this_time_cpu =
|
||||
seq_len_this_time.copy_to(paddle::CPUPlace(), true);
|
||||
auto seq_lens_decoder_cpu =
|
||||
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
|
||||
auto seq_lens_encoder_cpu =
|
||||
seq_lens_encoder.copy_to(paddle::CPUPlace(), true);
|
||||
paddle::optional<paddle::Tensor> output_padding_offset_cpu;
|
||||
if (output_padding_offset) {
|
||||
output_padding_offset_cpu =
|
||||
output_padding_offset->copy_to(paddle::CPUPlace(), true);
|
||||
}
|
||||
|
||||
int token_num = tmp_out_cpu.shape()[0];
|
||||
int dim_embed = tmp_out_cpu.shape()[1];
|
||||
int bsz = cu_seqlens_q_cpu.shape()[0] - 1;
|
||||
|
||||
paddle::Tensor out;
|
||||
if (output_padding_offset_cpu) {
|
||||
int need_delete_token_num = 0;
|
||||
for (int i = 0; i < bsz; ++i) {
|
||||
if (seq_lens_encoder_cpu.data<int>()[i] > 0) {
|
||||
need_delete_token_num += seq_lens_encoder_cpu.data<int>()[i] - 1;
|
||||
}
|
||||
}
|
||||
int output_token_num = token_num - need_delete_token_num;
|
||||
out = paddle::full({output_token_num, dim_embed},
|
||||
0,
|
||||
tmp_out_cpu.dtype(),
|
||||
paddle::CPUPlace());
|
||||
} else {
|
||||
out = paddle::full(
|
||||
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
|
||||
}
|
||||
|
||||
int token_num = tmp_out_cpu.shape()[0];
|
||||
int dim_embed = tmp_out_cpu.shape()[1];
|
||||
int bsz = cu_seqlens_q_cpu.shape()[0] - 1;
|
||||
const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
|
||||
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
|
||||
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
|
||||
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
|
||||
int elem_nums = out.numel();
|
||||
|
||||
paddle::Tensor out;
|
||||
if (output_padding_offset_cpu) {
|
||||
int need_delete_token_num = 0;
|
||||
for (int i = 0; i < bsz; ++i) {
|
||||
if (seq_lens_encoder_cpu.data<int>()[i] > 0) {
|
||||
need_delete_token_num +=
|
||||
seq_lens_encoder_cpu.data<int>()[i] - 1;
|
||||
}
|
||||
}
|
||||
int output_token_num = token_num - need_delete_token_num;
|
||||
out = paddle::full({output_token_num, dim_embed},
|
||||
0,
|
||||
tmp_out_cpu.dtype(),
|
||||
paddle::CPUPlace());
|
||||
} else {
|
||||
out = paddle::full(
|
||||
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
|
||||
if (output_padding_offset_cpu) {
|
||||
const int *output_padding_offset_data =
|
||||
output_padding_offset_cpu->data<int>();
|
||||
switch (tmp_out_cpu.dtype()) {
|
||||
case paddle::DataType::FLOAT32:
|
||||
RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
|
||||
tmp_out_cpu.data<float>(),
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
output_padding_offset_data,
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
RebuildAppendPaddingCPUImpl<paddle::float16>(
|
||||
out.data<paddle::float16>(),
|
||||
tmp_out_cpu.data<paddle::float16>(),
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
output_padding_offset_data,
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
break;
|
||||
case paddle::DataType::BFLOAT16:
|
||||
RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
|
||||
out.data<paddle::bfloat16>(),
|
||||
tmp_out_cpu.data<paddle::bfloat16>(),
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
output_padding_offset_data,
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
break;
|
||||
default:
|
||||
PD_THROW(
|
||||
"Unsupported data type for rebuild_padding_cpu. "
|
||||
"Only float32, float16, and bfloat16 are supported.");
|
||||
}
|
||||
|
||||
const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
|
||||
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
|
||||
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
|
||||
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
|
||||
int elem_nums = out.numel();
|
||||
|
||||
if (output_padding_offset_cpu) {
|
||||
const int *output_padding_offset_data =
|
||||
output_padding_offset_cpu->data<int>();
|
||||
switch (tmp_out_cpu.dtype()) {
|
||||
case paddle::DataType::FLOAT32:
|
||||
RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
|
||||
tmp_out_cpu.data<float>(),
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
output_padding_offset_data,
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
RebuildAppendPaddingCPUImpl<paddle::float16>(
|
||||
out.data<paddle::float16>(),
|
||||
tmp_out_cpu.data<paddle::float16>(),
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
output_padding_offset_data,
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
break;
|
||||
case paddle::DataType::BFLOAT16:
|
||||
RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
|
||||
out.data<paddle::bfloat16>(),
|
||||
tmp_out_cpu.data<paddle::bfloat16>(),
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
output_padding_offset_data,
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
break;
|
||||
default:
|
||||
PD_THROW(
|
||||
"Unsupported data type for rebuild_padding_cpu. "
|
||||
"Only float32, float16, and bfloat16 are supported.");
|
||||
}
|
||||
} else {
|
||||
switch (tmp_out_cpu.dtype()) {
|
||||
case paddle::DataType::FLOAT32:
|
||||
RebuildPaddingCPUImpl<float>(out.data<float>(),
|
||||
tmp_out_cpu.data<float>(),
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
RebuildPaddingCPUImpl<paddle::float16>(
|
||||
out.data<paddle::float16>(),
|
||||
tmp_out_cpu.data<paddle::float16>(),
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
break;
|
||||
case paddle::DataType::BFLOAT16:
|
||||
RebuildPaddingCPUImpl<paddle::bfloat16>(
|
||||
out.data<paddle::bfloat16>(),
|
||||
tmp_out_cpu.data<paddle::bfloat16>(),
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
break;
|
||||
default:
|
||||
PD_THROW(
|
||||
"Unsupported data type for rebuild_padding_cpu. "
|
||||
"Only float32, float16, and bfloat16 are supported.");
|
||||
}
|
||||
} else {
|
||||
switch (tmp_out_cpu.dtype()) {
|
||||
case paddle::DataType::FLOAT32:
|
||||
RebuildPaddingCPUImpl<float>(out.data<float>(),
|
||||
tmp_out_cpu.data<float>(),
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
RebuildPaddingCPUImpl<paddle::float16>(
|
||||
out.data<paddle::float16>(),
|
||||
tmp_out_cpu.data<paddle::float16>(),
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
break;
|
||||
case paddle::DataType::BFLOAT16:
|
||||
RebuildPaddingCPUImpl<paddle::bfloat16>(
|
||||
out.data<paddle::bfloat16>(),
|
||||
tmp_out_cpu.data<paddle::bfloat16>(),
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
break;
|
||||
default:
|
||||
PD_THROW(
|
||||
"Unsupported data type for rebuild_padding_cpu. "
|
||||
"Only float32, float16, and bfloat16 are supported.");
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
|
||||
@@ -238,13 +234,13 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
|
||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
|
||||
int64_t dim_embed = tmp_out_shape[1];
|
||||
if (output_padding_offset_shape) {
|
||||
return {{-1, dim_embed}};
|
||||
} else {
|
||||
int64_t bsz = cu_seqlens_q_shape[0] - 1;
|
||||
return {{bsz, dim_embed}};
|
||||
}
|
||||
int64_t dim_embed = tmp_out_shape[1];
|
||||
if (output_padding_offset_shape) {
|
||||
return {{-1, dim_embed}};
|
||||
} else {
|
||||
int64_t bsz = cu_seqlens_q_shape[0] - 1;
|
||||
return {{bsz, dim_embed}};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> RebuildPaddingInferDtype(
|
||||
@@ -254,7 +250,7 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
|
||||
const paddle::DataType &seq_lens_decoder_dtype,
|
||||
const paddle::DataType &seq_lens_encoder_dtype,
|
||||
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
|
||||
return {tmp_out_dtype};
|
||||
return {tmp_out_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(rebuild_padding_cpu)
|
||||
|
||||
@@ -15,27 +15,27 @@
|
||||
#include "paddle/extension.h"
|
||||
|
||||
void set_value_by_flags_and_idx(const bool *stop_flags,
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *input_ids,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *step_idx,
|
||||
int bs,
|
||||
int length,
|
||||
int length_input_ids) {
|
||||
for (int bi = 0; bi < bs; bi++) {
|
||||
if (!stop_flags[bi]) {
|
||||
const int seq_len_dec = seq_lens_decoder[bi];
|
||||
const int seq_len_enc = seq_lens_encoder[bi];
|
||||
int64_t *pre_ids_all_now = pre_ids_all + bi * length;
|
||||
const int64_t *input_ids_now = input_ids + bi * length_input_ids;
|
||||
if (seq_len_dec == 0) {
|
||||
pre_ids_all_now[step_idx[bi]] = input_ids_now[seq_len_enc - 1];
|
||||
} else {
|
||||
pre_ids_all_now[step_idx[bi]] = input_ids_now[0];
|
||||
}
|
||||
}
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *input_ids,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *step_idx,
|
||||
int bs,
|
||||
int length,
|
||||
int length_input_ids) {
|
||||
for (int bi = 0; bi < bs; bi++) {
|
||||
if (!stop_flags[bi]) {
|
||||
const int seq_len_dec = seq_lens_decoder[bi];
|
||||
const int seq_len_enc = seq_lens_encoder[bi];
|
||||
int64_t *pre_ids_all_now = pre_ids_all + bi * length;
|
||||
const int64_t *input_ids_now = input_ids + bi * length_input_ids;
|
||||
if (seq_len_dec == 0) {
|
||||
pre_ids_all_now[step_idx[bi]] = input_ids_now[seq_len_enc - 1];
|
||||
} else {
|
||||
pre_ids_all_now[step_idx[bi]] = input_ids_now[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
@@ -45,12 +45,12 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_flags) {
|
||||
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
|
||||
int bs = seq_lens_this_time.shape()[0];
|
||||
int length = pre_ids_all_shape[1];
|
||||
int length_input_ids = input_ids.shape()[1];
|
||||
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
|
||||
int bs = seq_lens_this_time.shape()[0];
|
||||
int length = pre_ids_all_shape[1];
|
||||
int length_input_ids = input_ids.shape()[1];
|
||||
|
||||
set_value_by_flags_and_idx(stop_flags.data<bool>(),
|
||||
set_value_by_flags_and_idx(stop_flags.data<bool>(),
|
||||
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
||||
input_ids.data<int64_t>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
|
||||
@@ -21,45 +21,45 @@ void probs_sort(const float *probs,
|
||||
float *ProbsVals,
|
||||
int vocab_size,
|
||||
int bsz) {
|
||||
float cursum = 0;
|
||||
std::vector<int64_t> elementsIds(vocab_size);
|
||||
std::vector<float> elementsProbs(vocab_size);
|
||||
float cursum = 0;
|
||||
std::vector<int64_t> elementsIds(vocab_size);
|
||||
std::vector<float> elementsProbs(vocab_size);
|
||||
#pragma omp parallel for
|
||||
for (int j = 0; j < vocab_size; j++) {
|
||||
elementsIds[j] = j;
|
||||
elementsProbs[j] = probs[j];
|
||||
}
|
||||
x86simdsortStatic::keyvalue_qsort(
|
||||
elementsProbs.data(), elementsIds.data(), vocab_size, false, true);
|
||||
for (int j = 0; j < vocab_size; j++) {
|
||||
elementsIds[j] = j;
|
||||
elementsProbs[j] = probs[j];
|
||||
}
|
||||
x86simdsortStatic::keyvalue_qsort(
|
||||
elementsProbs.data(), elementsIds.data(), vocab_size, false, true);
|
||||
#pragma omp parallel for
|
||||
for (int j = 0; j < vocab_size; ++j) {
|
||||
ProbsVals[j] = elementsProbs[j];
|
||||
ProbsIds[j] = elementsIds[j];
|
||||
}
|
||||
for (int j = 0; j < vocab_size; ++j) {
|
||||
ProbsVals[j] = elementsProbs[j];
|
||||
ProbsIds[j] = elementsIds[j];
|
||||
}
|
||||
}
|
||||
std::vector<paddle::Tensor> SimdSort(const paddle::Tensor &probs) {
|
||||
const int bsz = probs.shape()[0];
|
||||
const int vocab_size = probs.shape()[1];
|
||||
auto sorted_indices = paddle::empty(
|
||||
{bsz, vocab_size}, paddle::DataType::INT64, probs.place());
|
||||
auto sorted_probs = paddle::empty(
|
||||
{bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place());
|
||||
probs_sort(probs.data<float>(),
|
||||
const_cast<int64_t *>(sorted_indices.data<int64_t>()),
|
||||
const_cast<float *>(sorted_probs.data<float>()),
|
||||
vocab_size,
|
||||
bsz);
|
||||
return {sorted_indices, sorted_probs};
|
||||
const int bsz = probs.shape()[0];
|
||||
const int vocab_size = probs.shape()[1];
|
||||
auto sorted_indices =
|
||||
paddle::empty({bsz, vocab_size}, paddle::DataType::INT64, probs.place());
|
||||
auto sorted_probs = paddle::empty(
|
||||
{bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place());
|
||||
probs_sort(probs.data<float>(),
|
||||
const_cast<int64_t *>(sorted_indices.data<int64_t>()),
|
||||
const_cast<float *>(sorted_probs.data<float>()),
|
||||
vocab_size,
|
||||
bsz);
|
||||
return {sorted_indices, sorted_probs};
|
||||
}
|
||||
std::vector<std::vector<int64_t>> SimdSortInferShape(
|
||||
const std::vector<int64_t> &probs_shape) {
|
||||
int64_t bsz = probs_shape[0];
|
||||
int64_t vocab_size = probs_shape[1];
|
||||
return {{bsz, vocab_size}, {bsz, vocab_size}};
|
||||
int64_t bsz = probs_shape[0];
|
||||
int64_t vocab_size = probs_shape[1];
|
||||
return {{bsz, vocab_size}, {bsz, vocab_size}};
|
||||
}
|
||||
std::vector<paddle::DataType> SimdSortInferDtype(
|
||||
const paddle::DataType &probs_dtype) {
|
||||
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
|
||||
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
|
||||
}
|
||||
PD_BUILD_STATIC_OP(simd_sort)
|
||||
.Inputs({"probs"})
|
||||
|
||||
@@ -16,23 +16,23 @@
|
||||
#include "paddle/extension.h"
|
||||
|
||||
std::vector<paddle::Tensor> SimdSort(const paddle::Tensor &probs) {
|
||||
const int bsz = probs.shape()[0];
|
||||
const int vocab_size = probs.shape()[1];
|
||||
auto sorted_indices = paddle::empty(
|
||||
{bsz, vocab_size}, paddle::DataType::INT64, probs.place());
|
||||
auto sorted_probs = paddle::empty(
|
||||
{bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place());
|
||||
return {sorted_indices, sorted_probs};
|
||||
const int bsz = probs.shape()[0];
|
||||
const int vocab_size = probs.shape()[1];
|
||||
auto sorted_indices =
|
||||
paddle::empty({bsz, vocab_size}, paddle::DataType::INT64, probs.place());
|
||||
auto sorted_probs = paddle::empty(
|
||||
{bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place());
|
||||
return {sorted_indices, sorted_probs};
|
||||
}
|
||||
std::vector<std::vector<int64_t>> SimdSortInferShape(
|
||||
const std::vector<int64_t> &probs_shape) {
|
||||
int64_t bsz = probs_shape[0];
|
||||
int64_t vocab_size = probs_shape[1];
|
||||
return {{bsz, vocab_size}, {bsz, vocab_size}};
|
||||
int64_t bsz = probs_shape[0];
|
||||
int64_t vocab_size = probs_shape[1];
|
||||
return {{bsz, vocab_size}, {bsz, vocab_size}};
|
||||
}
|
||||
std::vector<paddle::DataType> SimdSortInferDtype(
|
||||
const paddle::DataType &probs_dtype) {
|
||||
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
|
||||
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
|
||||
}
|
||||
PD_BUILD_STATIC_OP(simd_sort)
|
||||
.Inputs({"probs"})
|
||||
|
||||
@@ -23,13 +23,13 @@
|
||||
#endif
|
||||
|
||||
bool is_in_end(const int64_t id, const int64_t *end_ids, int length) {
|
||||
bool flag = false;
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (id == end_ids[i]) {
|
||||
return true;
|
||||
}
|
||||
bool flag = false;
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (id == end_ids[i]) {
|
||||
return true;
|
||||
}
|
||||
return flag;
|
||||
}
|
||||
return flag;
|
||||
}
|
||||
|
||||
void set_value_by_flags(bool *stop_flags,
|
||||
@@ -40,23 +40,23 @@ void set_value_by_flags(bool *stop_flags,
|
||||
const int bs,
|
||||
const int end_length,
|
||||
bool beam_search) {
|
||||
for (int bi = 0; bi < bs; bi++) {
|
||||
if (stop_flags[bi]) {
|
||||
if ((seq_lens[bi] == 0)) {
|
||||
topk_ids[bi] = -1;
|
||||
} else {
|
||||
topk_ids[bi] = end_ids[0];
|
||||
next_tokens[bi] = end_ids[0];
|
||||
}
|
||||
} else {
|
||||
next_tokens[bi] = topk_ids[bi];
|
||||
}
|
||||
if (!beam_search && is_in_end(topk_ids[bi], end_ids, end_length)) {
|
||||
stop_flags[bi] = true;
|
||||
topk_ids[bi] = end_ids[0];
|
||||
next_tokens[bi] = end_ids[0];
|
||||
}
|
||||
for (int bi = 0; bi < bs; bi++) {
|
||||
if (stop_flags[bi]) {
|
||||
if ((seq_lens[bi] == 0)) {
|
||||
topk_ids[bi] = -1;
|
||||
} else {
|
||||
topk_ids[bi] = end_ids[0];
|
||||
next_tokens[bi] = end_ids[0];
|
||||
}
|
||||
} else {
|
||||
next_tokens[bi] = topk_ids[bi];
|
||||
}
|
||||
if (!beam_search && is_in_end(topk_ids[bi], end_ids, end_length)) {
|
||||
stop_flags[bi] = true;
|
||||
topk_ids[bi] = end_ids[0];
|
||||
next_tokens[bi] = end_ids[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
@@ -65,17 +65,17 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &end_ids,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const bool beam_search) {
|
||||
std::vector<int64_t> shape = topk_ids.shape();
|
||||
int64_t bs_now = shape[0];
|
||||
int64_t end_length = end_ids.shape()[0];
|
||||
set_value_by_flags(const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
end_ids.data<int64_t>(),
|
||||
seq_lens.data<int>(),
|
||||
bs_now,
|
||||
end_length,
|
||||
false);
|
||||
std::vector<int64_t> shape = topk_ids.shape();
|
||||
int64_t bs_now = shape[0];
|
||||
int64_t end_length = end_ids.shape()[0];
|
||||
set_value_by_flags(const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
end_ids.data<int64_t>(),
|
||||
seq_lens.data<int>(),
|
||||
bs_now,
|
||||
end_length,
|
||||
false);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(set_stop_value_multi_ends_cpu)
|
||||
|
||||
@@ -23,16 +23,16 @@ void min_length_logits_process(float *logits,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t end_length) {
|
||||
for (int bi = 0; bi < bs; ++bi) {
|
||||
if (cur_len[bi] < 0) {
|
||||
continue;
|
||||
}
|
||||
if (cur_len[bi] < min_len[bi]) {
|
||||
for (int i = 0; i < end_length; ++i) {
|
||||
logits[bi * length + eos_token_id[i]] = -1e10;
|
||||
}
|
||||
}
|
||||
for (int bi = 0; bi < bs; ++bi) {
|
||||
if (cur_len[bi] < 0) {
|
||||
continue;
|
||||
}
|
||||
if (cur_len[bi] < min_len[bi]) {
|
||||
for (int i = 0; i < end_length; ++i) {
|
||||
logits[bi * length + eos_token_id[i]] = -1e10;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void update_repeat_times(const int64_t *pre_ids,
|
||||
@@ -41,20 +41,20 @@ void update_repeat_times(const int64_t *pre_ids,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id) {
|
||||
for (int bi = 0; bi < bs; ++bi) {
|
||||
if (cur_len[bi] < 0) {
|
||||
continue;
|
||||
}
|
||||
const int64_t *pre_ids_now = pre_ids + bi * length_id;
|
||||
int *repeat_times_now = repeat_times + bi * length;
|
||||
for (int i = 0; i < length_id; i++) {
|
||||
int64_t id = pre_ids_now[i];
|
||||
if (id < 0) {
|
||||
break;
|
||||
}
|
||||
repeat_times_now[id] += 1;
|
||||
}
|
||||
for (int bi = 0; bi < bs; ++bi) {
|
||||
if (cur_len[bi] < 0) {
|
||||
continue;
|
||||
}
|
||||
const int64_t *pre_ids_now = pre_ids + bi * length_id;
|
||||
int *repeat_times_now = repeat_times + bi * length;
|
||||
for (int i = 0; i < length_id; i++) {
|
||||
int64_t id = pre_ids_now[i];
|
||||
if (id < 0) {
|
||||
break;
|
||||
}
|
||||
repeat_times_now[id] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void update_value_by_repeat_times(const int *repeat_times,
|
||||
@@ -65,24 +65,22 @@ void update_value_by_repeat_times(const int *repeat_times,
|
||||
float *logits,
|
||||
const int64_t bs,
|
||||
const int64_t length) {
|
||||
for (int bi = 0; bi < bs; ++bi) {
|
||||
float *logits_now = logits + bi * length;
|
||||
const int *repeat_times_now = repeat_times + bi * length;
|
||||
float alpha = static_cast<float>(penalty_scores[bi]);
|
||||
float beta = static_cast<float>(frequency_score[bi]);
|
||||
float gamma = static_cast<float>(presence_score[bi]);
|
||||
for (int i = 0; i < length; ++i) {
|
||||
int times = repeat_times_now[i];
|
||||
float logit_now = static_cast<float>(logits_now[i]);
|
||||
if (times == 0) {
|
||||
logits_now[i] =
|
||||
static_cast<float>(logit_now / temperatures[bi]);
|
||||
}
|
||||
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
|
||||
logits_now[i] =
|
||||
static_cast<float>(logit_now - times * beta - gamma);
|
||||
}
|
||||
for (int bi = 0; bi < bs; ++bi) {
|
||||
float *logits_now = logits + bi * length;
|
||||
const int *repeat_times_now = repeat_times + bi * length;
|
||||
float alpha = static_cast<float>(penalty_scores[bi]);
|
||||
float beta = static_cast<float>(frequency_score[bi]);
|
||||
float gamma = static_cast<float>(presence_score[bi]);
|
||||
for (int i = 0; i < length; ++i) {
|
||||
int times = repeat_times_now[i];
|
||||
float logit_now = static_cast<float>(logits_now[i]);
|
||||
if (times == 0) {
|
||||
logits_now[i] = static_cast<float>(logit_now / temperatures[bi]);
|
||||
}
|
||||
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
|
||||
logits_now[i] = static_cast<float>(logit_now - times * beta - gamma);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ban_bad_words(float *logits,
|
||||
@@ -90,15 +88,14 @@ void ban_bad_words(float *logits,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t bad_words_length) {
|
||||
for (int bi = 0; bi < bs; ++bi) {
|
||||
float *logits_now = logits + bi * length;
|
||||
for (int bwid = 0; bwid < bad_words_length; ++bwid) {
|
||||
const int64_t bad_words_token_id = bad_words_list[bwid];
|
||||
if (bad_words_token_id >= length || bad_words_token_id < 0)
|
||||
continue;
|
||||
logits_now[bad_words_token_id] = -1e10;
|
||||
}
|
||||
for (int bi = 0; bi < bs; ++bi) {
|
||||
float *logits_now = logits + bi * length;
|
||||
for (int bwid = 0; bwid < bad_words_length; ++bwid) {
|
||||
const int64_t bad_words_token_id = bad_words_list[bwid];
|
||||
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
|
||||
logits_now[bad_words_token_id] = -1e10;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
@@ -112,44 +109,44 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id) {
|
||||
std::vector<int64_t> shape = logits.shape();
|
||||
auto repeat_times =
|
||||
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
|
||||
int64_t bs = shape[0];
|
||||
int64_t length = shape[1];
|
||||
int64_t length_id = pre_ids.shape()[1];
|
||||
int64_t end_length = eos_token_id.shape()[0];
|
||||
int64_t length_bad_words = bad_tokens.shape()[0];
|
||||
std::vector<int64_t> shape = logits.shape();
|
||||
auto repeat_times =
|
||||
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
|
||||
int64_t bs = shape[0];
|
||||
int64_t length = shape[1];
|
||||
int64_t length_id = pre_ids.shape()[1];
|
||||
int64_t end_length = eos_token_id.shape()[0];
|
||||
int64_t length_bad_words = bad_tokens.shape()[0];
|
||||
|
||||
min_length_logits_process(const_cast<float *>(logits.data<float>()),
|
||||
cur_len.data<int64_t>(),
|
||||
min_len.data<int64_t>(),
|
||||
eos_token_id.data<int64_t>(),
|
||||
bs,
|
||||
length,
|
||||
end_length);
|
||||
min_length_logits_process(const_cast<float *>(logits.data<float>()),
|
||||
cur_len.data<int64_t>(),
|
||||
min_len.data<int64_t>(),
|
||||
eos_token_id.data<int64_t>(),
|
||||
bs,
|
||||
length,
|
||||
end_length);
|
||||
|
||||
update_repeat_times(pre_ids.data<int64_t>(),
|
||||
cur_len.data<int64_t>(),
|
||||
repeat_times.data<int>(),
|
||||
bs,
|
||||
length,
|
||||
length_id);
|
||||
update_repeat_times(pre_ids.data<int64_t>(),
|
||||
cur_len.data<int64_t>(),
|
||||
repeat_times.data<int>(),
|
||||
bs,
|
||||
length,
|
||||
length_id);
|
||||
|
||||
update_value_by_repeat_times(repeat_times.data<int>(),
|
||||
penalty_scores.data<float>(),
|
||||
frequency_score.data<float>(),
|
||||
presence_score.data<float>(),
|
||||
temperatures.data<float>(),
|
||||
const_cast<float *>(logits.data<float>()),
|
||||
bs,
|
||||
length);
|
||||
update_value_by_repeat_times(repeat_times.data<int>(),
|
||||
penalty_scores.data<float>(),
|
||||
frequency_score.data<float>(),
|
||||
presence_score.data<float>(),
|
||||
temperatures.data<float>(),
|
||||
const_cast<float *>(logits.data<float>()),
|
||||
bs,
|
||||
length);
|
||||
|
||||
ban_bad_words(const_cast<float *>(logits.data<float>()),
|
||||
bad_tokens.data<int64_t>(),
|
||||
bs,
|
||||
length,
|
||||
length_bad_words);
|
||||
ban_bad_words(const_cast<float *>(logits.data<float>()),
|
||||
bad_tokens.data<int64_t>(),
|
||||
bs,
|
||||
length,
|
||||
length_bad_words);
|
||||
}
|
||||
|
||||
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
@@ -162,17 +159,17 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id) {
|
||||
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
|
||||
pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id);
|
||||
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
|
||||
pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_token_penalty_multi_scores_cpu)
|
||||
|
||||
@@ -24,50 +24,50 @@ void update_inputs_kernel(bool *not_need_stop,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int input_ids_stride) {
|
||||
int64_t stop_sum = 0;
|
||||
for (int bi = 0; bi < bsz; ++bi) {
|
||||
bool stop_flag_now = false;
|
||||
int64_t stop_flag_now_int = 0;
|
||||
stop_flag_now = stop_flags[bi];
|
||||
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
|
||||
auto seq_len_this_time = seq_lens_this_time[bi];
|
||||
auto seq_len_encoder = seq_lens_encoder[bi];
|
||||
auto seq_len_decoder = seq_lens_decoder[bi];
|
||||
seq_lens_decoder[bi] =
|
||||
stop_flag_now ? 0
|
||||
: (seq_len_decoder == 0 ? seq_len_encoder
|
||||
: seq_len_decoder + 1);
|
||||
seq_lens_this_time[bi] = stop_flag_now ? 0 : 1;
|
||||
seq_lens_encoder[bi] = 0;
|
||||
int64_t *input_ids_now = input_ids + bi * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[bi];
|
||||
stop_sum += stop_flag_now_int;
|
||||
}
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
int64_t stop_sum = 0;
|
||||
for (int bi = 0; bi < bsz; ++bi) {
|
||||
bool stop_flag_now = false;
|
||||
int64_t stop_flag_now_int = 0;
|
||||
stop_flag_now = stop_flags[bi];
|
||||
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
|
||||
auto seq_len_this_time = seq_lens_this_time[bi];
|
||||
auto seq_len_encoder = seq_lens_encoder[bi];
|
||||
auto seq_len_decoder = seq_lens_decoder[bi];
|
||||
seq_lens_decoder[bi] =
|
||||
stop_flag_now
|
||||
? 0
|
||||
: (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1);
|
||||
seq_lens_this_time[bi] = stop_flag_now ? 0 : 1;
|
||||
seq_lens_encoder[bi] = 0;
|
||||
int64_t *input_ids_now = input_ids + bi * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[bi];
|
||||
stop_sum += stop_flag_now_int;
|
||||
}
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
|
||||
void UpdateInputs(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step) {
|
||||
const int bsz = input_ids.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
update_inputs_kernel(const_cast<bool *>(not_need_stop.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
stop_flags.data<bool>(),
|
||||
is_block_step.data<bool>(),
|
||||
next_tokens.data<int64_t>(),
|
||||
bsz,
|
||||
input_ids_stride);
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step) {
|
||||
const int bsz = input_ids.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
update_inputs_kernel(const_cast<bool *>(not_need_stop.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
stop_flags.data<bool>(),
|
||||
is_block_step.data<bool>(),
|
||||
next_tokens.data<int64_t>(),
|
||||
bsz,
|
||||
input_ids_stride);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(update_inputs_cpu)
|
||||
|
||||
@@ -45,18 +45,18 @@ std::vector<paddle::Tensor> InvokeAllLLaMALayer(
|
||||
int maxPositions,
|
||||
int maxPosEmbed,
|
||||
int intermediateSize) {
|
||||
auto out = paddle::empty_like(input);
|
||||
return {out};
|
||||
auto out = paddle::empty_like(input);
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> AllLLaMALayerInferShape(
|
||||
std::vector<int64_t> x_shape) {
|
||||
return {x_shape};
|
||||
return {x_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> AllLLaMALayerInferDtype(
|
||||
paddle::DataType x_dtype) {
|
||||
return {x_dtype};
|
||||
return {x_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(xft_llama_all_layer)
|
||||
|
||||
@@ -16,20 +16,20 @@
|
||||
#include "paddle/extension.h"
|
||||
|
||||
std::vector<paddle::Tensor> XftGreedySearch(const paddle::Tensor &probs) {
|
||||
const int bsz = probs.shape()[0];
|
||||
const int vocab_size = probs.shape()[1];
|
||||
auto next_tokens =
|
||||
paddle::empty({bsz, 1}, paddle::DataType::INT64, probs.place());
|
||||
return {next_tokens};
|
||||
const int bsz = probs.shape()[0];
|
||||
const int vocab_size = probs.shape()[1];
|
||||
auto next_tokens =
|
||||
paddle::empty({bsz, 1}, paddle::DataType::INT64, probs.place());
|
||||
return {next_tokens};
|
||||
}
|
||||
std::vector<std::vector<int64_t>> XftGreedySearchInferShape(
|
||||
const std::vector<int64_t> &probs_shape) {
|
||||
int64_t bsz = probs_shape[0];
|
||||
return {{bsz, 1}};
|
||||
int64_t bsz = probs_shape[0];
|
||||
return {{bsz, 1}};
|
||||
}
|
||||
std::vector<paddle::DataType> XftGreedySearchInferDtype(
|
||||
const paddle::DataType &probs_dtype) {
|
||||
return {paddle::DataType::INT64};
|
||||
return {paddle::DataType::INT64};
|
||||
}
|
||||
PD_BUILD_STATIC_OP(xft_greedy_search)
|
||||
.Inputs({"probs"})
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "cub/cub.cuh"
|
||||
|
||||
namespace phi {
|
||||
|
||||
@@ -19,8 +19,8 @@
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include "fused_moe_imp_op.h"
|
||||
#include "fused_moe_helper.h"
|
||||
#include "fused_moe_imp_op.h"
|
||||
// Ignore CUTLASS warnings about type punning
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
@@ -34,8 +34,8 @@
|
||||
namespace phi {
|
||||
|
||||
struct GpuLaunchConfig {
|
||||
dim3 block_per_grid;
|
||||
dim3 thread_per_block;
|
||||
dim3 block_per_grid;
|
||||
dim3 thread_per_block;
|
||||
};
|
||||
|
||||
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
|
||||
@@ -81,7 +81,6 @@ __launch_bounds__(TPB) __global__
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
@@ -275,7 +274,8 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
|
||||
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert]
|
||||
: inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
|
||||
@@ -292,7 +292,9 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
output[idx] =
|
||||
bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]
|
||||
: result_kvp.value;
|
||||
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
}
|
||||
@@ -301,14 +303,15 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moe_softmax_top_k_fused(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
// softmax
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
@@ -321,11 +324,12 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_experts;
|
||||
const int64_t idx = thread_row_offset+threadIdx.x;
|
||||
const int64_t idx = thread_row_offset + threadIdx.x;
|
||||
|
||||
cub::Sum sum;
|
||||
|
||||
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
|
||||
float threadData =
|
||||
(threadIdx.x < num_experts) ? static_cast<float>(input[idx]) : (-FLT_MAX);
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
@@ -377,7 +381,8 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
|
||||
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int cur_idx = k * globalIdx + k_idx;
|
||||
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
output[cur_idx] =
|
||||
bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
indices[cur_idx] = result_kvp.key;
|
||||
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
|
||||
}
|
||||
@@ -386,14 +391,15 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moe_top_k_normed(const T* inputs_after_softmax,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
@@ -422,7 +428,8 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
|
||||
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert]
|
||||
: inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
@@ -439,11 +446,14 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so
|
||||
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
// output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
// output[idx] = bias ? inputs_after_softmax[thread_read_offset +
|
||||
// result_kvp.key]: result_kvp.value;
|
||||
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
|
||||
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
T row_out =
|
||||
bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]
|
||||
: result_kvp.value;
|
||||
row_outputs[k_idx] = row_out;
|
||||
weight_sum += row_out;
|
||||
}
|
||||
@@ -458,16 +468,16 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moe_softmax_top_k_normed_fused(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
// softmax
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
@@ -480,11 +490,12 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_experts;
|
||||
const int64_t idx = thread_row_offset+threadIdx.x;
|
||||
const int64_t idx = thread_row_offset + threadIdx.x;
|
||||
|
||||
cub::Sum sum;
|
||||
|
||||
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
|
||||
float threadData =
|
||||
(threadIdx.x < num_experts) ? static_cast<float>(input[idx]) : (-FLT_MAX);
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
@@ -542,7 +553,8 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
|
||||
if (threadIdx.x == 0) {
|
||||
const int cur_idx = k * globalIdx + k_idx;
|
||||
|
||||
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
T row_out =
|
||||
bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
row_outputs[k_idx] = row_out;
|
||||
weight_sum += row_out;
|
||||
|
||||
@@ -595,29 +607,36 @@ void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
if (topk_only_mode) {
|
||||
static constexpr int TPB = 256;
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, gating_correction_bias, output, indices, source_row, num_experts, k, num_rows);
|
||||
moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
|
||||
gating_correction_bias,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
return;
|
||||
}
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
case N: { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
|
||||
input, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
break; \
|
||||
}
|
||||
int64_t tem_num_experts = num_experts;
|
||||
if(gating_correction_bias != nullptr) tem_num_experts = 0;
|
||||
if (gating_correction_bias != nullptr) tem_num_experts = 0;
|
||||
switch (tem_num_experts) {
|
||||
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
|
||||
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
|
||||
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
|
||||
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
|
||||
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
|
||||
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
|
||||
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
|
||||
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
|
||||
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
|
||||
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
|
||||
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
|
||||
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
|
||||
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
|
||||
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
|
||||
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
|
||||
// LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
|
||||
|
||||
default: {
|
||||
static constexpr int TPB = 256;
|
||||
@@ -646,15 +665,15 @@ void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, softmax, num_experts, num_rows);
|
||||
moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
gating_correction_bias,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
softmax,
|
||||
gating_correction_bias,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,8 +23,8 @@ void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& decode_block_table,
|
||||
const paddle::Tensor& cu_seqlens_qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||
const paddle::optional<paddle::Tensor>& rope_sin,
|
||||
const paddle::optional<paddle::Tensor>& rope_cos,
|
||||
int prefill_num_tokens,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
@@ -42,318 +42,354 @@ void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv,
|
||||
bool enable_cuda_graph,
|
||||
bool use_sqrt_alibi,
|
||||
paddle::Tensor& out) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
const auto& dtype = qkv.dtype();
|
||||
cuinferDataType_t cuinfer_data_type;
|
||||
cudaDataType_t cu_data_type;
|
||||
if (dtype == paddle::DataType::FLOAT16) {
|
||||
cuinfer_data_type = CUINFER_DATA_HALF;
|
||||
cu_data_type = CUDA_R_16F;
|
||||
} else {
|
||||
cuinfer_data_type = CUINFER_DATA_BFLOAT16;
|
||||
cu_data_type = CUDA_R_16BF;
|
||||
}
|
||||
|
||||
const auto& dtype = qkv.dtype();
|
||||
cuinferDataType_t cuinfer_data_type;
|
||||
cudaDataType_t cu_data_type;
|
||||
if (dtype == paddle::DataType::FLOAT16) {
|
||||
cuinfer_data_type = CUINFER_DATA_HALF;
|
||||
cu_data_type = CUDA_R_16F;
|
||||
} else {
|
||||
cuinfer_data_type = CUINFER_DATA_BFLOAT16;
|
||||
cu_data_type = CUDA_R_16BF;
|
||||
}
|
||||
const auto& qkv_dims = qkv.dims();
|
||||
const auto& kv_cache_dims = k_cache.dims();
|
||||
const auto& prefill_block_table_dims = prefill_block_table.dims();
|
||||
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
|
||||
|
||||
const auto& qkv_dims = qkv.dims();
|
||||
const auto& kv_cache_dims = k_cache.dims();
|
||||
const auto& prefill_block_table_dims = prefill_block_table.dims();
|
||||
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
|
||||
int prefill_batch_size = prefill_block_table_dims[0];
|
||||
int num_tokens = qkv_dims[0];
|
||||
int decode_num_tokens = num_tokens - prefill_num_tokens;
|
||||
int num_total_heads = num_heads + 2 * num_kv_heads;
|
||||
int max_num_blocks_per_seq = prefill_block_table_dims[1];
|
||||
int qkv_stride = qkv.strides()[0];
|
||||
int num_blocks = kv_cache_dims[0];
|
||||
|
||||
int prefill_batch_size = prefill_block_table_dims[0];
|
||||
int num_tokens = qkv_dims[0];
|
||||
int decode_num_tokens = num_tokens - prefill_num_tokens;
|
||||
int num_total_heads = num_heads + 2 * num_kv_heads;
|
||||
int max_num_blocks_per_seq = prefill_block_table_dims[1];
|
||||
int qkv_stride = qkv.strides()[0];
|
||||
int num_blocks = kv_cache_dims[0];
|
||||
int kv_block_stride = k_cache.strides()[0];
|
||||
int kv_head_stride = k_cache.strides()[1];
|
||||
int block_table_stride = prefill_block_table.strides()[0];
|
||||
const float* rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
|
||||
const float* rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
|
||||
|
||||
int kv_block_stride = k_cache.strides()[0];
|
||||
int kv_head_stride = k_cache.strides()[1];
|
||||
int block_table_stride = prefill_block_table.strides()[0];
|
||||
const float *rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
|
||||
const float *rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
|
||||
|
||||
cuinferTensorDescriptor_t qkv_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t qkv_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
qkv_desc,
|
||||
cuinfer_data_type,
|
||||
3,
|
||||
std::vector<int>({prefill_num_tokens, num_total_heads, head_dim}).data(),
|
||||
std::vector<int>({num_total_heads * head_dim, head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t qkv_seqlens_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t qkv_seqlens_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
qkv_seqlens_desc,
|
||||
CUINFER_DATA_INT32,
|
||||
1,
|
||||
std::vector<int>({prefill_batch_size + 1}).data(),
|
||||
std::vector<int>({1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t block_table_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t block_table_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
block_table_desc,
|
||||
CUINFER_DATA_INT32,
|
||||
2,
|
||||
std::vector<int>({prefill_batch_size, block_table_stride}).data(),
|
||||
std::vector<int>({block_table_stride, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t o_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t o_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
o_desc,
|
||||
cuinfer_data_type,
|
||||
3,
|
||||
std::vector<int>({prefill_num_tokens, num_heads, head_dim}).data(),
|
||||
std::vector<int>({num_heads * head_dim, head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t k_cache_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t k_cache_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
k_cache_desc,
|
||||
cuinfer_data_type,
|
||||
4,
|
||||
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
|
||||
std::vector<int>({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data()));
|
||||
std::vector<int>({num_kv_heads * block_size * head_dim,
|
||||
block_size * head_dim,
|
||||
head_dim,
|
||||
1})
|
||||
.data()));
|
||||
|
||||
cuinferTensorDescriptor_t v_cache_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t v_cache_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
v_cache_desc,
|
||||
cuinfer_data_type,
|
||||
4,
|
||||
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
|
||||
std::vector<int>({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data()));
|
||||
std::vector<int>({num_kv_heads * block_size * head_dim,
|
||||
block_size * head_dim,
|
||||
head_dim,
|
||||
1})
|
||||
.data()));
|
||||
|
||||
cuinferTensorDescriptor_t cos_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t cos_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cos_desc,
|
||||
CUINFER_DATA_FLOAT,
|
||||
2,
|
||||
std::vector<int>({max_seq_len, head_dim}).data(),
|
||||
std::vector<int>({head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t sin_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t sin_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
sin_desc,
|
||||
CUINFER_DATA_FLOAT,
|
||||
2,
|
||||
std::vector<int>({max_seq_len, head_dim}).data(),
|
||||
std::vector<int>({head_dim, 1}).data()));
|
||||
|
||||
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||
cuinferHandle_t cuinfer_handle =
|
||||
iluvatar::getContextInstance()->getIxInferHandle();
|
||||
|
||||
size_t prefill_workspace_size = 0;
|
||||
CUINFER_CHECK(cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(prefill_num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
cuinfer_data_type,
|
||||
cuinfer_data_type,
|
||||
cuinfer_data_type,
|
||||
&prefill_workspace_size));
|
||||
|
||||
auto* allocator = paddle::GetAllocator(qkv.place());
|
||||
|
||||
phi::Allocator::AllocationPtr prefill_tmp_workspace = allocator->Allocate(prefill_workspace_size);
|
||||
void* prefill_workspace_ptr = prefill_tmp_workspace->ptr();
|
||||
|
||||
CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
|
||||
qkv_desc,
|
||||
qkv.data(),
|
||||
qkv_seqlens_desc,
|
||||
cu_seqlens_qkv.data<int32_t>(),
|
||||
block_table_desc,
|
||||
prefill_block_table.data<int32_t>(),
|
||||
o_desc,
|
||||
out.data(),
|
||||
k_cache_desc,
|
||||
k_cache.data(),
|
||||
v_cache_desc,
|
||||
v_cache.data(),
|
||||
prefill_workspace_ptr,
|
||||
prefill_workspace_size,
|
||||
cos_desc,
|
||||
rope_cos_ptr,
|
||||
sin_desc,
|
||||
rope_sin_ptr,
|
||||
prefill_batch_size,
|
||||
size_t prefill_workspace_size = 0;
|
||||
CUINFER_CHECK(
|
||||
cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(prefill_num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
scale,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope));
|
||||
v_rope,
|
||||
cuinfer_data_type,
|
||||
cuinfer_data_type,
|
||||
cuinfer_data_type,
|
||||
&prefill_workspace_size));
|
||||
|
||||
size_t decode_workspace_size = 0;
|
||||
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(decode_num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
&decode_workspace_size));
|
||||
auto* allocator = paddle::GetAllocator(qkv.place());
|
||||
|
||||
phi::Allocator::AllocationPtr decode_tmp_workspace = allocator->Allocate(decode_workspace_size);
|
||||
void* decode_workspace_ptr = decode_tmp_workspace->ptr();
|
||||
phi::Allocator::AllocationPtr prefill_tmp_workspace =
|
||||
allocator->Allocate(prefill_workspace_size);
|
||||
void* prefill_workspace_ptr = prefill_tmp_workspace->ptr();
|
||||
|
||||
void* decode_qkv_ptr = (void*)(qkv.data<data_t>() + prefill_num_tokens * qkv_stride);
|
||||
void* decode_out_ptr = (void*)(out.data<data_t>() + prefill_num_tokens * out.strides()[0]);
|
||||
CUINFER_CHECK(
|
||||
cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
|
||||
qkv_desc,
|
||||
qkv.data(),
|
||||
qkv_seqlens_desc,
|
||||
cu_seqlens_qkv.data<int32_t>(),
|
||||
block_table_desc,
|
||||
prefill_block_table.data<int32_t>(),
|
||||
o_desc,
|
||||
out.data(),
|
||||
k_cache_desc,
|
||||
k_cache.data(),
|
||||
v_cache_desc,
|
||||
v_cache.data(),
|
||||
prefill_workspace_ptr,
|
||||
prefill_workspace_size,
|
||||
cos_desc,
|
||||
rope_cos_ptr,
|
||||
sin_desc,
|
||||
rope_sin_ptr,
|
||||
prefill_batch_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
scale,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope));
|
||||
|
||||
PageAttentionWithKVCacheArguments args{
|
||||
static_cast<float>(scale), 1.0, 1.0, static_cast<float>(softcap), window_left, window_right,
|
||||
causal, use_sqrt_alibi, enable_cuda_graph, false, nullptr, decode_qkv_ptr, decode_qkv_ptr,
|
||||
decode_workspace_ptr, true, rope_sin_ptr, rope_cos_ptr};
|
||||
size_t decode_workspace_size = 0;
|
||||
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(decode_num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
&decode_workspace_size));
|
||||
|
||||
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
|
||||
decode_out_ptr,
|
||||
cu_data_type,
|
||||
phi::Allocator::AllocationPtr decode_tmp_workspace =
|
||||
allocator->Allocate(decode_workspace_size);
|
||||
void* decode_workspace_ptr = decode_tmp_workspace->ptr();
|
||||
|
||||
void* decode_qkv_ptr =
|
||||
(void*)(qkv.data<data_t>() + prefill_num_tokens * qkv_stride);
|
||||
void* decode_out_ptr =
|
||||
(void*)(out.data<data_t>() + prefill_num_tokens * out.strides()[0]);
|
||||
|
||||
PageAttentionWithKVCacheArguments args{static_cast<float>(scale),
|
||||
1.0,
|
||||
1.0,
|
||||
static_cast<float>(softcap),
|
||||
window_left,
|
||||
window_right,
|
||||
causal,
|
||||
use_sqrt_alibi,
|
||||
enable_cuda_graph,
|
||||
false,
|
||||
nullptr,
|
||||
decode_qkv_ptr,
|
||||
cu_data_type,
|
||||
decode_num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
qkv_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
k_cache.data(),
|
||||
cu_data_type,
|
||||
v_cache.data(),
|
||||
cu_data_type,
|
||||
block_size,
|
||||
max_num_blocks_per_seq,
|
||||
max_seq_len,
|
||||
decode_block_table.data<int32_t>(),
|
||||
seq_lens.data<int32_t>(),
|
||||
args));
|
||||
decode_qkv_ptr,
|
||||
decode_workspace_ptr,
|
||||
true,
|
||||
rope_sin_ptr,
|
||||
rope_cos_ptr};
|
||||
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
|
||||
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
|
||||
decode_out_ptr,
|
||||
cu_data_type,
|
||||
decode_qkv_ptr,
|
||||
cu_data_type,
|
||||
decode_num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
qkv_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
k_cache.data(),
|
||||
cu_data_type,
|
||||
v_cache.data(),
|
||||
cu_data_type,
|
||||
block_size,
|
||||
max_num_blocks_per_seq,
|
||||
max_seq_len,
|
||||
decode_block_table.data<int32_t>(),
|
||||
seq_lens.data<int32_t>(),
|
||||
args));
|
||||
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MixedFusedPagedAttn(const paddle::Tensor& qkv,
|
||||
paddle::Tensor& k_cache,
|
||||
paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& prefill_block_table,
|
||||
const paddle::Tensor& decode_block_table,
|
||||
const paddle::Tensor& cu_seqlens_qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||
int prefill_num_tokens,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope,
|
||||
int window_left,
|
||||
int window_right,
|
||||
float softcap,
|
||||
bool enable_cuda_graph,
|
||||
bool use_sqrt_alibi) {
|
||||
const auto dtype = qkv.dtype();
|
||||
auto out = paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
|
||||
std::vector<paddle::Tensor> MixedFusedPagedAttn(
|
||||
const paddle::Tensor& qkv,
|
||||
paddle::Tensor& k_cache,
|
||||
paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& prefill_block_table,
|
||||
const paddle::Tensor& decode_block_table,
|
||||
const paddle::Tensor& cu_seqlens_qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::optional<paddle::Tensor>& rope_sin,
|
||||
const paddle::optional<paddle::Tensor>& rope_cos,
|
||||
int prefill_num_tokens,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope,
|
||||
int window_left,
|
||||
int window_right,
|
||||
float softcap,
|
||||
bool enable_cuda_graph,
|
||||
bool use_sqrt_alibi) {
|
||||
const auto dtype = qkv.dtype();
|
||||
auto out =
|
||||
paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
|
||||
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MixedFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
prefill_block_table,
|
||||
decode_block_table,
|
||||
cu_seqlens_qkv,
|
||||
seq_lens,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
prefill_num_tokens,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
window_left,
|
||||
window_right,
|
||||
softcap,
|
||||
enable_cuda_graph,
|
||||
use_sqrt_alibi,
|
||||
out);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MixedFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
prefill_block_table,
|
||||
decode_block_table,
|
||||
cu_seqlens_qkv,
|
||||
seq_lens,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
prefill_num_tokens,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
window_left,
|
||||
window_right,
|
||||
softcap,
|
||||
enable_cuda_graph,
|
||||
use_sqrt_alibi,
|
||||
out);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for mixed paged attn");
|
||||
}
|
||||
return {out};
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MixedFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
prefill_block_table,
|
||||
decode_block_table,
|
||||
cu_seqlens_qkv,
|
||||
seq_lens,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
prefill_num_tokens,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
window_left,
|
||||
window_right,
|
||||
softcap,
|
||||
enable_cuda_graph,
|
||||
use_sqrt_alibi,
|
||||
out);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MixedFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
prefill_block_table,
|
||||
decode_block_table,
|
||||
cu_seqlens_qkv,
|
||||
seq_lens,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
prefill_num_tokens,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
window_left,
|
||||
window_right,
|
||||
softcap,
|
||||
enable_cuda_graph,
|
||||
use_sqrt_alibi,
|
||||
out);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for mixed paged attn");
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MixedFusedPagedAttnInferShape(const std::vector<int64_t>& qkv_shape,
|
||||
int num_heads,
|
||||
int head_dim) {
|
||||
return {{qkv_shape[0], num_heads * head_dim}};
|
||||
std::vector<std::vector<int64_t>> MixedFusedPagedAttnInferShape(
|
||||
const std::vector<int64_t>& qkv_shape, int num_heads, int head_dim) {
|
||||
return {{qkv_shape[0], num_heads * head_dim}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MixedFusedPagedAttnInferDtype(const paddle::DataType& qkv_dtype) {
|
||||
return {qkv_dtype};
|
||||
std::vector<paddle::DataType> MixedFusedPagedAttnInferDtype(
|
||||
const paddle::DataType& qkv_dtype) {
|
||||
return {qkv_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(mixed_fused_paged_attn)
|
||||
.Inputs({"qkv", "k_cache", "v_cache", "prefill_block_table", "decode_block_table",
|
||||
"cu_seqlens_qkv", "seq_lens", paddle::Optional("rope_sin"), paddle::Optional("rope_cos")})
|
||||
.Inputs({"qkv",
|
||||
"k_cache",
|
||||
"v_cache",
|
||||
"prefill_block_table",
|
||||
"decode_block_table",
|
||||
"cu_seqlens_qkv",
|
||||
"seq_lens",
|
||||
paddle::Optional("rope_sin"),
|
||||
paddle::Optional("rope_cos")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"prefill_num_tokens:int",
|
||||
"num_heads: int",
|
||||
@@ -362,14 +398,14 @@ PD_BUILD_STATIC_OP(mixed_fused_paged_attn)
|
||||
"block_size:int",
|
||||
"max_seq_len:int",
|
||||
"scale:float",
|
||||
"causal:bool",
|
||||
"q_rope:bool",
|
||||
"causal:bool",
|
||||
"q_rope:bool",
|
||||
"k_rope:bool",
|
||||
"v_rope:bool",
|
||||
"window_left:int",
|
||||
"window_right:int",
|
||||
"softcap:float",
|
||||
"enable_cuda_graph:bool",
|
||||
"enable_cuda_graph:bool",
|
||||
"use_sqrt_alibi:bool"})
|
||||
.SetKernelFn(PD_KERNEL(MixedFusedPagedAttn))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MixedFusedPagedAttnInferShape))
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
// Ignore CUTLASS warnings about type punning
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
@@ -29,10 +28,10 @@ __global__ void compute_total_rows_before_expert_kernel(
|
||||
const int64_t sorted_experts_len,
|
||||
const int64_t num_experts,
|
||||
int64_t* total_rows_before_expert) {
|
||||
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (expert >= num_experts) return;
|
||||
total_rows_before_expert[expert] =
|
||||
phi::find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
|
||||
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (expert >= num_experts) return;
|
||||
total_rows_before_expert[expert] = phi::find_total_elts_leq_target(
|
||||
sorted_experts, sorted_experts_len, expert);
|
||||
}
|
||||
|
||||
void compute_total_rows_before_expert(int* sorted_indices,
|
||||
@@ -40,36 +39,38 @@ void compute_total_rows_before_expert(int* sorted_indices,
|
||||
const int64_t num_experts,
|
||||
int64_t* total_rows_before_expert,
|
||||
cudaStream_t stream) {
|
||||
const int threads = std::min(int64_t(1024), num_experts);
|
||||
const int blocks = (num_experts + threads - 1) / threads;
|
||||
const int threads = std::min(int64_t(1024), num_experts);
|
||||
const int blocks = (num_experts + threads - 1) / threads;
|
||||
|
||||
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
|
||||
sorted_indices, total_indices, num_experts, total_rows_before_expert);
|
||||
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
|
||||
sorted_indices, total_indices, num_experts, total_rows_before_expert);
|
||||
}
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
const paddle::optional<paddle::Tensor>& gating_correction_bias,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const std::string &moe_quant_type,
|
||||
const bool topk_only_mode,
|
||||
const int num_rows,
|
||||
const int hidden_size,
|
||||
const int expert_num,
|
||||
paddle::Tensor* permute_input,
|
||||
paddle::Tensor* tokens_expert_prefix_sum,
|
||||
paddle::Tensor* permute_indices_per_token,
|
||||
paddle::Tensor* top_k_weight,
|
||||
paddle::Tensor* top_k_indices) {
|
||||
void MoeDispatchKernel(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
const paddle::optional<paddle::Tensor>& gating_correction_bias,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const std::string& moe_quant_type,
|
||||
const bool topk_only_mode,
|
||||
const int num_rows,
|
||||
const int hidden_size,
|
||||
const int expert_num,
|
||||
paddle::Tensor* permute_input,
|
||||
paddle::Tensor* tokens_expert_prefix_sum,
|
||||
paddle::Tensor* permute_indices_per_token,
|
||||
paddle::Tensor* top_k_weight,
|
||||
paddle::Tensor* top_k_indices) {
|
||||
using namespace phi;
|
||||
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto place = input.place();
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input.place()));
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(input.place()));
|
||||
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||
if (group_moe) {
|
||||
// Check if expert_num is divisible by moe_topk, else throw an error
|
||||
@@ -131,19 +132,21 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float>(gating_output.data<float>(),
|
||||
gating_correction_bias ? gating_correction_bias.get().data<float>() : nullptr,
|
||||
top_k_weight->data<float>(),
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
expert_num,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
stream,
|
||||
topk_only_mode);
|
||||
topk_gating_softmax_kernelLauncher<float>(
|
||||
gating_output.data<float>(),
|
||||
gating_correction_bias ? gating_correction_bias.get().data<float>()
|
||||
: nullptr,
|
||||
top_k_weight->data<float>(),
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
expert_num,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
stream,
|
||||
topk_only_mode);
|
||||
|
||||
sorter_.run(reinterpret_cast<void*>(sorter_ws_ptr),
|
||||
sorter_ws_size_bytes,
|
||||
@@ -155,7 +158,6 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
false,
|
||||
stream);
|
||||
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(),
|
||||
permute_input->data<data_t>(),
|
||||
@@ -167,16 +169,13 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
moe_topk,
|
||||
stream);
|
||||
|
||||
|
||||
compute_total_rows_before_expert(
|
||||
permuted_experts_,
|
||||
moe_topk * num_rows,
|
||||
expert_num,
|
||||
tokens_expert_prefix_sum->data<int64_t>(),
|
||||
stream);
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
moe_topk * num_rows,
|
||||
expert_num,
|
||||
tokens_expert_prefix_sum->data<int64_t>(),
|
||||
stream);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
@@ -184,7 +183,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::optional<paddle::Tensor>& w4a8_in_scale,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const std::string &moe_quant_type,
|
||||
const std::string& moe_quant_type,
|
||||
const bool topk_only_mode) {
|
||||
const auto input_type = input.dtype();
|
||||
auto place = input.place();
|
||||
@@ -214,7 +213,6 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
auto permute_indices_per_token =
|
||||
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
|
||||
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
|
||||
@@ -261,7 +259,6 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
top_k_indices};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& gating_output_shape,
|
||||
@@ -299,17 +296,21 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||
paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(moe_expert_dispatch)
|
||||
.Inputs({"input", "gating_output", paddle::Optional("gating_correction_bias"),
|
||||
paddle::Optional("w4a8_in_scale")})
|
||||
.Inputs({"input",
|
||||
"gating_output",
|
||||
paddle::Optional("gating_correction_bias"),
|
||||
paddle::Optional("w4a8_in_scale")})
|
||||
.Outputs({"permute_input",
|
||||
"tokens_expert_prefix_sum",
|
||||
"permute_indices_per_token",
|
||||
"top_k_weight",
|
||||
"top_k_indices",
|
||||
"expert_idx_per_token"})
|
||||
.Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"})
|
||||
.Attrs({"moe_topk:int",
|
||||
"group_moe:bool",
|
||||
"moe_quant_type:std::string",
|
||||
"topk_only_mode:bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
||||
|
||||
@@ -16,9 +16,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "fused_moe_helper.h"
|
||||
#include "fused_moe_op.h"
|
||||
#include "helper.h"
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
||||
@@ -32,27 +32,28 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
||||
const int hidden_size,
|
||||
const int topk,
|
||||
paddle::Tensor* output) {
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(ffn_out.place()));
|
||||
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(ffn_out.place()));
|
||||
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
ffn_out.data<data_t>(),
|
||||
output->data<data_t>(),
|
||||
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
|
||||
top_k_weight.data<float>(),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
top_k_indices.data<int>(),
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
ffn_out.data<data_t>(),
|
||||
output->data<data_t>(),
|
||||
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
|
||||
top_k_weight.data<float>(),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
top_k_indices.data<int>(),
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
}
|
||||
|
||||
paddle::Tensor MoeExpertReduceFunc(
|
||||
@@ -63,48 +64,46 @@ paddle::Tensor MoeExpertReduceFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor) {
|
||||
const auto input_type = ffn_out.dtype();
|
||||
auto place = ffn_out.place();
|
||||
const auto input_type = ffn_out.dtype();
|
||||
auto place = ffn_out.place();
|
||||
|
||||
const int topk = top_k_indices.dims()[1];
|
||||
const int num_rows = ffn_out.dims()[0] / topk;
|
||||
const int hidden_size = ffn_out.dims()[1];
|
||||
const int topk = top_k_indices.dims()[1];
|
||||
const int num_rows = ffn_out.dims()[0] / topk;
|
||||
const int hidden_size = ffn_out.dims()[1];
|
||||
|
||||
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||
ffn_out,
|
||||
top_k_weight,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
&output);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||
ffn_out,
|
||||
top_k_weight,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
&output);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeDispatchKernel");
|
||||
}
|
||||
return output;
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(ffn_out,
|
||||
top_k_weight,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
&output);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(ffn_out,
|
||||
top_k_weight,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
&output);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeDispatchKernel");
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
@@ -115,13 +114,13 @@ std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor) {
|
||||
return {MoeExpertReduceFunc(ffn_out,
|
||||
top_k_weight,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor)};
|
||||
return {MoeExpertReduceFunc(ffn_out,
|
||||
top_k_weight,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
|
||||
@@ -130,7 +129,7 @@ std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
|
||||
const std::vector<int64_t>& permute_indices_per_token_shape,
|
||||
const std::vector<int64_t>& top_k_indices_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape) {
|
||||
return {ffn_out_shape};
|
||||
return {ffn_out_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
@@ -139,7 +138,7 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
const paddle::DataType& permute_indices_per_token_dtype,
|
||||
const paddle::DataType& top_k_indices_dtype,
|
||||
const paddle::optional<paddle::DataType>& down_proj_bias_dtype) {
|
||||
return {ffn_out_dtype};
|
||||
return {ffn_out_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(moe_expert_reduce)
|
||||
|
||||
@@ -15,18 +15,17 @@
|
||||
#include "helper.h"
|
||||
#include "iluvatar_context.h"
|
||||
|
||||
|
||||
template <paddle::DataType T>
|
||||
void PagedAttnKernel(const paddle::Tensor& q,
|
||||
const paddle::Tensor& k_cache,
|
||||
const paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::optional<paddle::Tensor> &alibi_slopes,
|
||||
const paddle::optional<paddle::Tensor> &k,
|
||||
const paddle::optional<paddle::Tensor> &v,
|
||||
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||
const paddle::optional<paddle::Tensor>& alibi_slopes,
|
||||
const paddle::optional<paddle::Tensor>& k,
|
||||
const paddle::optional<paddle::Tensor>& v,
|
||||
const paddle::optional<paddle::Tensor>& rope_sin,
|
||||
const paddle::optional<paddle::Tensor>& rope_cos,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
@@ -41,298 +40,326 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
||||
bool use_sqrt_alibi,
|
||||
bool merged_qkv,
|
||||
paddle::Tensor& out) {
|
||||
if (alibi_slopes) {
|
||||
PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(),
|
||||
paddle::DataType::FLOAT32,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects alibi_slopes float tensor"));
|
||||
PADDLE_ENFORCE_EQ(alibi_slopes.get().is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects alibi_slopes is contiguous"));
|
||||
}
|
||||
if (alibi_slopes) {
|
||||
PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(),
|
||||
paddle::DataType::FLOAT32,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects alibi_slopes float tensor"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
alibi_slopes.get().is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects alibi_slopes is contiguous"));
|
||||
}
|
||||
|
||||
// check dtype and contiguous
|
||||
const auto& dtype = q.dtype();
|
||||
cudaDataType_t data_type;
|
||||
if (dtype == paddle::DataType::FLOAT16) {
|
||||
data_type = CUDA_R_16F;
|
||||
} else if (dtype == paddle::DataType::BFLOAT16) {
|
||||
data_type = CUDA_R_16BF;
|
||||
} else {
|
||||
common::errors::InvalidArgument("paged_attention support half and bfloat16 now");
|
||||
}
|
||||
// check dtype and contiguous
|
||||
const auto& dtype = q.dtype();
|
||||
cudaDataType_t data_type;
|
||||
if (dtype == paddle::DataType::FLOAT16) {
|
||||
data_type = CUDA_R_16F;
|
||||
} else if (dtype == paddle::DataType::BFLOAT16) {
|
||||
data_type = CUDA_R_16BF;
|
||||
} else {
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention support half and bfloat16 now");
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(k_cache.dtype(),
|
||||
dtype,
|
||||
common::errors::InvalidArgument(
|
||||
"k_cache dtype must be the same as query dtype"));
|
||||
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects k_cache is contiguous"));
|
||||
PADDLE_ENFORCE_EQ(block_table.dtype(),
|
||||
paddle::DataType::INT32,
|
||||
common::errors::InvalidArgument(
|
||||
"block_table dtype must be int32"));
|
||||
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects block_table is contiguous"));
|
||||
PADDLE_ENFORCE_EQ(seq_lens.dtype(),
|
||||
paddle::DataType::INT32,
|
||||
common::errors::InvalidArgument(
|
||||
"seq_lens dtype must be int32"));
|
||||
PADDLE_ENFORCE_EQ(seq_lens.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects seq_lens is contiguous"));
|
||||
// check dim and shape
|
||||
// k_cache: [num_blocks, kv_num_heads, block_size, head_dim]
|
||||
// v_cache: [num_blocks, kv_num_heads, block_size, head_dim]
|
||||
// block_table: [num_seqs, max_num_blocks_per_seq]
|
||||
// seq_lens: [num_seqs]
|
||||
// q and out:
|
||||
// if merged_qkv = false:
|
||||
// q:[num_seqs, hidden_size]
|
||||
// out:[num_seqs, hidden_size]
|
||||
// if merged_qkv = true:
|
||||
// q: [num_seqs, (num_heads+2*num_kv_heads)*head_dim]
|
||||
// out: [num_seqs, hidden_size]
|
||||
PADDLE_ENFORCE_EQ(k_cache.dtype(),
|
||||
dtype,
|
||||
common::errors::InvalidArgument(
|
||||
"k_cache dtype must be the same as query dtype"));
|
||||
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects k_cache is contiguous"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
block_table.dtype(),
|
||||
paddle::DataType::INT32,
|
||||
common::errors::InvalidArgument("block_table dtype must be int32"));
|
||||
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects block_table is contiguous"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
seq_lens.dtype(),
|
||||
paddle::DataType::INT32,
|
||||
common::errors::InvalidArgument("seq_lens dtype must be int32"));
|
||||
PADDLE_ENFORCE_EQ(seq_lens.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects seq_lens is contiguous"));
|
||||
// check dim and shape
|
||||
// k_cache: [num_blocks, kv_num_heads, block_size, head_dim]
|
||||
// v_cache: [num_blocks, kv_num_heads, block_size, head_dim]
|
||||
// block_table: [num_seqs, max_num_blocks_per_seq]
|
||||
// seq_lens: [num_seqs]
|
||||
// q and out:
|
||||
// if merged_qkv = false:
|
||||
// q:[num_seqs, hidden_size]
|
||||
// out:[num_seqs, hidden_size]
|
||||
// if merged_qkv = true:
|
||||
// q: [num_seqs, (num_heads+2*num_kv_heads)*head_dim]
|
||||
// out: [num_seqs, hidden_size]
|
||||
|
||||
const auto& q_dims = q.dims();
|
||||
PADDLE_ENFORCE_EQ(q_dims.size(),
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive query dims is "
|
||||
"[num_seqs, (num_heads+2*num_kv_heads)*head_dim]"));
|
||||
PADDLE_ENFORCE_EQ(out.dims().size(),
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive out dims is "
|
||||
"[num_seqs, hidden_size]"));
|
||||
const auto& q_dims = q.dims();
|
||||
PADDLE_ENFORCE_EQ(q_dims.size(),
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive query dims is "
|
||||
"[num_seqs, (num_heads+2*num_kv_heads)*head_dim]"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
out.dims().size(),
|
||||
2,
|
||||
common::errors::InvalidArgument("paged_attn receive out dims is "
|
||||
"[num_seqs, hidden_size]"));
|
||||
|
||||
const auto& kv_cache_dims = k_cache.dims();
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
|
||||
4,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive kv cache dims is "
|
||||
"[num_blocks, kv_num_heads, block_size, head_dim]"));
|
||||
const auto& kv_cache_dims = k_cache.dims();
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
|
||||
4,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive kv cache dims is "
|
||||
"[num_blocks, kv_num_heads, block_size, head_dim]"));
|
||||
|
||||
const auto& block_table_dims = block_table.dims();
|
||||
PADDLE_ENFORCE_EQ(block_table_dims.size(),
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive block_table dims is "
|
||||
"[num_seqs, max_num_blocks_per_seq]"));
|
||||
const auto& block_table_dims = block_table.dims();
|
||||
PADDLE_ENFORCE_EQ(
|
||||
block_table_dims.size(),
|
||||
2,
|
||||
common::errors::InvalidArgument("paged_attn receive block_table dims is "
|
||||
"[num_seqs, max_num_blocks_per_seq]"));
|
||||
|
||||
const auto& seq_lens_dims = seq_lens.dims();
|
||||
PADDLE_ENFORCE_EQ(seq_lens_dims.size(),
|
||||
1,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive seq_lens dims is [num_seqs]"));
|
||||
const auto& seq_lens_dims = seq_lens.dims();
|
||||
PADDLE_ENFORCE_EQ(seq_lens_dims.size(),
|
||||
1,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive seq_lens dims is [num_seqs]"));
|
||||
|
||||
int num_seqs = q_dims[0];
|
||||
int max_num_blocks_per_seq = block_table_dims[1];
|
||||
int q_stride = q.strides()[0];
|
||||
int num_blocks = kv_cache_dims[0];
|
||||
int num_seqs = q_dims[0];
|
||||
int max_num_blocks_per_seq = block_table_dims[1];
|
||||
int q_stride = q.strides()[0];
|
||||
int num_blocks = kv_cache_dims[0];
|
||||
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
|
||||
num_kv_heads,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[1] must be equal to num_kv_head"));
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
|
||||
block_size,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[2] must be equal to block_size"));
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
|
||||
head_dim,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[3] must be equal to head_dim"));
|
||||
PADDLE_ENFORCE_EQ(block_table_dims[0],
|
||||
num_seqs,
|
||||
common::errors::InvalidArgument(
|
||||
"block_table_dims[0] must be equal to num_seqs"));
|
||||
PADDLE_ENFORCE_EQ(seq_lens_dims[0],
|
||||
num_seqs,
|
||||
common::errors::InvalidArgument(
|
||||
"seq_lens_dims[0] must be equal to num_seqs"));
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
|
||||
num_kv_heads,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[1] must be equal to num_kv_head"));
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
|
||||
block_size,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[2] must be equal to block_size"));
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
|
||||
head_dim,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[3] must be equal to head_dim"));
|
||||
PADDLE_ENFORCE_EQ(block_table_dims[0],
|
||||
num_seqs,
|
||||
common::errors::InvalidArgument(
|
||||
"block_table_dims[0] must be equal to num_seqs"));
|
||||
PADDLE_ENFORCE_EQ(seq_lens_dims[0],
|
||||
num_seqs,
|
||||
common::errors::InvalidArgument(
|
||||
"seq_lens_dims[0] must be equal to num_seqs"));
|
||||
|
||||
int kv_block_stride = k_cache.strides()[0];
|
||||
int kv_head_stride = k_cache.strides()[1];
|
||||
const float *alibi_slopes_ptr = alibi_slopes ? alibi_slopes.get().data<float>() : nullptr;
|
||||
const void *key_ptr = k ? k.get().data() : nullptr;
|
||||
const void *value_ptr = v ? v.get().data() : nullptr;
|
||||
const float *rope_sin_ptr = merged_qkv ? rope_sin.get().data<float>() : nullptr;
|
||||
const float *rope_cos_ptr = merged_qkv ? rope_cos.get().data<float>() : nullptr;
|
||||
int kv_block_stride = k_cache.strides()[0];
|
||||
int kv_head_stride = k_cache.strides()[1];
|
||||
const float* alibi_slopes_ptr =
|
||||
alibi_slopes ? alibi_slopes.get().data<float>() : nullptr;
|
||||
const void* key_ptr = k ? k.get().data() : nullptr;
|
||||
const void* value_ptr = v ? v.get().data() : nullptr;
|
||||
const float* rope_sin_ptr =
|
||||
merged_qkv ? rope_sin.get().data<float>() : nullptr;
|
||||
const float* rope_cos_ptr =
|
||||
merged_qkv ? rope_cos.get().data<float>() : nullptr;
|
||||
|
||||
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||
cuinferHandle_t cuinfer_handle =
|
||||
iluvatar::getContextInstance()->getIxInferHandle();
|
||||
|
||||
size_t workspace_size = 0;
|
||||
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(num_seqs,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
block_size,
|
||||
max_context_len,
|
||||
&workspace_size));
|
||||
auto* allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace = allocator->Allocate(workspace_size);
|
||||
void* workspace_ptr = tmp_workspace->ptr();
|
||||
size_t workspace_size = 0;
|
||||
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(num_seqs,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
block_size,
|
||||
max_context_len,
|
||||
&workspace_size));
|
||||
auto* allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace =
|
||||
allocator->Allocate(workspace_size);
|
||||
void* workspace_ptr = tmp_workspace->ptr();
|
||||
|
||||
PageAttentionWithKVCacheArguments args{
|
||||
static_cast<float>(scale), 1.0, 1.0, static_cast<float>(softcap), window_left, window_right,
|
||||
causal, use_sqrt_alibi, enable_cuda_graph, false, alibi_slopes_ptr, key_ptr, value_ptr,
|
||||
workspace_ptr, merged_qkv, rope_sin_ptr, rope_cos_ptr};
|
||||
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
|
||||
out.data(),
|
||||
data_type,
|
||||
q.data(),
|
||||
data_type,
|
||||
num_seqs,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
q_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
k_cache.data(),
|
||||
data_type,
|
||||
v_cache.data(),
|
||||
data_type,
|
||||
block_size,
|
||||
max_num_blocks_per_seq,
|
||||
max_context_len,
|
||||
block_table.data<int32_t>(),
|
||||
seq_lens.data<int32_t>(),
|
||||
args));
|
||||
PageAttentionWithKVCacheArguments args{static_cast<float>(scale),
|
||||
1.0,
|
||||
1.0,
|
||||
static_cast<float>(softcap),
|
||||
window_left,
|
||||
window_right,
|
||||
causal,
|
||||
use_sqrt_alibi,
|
||||
enable_cuda_graph,
|
||||
false,
|
||||
alibi_slopes_ptr,
|
||||
key_ptr,
|
||||
value_ptr,
|
||||
workspace_ptr,
|
||||
merged_qkv,
|
||||
rope_sin_ptr,
|
||||
rope_cos_ptr};
|
||||
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
|
||||
out.data(),
|
||||
data_type,
|
||||
q.data(),
|
||||
data_type,
|
||||
num_seqs,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
q_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
k_cache.data(),
|
||||
data_type,
|
||||
v_cache.data(),
|
||||
data_type,
|
||||
block_size,
|
||||
max_num_blocks_per_seq,
|
||||
max_context_len,
|
||||
block_table.data<int32_t>(),
|
||||
seq_lens.data<int32_t>(),
|
||||
args));
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
||||
const paddle::Tensor& k_cache,
|
||||
const paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::optional<paddle::Tensor> &alibi_slopes,
|
||||
const paddle::optional<paddle::Tensor> &k,
|
||||
const paddle::optional<paddle::Tensor> &v,
|
||||
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
bool causal,
|
||||
int window_left,
|
||||
int window_right,
|
||||
float softcap,
|
||||
bool enable_cuda_graph,
|
||||
bool use_sqrt_alibi,
|
||||
bool merged_qkv) {
|
||||
std::vector<paddle::Tensor> PagedAttn(
|
||||
const paddle::Tensor& q,
|
||||
const paddle::Tensor& k_cache,
|
||||
const paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::optional<paddle::Tensor>& alibi_slopes,
|
||||
const paddle::optional<paddle::Tensor>& k,
|
||||
const paddle::optional<paddle::Tensor>& v,
|
||||
const paddle::optional<paddle::Tensor>& rope_sin,
|
||||
const paddle::optional<paddle::Tensor>& rope_cos,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
bool causal,
|
||||
int window_left,
|
||||
int window_right,
|
||||
float softcap,
|
||||
bool enable_cuda_graph,
|
||||
bool use_sqrt_alibi,
|
||||
bool merged_qkv) {
|
||||
const auto dtype = q.dtype();
|
||||
auto out =
|
||||
paddle::empty({q.shape()[0], num_heads * head_dim}, dtype, q.place());
|
||||
|
||||
const auto dtype = q.dtype();
|
||||
auto out = paddle::empty({q.shape()[0], num_heads * head_dim}, dtype, q.place());
|
||||
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
PagedAttnKernel<paddle::DataType::BFLOAT16>(q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
seq_lens,
|
||||
alibi_slopes,
|
||||
k,
|
||||
v,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_size,
|
||||
max_context_len,
|
||||
causal,
|
||||
window_left,
|
||||
window_right,
|
||||
softcap,
|
||||
enable_cuda_graph,
|
||||
use_sqrt_alibi,
|
||||
merged_qkv,
|
||||
out);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
PagedAttnKernel<paddle::DataType::FLOAT16>(q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
seq_lens,
|
||||
alibi_slopes,
|
||||
k,
|
||||
v,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_size,
|
||||
max_context_len,
|
||||
causal,
|
||||
window_left,
|
||||
window_right,
|
||||
softcap,
|
||||
enable_cuda_graph,
|
||||
use_sqrt_alibi,
|
||||
merged_qkv,
|
||||
out);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for Paged attn");
|
||||
}
|
||||
return {out};
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
PagedAttnKernel<paddle::DataType::BFLOAT16>(q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
seq_lens,
|
||||
alibi_slopes,
|
||||
k,
|
||||
v,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_size,
|
||||
max_context_len,
|
||||
causal,
|
||||
window_left,
|
||||
window_right,
|
||||
softcap,
|
||||
enable_cuda_graph,
|
||||
use_sqrt_alibi,
|
||||
merged_qkv,
|
||||
out);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
PagedAttnKernel<paddle::DataType::FLOAT16>(q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
seq_lens,
|
||||
alibi_slopes,
|
||||
k,
|
||||
v,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_size,
|
||||
max_context_len,
|
||||
causal,
|
||||
window_left,
|
||||
window_right,
|
||||
softcap,
|
||||
enable_cuda_graph,
|
||||
use_sqrt_alibi,
|
||||
merged_qkv,
|
||||
out);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for Paged attn");
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>& q_shape,
|
||||
const std::vector<int64_t>& k_cache_shape,
|
||||
const std::vector<int64_t>& v_cache_shape,
|
||||
const std::vector<int64_t>& block_table_shape,
|
||||
const std::vector<int64_t>& seq_lens_shape,
|
||||
const std::vector<int64_t>& alibi_slopes_shape,
|
||||
const std::vector<int64_t>& k_shape,
|
||||
const std::vector<int64_t>& v_shape,
|
||||
const std::vector<int64_t>& rope_sin_shape,
|
||||
const std::vector<int64_t>& rope_cos_shape,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
bool causal,
|
||||
int window_left,
|
||||
int window_right,
|
||||
float softcap,
|
||||
bool enable_cuda_graph,
|
||||
bool use_sqrt_alibi,
|
||||
bool merged_qkv) {
|
||||
if (merged_qkv) {
|
||||
return {{q_shape[0], num_heads * head_dim}};
|
||||
} else {
|
||||
return {q_shape};
|
||||
}
|
||||
std::vector<std::vector<int64_t>> PagedAttnInferShape(
|
||||
const std::vector<int64_t>& q_shape,
|
||||
const std::vector<int64_t>& k_cache_shape,
|
||||
const std::vector<int64_t>& v_cache_shape,
|
||||
const std::vector<int64_t>& block_table_shape,
|
||||
const std::vector<int64_t>& seq_lens_shape,
|
||||
const std::vector<int64_t>& alibi_slopes_shape,
|
||||
const std::vector<int64_t>& k_shape,
|
||||
const std::vector<int64_t>& v_shape,
|
||||
const std::vector<int64_t>& rope_sin_shape,
|
||||
const std::vector<int64_t>& rope_cos_shape,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
bool causal,
|
||||
int window_left,
|
||||
int window_right,
|
||||
float softcap,
|
||||
bool enable_cuda_graph,
|
||||
bool use_sqrt_alibi,
|
||||
bool merged_qkv) {
|
||||
if (merged_qkv) {
|
||||
return {{q_shape[0], num_heads * head_dim}};
|
||||
} else {
|
||||
return {q_shape};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtype) {
|
||||
return {q_dtype};
|
||||
std::vector<paddle::DataType> PagedAttnInferDtype(
|
||||
const paddle::DataType& q_dtype) {
|
||||
return {q_dtype};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(paged_attn)
|
||||
.Inputs({"q", "k_cache", "v_cache", "block_table", "seq_lens",
|
||||
paddle::Optional("alibi_slopes"), paddle::Optional("k"),
|
||||
paddle::Optional("v"), paddle::Optional("rope_sin"),
|
||||
.Inputs({"q",
|
||||
"k_cache",
|
||||
"v_cache",
|
||||
"block_table",
|
||||
"seq_lens",
|
||||
paddle::Optional("alibi_slopes"),
|
||||
paddle::Optional("k"),
|
||||
paddle::Optional("v"),
|
||||
paddle::Optional("rope_sin"),
|
||||
paddle::Optional("rope_cos")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"num_heads:int",
|
||||
@@ -341,11 +368,11 @@ PD_BUILD_STATIC_OP(paged_attn)
|
||||
"scale:float",
|
||||
"block_size:int",
|
||||
"max_context_len:int",
|
||||
"causal:bool",
|
||||
"causal:bool",
|
||||
"window_left:int",
|
||||
"window_right:int",
|
||||
"softcap:float",
|
||||
"enable_cuda_graph:bool",
|
||||
"enable_cuda_graph:bool",
|
||||
"use_sqrt_alibi:bool",
|
||||
"merged_qkv:bool"})
|
||||
.SetKernelFn(PD_KERNEL(PagedAttn))
|
||||
|
||||
@@ -16,352 +16,374 @@
|
||||
#include "iluvatar_context.h"
|
||||
|
||||
template <paddle::DataType T>
|
||||
void PrefillFusedPagedAttnKernel(const paddle::Tensor& qkv,
|
||||
paddle::Tensor& k_cache,
|
||||
paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& cu_seqlens_qkv,
|
||||
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope,
|
||||
paddle::Tensor& out) {
|
||||
void PrefillFusedPagedAttnKernel(
|
||||
const paddle::Tensor& qkv,
|
||||
paddle::Tensor& k_cache,
|
||||
paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& cu_seqlens_qkv,
|
||||
const paddle::optional<paddle::Tensor>& rope_sin,
|
||||
const paddle::optional<paddle::Tensor>& rope_cos,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope,
|
||||
paddle::Tensor& out) {
|
||||
// check dtype and contiguous
|
||||
const auto& dtype = qkv.dtype();
|
||||
cuinferDataType_t data_type;
|
||||
if (dtype == paddle::DataType::FLOAT16) {
|
||||
data_type = CUINFER_DATA_HALF;
|
||||
|
||||
// check dtype and contiguous
|
||||
const auto& dtype = qkv.dtype();
|
||||
cuinferDataType_t data_type;
|
||||
if (dtype == paddle::DataType::FLOAT16) {
|
||||
data_type = CUINFER_DATA_HALF;
|
||||
} else if (dtype == paddle::DataType::BFLOAT16) {
|
||||
data_type = CUINFER_DATA_BFLOAT16;
|
||||
} else {
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention support half and bfloat16 now");
|
||||
}
|
||||
|
||||
} else if (dtype == paddle::DataType::BFLOAT16) {
|
||||
data_type = CUINFER_DATA_BFLOAT16;
|
||||
} else {
|
||||
common::errors::InvalidArgument("paged_attention support half and bfloat16 now");
|
||||
}
|
||||
PADDLE_ENFORCE_EQ(k_cache.dtype(),
|
||||
dtype,
|
||||
common::errors::InvalidArgument(
|
||||
"k_cache dtype must be the same as query dtype"));
|
||||
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects k_cache is contiguous"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
block_table.dtype(),
|
||||
paddle::DataType::INT32,
|
||||
common::errors::InvalidArgument("block_table dtype must be int32"));
|
||||
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects block_table is contiguous"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
cu_seqlens_qkv.dtype(),
|
||||
paddle::DataType::INT32,
|
||||
common::errors::InvalidArgument("cu_seqlens_qkv dtype must be int32"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
cu_seqlens_qkv.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects cu_seqlens_qkv is contiguous"));
|
||||
// check dim and shape
|
||||
// k_cache: [num_blocks, kv_num_heads, block_size, head_dim]
|
||||
// v_cache: [num_blocks, kv_num_heads, block_size, head_dim]
|
||||
// block_table: [batch_size, max_num_blocks_per_seq]
|
||||
// seq_lens: [batch_size]
|
||||
// qkv: [num_tokens, (num_heads+2*num_kv_heads)*head_dim]
|
||||
// out: [num_tokens, hidden_size]
|
||||
|
||||
PADDLE_ENFORCE_EQ(k_cache.dtype(),
|
||||
dtype,
|
||||
common::errors::InvalidArgument(
|
||||
"k_cache dtype must be the same as query dtype"));
|
||||
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects k_cache is contiguous"));
|
||||
PADDLE_ENFORCE_EQ(block_table.dtype(),
|
||||
paddle::DataType::INT32,
|
||||
common::errors::InvalidArgument(
|
||||
"block_table dtype must be int32"));
|
||||
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects block_table is contiguous"));
|
||||
PADDLE_ENFORCE_EQ(cu_seqlens_qkv.dtype(),
|
||||
paddle::DataType::INT32,
|
||||
common::errors::InvalidArgument(
|
||||
"cu_seqlens_qkv dtype must be int32"));
|
||||
PADDLE_ENFORCE_EQ(cu_seqlens_qkv.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects cu_seqlens_qkv is contiguous"));
|
||||
// check dim and shape
|
||||
// k_cache: [num_blocks, kv_num_heads, block_size, head_dim]
|
||||
// v_cache: [num_blocks, kv_num_heads, block_size, head_dim]
|
||||
// block_table: [batch_size, max_num_blocks_per_seq]
|
||||
// seq_lens: [batch_size]
|
||||
// qkv: [num_tokens, (num_heads+2*num_kv_heads)*head_dim]
|
||||
// out: [num_tokens, hidden_size]
|
||||
const auto& qkv_dims = qkv.dims();
|
||||
PADDLE_ENFORCE_EQ(qkv_dims.size(),
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive query dims is "
|
||||
"[num_tokens, (num_heads+2*num_kv_heads)*head_dim]"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
out.dims().size(),
|
||||
2,
|
||||
common::errors::InvalidArgument("paged_attn receive out dims is "
|
||||
"[num_tokens, hidden_size]"));
|
||||
|
||||
const auto& qkv_dims = qkv.dims();
|
||||
PADDLE_ENFORCE_EQ(qkv_dims.size(),
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive query dims is "
|
||||
"[num_tokens, (num_heads+2*num_kv_heads)*head_dim]"));
|
||||
PADDLE_ENFORCE_EQ(out.dims().size(),
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive out dims is "
|
||||
"[num_tokens, hidden_size]"));
|
||||
const auto& kv_cache_dims = k_cache.dims();
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
|
||||
4,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive kv cache dims is "
|
||||
"[num_blocks, kv_num_heads, block_size, head_dim]"));
|
||||
|
||||
const auto& kv_cache_dims = k_cache.dims();
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
|
||||
4,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive kv cache dims is "
|
||||
"[num_blocks, kv_num_heads, block_size, head_dim]"));
|
||||
const auto& block_table_dims = block_table.dims();
|
||||
PADDLE_ENFORCE_EQ(
|
||||
block_table_dims.size(),
|
||||
2,
|
||||
common::errors::InvalidArgument("paged_attn receive block_table dims is "
|
||||
"[batch_size, max_num_blocks_per_seq]"));
|
||||
|
||||
const auto& block_table_dims = block_table.dims();
|
||||
PADDLE_ENFORCE_EQ(block_table_dims.size(),
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive block_table dims is "
|
||||
"[batch_size, max_num_blocks_per_seq]"));
|
||||
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
|
||||
PADDLE_ENFORCE_EQ(
|
||||
cu_seqlens_qkv_dims.size(),
|
||||
1,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive cu_seqlens_qkv dims is [batch_size]"));
|
||||
|
||||
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
|
||||
PADDLE_ENFORCE_EQ(cu_seqlens_qkv_dims.size(),
|
||||
1,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive cu_seqlens_qkv dims is [batch_size]"));
|
||||
int batch_size = block_table_dims[0];
|
||||
int num_tokens = qkv_dims[0];
|
||||
int num_total_heads = num_heads + 2 * num_kv_heads;
|
||||
int qkv_stride = qkv.strides()[0];
|
||||
int num_blocks = kv_cache_dims[0];
|
||||
|
||||
int batch_size = block_table_dims[0];
|
||||
int num_tokens = qkv_dims[0];
|
||||
int num_total_heads = num_heads + 2 * num_kv_heads;
|
||||
int qkv_stride = qkv.strides()[0];
|
||||
int num_blocks = kv_cache_dims[0];
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
|
||||
num_kv_heads,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[1] must be equal to num_kv_head"));
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
|
||||
block_size,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[2] must be equal to block_size"));
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
|
||||
head_dim,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[3] must be equal to head_dim"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
cu_seqlens_qkv_dims[0],
|
||||
batch_size + 1,
|
||||
common::errors::InvalidArgument(
|
||||
"cu_seqlens_qkv_dims[0] must be equal to batch_size + 1"));
|
||||
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
|
||||
num_kv_heads,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[1] must be equal to num_kv_head"));
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
|
||||
block_size,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[2] must be equal to block_size"));
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
|
||||
head_dim,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[3] must be equal to head_dim"));
|
||||
PADDLE_ENFORCE_EQ(cu_seqlens_qkv_dims[0],
|
||||
batch_size + 1,
|
||||
common::errors::InvalidArgument(
|
||||
"cu_seqlens_qkv_dims[0] must be equal to batch_size + 1"));
|
||||
int block_table_stride = block_table.strides()[0];
|
||||
const float* rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
|
||||
const float* rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
|
||||
|
||||
int block_table_stride = block_table.strides()[0];
|
||||
const float *rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
|
||||
const float *rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
|
||||
cuinferHandle_t cuinfer_handle =
|
||||
iluvatar::getContextInstance()->getIxInferHandle();
|
||||
|
||||
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||
size_t workspace_size = 0;
|
||||
CUINFER_CHECK(cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
data_type,
|
||||
data_type,
|
||||
data_type,
|
||||
&workspace_size));
|
||||
auto* allocator = paddle::GetAllocator(qkv.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace =
|
||||
allocator->Allocate(workspace_size);
|
||||
void* workspace_ptr = tmp_workspace->ptr();
|
||||
|
||||
size_t workspace_size = 0;
|
||||
CUINFER_CHECK(cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
data_type,
|
||||
data_type,
|
||||
data_type,
|
||||
&workspace_size));
|
||||
auto* allocator = paddle::GetAllocator(qkv.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace = allocator->Allocate(workspace_size);
|
||||
void* workspace_ptr = tmp_workspace->ptr();
|
||||
|
||||
cuinferTensorDescriptor_t qkv_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t qkv_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
qkv_desc,
|
||||
data_type,
|
||||
3,
|
||||
std::vector<int>({num_tokens, num_total_heads, head_dim}).data(),
|
||||
std::vector<int>({num_total_heads * head_dim, head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t qkv_seqlens_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
qkv_seqlens_desc,
|
||||
CUINFER_DATA_INT32,
|
||||
1,
|
||||
std::vector<int>({batch_size + 1}).data(),
|
||||
std::vector<int>({1}).data()));
|
||||
cuinferTensorDescriptor_t qkv_seqlens_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
|
||||
CUINFER_CHECK(
|
||||
cuinferSetTensorNdDescriptor(qkv_seqlens_desc,
|
||||
CUINFER_DATA_INT32,
|
||||
1,
|
||||
std::vector<int>({batch_size + 1}).data(),
|
||||
std::vector<int>({1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t block_table_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t block_table_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
block_table_desc,
|
||||
CUINFER_DATA_INT32,
|
||||
2,
|
||||
std::vector<int>({batch_size, block_table_stride}).data(),
|
||||
std::vector<int>({block_table_stride, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t o_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t o_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
o_desc,
|
||||
data_type,
|
||||
3,
|
||||
std::vector<int>({num_tokens, num_heads, head_dim}).data(),
|
||||
std::vector<int>({num_heads * head_dim, head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t k_cache_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t k_cache_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
k_cache_desc,
|
||||
data_type,
|
||||
4,
|
||||
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
|
||||
std::vector<int>({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data()));
|
||||
std::vector<int>({num_kv_heads * block_size * head_dim,
|
||||
block_size * head_dim,
|
||||
head_dim,
|
||||
1})
|
||||
.data()));
|
||||
|
||||
cuinferTensorDescriptor_t v_cache_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t v_cache_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
v_cache_desc,
|
||||
data_type,
|
||||
4,
|
||||
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
|
||||
std::vector<int>({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data()));
|
||||
std::vector<int>({num_kv_heads * block_size * head_dim,
|
||||
block_size * head_dim,
|
||||
head_dim,
|
||||
1})
|
||||
.data()));
|
||||
|
||||
cuinferTensorDescriptor_t cos_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t cos_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cos_desc,
|
||||
CUINFER_DATA_FLOAT,
|
||||
2,
|
||||
std::vector<int>({max_seq_len, head_dim}).data(),
|
||||
std::vector<int>({head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t sin_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cuinferTensorDescriptor_t sin_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
sin_desc,
|
||||
CUINFER_DATA_FLOAT,
|
||||
2,
|
||||
std::vector<int>({max_seq_len, head_dim}).data(),
|
||||
std::vector<int>({head_dim, 1}).data()));
|
||||
|
||||
CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
|
||||
qkv_desc,
|
||||
qkv.data(),
|
||||
qkv_seqlens_desc,
|
||||
cu_seqlens_qkv.data<int32_t>(),
|
||||
block_table_desc,
|
||||
block_table.data<int32_t>(),
|
||||
o_desc,
|
||||
out.data(),
|
||||
k_cache_desc,
|
||||
k_cache.data(),
|
||||
v_cache_desc,
|
||||
v_cache.data(),
|
||||
workspace_ptr,
|
||||
workspace_size,
|
||||
cos_desc,
|
||||
rope_cos_ptr,
|
||||
sin_desc,
|
||||
rope_sin_ptr,
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
scale,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope));
|
||||
CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
|
||||
qkv_desc,
|
||||
qkv.data(),
|
||||
qkv_seqlens_desc,
|
||||
cu_seqlens_qkv.data<int32_t>(),
|
||||
block_table_desc,
|
||||
block_table.data<int32_t>(),
|
||||
o_desc,
|
||||
out.data(),
|
||||
k_cache_desc,
|
||||
k_cache.data(),
|
||||
v_cache_desc,
|
||||
v_cache.data(),
|
||||
workspace_ptr,
|
||||
workspace_size,
|
||||
cos_desc,
|
||||
rope_cos_ptr,
|
||||
sin_desc,
|
||||
rope_sin_ptr,
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
scale,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope));
|
||||
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PrefillFusedPagedAttn(const paddle::Tensor& qkv,
|
||||
paddle::Tensor& k_cache,
|
||||
paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& cu_seqlens_qkv,
|
||||
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope) {
|
||||
std::vector<paddle::Tensor> PrefillFusedPagedAttn(
|
||||
const paddle::Tensor& qkv,
|
||||
paddle::Tensor& k_cache,
|
||||
paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& cu_seqlens_qkv,
|
||||
const paddle::optional<paddle::Tensor>& rope_sin,
|
||||
const paddle::optional<paddle::Tensor>& rope_cos,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope) {
|
||||
const auto dtype = qkv.dtype();
|
||||
auto out =
|
||||
paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
|
||||
|
||||
const auto dtype = qkv.dtype();
|
||||
auto out = paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
|
||||
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
PrefillFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
cu_seqlens_qkv,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
out);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
PrefillFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
cu_seqlens_qkv,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
out);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for Paged attn");
|
||||
}
|
||||
return {out};
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
PrefillFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
cu_seqlens_qkv,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
out);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
PrefillFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
cu_seqlens_qkv,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
out);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for Paged attn");
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> PrefillFusedPagedAttnInferShape(const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& k_cache_shape,
|
||||
const std::vector<int64_t>& v_cache_shape,
|
||||
const std::vector<int64_t>& block_table_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_qkv_shape,
|
||||
const std::vector<int64_t>& rope_sin_shape,
|
||||
const std::vector<int64_t>& rope_cos_shape,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope) {
|
||||
return {{qkv_shape[0], num_heads * head_dim}};
|
||||
std::vector<std::vector<int64_t>> PrefillFusedPagedAttnInferShape(
|
||||
const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& k_cache_shape,
|
||||
const std::vector<int64_t>& v_cache_shape,
|
||||
const std::vector<int64_t>& block_table_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_qkv_shape,
|
||||
const std::vector<int64_t>& rope_sin_shape,
|
||||
const std::vector<int64_t>& rope_cos_shape,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope) {
|
||||
return {{qkv_shape[0], num_heads * head_dim}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> PrefillFusedPagedAttnInferDtype(const paddle::DataType& qkv_dtype) {
|
||||
return {qkv_dtype};
|
||||
std::vector<paddle::DataType> PrefillFusedPagedAttnInferDtype(
|
||||
const paddle::DataType& qkv_dtype) {
|
||||
return {qkv_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(prefill_fused_paged_attn)
|
||||
.Inputs({"qkv", "k_cache", "v_cache", "block_table", "cu_seqlens_qkv",
|
||||
paddle::Optional("rope_sin"), paddle::Optional("rope_cos")})
|
||||
.Inputs({"qkv",
|
||||
"k_cache",
|
||||
"v_cache",
|
||||
"block_table",
|
||||
"cu_seqlens_qkv",
|
||||
paddle::Optional("rope_sin"),
|
||||
paddle::Optional("rope_cos")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"num_heads:int",
|
||||
"head_dim:int",
|
||||
@@ -369,8 +391,8 @@ PD_BUILD_STATIC_OP(prefill_fused_paged_attn)
|
||||
"block_size:int",
|
||||
"max_seq_len:int",
|
||||
"scale:float",
|
||||
"causal:bool",
|
||||
"q_rope:bool",
|
||||
"causal:bool",
|
||||
"q_rope:bool",
|
||||
"k_rope:bool",
|
||||
"v_rope:bool"})
|
||||
.SetKernelFn(PD_KERNEL(PrefillFusedPagedAttn))
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include "iluvatar_context.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -33,27 +32,26 @@
|
||||
#include <vector>
|
||||
|
||||
#define CUINFER_CHECK(func) \
|
||||
do { \
|
||||
cuinferStatus_t status = (func); \
|
||||
if (status != CUINFER_STATUS_SUCCESS) { \
|
||||
std::cerr << "Error in file " << __FILE__ << " on line " \
|
||||
<< __LINE__ << ": " << cuinferGetErrorString(status) \
|
||||
<< std::endl; \
|
||||
throw std::runtime_error("CUINFER_CHECK ERROR"); \
|
||||
} \
|
||||
} while (0)
|
||||
do { \
|
||||
cuinferStatus_t status = (func); \
|
||||
if (status != CUINFER_STATUS_SUCCESS) { \
|
||||
std::cerr << "Error in file " << __FILE__ << " on line " << __LINE__ \
|
||||
<< ": " << cuinferGetErrorString(status) << std::endl; \
|
||||
throw std::runtime_error("CUINFER_CHECK ERROR"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace iluvatar {
|
||||
|
||||
class IluvatarContext {
|
||||
public:
|
||||
IluvatarContext() = default;
|
||||
~IluvatarContext();
|
||||
public:
|
||||
IluvatarContext() = default;
|
||||
~IluvatarContext();
|
||||
|
||||
cuinferHandle_t getIxInferHandle();
|
||||
cuinferHandle_t getIxInferHandle();
|
||||
|
||||
private:
|
||||
cuinferHandle_t ixinfer_handle_{nullptr};
|
||||
private:
|
||||
cuinferHandle_t ixinfer_handle_{nullptr};
|
||||
};
|
||||
IluvatarContext* getContextInstance();
|
||||
|
||||
|
||||
@@ -20,157 +20,157 @@ std::vector<paddle::Tensor> GroupGemm(const paddle::Tensor& x,
|
||||
const paddle::Tensor& weight_scale,
|
||||
const paddle::Tensor& prefix_sum,
|
||||
const int32_t group_size) {
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
|
||||
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||
const auto& x_dims = x.dims();
|
||||
const auto& w_dims = weight.dims();
|
||||
const auto& ws_dims = weight_scale.dims();
|
||||
const auto& prefix_sum_dims = prefix_sum.dims();
|
||||
// [m, k]
|
||||
PD_CHECK(x_dims.size() == 2, "x should be 2D");
|
||||
// [n_experts, n, k]
|
||||
PD_CHECK(w_dims.size() == 3, "weight should be 3D");
|
||||
// [n_experts, n]
|
||||
PD_CHECK(ws_dims.size() == 2, "weight_scale should be 2D");
|
||||
// [n_experts]
|
||||
PD_CHECK(prefix_sum_dims.size() == 1, "prefix_sum should be 1D");
|
||||
PD_CHECK(group_size == -1);
|
||||
auto m = x_dims[0];
|
||||
auto k = x_dims[1];
|
||||
auto n_experts = w_dims[0];
|
||||
auto n = w_dims[1];
|
||||
PD_CHECK(w_dims[2] == k);
|
||||
PD_CHECK(ws_dims[0] == n_experts);
|
||||
PD_CHECK(ws_dims[1] == n);
|
||||
PD_CHECK(prefix_sum_dims[0] == n_experts);
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
|
||||
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||
const auto& x_dims = x.dims();
|
||||
const auto& w_dims = weight.dims();
|
||||
const auto& ws_dims = weight_scale.dims();
|
||||
const auto& prefix_sum_dims = prefix_sum.dims();
|
||||
// [m, k]
|
||||
PD_CHECK(x_dims.size() == 2, "x should be 2D");
|
||||
// [n_experts, n, k]
|
||||
PD_CHECK(w_dims.size() == 3, "weight should be 3D");
|
||||
// [n_experts, n]
|
||||
PD_CHECK(ws_dims.size() == 2, "weight_scale should be 2D");
|
||||
// [n_experts]
|
||||
PD_CHECK(prefix_sum_dims.size() == 1, "prefix_sum should be 1D");
|
||||
PD_CHECK(group_size == -1);
|
||||
auto m = x_dims[0];
|
||||
auto k = x_dims[1];
|
||||
auto n_experts = w_dims[0];
|
||||
auto n = w_dims[1];
|
||||
PD_CHECK(w_dims[2] == k);
|
||||
PD_CHECK(ws_dims[0] == n_experts);
|
||||
PD_CHECK(ws_dims[1] == n);
|
||||
PD_CHECK(prefix_sum_dims[0] == n_experts);
|
||||
|
||||
PD_CHECK(prefix_sum.dtype() == paddle::DataType::INT64);
|
||||
PD_CHECK(prefix_sum.is_cpu());
|
||||
PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 ||
|
||||
x.dtype() == paddle::DataType::FLOAT16);
|
||||
PD_CHECK(weight.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(weight_scale.dtype() == x.dtype());
|
||||
PD_CHECK(x.is_contiguous());
|
||||
PD_CHECK(weight.is_contiguous());
|
||||
PD_CHECK(weight_scale.is_contiguous());
|
||||
PD_CHECK(prefix_sum.dtype() == paddle::DataType::INT64);
|
||||
PD_CHECK(prefix_sum.is_cpu());
|
||||
PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 ||
|
||||
x.dtype() == paddle::DataType::FLOAT16);
|
||||
PD_CHECK(weight.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(weight_scale.dtype() == x.dtype());
|
||||
PD_CHECK(x.is_contiguous());
|
||||
PD_CHECK(weight.is_contiguous());
|
||||
PD_CHECK(weight_scale.is_contiguous());
|
||||
|
||||
const int64_t* prefix_sum_ptr = prefix_sum.data<int64_t>();
|
||||
auto output = GetEmptyTensor({m, n}, x.dtype(), x.place());
|
||||
int16_t* out_data = static_cast<int16_t*>(output.data());
|
||||
const int16_t* x_data = static_cast<const int16_t*>(x.data());
|
||||
const int8_t* weight_data = weight.data<int8_t>();
|
||||
const int16_t* weight_scale_data =
|
||||
static_cast<const int16_t*>(weight_scale.data());
|
||||
const int64_t* prefix_sum_ptr = prefix_sum.data<int64_t>();
|
||||
auto output = GetEmptyTensor({m, n}, x.dtype(), x.place());
|
||||
int16_t* out_data = static_cast<int16_t*>(output.data());
|
||||
const int16_t* x_data = static_cast<const int16_t*>(x.data());
|
||||
const int8_t* weight_data = weight.data<int8_t>();
|
||||
const int16_t* weight_scale_data =
|
||||
static_cast<const int16_t*>(weight_scale.data());
|
||||
|
||||
cuinferHandle_t handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||
cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST;
|
||||
cuinferOperation_t transa = CUINFER_OP_T;
|
||||
cuinferOperation_t transb = CUINFER_OP_N;
|
||||
cudaDataType_t a_type = CUDA_R_8I;
|
||||
cudaDataType_t b_type;
|
||||
cudaDataType_t c_type;
|
||||
if (x.dtype() == paddle::DataType::FLOAT16) {
|
||||
b_type = CUDA_R_16F;
|
||||
} else if (x.dtype() == paddle::DataType::BFLOAT16) {
|
||||
b_type = CUDA_R_16BF;
|
||||
} else {
|
||||
PADDLE_THROW(common::errors::Unimplemented("Unsupported input dtype."));
|
||||
cuinferHandle_t handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||
cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST;
|
||||
cuinferOperation_t transa = CUINFER_OP_T;
|
||||
cuinferOperation_t transb = CUINFER_OP_N;
|
||||
cudaDataType_t a_type = CUDA_R_8I;
|
||||
cudaDataType_t b_type;
|
||||
cudaDataType_t c_type;
|
||||
if (x.dtype() == paddle::DataType::FLOAT16) {
|
||||
b_type = CUDA_R_16F;
|
||||
} else if (x.dtype() == paddle::DataType::BFLOAT16) {
|
||||
b_type = CUDA_R_16BF;
|
||||
} else {
|
||||
PADDLE_THROW(common::errors::Unimplemented("Unsupported input dtype."));
|
||||
}
|
||||
c_type = b_type;
|
||||
cudaDataType_t Atype = a_type;
|
||||
cudaDataType_t Btype = b_type;
|
||||
cudaDataType_t Ctype = c_type;
|
||||
cudaDataType_t computeType = CUDA_R_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE;
|
||||
|
||||
cuinferQuantGEMMHostParam cust_host_param;
|
||||
cust_host_param.size = sizeof(cuinferQuantGEMMHostParam);
|
||||
cust_host_param.persistent = 0;
|
||||
cust_host_param.groupSize = group_size;
|
||||
cuinferQuantGEMMDeviceParam cust_device_param;
|
||||
cust_device_param.bias = nullptr;
|
||||
cust_device_param.workspace = nullptr;
|
||||
|
||||
int lda = k;
|
||||
int ldb = k;
|
||||
int ldc = n;
|
||||
float beta = 0.f;
|
||||
float alpha = 1.f;
|
||||
int batch_count = 1;
|
||||
size_t pre = 0;
|
||||
|
||||
auto* allocator = paddle::GetAllocator(x.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace;
|
||||
for (int i = 0; i < n_experts; i++) {
|
||||
size_t expert_i_end = prefix_sum_ptr[i];
|
||||
size_t cur_len = expert_i_end - pre;
|
||||
pre = expert_i_end;
|
||||
if (cur_len != 0) {
|
||||
cust_device_param.scale = weight_scale_data;
|
||||
|
||||
if (k % 64 != 0) {
|
||||
size_t workspace_size;
|
||||
CUINFER_CHECK(cuinferGetCustomGemmWorkspace(transa,
|
||||
transb,
|
||||
n,
|
||||
cur_len,
|
||||
k,
|
||||
Atype,
|
||||
lda,
|
||||
lda,
|
||||
Btype,
|
||||
ldb,
|
||||
ldb,
|
||||
Ctype,
|
||||
ldc,
|
||||
ldc,
|
||||
batch_count,
|
||||
computeType,
|
||||
scaleType,
|
||||
&workspace_size));
|
||||
tmp_workspace = allocator->Allocate(workspace_size);
|
||||
cust_device_param.workspace = tmp_workspace->ptr();
|
||||
} else {
|
||||
cust_device_param.workspace = nullptr;
|
||||
}
|
||||
|
||||
CUINFER_CHECK(cuinferCustomGemm(handle,
|
||||
stream,
|
||||
cuinfer_ptr_mode,
|
||||
transa,
|
||||
transb,
|
||||
n,
|
||||
cur_len,
|
||||
k,
|
||||
&alpha,
|
||||
weight_data,
|
||||
Atype,
|
||||
lda,
|
||||
lda,
|
||||
x_data,
|
||||
Btype,
|
||||
ldb,
|
||||
ldb,
|
||||
&beta,
|
||||
out_data,
|
||||
Ctype,
|
||||
ldc,
|
||||
ldc,
|
||||
batch_count,
|
||||
computeType,
|
||||
scaleType,
|
||||
&cust_host_param,
|
||||
&cust_device_param,
|
||||
customOption));
|
||||
}
|
||||
c_type = b_type;
|
||||
cudaDataType_t Atype = a_type;
|
||||
cudaDataType_t Btype = b_type;
|
||||
cudaDataType_t Ctype = c_type;
|
||||
cudaDataType_t computeType = CUDA_R_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE;
|
||||
|
||||
cuinferQuantGEMMHostParam cust_host_param;
|
||||
cust_host_param.size = sizeof(cuinferQuantGEMMHostParam);
|
||||
cust_host_param.persistent = 0;
|
||||
cust_host_param.groupSize = group_size;
|
||||
cuinferQuantGEMMDeviceParam cust_device_param;
|
||||
cust_device_param.bias = nullptr;
|
||||
cust_device_param.workspace = nullptr;
|
||||
|
||||
int lda = k;
|
||||
int ldb = k;
|
||||
int ldc = n;
|
||||
float beta = 0.f;
|
||||
float alpha = 1.f;
|
||||
int batch_count = 1;
|
||||
size_t pre = 0;
|
||||
|
||||
auto* allocator = paddle::GetAllocator(x.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace;
|
||||
for (int i = 0; i < n_experts; i++) {
|
||||
size_t expert_i_end = prefix_sum_ptr[i];
|
||||
size_t cur_len = expert_i_end - pre;
|
||||
pre = expert_i_end;
|
||||
if (cur_len != 0) {
|
||||
cust_device_param.scale = weight_scale_data;
|
||||
|
||||
if (k % 64 != 0) {
|
||||
size_t workspace_size;
|
||||
CUINFER_CHECK(cuinferGetCustomGemmWorkspace(transa,
|
||||
transb,
|
||||
n,
|
||||
cur_len,
|
||||
k,
|
||||
Atype,
|
||||
lda,
|
||||
lda,
|
||||
Btype,
|
||||
ldb,
|
||||
ldb,
|
||||
Ctype,
|
||||
ldc,
|
||||
ldc,
|
||||
batch_count,
|
||||
computeType,
|
||||
scaleType,
|
||||
&workspace_size));
|
||||
tmp_workspace = allocator->Allocate(workspace_size);
|
||||
cust_device_param.workspace = tmp_workspace->ptr();
|
||||
} else {
|
||||
cust_device_param.workspace = nullptr;
|
||||
}
|
||||
|
||||
CUINFER_CHECK(cuinferCustomGemm(handle,
|
||||
stream,
|
||||
cuinfer_ptr_mode,
|
||||
transa,
|
||||
transb,
|
||||
n,
|
||||
cur_len,
|
||||
k,
|
||||
&alpha,
|
||||
weight_data,
|
||||
Atype,
|
||||
lda,
|
||||
lda,
|
||||
x_data,
|
||||
Btype,
|
||||
ldb,
|
||||
ldb,
|
||||
&beta,
|
||||
out_data,
|
||||
Ctype,
|
||||
ldc,
|
||||
ldc,
|
||||
batch_count,
|
||||
computeType,
|
||||
scaleType,
|
||||
&cust_host_param,
|
||||
&cust_device_param,
|
||||
customOption));
|
||||
}
|
||||
x_data += cur_len * k;
|
||||
weight_data += k * n;
|
||||
weight_scale_data += n;
|
||||
out_data += cur_len * n;
|
||||
}
|
||||
return {output};
|
||||
x_data += cur_len * k;
|
||||
weight_data += k * n;
|
||||
weight_scale_data += n;
|
||||
out_data += cur_len * n;
|
||||
}
|
||||
return {output};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GroupGemmInferShape(
|
||||
@@ -178,7 +178,7 @@ std::vector<std::vector<int64_t>> GroupGemmInferShape(
|
||||
const std::vector<int64_t>& weight_shape,
|
||||
const std::vector<int64_t>& weight_scale_shape,
|
||||
const std::vector<int64_t>& prefix_sum_shape) {
|
||||
return {{x_shape[0], weight_shape[1]}};
|
||||
return {{x_shape[0], weight_shape[1]}};
|
||||
}
|
||||
std::vector<paddle::DataType> GroupGemmInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
@@ -186,7 +186,7 @@ std::vector<paddle::DataType> GroupGemmInferDtype(
|
||||
const paddle::DataType& weight_scale_dtype,
|
||||
const paddle::DataType& prefix_sum_dtype,
|
||||
const int moe_topk) {
|
||||
return {input_dtype};
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(w8a16_group_gemm)
|
||||
|
||||
@@ -12,12 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fused_moe_op.h"
|
||||
#include "helper.h"
|
||||
#include "mc_fused_moe_helper.h"
|
||||
#include "fused_moe_op.h"
|
||||
|
||||
__global__ void compute_total_rows_before_expert_kernel(
|
||||
int* sorted_experts,
|
||||
@@ -43,7 +42,10 @@ void compute_total_rows_before_expert(int* sorted_indices,
|
||||
sorted_indices, total_indices, num_experts, total_rows_before_expert);
|
||||
}
|
||||
|
||||
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
|
||||
template <paddle::DataType T,
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC>
|
||||
void FusedMoeKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
@@ -63,27 +65,26 @@ void FusedMoeKernel(const paddle::Tensor& input,
|
||||
|
||||
auto* output_data = output->data<data_t>();
|
||||
|
||||
auto moe_compute = McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
|
||||
auto moe_compute =
|
||||
McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
|
||||
|
||||
moe_compute.computeFFN(
|
||||
&input,
|
||||
&gate_weight,
|
||||
&ffn1_weight,
|
||||
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
|
||||
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
|
||||
&ffn2_weight,
|
||||
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
|
||||
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
|
||||
nullptr,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
1.0, // ComputeFFN
|
||||
"ffn",
|
||||
output);
|
||||
moe_compute.computeFFN(&input,
|
||||
&gate_weight,
|
||||
&ffn1_weight,
|
||||
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
|
||||
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
|
||||
&ffn2_weight,
|
||||
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
|
||||
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
|
||||
nullptr,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
1.0, // ComputeFFN
|
||||
"ffn",
|
||||
output);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
@@ -102,19 +103,22 @@ std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
FusedMoeKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(input,
|
||||
gate_weight,
|
||||
ffn1_weight,
|
||||
ffn1_scale,
|
||||
ffn1_bias,
|
||||
ffn2_weight,
|
||||
ffn2_scale,
|
||||
ffn2_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
&output);
|
||||
FusedMoeKernel<paddle::DataType::BFLOAT16,
|
||||
maca_bfloat16,
|
||||
int8_t,
|
||||
maca_bfloat16>(input,
|
||||
gate_weight,
|
||||
ffn1_weight,
|
||||
ffn1_scale,
|
||||
ffn1_bias,
|
||||
ffn2_weight,
|
||||
ffn2_scale,
|
||||
ffn2_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
&output);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// FusedMoeKernel<paddle::DataType::FLOAT16>(input,
|
||||
@@ -161,7 +165,6 @@ std::vector<paddle::DataType> FusedExpertMoeInferDtype(
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(fused_expert_moe)
|
||||
.Inputs({"input",
|
||||
"gate_weight",
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "cub/cub.cuh"
|
||||
|
||||
static const float HALF_FLT_MAX = 65504.F;
|
||||
|
||||
@@ -19,9 +19,9 @@
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include "fused_moe_imp_op.h"
|
||||
#include "fused_moe_helper.h"
|
||||
#include "mctlass/numeric_conversion.h" // BUILD_MARK
|
||||
#include "fused_moe_imp_op.h"
|
||||
#include "mctlass/numeric_conversion.h" // BUILD_MARK
|
||||
// Ignore mctlass warnings about type punning
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
@@ -35,8 +35,8 @@
|
||||
#define WARP_SIZE 32
|
||||
|
||||
struct GpuLaunchConfig {
|
||||
dim3 block_per_grid;
|
||||
dim3 thread_per_block;
|
||||
dim3 block_per_grid;
|
||||
dim3 thread_per_block;
|
||||
};
|
||||
|
||||
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
|
||||
@@ -82,7 +82,6 @@ __launch_bounds__(TPB) __global__
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
@@ -603,7 +602,7 @@ void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
}
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
case N: { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
|
||||
input, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
@@ -646,14 +645,8 @@ void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, softmax, num_experts, num_rows);
|
||||
moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
softmax, output, indices, source_row, num_experts, k, num_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,52 +1,71 @@
|
||||
#include "fused_moe_helper.h"
|
||||
#include "mctlass/numeric_conversion.h"
|
||||
#include "mctlassEx/mctlassEx.h"
|
||||
#include "fused_moe_helper.h"
|
||||
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementC>
|
||||
void mc_grouped_gemm_basic_kernel(
|
||||
const ElementA* ptrA,
|
||||
mctlassExOrder_t majorA,
|
||||
const ElementB* ptrB,
|
||||
mctlassExOrder_t majorB,
|
||||
const ElementA* ptrScale,
|
||||
const ElementA* ptrBias,
|
||||
ElementC* ptrC,
|
||||
mctlassExOrder_t majorC,
|
||||
const int *ptrSegInd,
|
||||
int numExperts,
|
||||
int m, // expanded_active_expert_rows
|
||||
int n, // inter_dim
|
||||
int k, // hidden_size
|
||||
mcStream_t stream) {
|
||||
void mc_grouped_gemm_basic_kernel(const ElementA *ptrA,
|
||||
mctlassExOrder_t majorA,
|
||||
const ElementB *ptrB,
|
||||
mctlassExOrder_t majorB,
|
||||
const ElementA *ptrScale,
|
||||
const ElementA *ptrBias,
|
||||
ElementC *ptrC,
|
||||
mctlassExOrder_t majorC,
|
||||
const int *ptrSegInd,
|
||||
int numExperts,
|
||||
int m, // expanded_active_expert_rows
|
||||
int n, // inter_dim
|
||||
int k, // hidden_size
|
||||
mcStream_t stream) {
|
||||
mctlassExHandle_t handle;
|
||||
mctlassExHandleCreate(&handle);
|
||||
|
||||
int* ptrMNumTilesInd;
|
||||
mcMallocAsync((void**)&ptrMNumTilesInd, sizeof(int) * numExperts, stream);
|
||||
int *ptrMNumTilesInd;
|
||||
mcMallocAsync((void **)&ptrMNumTilesInd, sizeof(int) * numExperts, stream);
|
||||
|
||||
mctlassExMatrixLayout_t matLayoutA;
|
||||
mctlassExMatrixLayout_t matLayoutB;
|
||||
mctlassExMatrixLayout_t matLayoutC;
|
||||
|
||||
// mat A: (m, k)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorA, sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts, sizeof(int));
|
||||
mctlassExMatrixLayoutCreate(
|
||||
&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k);
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutA,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorA,
|
||||
sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutA,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts,
|
||||
sizeof(int));
|
||||
// mat B: (num_experts, n, k)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorB, sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts, sizeof(int));
|
||||
mctlassExMatrixLayoutCreate(
|
||||
&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k);
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutB,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorB,
|
||||
sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutB,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts,
|
||||
sizeof(int));
|
||||
// mat C: (m, n)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorC, sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts, sizeof(int));
|
||||
mctlassExMatrixLayoutCreate(
|
||||
&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n);
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutC,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorC,
|
||||
sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutC,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts,
|
||||
sizeof(int));
|
||||
// bias: (num_experts, n)
|
||||
// scale: (num, n)
|
||||
|
||||
@@ -55,44 +74,81 @@ void mc_grouped_gemm_basic_kernel(
|
||||
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_DATATYPE_BF16;
|
||||
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_DATATYPE_INT8;
|
||||
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_DATATYPE_FP32;
|
||||
mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
|
||||
mctlassExEpilogueType epilogue_type =
|
||||
mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
|
||||
if (ptrBias) {
|
||||
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS;
|
||||
}
|
||||
// set scale
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
|
||||
&ptrScale, sizeof(ptrScale));
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
|
||||
&input_type, sizeof(mctlassExDataType));
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
|
||||
&ptrScale,
|
||||
sizeof(ptrScale));
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
|
||||
&input_type,
|
||||
sizeof(mctlassExDataType));
|
||||
// set bias
|
||||
if (ptrBias) {
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
|
||||
&ptrBias, sizeof(ptrBias));
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
|
||||
&ptrBias,
|
||||
sizeof(ptrBias));
|
||||
}
|
||||
// set coumpute type
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
|
||||
&compute_type, sizeof(mctlassExDataType));
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
|
||||
&compute_type,
|
||||
sizeof(mctlassExDataType));
|
||||
// set epilogue type
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
|
||||
&epilogue_type, sizeof(mctlassExEpilogueType));
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
|
||||
&epilogue_type,
|
||||
sizeof(mctlassExEpilogueType));
|
||||
|
||||
const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
|
||||
const mctlassExContiguousGroupedGemmAlgo_t algo =
|
||||
mctlassExContiguousGroupedGemmAlgo_t::
|
||||
MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
|
||||
mctlassExContiguousGroupedDesc_t contiguous_group_desc;
|
||||
mctlassExContiguousGroupedDescCreate(&contiguous_group_desc,
|
||||
ptrSegInd,
|
||||
nullptr,
|
||||
ptrMNumTilesInd,
|
||||
1);
|
||||
mctlassExContiguousGroupedDescCreate(
|
||||
&contiguous_group_desc, ptrSegInd, nullptr, ptrMNumTilesInd, 1);
|
||||
int blocksizeM;
|
||||
mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, &blocksizeM);
|
||||
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, contiguous_group_desc, numExperts, blocksizeM, stream);
|
||||
mctlassExContiguousGroupedGemmGetBlocksizeM(handle,
|
||||
mctlass_desc,
|
||||
matLayoutA,
|
||||
matLayoutB,
|
||||
matLayoutC,
|
||||
&algo,
|
||||
&blocksizeM);
|
||||
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle,
|
||||
mctlass_desc,
|
||||
matLayoutA,
|
||||
matLayoutB,
|
||||
matLayoutC,
|
||||
&algo,
|
||||
contiguous_group_desc,
|
||||
numExperts,
|
||||
blocksizeM,
|
||||
stream);
|
||||
|
||||
mctlassExContiguousGroupedGemmBasic(handle, mctlass_desc,
|
||||
ptrA, matLayoutA,
|
||||
ptrB, matLayoutB,
|
||||
ptrC, matLayoutC,
|
||||
mctlassExContiguousGroupedGemmBasic(handle,
|
||||
mctlass_desc,
|
||||
ptrA,
|
||||
matLayoutA,
|
||||
ptrB,
|
||||
matLayoutB,
|
||||
ptrC,
|
||||
matLayoutC,
|
||||
contiguous_group_desc,
|
||||
&algo, nullptr, 0, stream);
|
||||
&algo,
|
||||
nullptr,
|
||||
0,
|
||||
stream);
|
||||
|
||||
mctlassExHandleDestroy(handle);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutA);
|
||||
@@ -103,312 +159,312 @@ void mc_grouped_gemm_basic_kernel(
|
||||
mcFreeAsync(ptrMNumTilesInd, stream);
|
||||
}
|
||||
|
||||
template<typename T, typename ElementA, typename ElementB, typename ElementC>
|
||||
template <typename T, typename ElementA, typename ElementB, typename ElementC>
|
||||
class McMoeHelper {
|
||||
public:
|
||||
McMoeHelper(const std::string gemm_method): gemm_method_(gemm_method) {}
|
||||
public:
|
||||
McMoeHelper(const std::string gemm_method) : gemm_method_(gemm_method) {}
|
||||
|
||||
// -------- getWorkspaceSize -------- //
|
||||
template <typename KeyT>
|
||||
size_t getWorkspaceSize(const int64_t num_rows,
|
||||
const int64_t hidden_size,
|
||||
const int64_t inter_size,
|
||||
const int64_t num_experts,
|
||||
const int64_t k) {
|
||||
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const size_t padded_experts = AlignTo16(num_experts);
|
||||
const size_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
// softmax output, permuted_rows and permuted_experts have moved to outside
|
||||
// of moe kernel, allocate them in Encoder or Decoder before invoking
|
||||
// FfnLayer forward.
|
||||
size_t total_ws_bytes =
|
||||
5 * num_moe_inputs *
|
||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||
total_ws_bytes +=
|
||||
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
|
||||
// -------- getWorkspaceSize -------- //
|
||||
template <typename KeyT>
|
||||
size_t getWorkspaceSize(const int64_t num_rows,
|
||||
const int64_t hidden_size,
|
||||
const int64_t inter_size,
|
||||
const int64_t num_experts,
|
||||
const int64_t k) {
|
||||
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const size_t padded_experts = AlignTo16(num_experts);
|
||||
const size_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
// softmax output, permuted_rows and permuted_experts have moved to outside
|
||||
// of moe kernel, allocate them in Encoder or Decoder before invoking
|
||||
// FfnLayer forward.
|
||||
size_t total_ws_bytes =
|
||||
5 * num_moe_inputs *
|
||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||
total_ws_bytes +=
|
||||
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
|
||||
|
||||
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
||||
const size_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(num_rows));
|
||||
sorter_.update_num_experts(num_experts);
|
||||
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
||||
const size_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(num_rows));
|
||||
sorter_.update_num_experts(num_experts);
|
||||
|
||||
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
|
||||
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
|
||||
int64_t remaining_bytes =
|
||||
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
|
||||
bytes_for_intermediate_and_sorting += remaining_bytes;
|
||||
}
|
||||
|
||||
total_ws_bytes +=
|
||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||
// sorting workspace
|
||||
|
||||
int64_t num_softmax_outs = 0;
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
num_softmax_outs = AlignTo16(num_rows * num_experts);
|
||||
}
|
||||
|
||||
total_ws_bytes += num_softmax_outs * sizeof(float);
|
||||
|
||||
return total_ws_bytes;
|
||||
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
|
||||
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
|
||||
int64_t remaining_bytes =
|
||||
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
|
||||
bytes_for_intermediate_and_sorting += remaining_bytes;
|
||||
}
|
||||
|
||||
void computeFFN(const paddle::Tensor *input,
|
||||
const paddle::Tensor *gate_weight,
|
||||
const paddle::Tensor *ffn1_weight,
|
||||
const paddle::Tensor *ffn1_scale,
|
||||
const paddle::Tensor *ffn1_bias,
|
||||
const paddle::Tensor *ffn2_weight,
|
||||
const paddle::Tensor *ffn2_scale,
|
||||
const paddle::Tensor *ffn2_bias,
|
||||
const paddle::Tensor *moe_token_type_ids,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const std::string moe_type,
|
||||
paddle::Tensor *output) {
|
||||
auto *input_activations = input->data<T>();
|
||||
auto *gating_weights = gate_weight->data<float>();
|
||||
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
|
||||
total_ws_bytes +=
|
||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||
// sorting workspace
|
||||
|
||||
auto *output_ = output->data<T>();
|
||||
auto stream = input->stream();
|
||||
auto place = input->place();
|
||||
auto input_type = input->dtype();
|
||||
int64_t num_softmax_outs = 0;
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
num_softmax_outs = AlignTo16(num_rows * num_experts);
|
||||
}
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto ffn1_dims = ffn1_weight->dims();
|
||||
int64_t token_num = 0;
|
||||
if (input_dims.size() == 3) {
|
||||
token_num = input_dims[0] * input_dims[1];
|
||||
} else {
|
||||
token_num = input_dims[0];
|
||||
}
|
||||
const int64_t num_rows = token_num;
|
||||
total_ws_bytes += num_softmax_outs * sizeof(float);
|
||||
|
||||
const int64_t hidden_size = ffn1_dims[2];
|
||||
int64_t inter_dim = 0;
|
||||
if (moe_type == "qkv") {
|
||||
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
|
||||
} else {
|
||||
inter_dim = ffn1_dims[1];
|
||||
}
|
||||
return total_ws_bytes;
|
||||
}
|
||||
|
||||
// if (gemm_method == "weight_only_int4") {
|
||||
// inter_dim = inter_dim * 2;
|
||||
// }
|
||||
void computeFFN(const paddle::Tensor *input,
|
||||
const paddle::Tensor *gate_weight,
|
||||
const paddle::Tensor *ffn1_weight,
|
||||
const paddle::Tensor *ffn1_scale,
|
||||
const paddle::Tensor *ffn1_bias,
|
||||
const paddle::Tensor *ffn2_weight,
|
||||
const paddle::Tensor *ffn2_scale,
|
||||
const paddle::Tensor *ffn2_bias,
|
||||
const paddle::Tensor *moe_token_type_ids,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const std::string moe_type,
|
||||
paddle::Tensor *output) {
|
||||
auto *input_activations = input->data<T>();
|
||||
auto *gating_weights = gate_weight->data<float>();
|
||||
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
|
||||
|
||||
const int64_t inter_size = inter_dim;
|
||||
const int64_t num_experts = ffn1_dims[0];
|
||||
const int64_t k = moe_topk;
|
||||
auto *output_ = output->data<T>();
|
||||
auto stream = input->stream();
|
||||
auto place = input->place();
|
||||
auto input_type = input->dtype();
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto ffn1_dims = ffn1_weight->dims();
|
||||
int64_t token_num = 0;
|
||||
if (input_dims.size() == 3) {
|
||||
token_num = input_dims[0] * input_dims[1];
|
||||
} else {
|
||||
token_num = input_dims[0];
|
||||
}
|
||||
const int64_t num_rows = token_num;
|
||||
|
||||
int64_t bytes =
|
||||
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
|
||||
const int64_t hidden_size = ffn1_dims[2];
|
||||
int64_t inter_dim = 0;
|
||||
if (moe_type == "qkv") {
|
||||
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
|
||||
} else {
|
||||
inter_dim = ffn1_dims[1];
|
||||
}
|
||||
|
||||
// Pointers
|
||||
int *expert_for_source_row;
|
||||
int *source_rows_;
|
||||
int *permuted_rows_;
|
||||
int *permuted_experts_;
|
||||
int *expanded_source_row_to_expanded_dest_row;
|
||||
// if (gemm_method == "weight_only_int4") {
|
||||
// inter_dim = inter_dim * 2;
|
||||
// }
|
||||
|
||||
T *permuted_data_;
|
||||
int32_t *total_rows_before_expert_;
|
||||
T *fc1_result_;
|
||||
float *softmax_out_;
|
||||
const int64_t inter_size = inter_dim;
|
||||
const int64_t num_experts = ffn1_dims[0];
|
||||
const int64_t k = moe_topk;
|
||||
|
||||
paddle::Tensor ws_ptr_tensor =
|
||||
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
|
||||
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
int64_t bytes =
|
||||
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
|
||||
|
||||
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const int64_t padded_experts = AlignTo16(num_experts);
|
||||
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
// Pointers
|
||||
int *expert_for_source_row;
|
||||
int *source_rows_;
|
||||
int *permuted_rows_;
|
||||
int *permuted_experts_;
|
||||
int *expanded_source_row_to_expanded_dest_row;
|
||||
|
||||
expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
|
||||
source_rows_ = expert_for_source_row + num_moe_inputs;
|
||||
permuted_rows_ = source_rows_ + num_moe_inputs;
|
||||
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
||||
expanded_source_row_to_expanded_dest_row =
|
||||
permuted_experts_ + num_moe_inputs;
|
||||
permuted_data_ = reinterpret_cast<T *>(
|
||||
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
|
||||
total_rows_before_expert_ =
|
||||
reinterpret_cast<int32_t *>(permuted_data_ + buf_size);
|
||||
fc1_result_ =
|
||||
reinterpret_cast<T *>(total_rows_before_expert_ + padded_experts);
|
||||
T *permuted_data_;
|
||||
int32_t *total_rows_before_expert_;
|
||||
T *fc1_result_;
|
||||
float *softmax_out_;
|
||||
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
softmax_out_ = reinterpret_cast<float *>(fc1_result_ + interbuf_size);
|
||||
} else {
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
paddle::Tensor ws_ptr_tensor =
|
||||
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
|
||||
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
|
||||
paddle::Tensor expert_scales_float_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
float *expert_scales_float = expert_scales_float_tensor.data<float>();
|
||||
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const int64_t padded_experts = AlignTo16(num_experts);
|
||||
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
|
||||
float *softmax_max_prob = nullptr;
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
|
||||
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill success ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
|
||||
source_rows_ = expert_for_source_row + num_moe_inputs;
|
||||
permuted_rows_ = source_rows_ + num_moe_inputs;
|
||||
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
||||
expanded_source_row_to_expanded_dest_row =
|
||||
permuted_experts_ + num_moe_inputs;
|
||||
permuted_data_ = reinterpret_cast<T *>(
|
||||
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
|
||||
total_rows_before_expert_ =
|
||||
reinterpret_cast<int32_t *>(permuted_data_ + buf_size);
|
||||
fc1_result_ =
|
||||
reinterpret_cast<T *>(total_rows_before_expert_ + padded_experts);
|
||||
|
||||
paddle::Tensor fc1_out_tensor =
|
||||
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
|
||||
T *fc1_out = fc1_out_tensor.data<T>();
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
softmax_out_ = reinterpret_cast<float *>(fc1_result_ + interbuf_size);
|
||||
} else {
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
auto input_cast_tensor =
|
||||
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
|
||||
auto gate_tensor =
|
||||
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
|
||||
float *gating_output = gate_tensor.data<float>();
|
||||
paddle::Tensor expert_scales_float_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
float *expert_scales_float = expert_scales_float_tensor.data<float>();
|
||||
|
||||
if (moe_token_type_ids) {
|
||||
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||
moe_token_type_ids_out,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
stream);
|
||||
}
|
||||
float *softmax_max_prob = nullptr;
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
|
||||
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill success ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float>(gating_output,
|
||||
expert_scales_float,
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
group_moe,
|
||||
stream);
|
||||
paddle::Tensor fc1_out_tensor =
|
||||
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
|
||||
T *fc1_out = fc1_out_tensor.data<T>();
|
||||
|
||||
const int64_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
||||
auto input_cast_tensor =
|
||||
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
|
||||
auto gate_tensor =
|
||||
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
|
||||
float *gating_output = gate_tensor.data<float>();
|
||||
|
||||
sorter_.run(fc1_result_,
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
k * num_rows,
|
||||
false,
|
||||
stream);
|
||||
if (moe_token_type_ids) {
|
||||
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||
moe_token_type_ids_out,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
stream);
|
||||
}
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations,
|
||||
permuted_data_,
|
||||
permuted_rows_,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
stream);
|
||||
topk_gating_softmax_kernelLauncher<float>(gating_output,
|
||||
expert_scales_float,
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
group_moe,
|
||||
stream);
|
||||
|
||||
const int64_t expanded_active_expert_rows = k * num_rows;
|
||||
const int64_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
||||
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
expanded_active_expert_rows,
|
||||
num_experts,
|
||||
total_rows_before_expert_,
|
||||
stream);
|
||||
sorter_.run(fc1_result_,
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
k * num_rows,
|
||||
false,
|
||||
stream);
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations,
|
||||
permuted_data_,
|
||||
permuted_rows_,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
stream);
|
||||
|
||||
const int64_t expanded_active_expert_rows = k * num_rows;
|
||||
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
expanded_active_expert_rows,
|
||||
num_experts,
|
||||
total_rows_before_expert_,
|
||||
stream);
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||
mctlassExOrder_t column_major =
|
||||
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(permuted_data_),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn1_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(ffn1_scale->data<T>()),
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC *>(fc1_out),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
stream);
|
||||
|
||||
if (moe_type == "ffn") {
|
||||
auto act_out_tensor =
|
||||
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<T>();
|
||||
|
||||
paddle::Tensor fc2_output_tensor =
|
||||
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
|
||||
T *fc2_result = fc2_output_tensor.data<T>();
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(permuted_data_),
|
||||
reinterpret_cast<const ElementA *>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn1_weight->data<ElementB>()),
|
||||
reinterpret_cast<const ElementB *>(ffn2_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(ffn1_scale->data<T>()),
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC *>(fc1_out),
|
||||
reinterpret_cast<const ElementA *>(ffn2_scale->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC *>(fc2_result),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
stream);
|
||||
|
||||
if (moe_type == "ffn") {
|
||||
auto act_out_tensor =
|
||||
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<T>();
|
||||
|
||||
paddle::Tensor fc2_output_tensor =
|
||||
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
|
||||
T *fc2_result = fc2_output_tensor.data<T>();
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn2_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(ffn2_scale->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC *>(fc2_result),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
stream);
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result,
|
||||
output_,
|
||||
fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out,
|
||||
output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
inter_size,
|
||||
k,
|
||||
static_cast<int>(0),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
}
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result,
|
||||
output_,
|
||||
fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out,
|
||||
output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
inter_size,
|
||||
k,
|
||||
static_cast<int>(0),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
std::string gemm_method_;
|
||||
CubKeyValueSorter sorter_;
|
||||
};
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#pragma GCC diagnostic ignored "-Wunused-function"
|
||||
@@ -24,7 +23,6 @@
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
@@ -128,7 +126,6 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
false,
|
||||
stream);
|
||||
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(),
|
||||
permute_input->data<data_t>(),
|
||||
@@ -140,16 +137,13 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
moe_topk,
|
||||
stream);
|
||||
|
||||
|
||||
compute_total_rows_before_expert(
|
||||
permuted_experts_,
|
||||
moe_topk * num_rows,
|
||||
expert_num,
|
||||
tokens_expert_prefix_sum->data<int32_t>(),
|
||||
stream);
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
moe_topk * num_rows,
|
||||
expert_num,
|
||||
tokens_expert_prefix_sum->data<int32_t>(),
|
||||
stream);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
@@ -184,7 +178,6 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
auto permute_indices_per_token =
|
||||
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
|
||||
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
|
||||
@@ -226,7 +219,6 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
top_k_indices};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& gating_output_shape,
|
||||
@@ -260,7 +252,6 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||
paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(moe_expert_dispatch)
|
||||
.Inputs({"input", "gating_output"})
|
||||
.Outputs({"permute_input",
|
||||
|
||||
@@ -14,19 +14,22 @@
|
||||
|
||||
// BUILD_MARK
|
||||
#pragma once
|
||||
#include "mc_fused_moe_helper.h"
|
||||
#include "helper.h"
|
||||
#include "mc_fused_moe_helper.h"
|
||||
|
||||
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
|
||||
template <paddle::DataType T,
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC>
|
||||
void McMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out) {
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -37,61 +40,65 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
auto input_type = permute_input.dtype();
|
||||
auto stream = permute_input.stream();
|
||||
|
||||
const int expanded_active_expert_rows = permute_input.dims()[0]; // permute_input.dims(): m, k
|
||||
const int num_experts = ffn1_weight.dims()[0]; // batchsize
|
||||
const int hidden_size = ffn1_weight.dims()[2]; // n
|
||||
int inter_dim = ffn1_weight.dims()[1]; // k
|
||||
const int expanded_active_expert_rows =
|
||||
permute_input.dims()[0]; // permute_input.dims(): m, k
|
||||
const int num_experts = ffn1_weight.dims()[0]; // batchsize
|
||||
const int hidden_size = ffn1_weight.dims()[2]; // n
|
||||
int inter_dim = ffn1_weight.dims()[1]; // k
|
||||
|
||||
const int64_t inter_size = inter_dim; // since weight_only_int_8
|
||||
const int64_t inter_size = inter_dim; // since weight_only_int_8
|
||||
paddle::Tensor fc1_out_tensor = GetEmptyTensor(
|
||||
{expanded_active_expert_rows, inter_size}, input_type, place);
|
||||
auto fc1_out_ptr = fc1_out_tensor.data<data_t>();
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
mctlassExOrder_t column_major =
|
||||
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
|
||||
// ffn1
|
||||
auto fc1_expert_biases =
|
||||
ffn1_bias
|
||||
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
|
||||
: nullptr;
|
||||
auto fc1_expert_scales = const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())->data<data_t>();
|
||||
ffn1_bias
|
||||
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
|
||||
: nullptr;
|
||||
auto fc1_expert_scales =
|
||||
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())->data<data_t>();
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(permuted_input_ptr),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn1_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_scales),
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC *>(fc1_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_dim,
|
||||
hidden_size,
|
||||
stream);
|
||||
reinterpret_cast<const ElementA*>(permuted_input_ptr),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB*>(ffn1_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA*>(fc1_expert_scales),
|
||||
reinterpret_cast<const ElementA*>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC*>(fc1_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_dim,
|
||||
hidden_size,
|
||||
stream);
|
||||
|
||||
// swiglu
|
||||
auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<data_t>();
|
||||
|
||||
auto fc2_expert_scales = const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())->data<data_t>();
|
||||
auto fc2_expert_scales =
|
||||
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())->data<data_t>();
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn2_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(fc2_expert_scales),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC *>(ffn_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_dim / 2,
|
||||
stream);
|
||||
reinterpret_cast<const ElementA*>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB*>(ffn2_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA*>(fc2_expert_scales),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC*>(ffn_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_dim / 2,
|
||||
stream);
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
@@ -109,15 +116,18 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
McMoeFFNKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
quant_method,
|
||||
ffn_out);
|
||||
McMoeFFNKernel<paddle::DataType::BFLOAT16,
|
||||
maca_bfloat16,
|
||||
int8_t,
|
||||
maca_bfloat16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
quant_method,
|
||||
ffn_out);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
|
||||
@@ -12,12 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "fused_moe_helper.h"
|
||||
#include "fused_moe_op.h"
|
||||
#include "helper.h"
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
||||
@@ -52,7 +51,6 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
||||
stream);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
const paddle::Tensor& ffn_out,
|
||||
const paddle::Tensor& top_k_weight,
|
||||
@@ -106,7 +104,6 @@ std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
return {output};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
|
||||
const std::vector<int64_t>& ffn_out_shape,
|
||||
const std::vector<int64_t>& top_k_weight_shape,
|
||||
@@ -129,7 +126,6 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
return {ffn_out_dtype};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(moe_expert_reduce)
|
||||
.Inputs({"ffn_out",
|
||||
"top_k_weight",
|
||||
|
||||
@@ -12,64 +12,69 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "utility/helper.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor>
|
||||
AdjustBatchKernel(const paddle::Tensor &x, // [token_num, dim_embed]
|
||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor &encoder_seq_lod,
|
||||
const paddle::Tensor &encoder_batch_idx,
|
||||
const paddle::Tensor &decoder_batch_idx,
|
||||
const paddle::Tensor &encoder_seq_lod_cpu,
|
||||
const paddle::Tensor &encoder_batch_idx_cpu,
|
||||
const paddle::Tensor &decoder_batch_idx_cpu,
|
||||
const paddle::Tensor &enc_batch_tensor,
|
||||
const paddle::Tensor &dec_batch_tensor,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
PD_CHECK(x.dtype() == T);
|
||||
PD_CHECK(x.dims().size() == 2);
|
||||
std::vector<paddle::Tensor> AdjustBatchKernel(
|
||||
const paddle::Tensor &x, // [token_num, dim_embed]
|
||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor &encoder_seq_lod,
|
||||
const paddle::Tensor &encoder_batch_idx,
|
||||
const paddle::Tensor &decoder_batch_idx,
|
||||
const paddle::Tensor &encoder_seq_lod_cpu,
|
||||
const paddle::Tensor &encoder_batch_idx_cpu,
|
||||
const paddle::Tensor &decoder_batch_idx_cpu,
|
||||
const paddle::Tensor &enc_batch_tensor,
|
||||
const paddle::Tensor &dec_batch_tensor,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
PD_CHECK(x.dtype() == T);
|
||||
PD_CHECK(x.dims().size() == 2);
|
||||
|
||||
using XPUType = typename XPUTypeTrait<typename PDTraits<T>::DataType>::Type;
|
||||
using data_t = typename PDTraits<T>::data_t;
|
||||
const int token_num = x.dims()[0];
|
||||
const int dim = x.dims()[1];
|
||||
const int bsz = cum_offsets.shape()[0];
|
||||
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
|
||||
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
|
||||
using XPUType = typename XPUTypeTrait<typename PDTraits<T>::DataType>::Type;
|
||||
using data_t = typename PDTraits<T>::data_t;
|
||||
const int token_num = x.dims()[0];
|
||||
const int dim = x.dims()[1];
|
||||
const int bsz = cum_offsets.shape()[0];
|
||||
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
|
||||
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
|
||||
|
||||
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
|
||||
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
|
||||
enc_batch + 1, const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
|
||||
const_cast<int32_t *>(encoder_batch_idx_cpu.data<int32_t>()), enc_batch,
|
||||
const_cast<int32_t *>(encoder_batch_idx.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp{
|
||||
const_cast<int32_t *>(decoder_batch_idx_cpu.data<int32_t>()), dec_batch,
|
||||
const_cast<int32_t *>(decoder_batch_idx.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
|
||||
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
|
||||
enc_batch + 1,
|
||||
const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
|
||||
const_cast<int32_t *>(encoder_batch_idx_cpu.data<int32_t>()),
|
||||
enc_batch,
|
||||
const_cast<int32_t *>(encoder_batch_idx.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp{
|
||||
const_cast<int32_t *>(decoder_batch_idx_cpu.data<int32_t>()),
|
||||
dec_batch,
|
||||
const_cast<int32_t *>(decoder_batch_idx.data<int32_t>())};
|
||||
|
||||
auto out = paddle::full({token_num, dim}, -2, x.type(), x.place());
|
||||
auto out = paddle::full({token_num, dim}, -2, x.type(), x.place());
|
||||
|
||||
int r = baidu::xpu::api::plugin::eb_adjust_batch<XPUType, XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPUType *>(x.data<data_t>()),
|
||||
reinterpret_cast<XPUType *>(out.data<data_t>()), encoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp, decoder_batch_map_vp, dim);
|
||||
return {out};
|
||||
int r = baidu::xpu::api::plugin::eb_adjust_batch<XPUType, XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPUType *>(x.data<data_t>()),
|
||||
reinterpret_cast<XPUType *>(out.data<data_t>()),
|
||||
encoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp,
|
||||
decoder_batch_map_vp,
|
||||
dim);
|
||||
return {out};
|
||||
}
|
||||
|
||||
using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
|
||||
const paddle::Tensor &x, // [token_num, dim_embed]
|
||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor &x, // [token_num, dim_embed]
|
||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor &encoder_seq_lod,
|
||||
const paddle::Tensor &encoder_batch_idx,
|
||||
const paddle::Tensor &decoder_batch_idx,
|
||||
@@ -81,42 +86,50 @@ using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length);
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
AdjustBatch(const paddle::Tensor &x, // [token_num, dim_embed]
|
||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor &encoder_seq_lod,
|
||||
const paddle::Tensor &encoder_batch_idx,
|
||||
const paddle::Tensor &decoder_batch_idx,
|
||||
const paddle::Tensor &encoder_seq_lod_cpu,
|
||||
const paddle::Tensor &encoder_batch_idx_cpu,
|
||||
const paddle::Tensor &decoder_batch_idx_cpu,
|
||||
const paddle::Tensor &enc_batch_tensor,
|
||||
const paddle::Tensor &dec_batch_tensor,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
AdjustBatchKernelFuncPtr func = nullptr;
|
||||
std::vector<paddle::Tensor> AdjustBatch(
|
||||
const paddle::Tensor &x, // [token_num, dim_embed]
|
||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor &encoder_seq_lod,
|
||||
const paddle::Tensor &encoder_batch_idx,
|
||||
const paddle::Tensor &decoder_batch_idx,
|
||||
const paddle::Tensor &encoder_seq_lod_cpu,
|
||||
const paddle::Tensor &encoder_batch_idx_cpu,
|
||||
const paddle::Tensor &decoder_batch_idx_cpu,
|
||||
const paddle::Tensor &enc_batch_tensor,
|
||||
const paddle::Tensor &dec_batch_tensor,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
AdjustBatchKernelFuncPtr func = nullptr;
|
||||
|
||||
switch (x.dtype()) {
|
||||
switch (x.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
func = &AdjustBatchKernel<paddle::DataType::BFLOAT16>;
|
||||
break;
|
||||
func = &AdjustBatchKernel<paddle::DataType::BFLOAT16>;
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
func = &AdjustBatchKernel<paddle::DataType::FLOAT16>;
|
||||
break;
|
||||
func = &AdjustBatchKernel<paddle::DataType::FLOAT16>;
|
||||
break;
|
||||
case paddle::DataType::FLOAT32:
|
||||
func = &AdjustBatchKernel<paddle::DataType::FLOAT32>;
|
||||
break;
|
||||
func = &AdjustBatchKernel<paddle::DataType::FLOAT32>;
|
||||
break;
|
||||
case paddle::DataType::INT64:
|
||||
func = &AdjustBatchKernel<paddle::DataType::INT64>;
|
||||
break;
|
||||
func = &AdjustBatchKernel<paddle::DataType::INT64>;
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type: ", x.dtype());
|
||||
}
|
||||
PD_THROW("Unsupported data type: ", x.dtype());
|
||||
}
|
||||
|
||||
return func(x, cum_offsets, encoder_seq_lod, encoder_batch_idx,
|
||||
decoder_batch_idx, encoder_seq_lod_cpu, encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu, enc_batch_tensor, dec_batch_tensor,
|
||||
output_padding_offset, max_input_length);
|
||||
return func(x,
|
||||
cum_offsets,
|
||||
encoder_seq_lod,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
enc_batch_tensor,
|
||||
dec_batch_tensor,
|
||||
output_padding_offset,
|
||||
max_input_length);
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> AdjustBatchInferShape(
|
||||
@@ -131,16 +144,17 @@ std::vector<std::vector<int64_t>> AdjustBatchInferShape(
|
||||
const std::vector<int64_t> &enc_batch_tensor_shape,
|
||||
const std::vector<int64_t> &dec_batch_tensor_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
|
||||
if (output_padding_offset_shape) {
|
||||
PD_THROW("speculative decoding is not supported in XPU.");
|
||||
}
|
||||
int64_t token_num = x_shape[0];
|
||||
int64_t dim_embed = x_shape[1];
|
||||
return {{token_num, dim_embed}};
|
||||
if (output_padding_offset_shape) {
|
||||
PD_THROW("speculative decoding is not supported in XPU.");
|
||||
}
|
||||
int64_t token_num = x_shape[0];
|
||||
int64_t dim_embed = x_shape[1];
|
||||
return {{token_num, dim_embed}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> AdjustBatchInferDtype(
|
||||
const paddle::DataType &x_dtype, const paddle::DataType &cum_offsets_dtype,
|
||||
const paddle::DataType &x_dtype,
|
||||
const paddle::DataType &cum_offsets_dtype,
|
||||
const paddle::DataType &encoder_seq_lod_dtype,
|
||||
const paddle::DataType &encoder_batch_idx_dtype,
|
||||
const paddle::DataType &decoder_batch_idx_dtype,
|
||||
@@ -150,14 +164,20 @@ std::vector<paddle::DataType> AdjustBatchInferDtype(
|
||||
const paddle::DataType &enc_batch_tensor_dtype,
|
||||
const paddle::DataType &dec_batch_tensor_dtype,
|
||||
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
|
||||
return {x_dtype};
|
||||
return {x_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(adjust_batch)
|
||||
.Inputs({"x", "cum_offsets", "encoder_seq_lod", "encoder_batch_idx",
|
||||
"decoder_batch_idx", "encoder_seq_lod_cpu",
|
||||
"encoder_batch_idx_cpu", "decoder_batch_idx_cpu",
|
||||
"enc_batch_tensor", "dec_batch_tensor",
|
||||
.Inputs({"x",
|
||||
"cum_offsets",
|
||||
"encoder_seq_lod",
|
||||
"encoder_batch_idx",
|
||||
"decoder_batch_idx",
|
||||
"encoder_seq_lod_cpu",
|
||||
"encoder_batch_idx_cpu",
|
||||
"decoder_batch_idx_cpu",
|
||||
"enc_batch_tensor",
|
||||
"dec_batch_tensor",
|
||||
paddle::Optional("output_padding_offset")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"max_input_length: int"})
|
||||
|
||||
@@ -89,7 +89,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
const paddle::optional<paddle::Tensor>& smooth,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
|
||||
const std::string &pos_emb_type,
|
||||
const std::string& pos_emb_type,
|
||||
bool rope_3d) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
@@ -215,8 +215,8 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
param.prefill_len = is_prefix_cache ? param.max_valid_seqlen : -1;
|
||||
param.page_attn.block_size = block_size;
|
||||
param.page_attn.max_num_blocks_per_seq = prefix_block_num_per_seq;
|
||||
// prefix_block_tables is a subset of block_tables, which is used for prefix
|
||||
// cache
|
||||
// prefix_block_tables is a subset of block_tables, which is used for
|
||||
// prefix cache
|
||||
xftblock::Tensor prefix_block_tables_tensor(
|
||||
is_prefix_cache ? reinterpret_cast<void*>(const_cast<int32_t*>(
|
||||
prefix_block_tables.data<int32_t>()))
|
||||
@@ -306,12 +306,12 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||
value_cache.data<cdata_t>())),
|
||||
vsl.usual_lod_vp, // seq_lod
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
prefix_lens_vp, // start_tokens
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rope_max_seqlen, // max_seqlen
|
||||
vsl.usual_lod_vp, // seq_lod
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
prefix_lens_vp, // start_tokens
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
@@ -480,7 +480,8 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
api::VectorParam<int32_t> decoder_context_len_vp = {
|
||||
const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()),
|
||||
dec_batch,
|
||||
nullptr}; // use for speculative_attention_decoder seq_len in MTP
|
||||
nullptr}; // use for speculative_attention_decoder seq_len in
|
||||
// MTP
|
||||
api::VectorParam<int32_t> decoder_context_len_cache_vp = {
|
||||
const_cast<int32_t*>(decoder_context_len_cache_cpu.data<int32_t>()),
|
||||
dec_batch,
|
||||
@@ -597,49 +598,49 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
tfloat32,
|
||||
int8_wo_t>;
|
||||
constexpr int quant_mode = std::is_same_v<XPU_CType, int8_t> ? 3 : 0;
|
||||
ret = baidu::xpu::xfa::speculative_attention_decoder<XPU_XType,
|
||||
XPU_CType,
|
||||
XPU_XType,
|
||||
TGEMM,
|
||||
TGEMM,
|
||||
float,
|
||||
int32_t,
|
||||
quant_mode>(
|
||||
xpu_ctx->x_context(),
|
||||
decode_output_ptr, // out
|
||||
q_buf_ptr, // q
|
||||
nullptr, // k
|
||||
nullptr, // v
|
||||
reinterpret_cast<const XPU_CType*>(
|
||||
key_cache.data<cdata_t>()), // k_cache
|
||||
reinterpret_cast<const XPU_CType*>(
|
||||
value_cache.data<cdata_t>()), // v_cache
|
||||
reinterpret_cast<const int32_t*>(
|
||||
block_tables.data<int32_t>()), // block_tables
|
||||
decoder_context_len_vp, // seq_lengths
|
||||
decoder_batch_map_vp, // valid_batch
|
||||
param.max_batch_size, // batch_num
|
||||
q_len, // qlen
|
||||
max_seq_len, // max_seq_len
|
||||
param.head_num, // head_num
|
||||
param.head_dim, // head_dim
|
||||
param.kv_head_num, // kv_head_num
|
||||
nullptr, // attn_mask
|
||||
1.0f /
|
||||
std::sqrt(static_cast<float>(param.head_dim)), // scale 【check】
|
||||
block_size, // block_size
|
||||
max_block_per_seq, // max_blocks_per_seq
|
||||
-1, // max_window_size
|
||||
nullptr, // q_maxptr
|
||||
has_zp // k_cache_maxptr
|
||||
? fake_perhead_scale
|
||||
: quant_k_scale_inv,
|
||||
has_zp // v_cache_maxptr
|
||||
? fake_perhead_scale
|
||||
: quant_v_scale_inv,
|
||||
nullptr, // o_maxptr
|
||||
param.head_dim); // vo_head_dim
|
||||
PD_CHECK(0, "speculative_attention unimplemented");
|
||||
ret = baidu::xpu::xfa::speculative_attention_decoder<XPU_XType,
|
||||
XPU_CType,
|
||||
XPU_XType,
|
||||
TGEMM,
|
||||
TGEMM,
|
||||
float,
|
||||
int32_t,
|
||||
quant_mode>(
|
||||
xpu_ctx->x_context(),
|
||||
decode_output_ptr, // out
|
||||
q_buf_ptr, // q
|
||||
nullptr, // k
|
||||
nullptr, // v
|
||||
reinterpret_cast<const XPU_CType*>(
|
||||
key_cache.data<cdata_t>()), // k_cache
|
||||
reinterpret_cast<const XPU_CType*>(
|
||||
value_cache.data<cdata_t>()), // v_cache
|
||||
reinterpret_cast<const int32_t*>(
|
||||
block_tables.data<int32_t>()), // block_tables
|
||||
decoder_context_len_vp, // seq_lengths
|
||||
decoder_batch_map_vp, // valid_batch
|
||||
param.max_batch_size, // batch_num
|
||||
q_len, // qlen
|
||||
max_seq_len, // max_seq_len
|
||||
param.head_num, // head_num
|
||||
param.head_dim, // head_dim
|
||||
param.kv_head_num, // kv_head_num
|
||||
nullptr, // attn_mask
|
||||
1.0f /
|
||||
std::sqrt(static_cast<float>(param.head_dim)), // scale 【check】
|
||||
block_size, // block_size
|
||||
max_block_per_seq, // max_blocks_per_seq
|
||||
-1, // max_window_size
|
||||
nullptr, // q_maxptr
|
||||
has_zp // k_cache_maxptr
|
||||
? fake_perhead_scale
|
||||
: quant_k_scale_inv,
|
||||
has_zp // v_cache_maxptr
|
||||
? fake_perhead_scale
|
||||
: quant_v_scale_inv,
|
||||
nullptr, // o_maxptr
|
||||
param.head_dim); // vo_head_dim
|
||||
PD_CHECK(0, "speculative_attention unimplemented");
|
||||
PD_CHECK(ret == api::SUCCESS,
|
||||
"xfa::speculative_attention_decoder failed.");
|
||||
if (!Eq_len) {
|
||||
@@ -702,11 +703,11 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||
const_cast<XPU_CType*>(
|
||||
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
|
||||
vsl.usual_lod_vp, // seq_lod
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size = rotary_embs.dims()[1] = 1
|
||||
rope_max_seqlen, // max_seqlen
|
||||
vsl.usual_lod_vp, // seq_lod
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size = rotary_embs.dims()[1] = 1
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
@@ -777,7 +778,8 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
ret = xftblock::xft_decoder_core_attenion_block<
|
||||
XPU_XType,
|
||||
XPU_CType,
|
||||
XPU_XType>( // TGEMM = XPU_XType TODOlizan03: used high precision
|
||||
XPU_XType>( // TGEMM = XPU_XType TODOlizan03: used high
|
||||
// precision
|
||||
&xctx,
|
||||
&q_buf,
|
||||
&key_cache_tensor,
|
||||
@@ -867,8 +869,8 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
const paddle::optional<paddle::Tensor>& smooth,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
|
||||
const std::string &pos_emb_type="NORMAL",
|
||||
bool rope_3d=false) {
|
||||
const std::string& pos_emb_type = "NORMAL",
|
||||
bool rope_3d = false) {
|
||||
#define APPLY_KERNEL(TX, TC, TS) \
|
||||
return BlockAttnKernel<TX, TC, TS>(qkv, \
|
||||
key_cache, \
|
||||
|
||||
@@ -12,13 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/xpuml.h"
|
||||
#include <cstdlib>
|
||||
#include <fcntl.h>
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include <random>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
@@ -26,27 +21,31 @@
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cstdlib>
|
||||
#include <random>
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/xpuml.h"
|
||||
|
||||
std::vector<paddle::Tensor> GetMaxMemDemand(int64_t device_id) {
|
||||
if (device_id == -1) {
|
||||
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
|
||||
}
|
||||
phi::XPUPlace place(device_id);
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
if (device_id == -1) {
|
||||
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
|
||||
}
|
||||
phi::XPUPlace place(device_id);
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
|
||||
paddle::Tensor max_mem_demand = paddle::zeros({1}, paddle::DataType::INT64);
|
||||
paddle::Tensor max_mem_demand = paddle::zeros({1}, paddle::DataType::INT64);
|
||||
|
||||
max_mem_demand.data<int64_t>()[0] =
|
||||
xpu_ctx->x_context()->_gm_mgr.get_max_mem_demand();
|
||||
return {max_mem_demand};
|
||||
max_mem_demand.data<int64_t>()[0] =
|
||||
xpu_ctx->x_context()->_gm_mgr.get_max_mem_demand();
|
||||
return {max_mem_demand};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetMaxMemDemandInferShape() { return {{1}}; }
|
||||
|
||||
std::vector<paddle::DataType> GetMaxMemDemandInferDtype() {
|
||||
return {paddle::DataType::INT64};
|
||||
return {paddle::DataType::INT64};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(xpu_get_context_gm_max_mem_demand)
|
||||
|
||||
@@ -12,13 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/xpuml.h"
|
||||
#include <cstdlib>
|
||||
#include <fcntl.h>
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include <random>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
@@ -26,30 +21,35 @@
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cstdlib>
|
||||
#include <random>
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/xpuml.h"
|
||||
|
||||
std::vector<paddle::Tensor> GetFreeGlobalMemory(int64_t device_id) {
|
||||
if (device_id == -1) {
|
||||
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
|
||||
}
|
||||
if (device_id == -1) {
|
||||
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
|
||||
}
|
||||
|
||||
paddle::Tensor free_global_memory =
|
||||
paddle::zeros({1}, paddle::DataType::INT64);
|
||||
paddle::Tensor free_global_memory =
|
||||
paddle::zeros({1}, paddle::DataType::INT64);
|
||||
|
||||
xpumlDevice_t device_handle;
|
||||
xpumlInit();
|
||||
xpumlDeviceGetHandleByIndex(device_id, &device_handle);
|
||||
xpumlMemory_t device_memory;
|
||||
xpumlDeviceGetMemoryInfo(device_handle, &device_memory);
|
||||
free_global_memory.data<int64_t>()[0] = device_memory.freeGlobalMemory;
|
||||
return {free_global_memory};
|
||||
xpumlDevice_t device_handle;
|
||||
xpumlInit();
|
||||
xpumlDeviceGetHandleByIndex(device_id, &device_handle);
|
||||
xpumlMemory_t device_memory;
|
||||
xpumlDeviceGetMemoryInfo(device_handle, &device_memory);
|
||||
free_global_memory.data<int64_t>()[0] = device_memory.freeGlobalMemory;
|
||||
return {free_global_memory};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetFreeGlobalMemoryInferShape() {
|
||||
return {{1}};
|
||||
return {{1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetFreeGlobalMemoryInferDtype() {
|
||||
return {paddle::DataType::INT64};
|
||||
return {paddle::DataType::INT64};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(xpu_get_free_global_memory)
|
||||
|
||||
@@ -12,13 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/xpuml.h"
|
||||
#include <cstdlib>
|
||||
#include <fcntl.h>
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include <random>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
@@ -26,29 +21,34 @@
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cstdlib>
|
||||
#include <random>
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/xpuml.h"
|
||||
|
||||
std::vector<paddle::Tensor> GetTotalGlobalMemory(int64_t device_id) {
|
||||
if (device_id == -1) {
|
||||
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
|
||||
}
|
||||
if (device_id == -1) {
|
||||
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
|
||||
}
|
||||
|
||||
paddle::Tensor total_global_memory =
|
||||
paddle::zeros({1}, paddle::DataType::INT64);
|
||||
xpumlDevice_t device_handle;
|
||||
xpumlInit();
|
||||
xpumlDeviceGetHandleByIndex(device_id, &device_handle);
|
||||
xpumlMemory_t device_memory;
|
||||
xpumlDeviceGetMemoryInfo(device_handle, &device_memory);
|
||||
total_global_memory.data<int64_t>()[0] = device_memory.totalGlobalMemory;
|
||||
return {total_global_memory};
|
||||
paddle::Tensor total_global_memory =
|
||||
paddle::zeros({1}, paddle::DataType::INT64);
|
||||
xpumlDevice_t device_handle;
|
||||
xpumlInit();
|
||||
xpumlDeviceGetHandleByIndex(device_id, &device_handle);
|
||||
xpumlMemory_t device_memory;
|
||||
xpumlDeviceGetMemoryInfo(device_handle, &device_memory);
|
||||
total_global_memory.data<int64_t>()[0] = device_memory.totalGlobalMemory;
|
||||
return {total_global_memory};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetTotalGlobalMemoryInferShape() {
|
||||
return {{1}};
|
||||
return {{1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetTotalGlobalMemoryInferDtype() {
|
||||
return {paddle::DataType::INT64};
|
||||
return {paddle::DataType::INT64};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(xpu_get_total_global_memory)
|
||||
|
||||
@@ -12,13 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/xpuml.h"
|
||||
#include <cstdlib>
|
||||
#include <fcntl.h>
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include <random>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
@@ -26,29 +21,34 @@
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cstdlib>
|
||||
#include <random>
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/xpuml.h"
|
||||
|
||||
std::vector<paddle::Tensor> GetUsedGlobalMemory(int64_t device_id) {
|
||||
if (device_id == -1) {
|
||||
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
|
||||
}
|
||||
if (device_id == -1) {
|
||||
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
|
||||
}
|
||||
|
||||
paddle::Tensor used_global_memory =
|
||||
paddle::zeros({1}, paddle::DataType::INT64);
|
||||
xpumlDevice_t device_handle;
|
||||
xpumlInit();
|
||||
xpumlDeviceGetHandleByIndex(device_id, &device_handle);
|
||||
xpumlMemory_t device_memory;
|
||||
xpumlDeviceGetMemoryInfo(device_handle, &device_memory);
|
||||
used_global_memory.data<int64_t>()[0] = device_memory.usedGlobalMemory;
|
||||
return {used_global_memory};
|
||||
paddle::Tensor used_global_memory =
|
||||
paddle::zeros({1}, paddle::DataType::INT64);
|
||||
xpumlDevice_t device_handle;
|
||||
xpumlInit();
|
||||
xpumlDeviceGetHandleByIndex(device_id, &device_handle);
|
||||
xpumlMemory_t device_memory;
|
||||
xpumlDeviceGetMemoryInfo(device_handle, &device_memory);
|
||||
used_global_memory.data<int64_t>()[0] = device_memory.usedGlobalMemory;
|
||||
return {used_global_memory};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetUsedGlobalMemoryInferShape() {
|
||||
return {{1}};
|
||||
return {{1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetUsedGlobalMemoryInferDtype() {
|
||||
return {paddle::DataType::INT64};
|
||||
return {paddle::DataType::INT64};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(xpu_get_used_global_memory)
|
||||
|
||||
@@ -12,52 +12,57 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
std::vector<paddle::Tensor>
|
||||
GatherNextToken(const paddle::Tensor &tmp_out, // [token_num, dim_embed]
|
||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor &encoder_seq_lod,
|
||||
const paddle::Tensor &encoder_batch_map,
|
||||
const paddle::Tensor &decoder_batch_map,
|
||||
const paddle::Tensor &encoder_seq_lod_cpu,
|
||||
const paddle::Tensor &encoder_batch_map_cpu,
|
||||
const paddle::Tensor &decoder_batch_map_cpu,
|
||||
const paddle::Tensor &enc_batch_tensor,
|
||||
const paddle::Tensor &dec_batch_tensor,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
using XPUType =
|
||||
typename XPUTypeTrait<bfloat16>::Type; // only support bfloat16
|
||||
typedef paddle::bfloat16 data_t;
|
||||
const int dim = tmp_out.dims()[1];
|
||||
const int bsz = cum_offsets.shape()[0];
|
||||
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
|
||||
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
|
||||
std::vector<paddle::Tensor> GatherNextToken(
|
||||
const paddle::Tensor &tmp_out, // [token_num, dim_embed]
|
||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor &encoder_seq_lod,
|
||||
const paddle::Tensor &encoder_batch_map,
|
||||
const paddle::Tensor &decoder_batch_map,
|
||||
const paddle::Tensor &encoder_seq_lod_cpu,
|
||||
const paddle::Tensor &encoder_batch_map_cpu,
|
||||
const paddle::Tensor &decoder_batch_map_cpu,
|
||||
const paddle::Tensor &enc_batch_tensor,
|
||||
const paddle::Tensor &dec_batch_tensor,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
using XPUType =
|
||||
typename XPUTypeTrait<bfloat16>::Type; // only support bfloat16
|
||||
typedef paddle::bfloat16 data_t;
|
||||
const int dim = tmp_out.dims()[1];
|
||||
const int bsz = cum_offsets.shape()[0];
|
||||
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
|
||||
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
|
||||
|
||||
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
|
||||
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
|
||||
enc_batch + 1, const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
|
||||
const_cast<int32_t *>(encoder_batch_map_cpu.data<int32_t>()), enc_batch,
|
||||
const_cast<int32_t *>(encoder_batch_map.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp{
|
||||
const_cast<int32_t *>(decoder_batch_map_cpu.data<int32_t>()), dec_batch,
|
||||
const_cast<int32_t *>(decoder_batch_map.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
|
||||
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
|
||||
enc_batch + 1,
|
||||
const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
|
||||
const_cast<int32_t *>(encoder_batch_map_cpu.data<int32_t>()),
|
||||
enc_batch,
|
||||
const_cast<int32_t *>(encoder_batch_map.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp{
|
||||
const_cast<int32_t *>(decoder_batch_map_cpu.data<int32_t>()),
|
||||
dec_batch,
|
||||
const_cast<int32_t *>(decoder_batch_map.data<int32_t>())};
|
||||
|
||||
auto out = paddle::full({bsz, dim}, -2, tmp_out.type(), tmp_out.place());
|
||||
auto out = paddle::full({bsz, dim}, -2, tmp_out.type(), tmp_out.place());
|
||||
|
||||
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPUType *>(tmp_out.data<data_t>()),
|
||||
reinterpret_cast<XPUType *>(out.data<data_t>()), encoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp, decoder_batch_map_vp, dim);
|
||||
return {out};
|
||||
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPUType *>(tmp_out.data<data_t>()),
|
||||
reinterpret_cast<XPUType *>(out.data<data_t>()),
|
||||
encoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp,
|
||||
decoder_batch_map_vp,
|
||||
dim);
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
|
||||
@@ -72,12 +77,12 @@ std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
|
||||
const std::vector<int64_t> &enc_batch_tensor_shape,
|
||||
const std::vector<int64_t> &dec_batch_tensor_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
|
||||
if (output_padding_offset_shape) {
|
||||
PD_THROW("speculative decoding is not supported in XPU.");
|
||||
}
|
||||
int64_t bsz = cum_offsets_shape[0];
|
||||
int64_t dim_embed = tmp_out_shape[1];
|
||||
return {{bsz, dim_embed}};
|
||||
if (output_padding_offset_shape) {
|
||||
PD_THROW("speculative decoding is not supported in XPU.");
|
||||
}
|
||||
int64_t bsz = cum_offsets_shape[0];
|
||||
int64_t dim_embed = tmp_out_shape[1];
|
||||
return {{bsz, dim_embed}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GatherNextTokenInferDtype(
|
||||
@@ -92,14 +97,20 @@ std::vector<paddle::DataType> GatherNextTokenInferDtype(
|
||||
const paddle::DataType &enc_batch_tensor_dtype,
|
||||
const paddle::DataType &dec_batch_tensor_dtype,
|
||||
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
|
||||
return {tmp_out_dtype};
|
||||
return {tmp_out_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(gather_next_token)
|
||||
.Inputs({"tmp_out", "cum_offsets", "encoder_seq_lod", "encoder_batch_map",
|
||||
"decoder_batch_map", "encoder_seq_lod_cpu",
|
||||
"encoder_batch_map_cpu", "decoder_batch_map_cpu",
|
||||
"enc_batch_tensor", "dec_batch_tensor",
|
||||
.Inputs({"tmp_out",
|
||||
"cum_offsets",
|
||||
"encoder_seq_lod",
|
||||
"encoder_batch_map",
|
||||
"decoder_batch_map",
|
||||
"encoder_seq_lod_cpu",
|
||||
"encoder_batch_map_cpu",
|
||||
"decoder_batch_map_cpu",
|
||||
"enc_batch_tensor",
|
||||
"dec_batch_tensor",
|
||||
paddle::Optional("output_padding_offset")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"max_input_length: int"})
|
||||
|
||||
@@ -14,43 +14,48 @@
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
std::vector<paddle::Tensor> GetImgBoundaries(const paddle::Tensor& task_input_ids,
|
||||
const paddle::Tensor& grid_thw,
|
||||
const int64_t image_patch_id) {
|
||||
// All tensor in cpu
|
||||
auto input_ids_ptr = task_input_ids.data<int64_t>();
|
||||
int64_t seq_lens_origin = task_input_ids.numel();
|
||||
auto grid_thw_ptr = grid_thw.data<int64_t>();
|
||||
std::vector<paddle::Tensor> GetImgBoundaries(
|
||||
const paddle::Tensor& task_input_ids,
|
||||
const paddle::Tensor& grid_thw,
|
||||
const int64_t image_patch_id) {
|
||||
// All tensor in cpu
|
||||
auto input_ids_ptr = task_input_ids.data<int64_t>();
|
||||
int64_t seq_lens_origin = task_input_ids.numel();
|
||||
auto grid_thw_ptr = grid_thw.data<int64_t>();
|
||||
|
||||
int token_times = 4;
|
||||
int token_idx = 0;
|
||||
int image_idx = 0;
|
||||
std::vector<int> img_boundaries, img_nums;
|
||||
img_boundaries.emplace_back(0);
|
||||
img_nums.emplace_back(0);
|
||||
while (token_idx < seq_lens_origin) {
|
||||
if (input_ids_ptr[token_idx] != image_patch_id) {
|
||||
do {
|
||||
token_idx++;
|
||||
} while (token_idx < seq_lens_origin && input_ids_ptr[token_idx] != image_patch_id);
|
||||
} else {
|
||||
int cur_image_token_len = (grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) / token_times;
|
||||
image_idx++;
|
||||
token_idx += cur_image_token_len;
|
||||
}
|
||||
img_boundaries.emplace_back(token_idx);
|
||||
img_nums.emplace_back(image_idx);
|
||||
int token_times = 4;
|
||||
int token_idx = 0;
|
||||
int image_idx = 0;
|
||||
std::vector<int> img_boundaries, img_nums;
|
||||
img_boundaries.emplace_back(0);
|
||||
img_nums.emplace_back(0);
|
||||
while (token_idx < seq_lens_origin) {
|
||||
if (input_ids_ptr[token_idx] != image_patch_id) {
|
||||
do {
|
||||
token_idx++;
|
||||
} while (token_idx < seq_lens_origin &&
|
||||
input_ids_ptr[token_idx] != image_patch_id);
|
||||
} else {
|
||||
int cur_image_token_len =
|
||||
(grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) /
|
||||
token_times;
|
||||
image_idx++;
|
||||
token_idx += cur_image_token_len;
|
||||
}
|
||||
img_boundaries.emplace_back(token_idx);
|
||||
img_nums.emplace_back(image_idx);
|
||||
}
|
||||
|
||||
int64_t num_img_boundaries = static_cast<int64_t>(img_boundaries.size());
|
||||
auto out = paddle::full({2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace());
|
||||
int64_t num_img_boundaries = static_cast<int64_t>(img_boundaries.size());
|
||||
auto out = paddle::full(
|
||||
{2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace());
|
||||
|
||||
for (int i = 0; i < num_img_boundaries; i++) {
|
||||
out.data<int64_t>()[i] = img_boundaries[i];
|
||||
out.data<int64_t>()[num_img_boundaries + i] = img_nums[i];
|
||||
}
|
||||
for (int i = 0; i < num_img_boundaries; i++) {
|
||||
out.data<int64_t>()[i] = img_boundaries[i];
|
||||
out.data<int64_t>()[num_img_boundaries + i] = img_nums[i];
|
||||
}
|
||||
|
||||
return {out};
|
||||
return {out};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_img_boundaries)
|
||||
|
||||
@@ -12,15 +12,15 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "msg_utils.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
void GetOutputKVSignal(const paddle::Tensor &x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
int msg_queue_id = 1024 + rank_id;
|
||||
@@ -28,7 +28,7 @@ void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
static key_t key = ftok("/opt/", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
int* out_data = const_cast<int*>(x.data<int>());
|
||||
int *out_data = const_cast<int *>(x.data<int>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
|
||||
@@ -48,69 +48,72 @@ void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
return;
|
||||
}
|
||||
|
||||
void GetOutput(const paddle::Tensor &x, int64_t rank_id, bool wait_flag,
|
||||
void GetOutput(const paddle::Tensor &x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag,
|
||||
int msg_queue_id) {
|
||||
if (rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
static struct msgdata msg_rcv;
|
||||
if (const char *inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(
|
||||
inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
|
||||
<< inference_msg_queue_id_from_env << std::endl;
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
static key_t key = ftok("/dev/shm", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
std::cout << "get_output msg_queue_id: " << msg_queue_id << std::endl;
|
||||
std::cout << "get_output key: " << key << std::endl;
|
||||
std::cout << "get_output msgid: " << msgid << std::endl;
|
||||
std::cout << "get_output wait_flag: " << wait_flag << std::endl;
|
||||
#endif
|
||||
|
||||
int64_t *out_data = const_cast<int64_t *>(x.data<int64_t>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
|
||||
} else {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
|
||||
}
|
||||
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
std::cout << "get_output finish msgrcv" << std::endl;
|
||||
#endif
|
||||
if (ret == -1) {
|
||||
out_data[0] = -2;
|
||||
out_data[1] = 0;
|
||||
return;
|
||||
}
|
||||
int bsz = msg_rcv.mtext[1];
|
||||
|
||||
for (int64_t i = 0; i < bsz + 2; i++) {
|
||||
out_data[i] = (int64_t)msg_rcv.mtext[i];
|
||||
}
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
std::cout << "get_output finished: " << msgid << std::endl;
|
||||
#endif
|
||||
|
||||
if (rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
static struct msgdata msg_rcv;
|
||||
if (const char *inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
|
||||
<< inference_msg_queue_id_from_env << std::endl;
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
static key_t key = ftok("/dev/shm", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
std::cout << "get_output msg_queue_id: " << msg_queue_id << std::endl;
|
||||
std::cout << "get_output key: " << key << std::endl;
|
||||
std::cout << "get_output msgid: " << msgid << std::endl;
|
||||
std::cout << "get_output wait_flag: " << wait_flag << std::endl;
|
||||
#endif
|
||||
|
||||
int64_t *out_data = const_cast<int64_t *>(x.data<int64_t>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
|
||||
} else {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
|
||||
}
|
||||
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
std::cout << "get_output finish msgrcv" << std::endl;
|
||||
#endif
|
||||
if (ret == -1) {
|
||||
out_data[0] = -2;
|
||||
out_data[1] = 0;
|
||||
return;
|
||||
}
|
||||
int bsz = msg_rcv.mtext[1];
|
||||
|
||||
for (int64_t i = 0; i < bsz + 2; i++) {
|
||||
out_data[i] = (int64_t)msg_rcv.mtext[i];
|
||||
}
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
std::cout << "get_output finished: " << msgid << std::endl;
|
||||
#endif
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void GetOutputStatic(const paddle::Tensor &x, int64_t rank_id, bool wait_flag) {
|
||||
GetOutput(x, rank_id, wait_flag, 1);
|
||||
GetOutput(x, rank_id, wait_flag, 1);
|
||||
}
|
||||
|
||||
void GetOutputDynamic(const paddle::Tensor &x, int64_t rank_id, bool wait_flag,
|
||||
void GetOutputDynamic(const paddle::Tensor &x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag,
|
||||
int msg_queue_id) {
|
||||
GetOutput(x, rank_id, wait_flag, msg_queue_id);
|
||||
GetOutput(x, rank_id, wait_flag, msg_queue_id);
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_output)
|
||||
|
||||
@@ -20,44 +20,43 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &token_num,
|
||||
const paddle::Tensor &seq_len) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
|
||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int seq_length = input_ids_shape[1];
|
||||
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
|
||||
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
|
||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int seq_length = input_ids_shape[1];
|
||||
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
|
||||
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||
auto x_remove_padding = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
|
||||
auto batch_id_per_token = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_q =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
int r = baidu::xpu::api::plugin::get_padding_offset(
|
||||
xpu_ctx->x_context(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
x_remove_padding.data<int64_t>(),
|
||||
input_ids.data<int64_t>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_len.data<int>(),
|
||||
seq_length,
|
||||
bsz);
|
||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
|
||||
return {x_remove_padding,
|
||||
cum_offsets_out,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k};
|
||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||
auto x_remove_padding = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
|
||||
auto batch_id_per_token = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_q =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
int r = baidu::xpu::api::plugin::get_padding_offset(
|
||||
xpu_ctx->x_context(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
x_remove_padding.data<int64_t>(),
|
||||
input_ids.data<int64_t>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_len.data<int>(),
|
||||
seq_length,
|
||||
bsz);
|
||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
|
||||
return {x_remove_padding,
|
||||
cum_offsets_out,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
|
||||
@@ -65,9 +64,9 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
|
||||
const std::vector<int64_t> &cum_offsets_shape,
|
||||
const std::vector<int64_t> &token_num_shape,
|
||||
const std::vector<int64_t> &seq_len_shape) {
|
||||
int64_t bsz = seq_len_shape[0];
|
||||
int64_t seq_len = input_ids_shape[1];
|
||||
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
int64_t bsz = seq_len_shape[0];
|
||||
int64_t seq_len = input_ids_shape[1];
|
||||
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
@@ -75,11 +74,11 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
const paddle::DataType &cum_offsets_dtype,
|
||||
const paddle::DataType &token_num_dtype,
|
||||
const paddle::DataType &seq_len_dtype) {
|
||||
return {input_ids_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype};
|
||||
return {input_ids_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_padding_offset)
|
||||
|
||||
@@ -12,70 +12,97 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
|
||||
void TokenPenaltyMultiScores(
|
||||
const paddle::Tensor &pre_ids, const paddle::Tensor &logits,
|
||||
const paddle::Tensor &penalty_scores,
|
||||
const paddle::Tensor &frequency_scores,
|
||||
const paddle::Tensor &presence_scores, const paddle::Tensor &temperatures,
|
||||
const paddle::Tensor &bad_tokens, const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len, const paddle::Tensor &eos_token_id) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
int64_t bs = logits.shape()[0];
|
||||
PADDLE_ENFORCE_LE(
|
||||
bs, 640,
|
||||
phi::errors::InvalidArgument(
|
||||
"Only support bsz <= 1024, but received bsz is %d", bs));
|
||||
int64_t length = logits.shape()[1];
|
||||
int64_t length_id = pre_ids.shape()[1];
|
||||
int64_t length_bad_words = bad_tokens.shape()[0];
|
||||
int64_t end_length = eos_token_id.shape()[0];
|
||||
switch (logits.type()) {
|
||||
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &logits,
|
||||
const paddle::Tensor &penalty_scores,
|
||||
const paddle::Tensor &frequency_scores,
|
||||
const paddle::Tensor &presence_scores,
|
||||
const paddle::Tensor &temperatures,
|
||||
const paddle::Tensor &bad_tokens,
|
||||
const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
int64_t bs = logits.shape()[0];
|
||||
PADDLE_ENFORCE_LE(
|
||||
bs,
|
||||
640,
|
||||
phi::errors::InvalidArgument(
|
||||
"Only support bsz <= 1024, but received bsz is %d", bs));
|
||||
int64_t length = logits.shape()[1];
|
||||
int64_t length_id = pre_ids.shape()[1];
|
||||
int64_t length_bad_words = bad_tokens.shape()[0];
|
||||
int64_t end_length = eos_token_id.shape()[0];
|
||||
switch (logits.type()) {
|
||||
case paddle::DataType::FLOAT16: {
|
||||
using XPUType = typename XPUTypeTrait<float16>::Type;
|
||||
typedef paddle::float16 data_t;
|
||||
int r = baidu::xpu::api::plugin::token_penalty_multi_scores(
|
||||
xpu_ctx->x_context(), pre_ids.data<int64_t>(),
|
||||
reinterpret_cast<XPUType *>(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
reinterpret_cast<const XPUType *>(penalty_scores.data<data_t>()),
|
||||
reinterpret_cast<const XPUType *>(frequency_scores.data<data_t>()),
|
||||
reinterpret_cast<const XPUType *>(presence_scores.data<data_t>()),
|
||||
temperatures.data<float>(), cur_len.data<int64_t>(),
|
||||
min_len.data<int64_t>(), eos_token_id.data<int64_t>(),
|
||||
bad_tokens.data<int64_t>(), bs, length, length_id, end_length,
|
||||
length_bad_words);
|
||||
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
|
||||
using XPUType = typename XPUTypeTrait<float16>::Type;
|
||||
typedef paddle::float16 data_t;
|
||||
int r = baidu::xpu::api::plugin::token_penalty_multi_scores(
|
||||
xpu_ctx->x_context(),
|
||||
pre_ids.data<int64_t>(),
|
||||
reinterpret_cast<XPUType *>(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
reinterpret_cast<const XPUType *>(penalty_scores.data<data_t>()),
|
||||
reinterpret_cast<const XPUType *>(frequency_scores.data<data_t>()),
|
||||
reinterpret_cast<const XPUType *>(presence_scores.data<data_t>()),
|
||||
temperatures.data<float>(),
|
||||
cur_len.data<int64_t>(),
|
||||
min_len.data<int64_t>(),
|
||||
eos_token_id.data<int64_t>(),
|
||||
bad_tokens.data<int64_t>(),
|
||||
bs,
|
||||
length,
|
||||
length_id,
|
||||
end_length,
|
||||
length_bad_words);
|
||||
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
|
||||
} break;
|
||||
case paddle::DataType::FLOAT32: {
|
||||
int r = baidu::xpu::api::plugin::token_penalty_multi_scores(
|
||||
xpu_ctx->x_context(), pre_ids.data<int64_t>(),
|
||||
const_cast<float *>(logits.data<float>()),
|
||||
penalty_scores.data<float>(), frequency_scores.data<float>(),
|
||||
presence_scores.data<float>(), temperatures.data<float>(),
|
||||
cur_len.data<int64_t>(), min_len.data<int64_t>(),
|
||||
eos_token_id.data<int64_t>(), bad_tokens.data<int64_t>(), bs,
|
||||
length, length_id, end_length, length_bad_words);
|
||||
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
|
||||
int r = baidu::xpu::api::plugin::token_penalty_multi_scores(
|
||||
xpu_ctx->x_context(),
|
||||
pre_ids.data<int64_t>(),
|
||||
const_cast<float *>(logits.data<float>()),
|
||||
penalty_scores.data<float>(),
|
||||
frequency_scores.data<float>(),
|
||||
presence_scores.data<float>(),
|
||||
temperatures.data<float>(),
|
||||
cur_len.data<int64_t>(),
|
||||
min_len.data<int64_t>(),
|
||||
eos_token_id.data<int64_t>(),
|
||||
bad_tokens.data<int64_t>(),
|
||||
bs,
|
||||
length,
|
||||
length_id,
|
||||
end_length,
|
||||
length_bad_words);
|
||||
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
|
||||
} break;
|
||||
default:
|
||||
PD_THROW("NOT supported data type. "
|
||||
"Only float16 and float32 are supported. ");
|
||||
break;
|
||||
}
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16 and float32 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_token_penalty_multi_scores)
|
||||
.Inputs({"pre_ids", "logits", "penalty_scores", "frequency_scores",
|
||||
"presence_scores", "temperatures", "bad_tokens", "cur_len",
|
||||
"min_len", "eos_token_id"})
|
||||
.Inputs({"pre_ids",
|
||||
"logits",
|
||||
"penalty_scores",
|
||||
"frequency_scores",
|
||||
"presence_scores",
|
||||
"temperatures",
|
||||
"bad_tokens",
|
||||
"cur_len",
|
||||
"min_len",
|
||||
"eos_token_id"})
|
||||
.Outputs({"logits_out"})
|
||||
.SetInplaceMap({{"logits", "logits_out"}})
|
||||
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));
|
||||
|
||||
@@ -72,7 +72,7 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
|
||||
is_padding_input ? token_num_info : nullptr,
|
||||
expert_num,
|
||||
1, // moe_topk
|
||||
0, // group_size
|
||||
0, // group_size
|
||||
ffn1_out_shape.size() == 2 ? xftblock::MoeFCInputMode::DENSE
|
||||
: xftblock::MoeFCInputMode::SPARSE);
|
||||
PD_CHECK(ret == 0);
|
||||
|
||||
@@ -29,210 +29,246 @@
|
||||
namespace xftblock = baidu::xpu::xftblock;
|
||||
namespace api = baidu::xpu::api;
|
||||
|
||||
template <typename TX, typename TW> struct fused_moe_ffn_trait {
|
||||
using GEMM_TYPE = TW;
|
||||
template <typename TX, typename TW>
|
||||
struct fused_moe_ffn_trait {
|
||||
using GEMM_TYPE = TW;
|
||||
};
|
||||
template <> struct fused_moe_ffn_trait<bfloat16, bfloat16> {
|
||||
using GEMM_TYPE = float;
|
||||
template <>
|
||||
struct fused_moe_ffn_trait<bfloat16, bfloat16> {
|
||||
using GEMM_TYPE = float;
|
||||
};
|
||||
template <> struct fused_moe_ffn_trait<bfloat16, int8_t> {
|
||||
using GEMM_TYPE = float;
|
||||
template <>
|
||||
struct fused_moe_ffn_trait<bfloat16, int8_t> {
|
||||
using GEMM_TYPE = float;
|
||||
};
|
||||
template <> struct fused_moe_ffn_trait<bfloat16, int4_t> {
|
||||
using GEMM_TYPE = int4_wo_int15;
|
||||
template <>
|
||||
struct fused_moe_ffn_trait<bfloat16, int4_t> {
|
||||
using GEMM_TYPE = int4_wo_int15;
|
||||
};
|
||||
|
||||
template <typename TX, typename TW>
|
||||
std::vector<paddle::Tensor> MoeLayerKernel(
|
||||
const paddle::Tensor &x, const paddle::Tensor &gate_weight,
|
||||
const paddle::Tensor &x,
|
||||
const paddle::Tensor &gate_weight,
|
||||
const paddle::optional<paddle::Tensor> &gate_correction_bias,
|
||||
const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight,
|
||||
const paddle::Tensor &up_gate_proj_weight,
|
||||
const paddle::Tensor &down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const paddle::optional<paddle::Tensor> &up_gate_proj_weight_scale,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_weight_scale,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_in_scale, // not support
|
||||
const std::string &quant_method, const int moe_top_k,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_in_scale, // not support
|
||||
const std::string &quant_method,
|
||||
const int moe_top_k,
|
||||
const bool moe_group) {
|
||||
// std::cout << "[Op Debug] enter moe layer" << std::endl;
|
||||
using XPU_TX = typename XPUTypeTrait<TX>::Type;
|
||||
using XPU_TW = typename XPUTypeTrait<TW>::Type;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
|
||||
auto rt_guard = xctx.get_rt_guard();
|
||||
// std::cout << "[Op Debug] enter moe layer" << std::endl;
|
||||
using XPU_TX = typename XPUTypeTrait<TX>::Type;
|
||||
using XPU_TW = typename XPUTypeTrait<TW>::Type;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
|
||||
auto rt_guard = xctx.get_rt_guard();
|
||||
|
||||
const auto xtype = x.dtype();
|
||||
auto x_dims = x.shape();
|
||||
auto up_gate_proj_dims = up_gate_proj_weight.shape();
|
||||
PD_CHECK(x_dims.size() == 2, "x_dims.size() should be 2.");
|
||||
PD_CHECK(up_gate_proj_dims.size() == 3,
|
||||
"up_gate_proj_dims.size() should be 3.");
|
||||
PD_CHECK(down_proj_in_scale.get_ptr() == nullptr,
|
||||
"down_proj_in_scale not support.");
|
||||
if (quant_method == "weight_only_int4") {
|
||||
PD_CHECK(x_dims[1] == up_gate_proj_dims[2] * 2,
|
||||
"x_dims[1] should equal to up_gate_proj_dims[2], (weight must be "
|
||||
"[e,n,k]).");
|
||||
} else {
|
||||
PD_CHECK(x_dims[1] == up_gate_proj_dims[2],
|
||||
"x_dims[1] should equal to up_gate_proj_dims[2], (weight must be "
|
||||
"[e,n,k]).");
|
||||
}
|
||||
|
||||
int token_num = x_dims[0];
|
||||
int hidden_dim = x_dims[1];
|
||||
int expert_num = up_gate_proj_dims[0];
|
||||
int inter_dim = up_gate_proj_dims[1];
|
||||
int outer_dim = inter_dim / 2;
|
||||
|
||||
paddle::Tensor fused_moe_out = paddle::empty_like(x);
|
||||
|
||||
auto x_mpart_shape = x_dims;
|
||||
int MPART_SIZE = 2048;
|
||||
if (const char *env_val = std::getenv("XPU_MPART_SIZE")) {
|
||||
MPART_SIZE = std::atoi(env_val);
|
||||
}
|
||||
int bsz = x_dims[0];
|
||||
for (int m_part_start = 0; m_part_start < bsz; m_part_start += MPART_SIZE) {
|
||||
auto m_part_end = std::min(m_part_start + MPART_SIZE, bsz);
|
||||
auto x_offset = m_part_start * hidden_dim;
|
||||
x_mpart_shape[0] = m_part_end - m_part_start;
|
||||
int ret = -1;
|
||||
auto xftblock_tx = xftblock::DataTypeToEnum<XPU_TX>::value;
|
||||
auto xftblock_tw = xftblock::DataTypeToEnum<XPU_TW>::value;
|
||||
// input + output
|
||||
xftblock::Tensor xin(
|
||||
const_cast<TX *>(x.data<TX>() + x_offset), xftblock_tx, x_mpart_shape);
|
||||
|
||||
xftblock::Tensor xout(fused_moe_out.mutable_data<TX>() + x_offset,
|
||||
xftblock_tx,
|
||||
x_mpart_shape);
|
||||
// gate
|
||||
xftblock::Tensor xgate_w(const_cast<float *>(gate_weight.data<float>()),
|
||||
xftblock::DataType::DT_FLOAT,
|
||||
gate_weight.shape());
|
||||
std::shared_ptr<xftblock::Tensor> xgate_correct_bias;
|
||||
if (gate_correction_bias.get_ptr()) {
|
||||
xgate_correct_bias = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<float *>(gate_correction_bias.get_ptr()->data<float>()),
|
||||
xftblock::DataType::DT_FLOAT,
|
||||
gate_correction_bias.get_ptr()->shape());
|
||||
}
|
||||
|
||||
// up_gate_proj + down_proj
|
||||
std::shared_ptr<xftblock::Tensor> xup_gate_proj_w, xdown_proj_w;
|
||||
|
||||
if (std::is_same<TW, int4_t>::value) {
|
||||
xup_gate_proj_w = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<int8_t *>(up_gate_proj_weight.data<int8_t>()),
|
||||
nullptr,
|
||||
const_cast<float *>(
|
||||
up_gate_proj_weight_scale.get_ptr()
|
||||
? up_gate_proj_weight_scale.get_ptr()->data<float>()
|
||||
: nullptr),
|
||||
xftblock_tw,
|
||||
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
|
||||
|
||||
xdown_proj_w = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<int8_t *>(down_proj_weight.data<int8_t>()),
|
||||
nullptr,
|
||||
const_cast<float *>(
|
||||
down_proj_weight_scale.get_ptr()
|
||||
? down_proj_weight_scale.get_ptr()->data<float>()
|
||||
: nullptr),
|
||||
xftblock_tw,
|
||||
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
|
||||
|
||||
const auto xtype = x.dtype();
|
||||
auto x_dims = x.shape();
|
||||
auto up_gate_proj_dims = up_gate_proj_weight.shape();
|
||||
PD_CHECK(x_dims.size() == 2, "x_dims.size() should be 2.");
|
||||
PD_CHECK(up_gate_proj_dims.size() == 3, "up_gate_proj_dims.size() should be 3.");
|
||||
PD_CHECK(down_proj_in_scale.get_ptr() == nullptr, "down_proj_in_scale not support.");
|
||||
if (quant_method == "weight_only_int4") {
|
||||
PD_CHECK(x_dims[1] == up_gate_proj_dims[2] * 2,
|
||||
"x_dims[1] should equal to up_gate_proj_dims[2], (weight must be "
|
||||
"[e,n,k]).");
|
||||
} else {
|
||||
PD_CHECK(x_dims[1] == up_gate_proj_dims[2],
|
||||
"x_dims[1] should equal to up_gate_proj_dims[2], (weight must be "
|
||||
"[e,n,k]).");
|
||||
xup_gate_proj_w = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<TW *>(up_gate_proj_weight.data<TW>()),
|
||||
nullptr,
|
||||
const_cast<float *>(
|
||||
up_gate_proj_weight_scale.get_ptr()
|
||||
? up_gate_proj_weight_scale.get_ptr()->data<float>()
|
||||
: nullptr),
|
||||
xftblock_tw,
|
||||
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
|
||||
|
||||
xdown_proj_w = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<TW *>(down_proj_weight.data<TW>()),
|
||||
nullptr,
|
||||
const_cast<float *>(
|
||||
down_proj_weight_scale.get_ptr()
|
||||
? down_proj_weight_scale.get_ptr()->data<float>()
|
||||
: nullptr),
|
||||
xftblock_tw,
|
||||
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
|
||||
}
|
||||
|
||||
int token_num = x_dims[0];
|
||||
int hidden_dim = x_dims[1];
|
||||
int expert_num = up_gate_proj_dims[0];
|
||||
int inter_dim = up_gate_proj_dims[1];
|
||||
int outer_dim = inter_dim / 2;
|
||||
|
||||
paddle::Tensor fused_moe_out = paddle::empty_like(x);
|
||||
|
||||
auto x_mpart_shape = x_dims;
|
||||
int MPART_SIZE = 2048;
|
||||
if (const char* env_val = std::getenv("XPU_MPART_SIZE")) {
|
||||
MPART_SIZE = std::atoi(env_val);
|
||||
std::shared_ptr<xftblock::Tensor> xup_gate_proj_bias;
|
||||
std::shared_ptr<xftblock::Tensor> xdown_proj_bias;
|
||||
if (up_gate_proj_bias.get_ptr()) {
|
||||
xup_gate_proj_bias = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<float *>(up_gate_proj_bias.get_ptr()->data<float>()),
|
||||
xftblock::DataType::DT_FLOAT,
|
||||
up_gate_proj_bias.get_ptr()->shape());
|
||||
}
|
||||
int bsz = x_dims[0];
|
||||
for (int m_part_start = 0; m_part_start < bsz; m_part_start += MPART_SIZE) {
|
||||
auto m_part_end = std::min(m_part_start + MPART_SIZE, bsz);
|
||||
auto x_offset = m_part_start * hidden_dim;
|
||||
x_mpart_shape[0] = m_part_end - m_part_start;
|
||||
int ret = -1;
|
||||
auto xftblock_tx = xftblock::DataTypeToEnum<XPU_TX>::value;
|
||||
auto xftblock_tw = xftblock::DataTypeToEnum<XPU_TW>::value;
|
||||
// input + output
|
||||
xftblock::Tensor xin(const_cast<TX *>(x.data<TX>() + x_offset), xftblock_tx,
|
||||
x_mpart_shape);
|
||||
|
||||
xftblock::Tensor xout(fused_moe_out.mutable_data<TX>() + x_offset, xftblock_tx,
|
||||
x_mpart_shape);
|
||||
// gate
|
||||
xftblock::Tensor xgate_w(const_cast<float *>(gate_weight.data<float>()),
|
||||
xftblock::DataType::DT_FLOAT, gate_weight.shape());
|
||||
std::shared_ptr<xftblock::Tensor> xgate_correct_bias;
|
||||
if (gate_correction_bias.get_ptr()) {
|
||||
xgate_correct_bias = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<float *>(gate_correction_bias.get_ptr()->data<float>()),
|
||||
xftblock::DataType::DT_FLOAT,
|
||||
gate_correction_bias.get_ptr()->shape());
|
||||
}
|
||||
|
||||
// up_gate_proj + down_proj
|
||||
std::shared_ptr<xftblock::Tensor> xup_gate_proj_w, xdown_proj_w;
|
||||
|
||||
if (std::is_same<TW, int4_t>::value) {
|
||||
xup_gate_proj_w = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<int8_t *>(up_gate_proj_weight.data<int8_t>()), nullptr,
|
||||
const_cast<float *>(up_gate_proj_weight_scale.get_ptr()
|
||||
? up_gate_proj_weight_scale.get_ptr()->data<float>()
|
||||
: nullptr),
|
||||
xftblock_tw,
|
||||
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
|
||||
|
||||
xdown_proj_w = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<int8_t *>(down_proj_weight.data<int8_t>()), nullptr,
|
||||
const_cast<float *>(down_proj_weight_scale.get_ptr()
|
||||
? down_proj_weight_scale.get_ptr()->data<float>()
|
||||
: nullptr),
|
||||
xftblock_tw,
|
||||
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
|
||||
|
||||
} else {
|
||||
xup_gate_proj_w = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<TW *>(up_gate_proj_weight.data<TW>()), nullptr,
|
||||
const_cast<float *>(up_gate_proj_weight_scale.get_ptr()
|
||||
? up_gate_proj_weight_scale.get_ptr()->data<float>()
|
||||
: nullptr),
|
||||
xftblock_tw,
|
||||
std::vector<int64_t>{expert_num, inter_dim, hidden_dim}
|
||||
);
|
||||
|
||||
xdown_proj_w = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<TW *>(down_proj_weight.data<TW>()), nullptr,
|
||||
const_cast<float *>(down_proj_weight_scale.get_ptr()
|
||||
? down_proj_weight_scale.get_ptr()->data<float>()
|
||||
: nullptr),
|
||||
xftblock_tw,
|
||||
std::vector<int64_t>{expert_num, hidden_dim, outer_dim}
|
||||
);
|
||||
}
|
||||
std::shared_ptr<xftblock::Tensor> xup_gate_proj_bias;
|
||||
std::shared_ptr<xftblock::Tensor> xdown_proj_bias;
|
||||
if (up_gate_proj_bias.get_ptr()) {
|
||||
xup_gate_proj_bias = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<float *>(up_gate_proj_bias.get_ptr()->data<float>()),
|
||||
xftblock::DataType::DT_FLOAT, up_gate_proj_bias.get_ptr()->shape());
|
||||
}
|
||||
if (down_proj_bias.get_ptr()) {
|
||||
xdown_proj_bias = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<float *>(down_proj_bias.get_ptr()->data<float>()),
|
||||
xftblock::DataType::DT_FLOAT, down_proj_bias.get_ptr()->shape());
|
||||
}
|
||||
// std::cout << "[Op Debug] start init moe_ffn weight and bias" <<
|
||||
// std::endl; MoeFFNWeight
|
||||
xftblock::MoeFFNWeight moe_ffn_w_struct;
|
||||
moe_ffn_w_struct.gate_weight = &xgate_w;
|
||||
moe_ffn_w_struct.ffn_inter_weights = xup_gate_proj_w.get();
|
||||
moe_ffn_w_struct.ffn_inter_bias = xup_gate_proj_bias.get();
|
||||
moe_ffn_w_struct.ffn_outer_weights = xdown_proj_w.get();
|
||||
moe_ffn_w_struct.ffn_outer_bias = xdown_proj_bias.get();
|
||||
moe_ffn_w_struct.score_bias = xgate_correct_bias.get();
|
||||
// MoeFFNParam
|
||||
xftblock::MoeFFNParam moe_ffn_param;
|
||||
moe_ffn_param.expert_num = expert_num;
|
||||
moe_ffn_param.moe_top_k = moe_top_k;
|
||||
moe_ffn_param.fast_swiglu = true;
|
||||
|
||||
// std::cout << "[Op Debug] pre in xvfblock moe_ffn" << std::endl;
|
||||
|
||||
using XPU_TGEMM = typename fused_moe_ffn_trait<XPU_TX, XPU_TW>::GEMM_TYPE;
|
||||
ret = baidu::xpu::xftblock::moe_ffn_block_sorted_castte_per_token<
|
||||
XPU_TX, XPU_TW, XPU_TX, XPU_TGEMM>(&xctx, &xin, &xout, moe_ffn_w_struct,
|
||||
moe_ffn_param);
|
||||
PD_CHECK(ret == 0,
|
||||
"xftblock::moe_ffn_block_sorted_castte_per_token failed");
|
||||
if (down_proj_bias.get_ptr()) {
|
||||
xdown_proj_bias = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<float *>(down_proj_bias.get_ptr()->data<float>()),
|
||||
xftblock::DataType::DT_FLOAT,
|
||||
down_proj_bias.get_ptr()->shape());
|
||||
}
|
||||
// std::cout << "[Op Debug] start init moe_ffn weight and bias" <<
|
||||
// std::endl; MoeFFNWeight
|
||||
xftblock::MoeFFNWeight moe_ffn_w_struct;
|
||||
moe_ffn_w_struct.gate_weight = &xgate_w;
|
||||
moe_ffn_w_struct.ffn_inter_weights = xup_gate_proj_w.get();
|
||||
moe_ffn_w_struct.ffn_inter_bias = xup_gate_proj_bias.get();
|
||||
moe_ffn_w_struct.ffn_outer_weights = xdown_proj_w.get();
|
||||
moe_ffn_w_struct.ffn_outer_bias = xdown_proj_bias.get();
|
||||
moe_ffn_w_struct.score_bias = xgate_correct_bias.get();
|
||||
// MoeFFNParam
|
||||
xftblock::MoeFFNParam moe_ffn_param;
|
||||
moe_ffn_param.expert_num = expert_num;
|
||||
moe_ffn_param.moe_top_k = moe_top_k;
|
||||
moe_ffn_param.fast_swiglu = true;
|
||||
|
||||
return {fused_moe_out};
|
||||
// std::cout << "[Op Debug] pre in xvfblock moe_ffn" << std::endl;
|
||||
|
||||
using XPU_TGEMM = typename fused_moe_ffn_trait<XPU_TX, XPU_TW>::GEMM_TYPE;
|
||||
ret =
|
||||
baidu::xpu::xftblock::moe_ffn_block_sorted_castte_per_token<XPU_TX,
|
||||
XPU_TW,
|
||||
XPU_TX,
|
||||
XPU_TGEMM>(
|
||||
&xctx, &xin, &xout, moe_ffn_w_struct, moe_ffn_param);
|
||||
PD_CHECK(ret == 0,
|
||||
"xftblock::moe_ffn_block_sorted_castte_per_token failed");
|
||||
}
|
||||
|
||||
return {fused_moe_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
MoeLayer(const paddle::Tensor &x, const paddle::Tensor &gate_weight,
|
||||
const paddle::optional<paddle::Tensor> &gate_correction_bias,
|
||||
const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const paddle::optional<paddle::Tensor> &up_gate_proj_weight_scale,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_weight_scale,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_in_scale,
|
||||
const std::string &quant_method, const int moe_top_k,
|
||||
const bool moe_group) {
|
||||
const auto x_type = x.dtype();
|
||||
const auto w_type = up_gate_proj_weight.dtype();
|
||||
std::vector<paddle::Tensor> MoeLayer(
|
||||
const paddle::Tensor &x,
|
||||
const paddle::Tensor &gate_weight,
|
||||
const paddle::optional<paddle::Tensor> &gate_correction_bias,
|
||||
const paddle::Tensor &up_gate_proj_weight,
|
||||
const paddle::Tensor &down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const paddle::optional<paddle::Tensor> &up_gate_proj_weight_scale,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_weight_scale,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_in_scale,
|
||||
const std::string &quant_method,
|
||||
const int moe_top_k,
|
||||
const bool moe_group) {
|
||||
const auto x_type = x.dtype();
|
||||
const auto w_type = up_gate_proj_weight.dtype();
|
||||
|
||||
#define APPLY_MOE_LAYER_KERNEL(TX, TW) \
|
||||
return MoeLayerKernel<TX, TW>( \
|
||||
x, gate_weight, gate_correction_bias, up_gate_proj_weight, down_proj_weight, \
|
||||
up_gate_proj_bias, down_proj_bias, up_gate_proj_weight_scale, down_proj_weight_scale, \
|
||||
down_proj_in_scale, quant_method, moe_top_k, moe_group);
|
||||
#define APPLY_MOE_LAYER_KERNEL(TX, TW) \
|
||||
return MoeLayerKernel<TX, TW>(x, \
|
||||
gate_weight, \
|
||||
gate_correction_bias, \
|
||||
up_gate_proj_weight, \
|
||||
down_proj_weight, \
|
||||
up_gate_proj_bias, \
|
||||
down_proj_bias, \
|
||||
up_gate_proj_weight_scale, \
|
||||
down_proj_weight_scale, \
|
||||
down_proj_in_scale, \
|
||||
quant_method, \
|
||||
moe_top_k, \
|
||||
moe_group);
|
||||
|
||||
// TODO(mayang02): how to use quant_method?
|
||||
if (x_type == paddle::DataType::BFLOAT16 &&
|
||||
w_type == paddle::DataType::BFLOAT16) {
|
||||
APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, paddle::bfloat16);
|
||||
} else if (x_type == paddle::DataType::BFLOAT16 &&
|
||||
quant_method == "weight_only_int8") {
|
||||
APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int8_t);
|
||||
} else if (x_type == paddle::DataType::BFLOAT16 &&
|
||||
quant_method == "weight_only_int4") {
|
||||
APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int4_t);
|
||||
} else {
|
||||
PD_THROW("MoeLayer not support x_type=", static_cast<int>(x_type),
|
||||
", w_type=", static_cast<int>(w_type),
|
||||
", quant_method=", quant_method);
|
||||
return {};
|
||||
}
|
||||
// TODO(mayang02): how to use quant_method?
|
||||
if (x_type == paddle::DataType::BFLOAT16 &&
|
||||
w_type == paddle::DataType::BFLOAT16) {
|
||||
APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, paddle::bfloat16);
|
||||
} else if (x_type == paddle::DataType::BFLOAT16 &&
|
||||
quant_method == "weight_only_int8") {
|
||||
APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int8_t);
|
||||
} else if (x_type == paddle::DataType::BFLOAT16 &&
|
||||
quant_method == "weight_only_int4") {
|
||||
APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int4_t);
|
||||
} else {
|
||||
PD_THROW("MoeLayer not support x_type=",
|
||||
static_cast<int>(x_type),
|
||||
", w_type=",
|
||||
static_cast<int>(w_type),
|
||||
", quant_method=",
|
||||
quant_method);
|
||||
return {};
|
||||
}
|
||||
#undef APPLY_MOE_LAYER_KERNEL
|
||||
}
|
||||
|
||||
@@ -244,14 +280,16 @@ std::vector<std::vector<int64_t>> MoeLayerInferShape(
|
||||
const std::vector<int64_t> &down_proj_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &up_gate_proj_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &down_proj_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &up_gate_proj_weight_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>
|
||||
&up_gate_proj_weight_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &down_proj_weight_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &down_proj_in_scale_shape) {
|
||||
return {x_shape};
|
||||
return {x_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeLayerInferDtype(
|
||||
const paddle::DataType &x_dtype, const paddle::DataType &gate_weight_dtype,
|
||||
const paddle::DataType &x_dtype,
|
||||
const paddle::DataType &gate_weight_dtype,
|
||||
const paddle::optional<paddle::DataType> &gate_correction_bias_dtype,
|
||||
const paddle::DataType &up_gate_proj_weight_dtype,
|
||||
const paddle::DataType &down_proj_weight_dtype,
|
||||
@@ -260,12 +298,16 @@ std::vector<paddle::DataType> MoeLayerInferDtype(
|
||||
const paddle::optional<paddle::DataType> &up_gate_proj_weight_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_weight_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype) {
|
||||
return {x_dtype};
|
||||
return {x_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(xpu_moe_layer) // fused_moe
|
||||
.Inputs({"x", "gate_weight", paddle::Optional("gate_correction_bias"),
|
||||
"up_gate_proj_weight", "down_proj_weight", paddle::Optional("up_gate_proj_bias"),
|
||||
PD_BUILD_OP(xpu_moe_layer) // fused_moe
|
||||
.Inputs({"x",
|
||||
"gate_weight",
|
||||
paddle::Optional("gate_correction_bias"),
|
||||
"up_gate_proj_weight",
|
||||
"down_proj_weight",
|
||||
paddle::Optional("up_gate_proj_bias"),
|
||||
paddle::Optional("down_proj_bias"),
|
||||
paddle::Optional("up_gate_proj_weight_scale"),
|
||||
paddle::Optional("down_proj_weight_scale"),
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
#include <sys/mman.h> // NOLINT
|
||||
#include "cuda_runtime_api.h" // NOLINT
|
||||
#include "ops/pybind/pybind.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/runtime.h"
|
||||
#include "ops/pybind/pybind.h"
|
||||
|
||||
void check_xpu_error(int error) {
|
||||
if (error != XPU_SUCCESS) {
|
||||
|
||||
@@ -33,13 +33,13 @@ void prof_start();
|
||||
|
||||
void prof_stop();
|
||||
|
||||
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
|
||||
const paddle::Tensor &seq_lens_this_time_tensor,
|
||||
const paddle::Tensor &seq_lens_decoder_tensor,
|
||||
void InitKVSignalPerQuery(const paddle::Tensor& seq_lens_encoder_tensor,
|
||||
const paddle::Tensor& seq_lens_this_time_tensor,
|
||||
const paddle::Tensor& seq_lens_decoder_tensor,
|
||||
const int rank,
|
||||
const int num_layers);
|
||||
|
||||
void GetOutputKVSignal(const paddle::Tensor &x,
|
||||
void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag);
|
||||
|
||||
@@ -70,8 +70,8 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
const paddle::optional<paddle::Tensor>& smooth,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
|
||||
const std::string &pos_emb_type="NORMAL",
|
||||
bool rope_3d=false);
|
||||
const std::string& pos_emb_type = "NORMAL",
|
||||
bool rope_3d = false);
|
||||
|
||||
std::vector<paddle::Tensor> MoERedundantTopKSelect(
|
||||
const paddle::Tensor& gating_logits,
|
||||
@@ -477,7 +477,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("bias"),
|
||||
py::arg("weight_dtype"),
|
||||
py::arg("arch"),
|
||||
py::arg("group_size")=-1);
|
||||
py::arg("group_size") = -1);
|
||||
|
||||
m.def("ep_moe_expert_combine",
|
||||
&MoeEPCombine,
|
||||
|
||||
@@ -18,32 +18,31 @@
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
int r = baidu::xpu::api::plugin::recover_decode_task(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::recover_decode_task failed.");
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
int r = baidu::xpu::api::plugin::recover_decode_task(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::recover_decode_task failed.");
|
||||
}
|
||||
|
||||
PD_BUILD_OP(recover_decode_task)
|
||||
|
||||
@@ -74,8 +74,8 @@ RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
|
||||
using type_meta_data =
|
||||
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;
|
||||
|
||||
// std::printf("#### open_shm_and_get_complete_signal_meta_data layer idx:%d,
|
||||
// to ptx:%p \n",
|
||||
// std::printf("#### open_shm_and_get_complete_signal_meta_data layer
|
||||
// idx:%d, to ptx:%p \n",
|
||||
// -1, signal_ptr);
|
||||
|
||||
type_meta_data meta_data(-1, signal_ptr, signal_shm_fd);
|
||||
@@ -102,8 +102,8 @@ void RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise(
|
||||
int32_t layer_id = meta_data_ptr[0];
|
||||
int32_t* ptr = reinterpret_cast<int32_t*>(meta_data_ptr[1]);
|
||||
*ptr = layer_id;
|
||||
// std::printf("#### save_cache_kv_complete_signal_layerwise layer idx:%d, to
|
||||
// ptx:%p \n",
|
||||
// std::printf("#### save_cache_kv_complete_signal_layerwise layer idx:%d,
|
||||
// to ptx:%p \n",
|
||||
// *ptr, meta_data_ptr[1]);
|
||||
}
|
||||
|
||||
|
||||
@@ -12,114 +12,118 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#define MAX_BSZ 256
|
||||
|
||||
// #define SAVE_WITH_OUTPUT_DEBUG
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
|
||||
};
|
||||
|
||||
// #define SAVE_WITH_OUTPUT_DEBUG
|
||||
void SaveOutMmsg(const paddle::Tensor &x, const paddle::Tensor ¬_need_stop,
|
||||
int64_t rank_id, int msg_queue_id, bool save_each_rank) {
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
|
||||
int64_t *x_data = x_cpu.data<int64_t>();
|
||||
static struct msgdata msg_sed;
|
||||
|
||||
if (const char *inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(
|
||||
inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
|
||||
<< inference_msg_queue_id_from_env << std::endl;
|
||||
#endif
|
||||
} else {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default."
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
int inference_msg_id_from_env = 1;
|
||||
if (const char *inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
if (inference_msg_id_from_env < 0) {
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be negative, please use other "
|
||||
"number.");
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
|
||||
<< std::endl;
|
||||
#endif
|
||||
} else {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout
|
||||
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
static key_t key = ftok("/dev/shm", msg_queue_id);
|
||||
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "save_output key: " << key << std::endl;
|
||||
std::cout << "save_output msgid: " << msgid << std::endl;
|
||||
#endif
|
||||
msg_sed.mtype = 1;
|
||||
bool not_need_stop_data = not_need_stop.data<bool>()[0];
|
||||
// printf("not_need_stop_data %d\n", (int)not_need_stop_data);
|
||||
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
|
||||
: -inference_msg_id_from_env;
|
||||
int bsz = x.shape()[0];
|
||||
msg_sed.mtext[1] = bsz;
|
||||
for (int i = 2; i < bsz + 2; i++) {
|
||||
msg_sed.mtext[i] = (int)x_data[i - 2];
|
||||
}
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "save_output msg data: ";
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
std::cout << " " << (int)x_data[i];
|
||||
}
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) {
|
||||
printf("save_output full msg buffer\n");
|
||||
}
|
||||
void SaveOutMmsg(const paddle::Tensor &x,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
|
||||
int64_t *x_data = x_cpu.data<int64_t>();
|
||||
static struct msgdata msg_sed;
|
||||
|
||||
if (const char *inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
|
||||
<< inference_msg_queue_id_from_env << std::endl;
|
||||
#endif
|
||||
} else {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default."
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
int inference_msg_id_from_env = 1;
|
||||
if (const char *inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
if (inference_msg_id_from_env < 0) {
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be negative, please use other "
|
||||
"number.");
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
|
||||
<< std::endl;
|
||||
#endif
|
||||
} else {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
static key_t key = ftok("/dev/shm", msg_queue_id);
|
||||
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "save_output key: " << key << std::endl;
|
||||
std::cout << "save_output msgid: " << msgid << std::endl;
|
||||
#endif
|
||||
msg_sed.mtype = 1;
|
||||
bool not_need_stop_data = not_need_stop.data<bool>()[0];
|
||||
// printf("not_need_stop_data %d\n", (int)not_need_stop_data);
|
||||
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
|
||||
: -inference_msg_id_from_env;
|
||||
int bsz = x.shape()[0];
|
||||
msg_sed.mtext[1] = bsz;
|
||||
for (int i = 2; i < bsz + 2; i++) {
|
||||
msg_sed.mtext[i] = (int)x_data[i - 2];
|
||||
}
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "save_output msg data: ";
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
std::cout << " " << (int)x_data[i];
|
||||
}
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) {
|
||||
printf("save_output full msg buffer\n");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void SaveOutMmsgStatic(const paddle::Tensor &x,
|
||||
const paddle::Tensor ¬_need_stop, int64_t rank_id,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank) {
|
||||
SaveOutMmsg(x, not_need_stop, rank_id, 1, save_each_rank);
|
||||
SaveOutMmsg(x, not_need_stop, rank_id, 1, save_each_rank);
|
||||
}
|
||||
|
||||
void SaveOutMmsgDynamic(const paddle::Tensor &x,
|
||||
const paddle::Tensor ¬_need_stop, int64_t rank_id,
|
||||
int msg_queue_id, bool save_each_rank) {
|
||||
SaveOutMmsg(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
SaveOutMmsg(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
|
||||
}
|
||||
|
||||
PD_BUILD_OP(save_output)
|
||||
|
||||
@@ -12,9 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
|
||||
void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
const paddle::Tensor &input_ids,
|
||||
@@ -23,26 +23,35 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_flags) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
|
||||
int bs = seq_lens_this_time.shape()[0];
|
||||
int length = pre_ids_all.shape()[1];
|
||||
int length_input_ids = input_ids.shape()[1];
|
||||
int r = baidu::xpu::api::plugin::set_value_by_flags_and_idx(
|
||||
xpu_ctx->x_context(), stop_flags.data<bool>(),
|
||||
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
||||
input_ids.data<int64_t>(), seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(), step_idx.data<int64_t>(), bs, length,
|
||||
length_input_ids);
|
||||
PD_CHECK(r == 0, "xpu::plugin::set_value_by_flags_and_idx failed.");
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
|
||||
int bs = seq_lens_this_time.shape()[0];
|
||||
int length = pre_ids_all.shape()[1];
|
||||
int length_input_ids = input_ids.shape()[1];
|
||||
int r = baidu::xpu::api::plugin::set_value_by_flags_and_idx(
|
||||
xpu_ctx->x_context(),
|
||||
stop_flags.data<bool>(),
|
||||
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
||||
input_ids.data<int64_t>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
step_idx.data<int64_t>(),
|
||||
bs,
|
||||
length,
|
||||
length_input_ids);
|
||||
PD_CHECK(r == 0, "xpu::plugin::set_value_by_flags_and_idx failed.");
|
||||
}
|
||||
|
||||
PD_BUILD_OP(set_value_by_flags_and_idx)
|
||||
.Inputs({"pre_ids_all", "input_ids", "seq_lens_this_time",
|
||||
"seq_lens_encoder", "seq_lens_decoder", "step_idx", "stop_flags"})
|
||||
.Inputs({"pre_ids_all",
|
||||
"input_ids",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"stop_flags"})
|
||||
.Outputs({"pre_ids_all_out"})
|
||||
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SetValueByFlagsAndIdx));
|
||||
|
||||
@@ -30,8 +30,9 @@ std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor &input,
|
||||
void *data_ptr_addr = nullptr;
|
||||
if (use_ipc) {
|
||||
#if XPURT_VERSION_MAJOR == 5
|
||||
int ret = xpu_ipc_open_memhandle(
|
||||
&data_ptr_addr, *(XPUIpcMemHandle *)&shm->memHandle, 0x01); // NOLINT
|
||||
int ret = xpu_ipc_open_memhandle(&data_ptr_addr,
|
||||
*(XPUIpcMemHandle *)&shm->memHandle,
|
||||
0x01); // NOLINT
|
||||
PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_open_memhandle failed");
|
||||
#elif XPURT_VERSION_MAJOR == 4
|
||||
PD_THROW("kl2 not support prefix cache");
|
||||
|
||||
@@ -12,82 +12,100 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
|
||||
void StepPaddle(
|
||||
const paddle::Tensor &stop_flags, const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &ori_seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor &encoder_block_lens,
|
||||
const paddle::Tensor &is_block_step, const paddle::Tensor &step_block_list,
|
||||
const paddle::Tensor &step_lens, const paddle::Tensor &recover_block_list,
|
||||
const paddle::Tensor &recover_lens, const paddle::Tensor &need_block_list,
|
||||
const paddle::Tensor &need_block_len, const paddle::Tensor &used_list_len,
|
||||
const paddle::Tensor &free_list, const paddle::Tensor &free_list_len,
|
||||
const paddle::Tensor &input_ids, const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx, const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &first_token_ids, const int block_size,
|
||||
const int encoder_decoder_block_num) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
void StepPaddle(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &ori_seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor &encoder_block_lens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &step_block_list,
|
||||
const paddle::Tensor &step_lens,
|
||||
const paddle::Tensor &recover_block_list,
|
||||
const paddle::Tensor &recover_lens,
|
||||
const paddle::Tensor &need_block_list,
|
||||
const paddle::Tensor &need_block_len,
|
||||
const paddle::Tensor &used_list_len,
|
||||
const paddle::Tensor &free_list,
|
||||
const paddle::Tensor &free_list_len,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &first_token_ids,
|
||||
const int block_size,
|
||||
const int encoder_decoder_block_num) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
PADDLE_ENFORCE_LE(
|
||||
bsz, 640,
|
||||
phi::errors::InvalidArgument(
|
||||
"Only support bsz <= 640, but received bsz is %d", bsz));
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
const int length = input_ids.shape()[1];
|
||||
const int pre_id_length = pre_ids.shape()[1];
|
||||
const int max_decoder_block_num = pre_id_length / block_size;
|
||||
int r = baidu::xpu::api::plugin::free_and_dispatch_block(
|
||||
xpu_ctx->x_context(), const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(encoder_block_lens.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<int *>(step_block_list.data<int>()),
|
||||
const_cast<int *>(step_lens.data<int>()),
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
PADDLE_ENFORCE_LE(
|
||||
bsz,
|
||||
640,
|
||||
phi::errors::InvalidArgument(
|
||||
"Only support bsz <= 640, but received bsz is %d", bsz));
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
const int length = input_ids.shape()[1];
|
||||
const int pre_id_length = pre_ids.shape()[1];
|
||||
const int max_decoder_block_num = pre_id_length / block_size;
|
||||
int r = baidu::xpu::api::plugin::free_and_dispatch_block(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(encoder_block_lens.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<int *>(step_block_list.data<int>()),
|
||||
const_cast<int *>(step_lens.data<int>()),
|
||||
const_cast<int *>(recover_block_list.data<int>()),
|
||||
const_cast<int *>(recover_lens.data<int>()),
|
||||
const_cast<int *>(need_block_list.data<int>()),
|
||||
const_cast<int *>(need_block_len.data<int>()),
|
||||
const_cast<int *>(used_list_len.data<int>()),
|
||||
const_cast<int *>(free_list.data<int>()),
|
||||
const_cast<int *>(free_list_len.data<int>()),
|
||||
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num);
|
||||
PD_CHECK(r == 0, "free_and_dispatch_block failed.");
|
||||
auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false);
|
||||
int recover_lens_cpu_data = recover_lens_cpu.data<int>()[0];
|
||||
if (recover_lens_cpu_data > 0) {
|
||||
r = baidu::xpu::api::plugin::recover_block(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<int *>(recover_block_list.data<int>()),
|
||||
const_cast<int *>(recover_lens.data<int>()),
|
||||
const_cast<int *>(need_block_list.data<int>()),
|
||||
const_cast<int *>(need_block_len.data<int>()),
|
||||
const_cast<int *>(used_list_len.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
ori_seq_lens_encoder.data<int>(),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
seq_lens_decoder.data<int>(),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(free_list.data<int>()),
|
||||
const_cast<int *>(free_list_len.data<int>()),
|
||||
const_cast<int64_t *>(first_token_ids.data<int64_t>()), bsz, block_size,
|
||||
block_num_per_seq, max_decoder_block_num);
|
||||
PD_CHECK(r == 0, "free_and_dispatch_block failed.");
|
||||
auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false);
|
||||
int recover_lens_cpu_data = recover_lens_cpu.data<int>()[0];
|
||||
if (recover_lens_cpu_data > 0) {
|
||||
r = baidu::xpu::api::plugin::recover_block(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<int *>(recover_block_list.data<int>()),
|
||||
const_cast<int *>(recover_lens.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
ori_seq_lens_encoder.data<int>(),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
seq_lens_decoder.data<int>(),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(free_list.data<int>()),
|
||||
const_cast<int *>(free_list_len.data<int>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
pre_ids.data<int64_t>(), step_idx.data<int64_t>(),
|
||||
encoder_block_lens.data<int>(), used_list_len.data<int>(),
|
||||
next_tokens.data<int64_t>(), first_token_ids.data<int64_t>(), bsz,
|
||||
block_num_per_seq, length, pre_id_length);
|
||||
PD_CHECK(r == 0, "recover_block failed.");
|
||||
}
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
pre_ids.data<int64_t>(),
|
||||
step_idx.data<int64_t>(),
|
||||
encoder_block_lens.data<int>(),
|
||||
used_list_len.data<int>(),
|
||||
next_tokens.data<int64_t>(),
|
||||
first_token_ids.data<int64_t>(),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
PD_CHECK(r == 0, "recover_block failed.");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(step_paddle)
|
||||
@@ -114,13 +132,24 @@ PD_BUILD_OP(step_paddle)
|
||||
"next_tokens",
|
||||
"first_token_ids"})
|
||||
.Attrs({"block_size: int", "encoder_decoder_block_num: int"})
|
||||
.Outputs({"stop_flags_out", "seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out", "seq_lens_decoder_out",
|
||||
"block_tables_out", "encoder_block_lens_out", "is_block_step_out",
|
||||
"step_block_list_out", "step_lens_out", "recover_block_list_out",
|
||||
"recover_lens_out", "need_block_list_out", "need_block_len_out",
|
||||
"used_list_len_out", "free_list_out", "free_list_len_out",
|
||||
"input_ids_out", "first_token_ids_out"})
|
||||
.Outputs({"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"block_tables_out",
|
||||
"encoder_block_lens_out",
|
||||
"is_block_step_out",
|
||||
"step_block_list_out",
|
||||
"step_lens_out",
|
||||
"recover_block_list_out",
|
||||
"recover_lens_out",
|
||||
"need_block_list_out",
|
||||
"need_block_len_out",
|
||||
"used_list_len_out",
|
||||
"free_list_out",
|
||||
"free_list_len_out",
|
||||
"input_ids_out",
|
||||
"first_token_ids_out"})
|
||||
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
|
||||
@@ -12,9 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include <cstdlib>
|
||||
#include <fcntl.h>
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include <stdio.h>
|
||||
@@ -24,6 +21,9 @@
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cstdlib>
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &stop_flags,
|
||||
@@ -31,22 +31,25 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &end_ids,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const bool beam_search) {
|
||||
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
|
||||
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
std::vector<int64_t> shape = topk_ids.shape();
|
||||
int64_t bs_now = shape[0];
|
||||
int64_t end_length = end_ids.shape()[0];
|
||||
int r = baidu::xpu::api::plugin::set_stop_value_multi_ends<int64_t>(
|
||||
xpu_ctx->x_context(), const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
end_ids.data<int64_t>(), seq_lens.data<int>(), bs_now, end_length,
|
||||
beam_search);
|
||||
PD_CHECK(r == 0, "xpu::plugin::set_stop_value_multi_ends failed.");
|
||||
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
|
||||
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
std::vector<int64_t> shape = topk_ids.shape();
|
||||
int64_t bs_now = shape[0];
|
||||
int64_t end_length = end_ids.shape()[0];
|
||||
int r = baidu::xpu::api::plugin::set_stop_value_multi_ends<int64_t>(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
end_ids.data<int64_t>(),
|
||||
seq_lens.data<int>(),
|
||||
bs_now,
|
||||
end_length,
|
||||
beam_search);
|
||||
PD_CHECK(r == 0, "xpu::plugin::set_stop_value_multi_ends failed.");
|
||||
}
|
||||
|
||||
PD_BUILD_OP(set_stop_value_multi_ends)
|
||||
|
||||
@@ -17,53 +17,49 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
void TextImageGatherScatter(
|
||||
paddle::Tensor& input,
|
||||
paddle::Tensor& text_input,
|
||||
paddle::Tensor& image_input,
|
||||
paddle::Tensor& token_type_ids,
|
||||
paddle::Tensor& text_index,
|
||||
paddle::Tensor& image_index,
|
||||
const bool is_scatter) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
void TextImageGatherScatter(paddle::Tensor& input,
|
||||
paddle::Tensor& text_input,
|
||||
paddle::Tensor& image_input,
|
||||
paddle::Tensor& token_type_ids,
|
||||
paddle::Tensor& text_index,
|
||||
paddle::Tensor& image_index,
|
||||
const bool is_scatter) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
const int64_t token_num = input.dims()[0];
|
||||
const int64_t hidden_size = input.dims()[1];
|
||||
const int64_t text_token_num = text_input.dims()[0];
|
||||
const int64_t image_token_num = image_input.dims()[0];
|
||||
const int64_t token_num = input.dims()[0];
|
||||
const int64_t hidden_size = input.dims()[1];
|
||||
const int64_t text_token_num = text_input.dims()[0];
|
||||
const int64_t image_token_num = image_input.dims()[0];
|
||||
|
||||
switch (input.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using XPUType = typename XPUTypeTrait<bfloat16>::Type;
|
||||
typedef paddle::bfloat16 data_t;
|
||||
int r = baidu::xpu::api::plugin::text_image_gather_scatter<XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<XPUType*>(input.data<data_t>()),
|
||||
reinterpret_cast<XPUType*>(text_input.data<data_t>()),
|
||||
reinterpret_cast<XPUType*>(image_input.data<data_t>()),
|
||||
reinterpret_cast<int*>(token_type_ids.data<int>()),
|
||||
reinterpret_cast<int*>(text_index.data<int>()),
|
||||
reinterpret_cast<int*>(image_index.data<int>()),
|
||||
token_num,
|
||||
text_token_num,
|
||||
image_token_num,
|
||||
hidden_size,
|
||||
is_scatter
|
||||
);
|
||||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter");
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. Only support BFLOAT16. ");
|
||||
break;
|
||||
}
|
||||
switch (input.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using XPUType = typename XPUTypeTrait<bfloat16>::Type;
|
||||
typedef paddle::bfloat16 data_t;
|
||||
int r = baidu::xpu::api::plugin::text_image_gather_scatter<XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<XPUType*>(input.data<data_t>()),
|
||||
reinterpret_cast<XPUType*>(text_input.data<data_t>()),
|
||||
reinterpret_cast<XPUType*>(image_input.data<data_t>()),
|
||||
reinterpret_cast<int*>(token_type_ids.data<int>()),
|
||||
reinterpret_cast<int*>(text_index.data<int>()),
|
||||
reinterpret_cast<int*>(image_index.data<int>()),
|
||||
token_num,
|
||||
text_token_num,
|
||||
image_token_num,
|
||||
hidden_size,
|
||||
is_scatter);
|
||||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter");
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PD_THROW("NOT supported data type. Only support BFLOAT16. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(text_image_gather_scatter)
|
||||
.Inputs({"input",
|
||||
"text_input",
|
||||
|
||||
@@ -16,33 +16,30 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
void TextImageIndexOut(
|
||||
const paddle::Tensor& token_type_ids,
|
||||
const paddle::Tensor& text_index,
|
||||
const paddle::Tensor& image_index) {
|
||||
if (token_type_ids.type() != paddle::DataType::INT32 || text_index.type()
|
||||
!= paddle::DataType::INT32 || image_index.type() != paddle::DataType::INT32) {
|
||||
PD_THROW("NOT supported data type. Only support BFLOAT16. ");
|
||||
}
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
const int64_t token_num = token_type_ids.shape()[0];
|
||||
int r = baidu::xpu::api::plugin::text_image_index_out(xpu_ctx->x_context(),
|
||||
token_type_ids.data<int32_t>(),
|
||||
const_cast<int32_t*>(text_index.data<int32_t>()),
|
||||
const_cast<int32_t*>(image_index.data<int32_t>()),
|
||||
token_num);
|
||||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out");
|
||||
void TextImageIndexOut(const paddle::Tensor& token_type_ids,
|
||||
const paddle::Tensor& text_index,
|
||||
const paddle::Tensor& image_index) {
|
||||
if (token_type_ids.type() != paddle::DataType::INT32 ||
|
||||
text_index.type() != paddle::DataType::INT32 ||
|
||||
image_index.type() != paddle::DataType::INT32) {
|
||||
PD_THROW("NOT supported data type. Only support BFLOAT16. ");
|
||||
}
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
const int64_t token_num = token_type_ids.shape()[0];
|
||||
int r = baidu::xpu::api::plugin::text_image_index_out(
|
||||
xpu_ctx->x_context(),
|
||||
token_type_ids.data<int32_t>(),
|
||||
const_cast<int32_t*>(text_index.data<int32_t>()),
|
||||
const_cast<int32_t*>(image_index.data<int32_t>()),
|
||||
token_num);
|
||||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out");
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(text_image_index_out)
|
||||
.Inputs({"token_type_ids",
|
||||
"text_index",
|
||||
"image_index"})
|
||||
.Outputs({"text_index_out",
|
||||
"image_index_out"})
|
||||
.Inputs({"token_type_ids", "text_index", "image_index"})
|
||||
.Outputs({"text_index_out", "image_index_out"})
|
||||
.SetInplaceMap({{"text_index", "text_index_out"},
|
||||
{"image_index", "image_index_out"}})
|
||||
.SetKernelFn(PD_KERNEL(TextImageIndexOut));
|
||||
|
||||
@@ -26,40 +26,39 @@ void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
PADDLE_ENFORCE_LE(
|
||||
max_bsz,
|
||||
1024,
|
||||
phi::errors::InvalidArgument(
|
||||
"Only support max_bs <= 1024, but received max_bs is %d", max_bsz));
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
auto not_need_stop_xpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
PADDLE_ENFORCE_LE(
|
||||
max_bsz,
|
||||
1024,
|
||||
phi::errors::InvalidArgument(
|
||||
"Only support max_bs <= 1024, but received max_bs is %d", max_bsz));
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
auto not_need_stop_xpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
|
||||
int r = baidu::xpu::api::plugin::update_inputs(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<bool *>(not_need_stop_xpu.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
stop_flags.data<bool>(),
|
||||
is_block_step.data<bool>(),
|
||||
next_tokens.data<int64_t>(),
|
||||
now_bsz,
|
||||
max_bsz,
|
||||
input_ids_stride);
|
||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs failed.");
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_xpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
int r = baidu::xpu::api::plugin::update_inputs(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<bool *>(not_need_stop_xpu.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
stop_flags.data<bool>(),
|
||||
is_block_step.data<bool>(),
|
||||
next_tokens.data<int64_t>(),
|
||||
now_bsz,
|
||||
max_bsz,
|
||||
input_ids_stride);
|
||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs failed.");
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_xpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_OP(update_inputs)
|
||||
|
||||
@@ -18,55 +18,54 @@
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop, // only on cpu
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &prompt_lens,
|
||||
const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
const paddle::Tensor ¬_need_stop, // only on cpu
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &prompt_lens,
|
||||
const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
// std::cout << "now_bsz: " << now_bsz << std::endl;
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
int r = baidu::xpu::api::plugin::update_inputs_v1(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(prompt_lens.data<int64_t>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
next_tokens.data<int64_t>(),
|
||||
now_bsz,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs_kernel_v1 failed.");
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
// std::cout << "now_bsz: " << now_bsz << std::endl;
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
int r = baidu::xpu::api::plugin::update_inputs_v1(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(prompt_lens.data<int64_t>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
next_tokens.data<int64_t>(),
|
||||
now_bsz,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs_kernel_v1 failed.");
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_OP(update_inputs_v1)
|
||||
|
||||
Executable → Regular
@@ -15,8 +15,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
@@ -24,47 +22,56 @@
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/allocator.h"
|
||||
#include "paddle/phi/core/dense_tensor.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
|
||||
template <paddle::DataType D> class PDTraits;
|
||||
template <paddle::DataType D>
|
||||
class PDTraits;
|
||||
|
||||
template <> class PDTraits<paddle::DataType::FLOAT32> {
|
||||
public:
|
||||
typedef float DataType;
|
||||
typedef float data_t;
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::FLOAT32> {
|
||||
public:
|
||||
typedef float DataType;
|
||||
typedef float data_t;
|
||||
};
|
||||
|
||||
template <> class PDTraits<paddle::DataType::FLOAT16> {
|
||||
public:
|
||||
typedef float16 DataType;
|
||||
typedef paddle::float16 data_t;
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::FLOAT16> {
|
||||
public:
|
||||
typedef float16 DataType;
|
||||
typedef paddle::float16 data_t;
|
||||
};
|
||||
|
||||
template <> class PDTraits<paddle::DataType::BFLOAT16> {
|
||||
public:
|
||||
typedef bfloat16 DataType;
|
||||
typedef paddle::bfloat16 data_t;
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::BFLOAT16> {
|
||||
public:
|
||||
typedef bfloat16 DataType;
|
||||
typedef paddle::bfloat16 data_t;
|
||||
};
|
||||
|
||||
template <> class PDTraits<paddle::DataType::INT8> {
|
||||
public:
|
||||
typedef int8_t DataType;
|
||||
typedef int8_t data_t;
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::INT8> {
|
||||
public:
|
||||
typedef int8_t DataType;
|
||||
typedef int8_t data_t;
|
||||
};
|
||||
|
||||
template <> class PDTraits<paddle::DataType::UINT8> {
|
||||
public:
|
||||
typedef uint8_t DataType;
|
||||
typedef uint8_t data_t;
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::UINT8> {
|
||||
public:
|
||||
typedef uint8_t DataType;
|
||||
typedef uint8_t data_t;
|
||||
};
|
||||
|
||||
template <> class PDTraits<paddle::DataType::INT64> {
|
||||
public:
|
||||
typedef int64_t DataType;
|
||||
typedef int64_t data_t;
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::INT64> {
|
||||
public:
|
||||
typedef int64_t DataType;
|
||||
typedef int64_t data_t;
|
||||
};
|
||||
|
||||
@@ -11,110 +11,124 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "xpu/plugin.h"
|
||||
#include <infer_ops.h>
|
||||
#include <infer_ops_eb.h>
|
||||
#include <paddle/extension.h>
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor>
|
||||
WeightQuantizeKernel(const paddle::Tensor &x, const std::string &algo,
|
||||
const int32_t arch, const int32_t group_size) {
|
||||
using XPUType = typename XPUTypeTrait<T>::Type;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
int64_t k = x.shape()[0];
|
||||
int64_t n = x.shape()[1];
|
||||
std::vector<paddle::Tensor> WeightQuantizeKernel(const paddle::Tensor &x,
|
||||
const std::string &algo,
|
||||
const int32_t arch,
|
||||
const int32_t group_size) {
|
||||
using XPUType = typename XPUTypeTrait<T>::Type;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
int64_t k = x.shape()[0];
|
||||
int64_t n = x.shape()[1];
|
||||
|
||||
paddle::Tensor scale =
|
||||
paddle::full({n}, 0, paddle::DataType::FLOAT32, x.place());
|
||||
if (algo == "weight_only_int8") {
|
||||
paddle::Tensor out =
|
||||
paddle::full({k, n}, 0, paddle::DataType::INT8, x.place());
|
||||
int ret = baidu::xpu::api::plugin::quant2d_per_channel<XPUType, float,
|
||||
int8_t>(
|
||||
paddle::Tensor scale =
|
||||
paddle::full({n}, 0, paddle::DataType::FLOAT32, x.place());
|
||||
if (algo == "weight_only_int8") {
|
||||
paddle::Tensor out =
|
||||
paddle::full({k, n}, 0, paddle::DataType::INT8, x.place());
|
||||
int ret =
|
||||
baidu::xpu::api::plugin::quant2d_per_channel<XPUType, float, int8_t>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPUType *>(x.template data<T>()), nullptr,
|
||||
out.data<int8_t>(), scale.data<float>(), k, n);
|
||||
PD_CHECK(ret == 0);
|
||||
return {out, scale};
|
||||
} else if (algo == "weight_only_int4") {
|
||||
// TODO(mayang02): fix quant2d_per_channel int4 bugs, use transpose +
|
||||
// quant2d_per_token + transpose at now
|
||||
PD_CHECK(k % 2 == 0);
|
||||
paddle::Tensor out = paddle::full({(k + 1) / 2, n}, 0,
|
||||
paddle::DataType::INT8, x.place());
|
||||
xpu::ctx_guard RAII_GUARD(xpu_ctx->x_context());
|
||||
XPUType *x_trans = RAII_GUARD.alloc<XPUType>(k * n);
|
||||
int8_t *out_trans = RAII_GUARD.alloc<int8_t>(k * n / 2);
|
||||
PD_CHECK(x_trans != nullptr);
|
||||
PD_CHECK(out_trans != nullptr);
|
||||
int ret = baidu::xpu::api::transpose<XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPUType *>(x.data<T>()), x_trans, {k, n},
|
||||
{1, 0});
|
||||
PD_CHECK(ret == 0);
|
||||
ret = infer_ops::quant2d_per_token<XPUType, float, int4_t>(
|
||||
xpu_ctx->x_context(), x_trans, nullptr,
|
||||
reinterpret_cast<int4_t *>(out_trans), scale.data<float>(), n, k);
|
||||
PD_CHECK(ret == 0);
|
||||
ret = baidu::xpu::api::transpose<int8_t>(xpu_ctx->x_context(),
|
||||
out_trans, out.data<int8_t>(),
|
||||
{n, k / 2}, {1, 0});
|
||||
PD_CHECK(ret == 0);
|
||||
return {out, scale};
|
||||
} else {
|
||||
PD_THROW("Weight quantize only supports weight_only_int8 on XPU now.");
|
||||
return {};
|
||||
}
|
||||
reinterpret_cast<const XPUType *>(x.template data<T>()),
|
||||
nullptr,
|
||||
out.data<int8_t>(),
|
||||
scale.data<float>(),
|
||||
k,
|
||||
n);
|
||||
PD_CHECK(ret == 0);
|
||||
return {out, scale};
|
||||
} else if (algo == "weight_only_int4") {
|
||||
// TODO(mayang02): fix quant2d_per_channel int4 bugs, use transpose +
|
||||
// quant2d_per_token + transpose at now
|
||||
PD_CHECK(k % 2 == 0);
|
||||
paddle::Tensor out =
|
||||
paddle::full({(k + 1) / 2, n}, 0, paddle::DataType::INT8, x.place());
|
||||
xpu::ctx_guard RAII_GUARD(xpu_ctx->x_context());
|
||||
XPUType *x_trans = RAII_GUARD.alloc<XPUType>(k * n);
|
||||
int8_t *out_trans = RAII_GUARD.alloc<int8_t>(k * n / 2);
|
||||
PD_CHECK(x_trans != nullptr);
|
||||
PD_CHECK(out_trans != nullptr);
|
||||
int ret = baidu::xpu::api::transpose<XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPUType *>(x.data<T>()),
|
||||
x_trans,
|
||||
{k, n},
|
||||
{1, 0});
|
||||
PD_CHECK(ret == 0);
|
||||
ret = infer_ops::quant2d_per_token<XPUType, float, int4_t>(
|
||||
xpu_ctx->x_context(),
|
||||
x_trans,
|
||||
nullptr,
|
||||
reinterpret_cast<int4_t *>(out_trans),
|
||||
scale.data<float>(),
|
||||
n,
|
||||
k);
|
||||
PD_CHECK(ret == 0);
|
||||
ret = baidu::xpu::api::transpose<int8_t>(xpu_ctx->x_context(),
|
||||
out_trans,
|
||||
out.data<int8_t>(),
|
||||
{n, k / 2},
|
||||
{1, 0});
|
||||
PD_CHECK(ret == 0);
|
||||
return {out, scale};
|
||||
} else {
|
||||
PD_THROW("Weight quantize only supports weight_only_int8 on XPU now.");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> WeightQuantize(const paddle::Tensor &x,
|
||||
const std::string &algo,
|
||||
const int32_t arch,
|
||||
const int32_t group_size) {
|
||||
const auto x_type = x.dtype();
|
||||
#define APPLY_WEIGHT_QUANTIZE_KERNEL(TX) \
|
||||
return WeightQuantizeKernel<TX>(x, algo, arch, group_size);
|
||||
const auto x_type = x.dtype();
|
||||
#define APPLY_WEIGHT_QUANTIZE_KERNEL(TX) \
|
||||
return WeightQuantizeKernel<TX>(x, algo, arch, group_size);
|
||||
|
||||
if (x_type == paddle::DataType::BFLOAT16) {
|
||||
APPLY_WEIGHT_QUANTIZE_KERNEL(paddle::bfloat16);
|
||||
} else if (x_type == paddle::DataType::FLOAT32) {
|
||||
APPLY_WEIGHT_QUANTIZE_KERNEL(float);
|
||||
} else {
|
||||
PD_THROW("WeightQuantize not support x_type==%d",
|
||||
static_cast<int>(x_type));
|
||||
return {};
|
||||
}
|
||||
if (x_type == paddle::DataType::BFLOAT16) {
|
||||
APPLY_WEIGHT_QUANTIZE_KERNEL(paddle::bfloat16);
|
||||
} else if (x_type == paddle::DataType::FLOAT32) {
|
||||
APPLY_WEIGHT_QUANTIZE_KERNEL(float);
|
||||
} else {
|
||||
PD_THROW("WeightQuantize not support x_type==%d", static_cast<int>(x_type));
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
WeightQuantizeInferShape(const std::vector<int64_t> &x_shape,
|
||||
const std::string &algo, const int32_t arch,
|
||||
const int32_t group_size) {
|
||||
if (algo == "weight_only_int8") {
|
||||
return {x_shape, {x_shape[1]}};
|
||||
} else if (algo == "weight_only_int4") {
|
||||
return {{x_shape[0] / 2, x_shape[1]}, {x_shape[1]}};
|
||||
} else {
|
||||
PD_THROW("weight_quantize not support algo=%s", algo);
|
||||
}
|
||||
std::vector<std::vector<int64_t>> WeightQuantizeInferShape(
|
||||
const std::vector<int64_t> &x_shape,
|
||||
const std::string &algo,
|
||||
const int32_t arch,
|
||||
const int32_t group_size) {
|
||||
if (algo == "weight_only_int8") {
|
||||
return {x_shape, {x_shape[1]}};
|
||||
} else if (algo == "weight_only_int4") {
|
||||
return {{x_shape[0] / 2, x_shape[1]}, {x_shape[1]}};
|
||||
} else {
|
||||
PD_THROW("weight_quantize not support algo=%s", algo);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType>
|
||||
WeightQuantizeInferDtype(const paddle::DataType &x_dtype,
|
||||
const std::string &algo, const int32_t arch,
|
||||
const int32_t group_size) {
|
||||
if (algo == "weight_only_int8") {
|
||||
return {paddle::DataType::INT8, paddle::DataType::FLOAT32};
|
||||
} else if (algo == "weight_only_int4") {
|
||||
return {paddle::DataType::INT8, paddle::DataType::FLOAT32};
|
||||
} else {
|
||||
PD_THROW("weight_quantize not support algo=%s", algo);
|
||||
}
|
||||
std::vector<paddle::DataType> WeightQuantizeInferDtype(
|
||||
const paddle::DataType &x_dtype,
|
||||
const std::string &algo,
|
||||
const int32_t arch,
|
||||
const int32_t group_size) {
|
||||
if (algo == "weight_only_int8") {
|
||||
return {paddle::DataType::INT8, paddle::DataType::FLOAT32};
|
||||
} else if (algo == "weight_only_int4") {
|
||||
return {paddle::DataType::INT8, paddle::DataType::FLOAT32};
|
||||
} else {
|
||||
PD_THROW("weight_quantize not support algo=%s", algo);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(weight_quantize_xpu)
|
||||
|
||||
@@ -24,60 +24,60 @@
|
||||
#include <sys/un.h>
|
||||
#include <sys/wait.h>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
#include <xpu/runtime.h>
|
||||
#include <xpu/version.h>
|
||||
#include <vector>
|
||||
|
||||
struct shmStruct {
|
||||
size_t nprocesses;
|
||||
size_t nprocesses;
|
||||
#if XPURT_VERSION_MAJOR == 5
|
||||
XPUIpcMemHandle memHandle;
|
||||
XPUIpcMemHandle memHandle;
|
||||
#endif
|
||||
uint64_t data_ptr_addr;
|
||||
uint64_t data_ptr_addr;
|
||||
};
|
||||
|
||||
struct sharedMemoryInfo {
|
||||
void *addr;
|
||||
size_t size;
|
||||
int shmFd;
|
||||
void *addr;
|
||||
size_t size;
|
||||
int shmFd;
|
||||
};
|
||||
|
||||
static int sharedMemoryCreate(const char *name, size_t sz,
|
||||
static int sharedMemoryCreate(const char *name,
|
||||
size_t sz,
|
||||
sharedMemoryInfo *info) {
|
||||
info->size = sz;
|
||||
info->size = sz;
|
||||
|
||||
info->shmFd = shm_open(name, O_RDWR | O_CREAT, 0777);
|
||||
PD_CHECK(info->shmFd >= 0, "shm_open failed");
|
||||
info->shmFd = shm_open(name, O_RDWR | O_CREAT, 0777);
|
||||
PD_CHECK(info->shmFd >= 0, "shm_open failed");
|
||||
|
||||
int status = ftruncate(info->shmFd, sz);
|
||||
PD_CHECK(status == 0, "ftruncate failed");
|
||||
int status = ftruncate(info->shmFd, sz);
|
||||
PD_CHECK(status == 0, "ftruncate failed");
|
||||
|
||||
info->addr =
|
||||
mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0);
|
||||
PD_CHECK(info->addr != NULL, "mmap failed");
|
||||
info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0);
|
||||
PD_CHECK(info->addr != NULL, "mmap failed");
|
||||
|
||||
return 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int sharedMemoryOpen(const char *name, size_t sz,
|
||||
static int sharedMemoryOpen(const char *name,
|
||||
size_t sz,
|
||||
sharedMemoryInfo *info) {
|
||||
info->size = sz;
|
||||
info->size = sz;
|
||||
|
||||
info->shmFd = shm_open(name, O_RDWR, 0777);
|
||||
PD_CHECK(info->shmFd >= 0, "shm_open failed");
|
||||
info->shmFd = shm_open(name, O_RDWR, 0777);
|
||||
PD_CHECK(info->shmFd >= 0, "shm_open failed");
|
||||
|
||||
info->addr =
|
||||
mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0);
|
||||
PD_CHECK(info->addr != nullptr, "mmap failed");
|
||||
info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0);
|
||||
PD_CHECK(info->addr != nullptr, "mmap failed");
|
||||
|
||||
return 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void sharedMemoryClose(sharedMemoryInfo *info) {
|
||||
if (info->addr) {
|
||||
munmap(info->addr, info->size);
|
||||
}
|
||||
if (info->shmFd) {
|
||||
close(info->shmFd);
|
||||
}
|
||||
if (info->addr) {
|
||||
munmap(info->addr, info->size);
|
||||
}
|
||||
if (info->shmFd) {
|
||||
close(info->shmFd);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,121 +24,176 @@ namespace api {
|
||||
namespace plugin {
|
||||
|
||||
template <typename T>
|
||||
DLL_EXPORT int set_stop_value_multi_ends(Context *ctx, bool *stop_flags,
|
||||
T *topk_ids, T *next_tokens,
|
||||
const T *end_ids, const int *seq_lens,
|
||||
const int bs, const int end_length,
|
||||
DLL_EXPORT int set_stop_value_multi_ends(Context* ctx,
|
||||
bool* stop_flags,
|
||||
T* topk_ids,
|
||||
T* next_tokens,
|
||||
const T* end_ids,
|
||||
const int* seq_lens,
|
||||
const int bs,
|
||||
const int end_length,
|
||||
const bool beam_search);
|
||||
|
||||
DLL_EXPORT int set_value_by_flags_and_idx(Context *ctx, const bool *stop_flags,
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *input_ids,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *step_idx, int bs,
|
||||
int length, int length_input_ids);
|
||||
DLL_EXPORT int set_value_by_flags_and_idx(Context* ctx,
|
||||
const bool* stop_flags,
|
||||
int64_t* pre_ids_all,
|
||||
const int64_t* input_ids,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
const int64_t* step_idx,
|
||||
int bs,
|
||||
int length,
|
||||
int length_input_ids);
|
||||
|
||||
template <typename T>
|
||||
DLL_EXPORT int token_penalty_multi_scores(
|
||||
Context *ctx, const int64_t *pre_ids, T *logits, const T *penalty_scores,
|
||||
const T *frequency_scores, const T *presence_scores,
|
||||
const float *temperatures, const int64_t *cur_len, const int64_t *min_len,
|
||||
const int64_t *eos_token_id, const int64_t *bad_words, const int64_t bs,
|
||||
const int64_t length, const int64_t length_id, const int64_t end_length,
|
||||
const int64_t length_bad_words);
|
||||
DLL_EXPORT int token_penalty_multi_scores(Context* ctx,
|
||||
const int64_t* pre_ids,
|
||||
T* logits,
|
||||
const T* penalty_scores,
|
||||
const T* frequency_scores,
|
||||
const T* presence_scores,
|
||||
const float* temperatures,
|
||||
const int64_t* cur_len,
|
||||
const int64_t* min_len,
|
||||
const int64_t* eos_token_id,
|
||||
const int64_t* bad_words,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id,
|
||||
const int64_t end_length,
|
||||
const int64_t length_bad_words);
|
||||
|
||||
DLL_EXPORT int get_padding_offset(Context *ctx, int *padding_offset,
|
||||
int *cum_offsets_out, int *cu_seqlens_q,
|
||||
int *cu_seqlens_k, int64_t *x_remove_padding,
|
||||
const int64_t *input_ids,
|
||||
const int *cum_offsets, const int *seq_lens,
|
||||
const int max_seq_len, const int bs);
|
||||
DLL_EXPORT int get_padding_offset(Context* ctx,
|
||||
int* padding_offset,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
int64_t* x_remove_padding,
|
||||
const int64_t* input_ids,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs);
|
||||
|
||||
DLL_EXPORT int update_inputs(Context *ctx, bool *not_need_stop,
|
||||
int *seq_lens_this_time, int *seq_lens_encoder,
|
||||
int *seq_lens_decoder, int64_t *input_ids,
|
||||
const int64_t *stop_nums, const bool *stop_flags,
|
||||
const bool *is_block_step,
|
||||
const int64_t *next_tokens, const int bsz,
|
||||
const int max_bsz, const int input_ids_stride);
|
||||
DLL_EXPORT int update_inputs(Context* ctx,
|
||||
bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* input_ids,
|
||||
const int64_t* stop_nums,
|
||||
const bool* stop_flags,
|
||||
const bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride);
|
||||
|
||||
DLL_EXPORT int free_and_dispatch_block(
|
||||
Context *ctx, bool *stop_flags, int *seq_lens_this_time,
|
||||
int *seq_lens_decoder, int *block_tables, int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len, int *recover_block_list, int *recover_len,
|
||||
int *need_block_list, int *need_block_len, int *used_list_len,
|
||||
int *free_list, int *free_list_len, int64_t *first_token_ids, const int bsz,
|
||||
const int block_size, const int block_num_per_seq,
|
||||
const int max_decoder_block_num);
|
||||
DLL_EXPORT int free_and_dispatch_block(Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_decoder,
|
||||
int* block_tables,
|
||||
int* encoder_block_lens,
|
||||
bool* is_block_step,
|
||||
int* step_block_list, // [bsz]
|
||||
int* step_len,
|
||||
int* recover_block_list,
|
||||
int* recover_len,
|
||||
int* need_block_list,
|
||||
int* need_block_len,
|
||||
int* used_list_len,
|
||||
int* free_list,
|
||||
int* free_list_len,
|
||||
int64_t* first_token_ids,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num);
|
||||
|
||||
DLL_EXPORT int
|
||||
recover_block(Context *ctx,
|
||||
int *recover_block_list, // [bsz]
|
||||
int *recover_len, bool *stop_flags, int *seq_lens_this_time,
|
||||
const int *ori_seq_lens_encoder, int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder, int *block_tables, int *free_list,
|
||||
int *free_list_len, int64_t *input_ids, const int64_t *pre_ids,
|
||||
const int64_t *step_idx, const int *encoder_block_lens,
|
||||
const int *used_list_len, const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids, const int bsz,
|
||||
const int block_num_per_seq, const int length,
|
||||
const int pre_id_length);
|
||||
DLL_EXPORT int recover_block(Context* ctx,
|
||||
int* recover_block_list, // [bsz]
|
||||
int* recover_len,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
const int* ori_seq_lens_encoder,
|
||||
int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
int* block_tables,
|
||||
int* free_list,
|
||||
int* free_list_len,
|
||||
int64_t* input_ids,
|
||||
const int64_t* pre_ids,
|
||||
const int64_t* step_idx,
|
||||
const int* encoder_block_lens,
|
||||
const int* used_list_len,
|
||||
const int64_t* next_tokens,
|
||||
const int64_t* first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length);
|
||||
|
||||
|
||||
DLL_EXPORT int
|
||||
recover_decode_task(Context *ctx, bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
DLL_EXPORT int recover_decode_task(Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int* step_seq_lens_decoder,
|
||||
int* block_tables,
|
||||
bool* is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size);
|
||||
|
||||
DLL_EXPORT int
|
||||
update_inputs_v1(Context *ctx, bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size);
|
||||
DLL_EXPORT int update_inputs_v1(Context* ctx,
|
||||
bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int* step_seq_lens_decoder,
|
||||
int64_t* prompt_lens,
|
||||
int64_t* topk_ids,
|
||||
int64_t* input_ids,
|
||||
int* block_tables,
|
||||
const int64_t* stop_nums,
|
||||
bool* stop_flags,
|
||||
bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size);
|
||||
|
||||
template <typename TX, typename TY>
|
||||
DLL_EXPORT int
|
||||
eb_adjust_batch(Context *ctx, const TX *x, TY *y,
|
||||
VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
DLL_EXPORT int eb_adjust_batch(
|
||||
Context* ctx,
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TY>
|
||||
DLL_EXPORT int
|
||||
eb_gather_next_token(Context *ctx, const TX *x, TY *y,
|
||||
VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
DLL_EXPORT int eb_gather_next_token(
|
||||
Context* ctx,
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TSCALE = float, typename TY = int8_t>
|
||||
DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x,
|
||||
const TSCALE *scale_in, TY *y,
|
||||
TSCALE *scale_out, int64_t m, int64_t n);
|
||||
DLL_EXPORT int quant2d_per_channel(api::Context* ctx,
|
||||
const TX* x,
|
||||
const TSCALE* scale_in,
|
||||
TY* y,
|
||||
TSCALE* scale_out,
|
||||
int64_t m,
|
||||
int64_t n);
|
||||
|
||||
DLL_EXPORT int text_image_index_out(Context* ctx,
|
||||
const int* token_type_ids, // x
|
||||
@@ -160,7 +215,8 @@ DLL_EXPORT int text_image_gather_scatter(api::Context* ctx,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter);
|
||||
|
||||
/*--------------------------------------- MTP being --------------------------------------------*/
|
||||
/*--------------------------------------- MTP being
|
||||
* --------------------------------------------*/
|
||||
|
||||
template <typename T>
|
||||
DLL_EXPORT int speculate_token_penalty_multi_scores(
|
||||
@@ -200,7 +256,6 @@ DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx,
|
||||
const int block_num_per_seq,
|
||||
const int max_draft_tokens);
|
||||
|
||||
|
||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||
DLL_EXPORT int speculate_verify(Context* ctx,
|
||||
int64_t* accept_tokens,
|
||||
@@ -457,9 +512,10 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx,
|
||||
T* output,
|
||||
int dim_embed,
|
||||
int elem_cnt);
|
||||
/*--------------------------------------- MTP end --------------------------------------------*/
|
||||
/*--------------------------------------- MTP end
|
||||
* --------------------------------------------*/
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
@@ -36,8 +36,8 @@ __global__ void get_padding_offset(int *batch_id_per_token,
|
||||
}
|
||||
mfence_lm();
|
||||
LM2GM(batch_id_per_token_lm,
|
||||
batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j,
|
||||
cur_len * sizeof(int));
|
||||
batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j,
|
||||
cur_len * sizeof(int));
|
||||
}
|
||||
if (cid == 0) {
|
||||
int cum_seq_len = (i + 1) * max_seq_len - cum_offsets_lm[1];
|
||||
|
||||
+61
-61
@@ -15,72 +15,72 @@ __global__ void RebuildAppendPaddingKernel(const T *full_hidden_states,
|
||||
int dim_embed,
|
||||
int elem_nums,
|
||||
T *out) {
|
||||
int ncores = core_num();
|
||||
int cid = core_id();
|
||||
int tid = cid * cluster_num() + cluster_id();
|
||||
int nthreads = cluster_num() * ncores;
|
||||
int64_t mstart = -1;
|
||||
int64_t mend = -1;
|
||||
int64_t nstart = -1;
|
||||
int64_t nend = -1;
|
||||
partition2d<int64_t>(tid,
|
||||
nthreads,
|
||||
elem_nums / dim_embed,
|
||||
dim_embed,
|
||||
&mstart,
|
||||
&mend,
|
||||
&nstart,
|
||||
&nend);
|
||||
int ncores = core_num();
|
||||
int cid = core_id();
|
||||
int tid = cid * cluster_num() + cluster_id();
|
||||
int nthreads = cluster_num() * ncores;
|
||||
int64_t mstart = -1;
|
||||
int64_t mend = -1;
|
||||
int64_t nstart = -1;
|
||||
int64_t nend = -1;
|
||||
partition2d<int64_t>(tid,
|
||||
nthreads,
|
||||
elem_nums / dim_embed,
|
||||
dim_embed,
|
||||
&mstart,
|
||||
&mend,
|
||||
&nstart,
|
||||
&nend);
|
||||
|
||||
const int64_t BUFFER_LEN = rounddown(6144 / sizeof(T), 64);
|
||||
__simd__ T lm_full_hidden_states[BUFFER_LEN];
|
||||
int output_padding_offset_val, cum_offset_val, seq_len_encoder_val,
|
||||
seq_len_decoder_val;
|
||||
const int64_t BUFFER_LEN = rounddown(6144 / sizeof(T), 64);
|
||||
__simd__ T lm_full_hidden_states[BUFFER_LEN];
|
||||
int output_padding_offset_val, cum_offset_val, seq_len_encoder_val,
|
||||
seq_len_decoder_val;
|
||||
|
||||
for (int64_t _m = mstart; _m < mend; _m++) {
|
||||
int out_token_id = _m;
|
||||
GM2LM(output_padding_offset + out_token_id,
|
||||
&output_padding_offset_val,
|
||||
sizeof(int));
|
||||
int ori_token_id = out_token_id + output_padding_offset_val;
|
||||
int bi = ori_token_id / max_seq_len;
|
||||
GM2LM_ASYNC(seq_len_encoder + bi, &seq_len_encoder_val, sizeof(int));
|
||||
GM2LM(seq_len_decoder + bi, &seq_len_decoder_val, sizeof(int));
|
||||
int seq_id = 0;
|
||||
if (seq_len_encoder_val == 0 and seq_len_decoder_val == 0) {
|
||||
continue;
|
||||
} else if (seq_len_encoder_val != 0) {
|
||||
seq_id = seq_len_encoder_val - 1;
|
||||
}
|
||||
GM2LM(cum_offset + bi, &cum_offset_val, sizeof(int));
|
||||
int input_token_id = ori_token_id - cum_offset_val + seq_id;
|
||||
for (int64_t _n = nstart; _n < nend; _n += BUFFER_LEN) {
|
||||
int64_t read_size = min(BUFFER_LEN, nend - _n);
|
||||
// out[i] = full_hidden_states[(i / dim_embed +
|
||||
// output_padding_offset[i / dim_embed] - cum_offset[(i / dim_embed
|
||||
// + output_padding_offset[i / dim_embed]) / max_seq_len] + seq_id)
|
||||
// * dim_embed + i % dim_embed]
|
||||
GM2LM(full_hidden_states + input_token_id * dim_embed + _n,
|
||||
lm_full_hidden_states,
|
||||
read_size * sizeof(T));
|
||||
LM2GM(lm_full_hidden_states,
|
||||
out + _m * dim_embed + _n,
|
||||
read_size * sizeof(T));
|
||||
}
|
||||
for (int64_t _m = mstart; _m < mend; _m++) {
|
||||
int out_token_id = _m;
|
||||
GM2LM(output_padding_offset + out_token_id,
|
||||
&output_padding_offset_val,
|
||||
sizeof(int));
|
||||
int ori_token_id = out_token_id + output_padding_offset_val;
|
||||
int bi = ori_token_id / max_seq_len;
|
||||
GM2LM_ASYNC(seq_len_encoder + bi, &seq_len_encoder_val, sizeof(int));
|
||||
GM2LM(seq_len_decoder + bi, &seq_len_decoder_val, sizeof(int));
|
||||
int seq_id = 0;
|
||||
if (seq_len_encoder_val == 0 and seq_len_decoder_val == 0) {
|
||||
continue;
|
||||
} else if (seq_len_encoder_val != 0) {
|
||||
seq_id = seq_len_encoder_val - 1;
|
||||
}
|
||||
GM2LM(cum_offset + bi, &cum_offset_val, sizeof(int));
|
||||
int input_token_id = ori_token_id - cum_offset_val + seq_id;
|
||||
for (int64_t _n = nstart; _n < nend; _n += BUFFER_LEN) {
|
||||
int64_t read_size = min(BUFFER_LEN, nend - _n);
|
||||
// out[i] = full_hidden_states[(i / dim_embed +
|
||||
// output_padding_offset[i / dim_embed] - cum_offset[(i / dim_embed
|
||||
// + output_padding_offset[i / dim_embed]) / max_seq_len] + seq_id)
|
||||
// * dim_embed + i % dim_embed]
|
||||
GM2LM(full_hidden_states + input_token_id * dim_embed + _n,
|
||||
lm_full_hidden_states,
|
||||
read_size * sizeof(T));
|
||||
LM2GM(lm_full_hidden_states,
|
||||
out + _m * dim_embed + _n,
|
||||
read_size * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define _XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(T) \
|
||||
template __global__ void RebuildAppendPaddingKernel<T>( \
|
||||
const T *full_hidden_states, \
|
||||
const int *cum_offset, \
|
||||
const int *seq_len_encoder, \
|
||||
const int *seq_len_decoder, \
|
||||
const int *output_padding_offset, \
|
||||
int max_seq_len, \
|
||||
int dim_embed, \
|
||||
int elem_nums, \
|
||||
T *out);
|
||||
#define _XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(T) \
|
||||
template __global__ void RebuildAppendPaddingKernel<T>( \
|
||||
const T *full_hidden_states, \
|
||||
const int *cum_offset, \
|
||||
const int *seq_len_encoder, \
|
||||
const int *seq_len_decoder, \
|
||||
const int *output_padding_offset, \
|
||||
int max_seq_len, \
|
||||
int dim_embed, \
|
||||
int elem_nums, \
|
||||
T *out);
|
||||
|
||||
_XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(bfloat16);
|
||||
_XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(float16);
|
||||
|
||||
+4
-4
@@ -152,8 +152,8 @@ __device__ void speculate_update_repeat_times_optimized(
|
||||
repeat_times_read_size_per_core * sizeof(int));
|
||||
}
|
||||
sync_all();
|
||||
// each core loads pre_ids step by step and record the index of pre_ids
|
||||
// which is less than zero, and store the index to boundary
|
||||
// each core loads pre_ids step by step and record the index of
|
||||
// pre_ids which is less than zero, and store the index to boundary
|
||||
if (repeat_times_start == 0) {
|
||||
bool do_prone = false;
|
||||
int64_t j = cid * pre_ids_lm_len;
|
||||
@@ -190,8 +190,8 @@ __device__ void speculate_update_repeat_times_optimized(
|
||||
buffer_ptr_pre_ids.toggle();
|
||||
}
|
||||
}
|
||||
// each core loads all the needed pre_ids into lm without mfence in between
|
||||
// according to the index recorded by previous iteration
|
||||
// each core loads all the needed pre_ids into lm without mfence in
|
||||
// between according to the index recorded by previous iteration
|
||||
else {
|
||||
int cnt = -1;
|
||||
int64_t pre_ids_read_size = 0;
|
||||
|
||||
+9
-7
@@ -240,18 +240,20 @@ __global__ void speculate_update_value_by_repeat_times_simd(
|
||||
alpha,
|
||||
logits_,
|
||||
logits_,
|
||||
(time_mask &
|
||||
~logit_mask)); // when time != 0 && logit < 0, do alpha * logit
|
||||
(time_mask & ~logit_mask)); // when time != 0 && logit < 0, do
|
||||
// alpha * logit
|
||||
logits_ = svmul_float32x16_mh(
|
||||
1.0f / alpha,
|
||||
logits_,
|
||||
logits_,
|
||||
(time_mask & logit_mask)); // when time != 0 && >=0, do logit / alpha
|
||||
logits_ = vvsub_float32x16_mh(
|
||||
logits_, time_, logits_, time_mask); // when time != 0, do logit =
|
||||
// logit - time * beta - gamma;
|
||||
logits_ =
|
||||
svmul_float32x16(1.0f / temperature, logits_); // logit / temperature
|
||||
logits_ = vvsub_float32x16_mh(logits_,
|
||||
time_,
|
||||
logits_,
|
||||
time_mask); // when time != 0, do logit =
|
||||
// logit - time * beta - gamma;
|
||||
logits_ = svmul_float32x16(1.0f / temperature,
|
||||
logits_); // logit / temperature
|
||||
vstore_lm_float32x16(logits_lm + j, logits_);
|
||||
}
|
||||
mfence_lm();
|
||||
|
||||
@@ -6,15 +6,15 @@ namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__global__ void recover_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int cid = core_id();
|
||||
int ncores = core_num();
|
||||
int clusterid = cluster_id();
|
||||
@@ -23,15 +23,17 @@ __global__ void recover_decode_task(bool *stop_flags,
|
||||
int nthreads = nclusters * ncores;
|
||||
// if (clusterid != 0) return;
|
||||
for (; thread_idx < bsz; thread_idx += nthreads) {
|
||||
if(is_block_step[thread_idx] == true) {
|
||||
// int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (block_tables[thread_idx * block_num_per_seq + step_seq_lens_decoder[thread_idx] / block_size] != -1) {
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx]= 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
if (is_block_step[thread_idx] == true) {
|
||||
// int *block_table_now = block_tables + thread_idx *
|
||||
// block_num_per_seq;
|
||||
if (block_tables[thread_idx * block_num_per_seq +
|
||||
step_seq_lens_decoder[thread_idx] / block_size] != -1) {
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx] = 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,8 +30,8 @@ __global__ void remove_padding(int64_t *x_remove_padding,
|
||||
input_lm,
|
||||
sizeof(int64_t) * cur_len);
|
||||
LM2GM(input_lm,
|
||||
x_remove_padding + i * sequence_length - cum_offset_lm + j,
|
||||
sizeof(int64_t) * cur_len);
|
||||
x_remove_padding + i * sequence_length - cum_offset_lm + j,
|
||||
sizeof(int64_t) * cur_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,14 +54,13 @@ __global__ void set_stop_value_multi_ends(bool* stop_flags,
|
||||
GM2LM_ASYNC(seq_lens + i, seq_lens_lm, sizeof(int) * readlen);
|
||||
mfence();
|
||||
for (int j = 0; j < readlen; j++) {
|
||||
if(prefill_one_step_stop){
|
||||
if (prefill_one_step_stop) {
|
||||
stop_flags_lm[j] = true;
|
||||
if (seq_lens_lm[j] == 0) {
|
||||
topk_ids_lm[j] = -1;
|
||||
}
|
||||
next_tokens_lm[j] = topk_ids_lm[j];
|
||||
}
|
||||
else{
|
||||
} else {
|
||||
if (stop_flags_lm[j]) {
|
||||
if (seq_lens_lm[j] == 0) {
|
||||
topk_ids_lm[j] = -1;
|
||||
|
||||
+183
-143
@@ -8,166 +8,206 @@ namespace plugin {
|
||||
|
||||
template <typename T>
|
||||
static __device__ inline void text_image_gather(
|
||||
__global_ptr__ T* input,
|
||||
__global_ptr__ T* text_input,
|
||||
__global_ptr__ T* image_input,
|
||||
__global_ptr__ int* token_type_ids,
|
||||
__global_ptr__ int* text_index,
|
||||
__global_ptr__ int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
T* input_lm) {
|
||||
int cid = core_id();
|
||||
int clusterid = cluster_id();
|
||||
int token_start_cluster;
|
||||
int token_end_cluster;
|
||||
int token_start_core;
|
||||
int token_end_core;
|
||||
__global_ptr__ T* input,
|
||||
__global_ptr__ T* text_input,
|
||||
__global_ptr__ T* image_input,
|
||||
__global_ptr__ int* token_type_ids,
|
||||
__global_ptr__ int* text_index,
|
||||
__global_ptr__ int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
T* input_lm) {
|
||||
int cid = core_id();
|
||||
int clusterid = cluster_id();
|
||||
int token_start_cluster;
|
||||
int token_end_cluster;
|
||||
int token_start_core;
|
||||
int token_end_core;
|
||||
|
||||
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
|
||||
// cluster partition
|
||||
partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster);
|
||||
if (token_start_cluster >= token_end_cluster) {
|
||||
return;
|
||||
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
|
||||
// cluster partition
|
||||
partition(cluster_id(),
|
||||
cluster_num(),
|
||||
(int)token_num,
|
||||
1,
|
||||
&token_start_cluster,
|
||||
&token_end_cluster);
|
||||
if (token_start_cluster >= token_end_cluster) {
|
||||
return;
|
||||
}
|
||||
int rows_cluster =
|
||||
token_end_cluster - token_start_cluster; // total rows for a cluster
|
||||
// core partition
|
||||
partition(core_id(),
|
||||
core_num(),
|
||||
rows_cluster,
|
||||
1,
|
||||
&token_start_core,
|
||||
&token_end_core);
|
||||
int rows_core = token_end_core - token_start_core; // total rows for a core
|
||||
token_start_core += token_start_cluster;
|
||||
token_end_core += token_start_cluster;
|
||||
|
||||
int read_len;
|
||||
for (int i = token_start_core; i < token_end_core; i += 1) {
|
||||
int token_type, text_image_token_idx;
|
||||
__global_ptr__ T* text_image_input = nullptr;
|
||||
__global_ptr__ int* text_image_index = nullptr;
|
||||
|
||||
GM2LM(token_type_ids + i, &token_type, sizeof(int));
|
||||
if (token_type == 0) {
|
||||
text_image_input = text_input;
|
||||
text_image_index = text_index;
|
||||
} else {
|
||||
text_image_input = image_input;
|
||||
text_image_index = image_index;
|
||||
}
|
||||
int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster
|
||||
// core partition
|
||||
partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core);
|
||||
int rows_core = token_end_core - token_start_core; // total rows for a core
|
||||
token_start_core += token_start_cluster;
|
||||
token_end_core += token_start_cluster;
|
||||
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
|
||||
int input_offset = i * hidden_size;
|
||||
int text_image_offset = text_image_token_idx * hidden_size;
|
||||
|
||||
int read_len;
|
||||
for (int i = token_start_core; i < token_end_core; i += 1) {
|
||||
int token_type, text_image_token_idx;
|
||||
__global_ptr__ T* text_image_input = nullptr;
|
||||
__global_ptr__ int* text_image_index = nullptr;
|
||||
|
||||
GM2LM(token_type_ids + i, &token_type, sizeof(int));
|
||||
if (token_type == 0) {
|
||||
text_image_input = text_input;
|
||||
text_image_index = text_index;
|
||||
} else {
|
||||
text_image_input = image_input;
|
||||
text_image_index = image_index;
|
||||
}
|
||||
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
|
||||
int input_offset = i * hidden_size;
|
||||
int text_image_offset = text_image_token_idx * hidden_size;
|
||||
|
||||
for (int j = 0; j < hidden_size; j += BUFSIZE) {
|
||||
read_len = min(hidden_size - j, BUFSIZE);
|
||||
GM2LM(text_image_input + text_image_offset + j, input_lm, sizeof(T) * read_len);
|
||||
LM2GM(input_lm, input + input_offset + j, sizeof(T) * read_len);
|
||||
}
|
||||
for (int j = 0; j < hidden_size; j += BUFSIZE) {
|
||||
read_len = min(hidden_size - j, BUFSIZE);
|
||||
GM2LM(text_image_input + text_image_offset + j,
|
||||
input_lm,
|
||||
sizeof(T) * read_len);
|
||||
LM2GM(input_lm, input + input_offset + j, sizeof(T) * read_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ inline void text_image_scatter(
|
||||
__global_ptr__ T* input,
|
||||
__global_ptr__ T* text_input,
|
||||
__global_ptr__ T* image_input,
|
||||
__global_ptr__ int* token_type_ids,
|
||||
__global_ptr__ int* text_index,
|
||||
__global_ptr__ int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
T* input_lm) {
|
||||
int cid = core_id();
|
||||
int clusterid = cluster_id();
|
||||
int token_start_cluster;
|
||||
int token_end_cluster;
|
||||
int token_start_core;
|
||||
int token_end_core;
|
||||
__global_ptr__ T* input,
|
||||
__global_ptr__ T* text_input,
|
||||
__global_ptr__ T* image_input,
|
||||
__global_ptr__ int* token_type_ids,
|
||||
__global_ptr__ int* text_index,
|
||||
__global_ptr__ int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
T* input_lm) {
|
||||
int cid = core_id();
|
||||
int clusterid = cluster_id();
|
||||
int token_start_cluster;
|
||||
int token_end_cluster;
|
||||
int token_start_core;
|
||||
int token_end_core;
|
||||
|
||||
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
|
||||
// cluster partition
|
||||
partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster);
|
||||
if (token_start_cluster >= token_end_cluster) {
|
||||
return;
|
||||
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
|
||||
// cluster partition
|
||||
partition(cluster_id(),
|
||||
cluster_num(),
|
||||
(int)token_num,
|
||||
1,
|
||||
&token_start_cluster,
|
||||
&token_end_cluster);
|
||||
if (token_start_cluster >= token_end_cluster) {
|
||||
return;
|
||||
}
|
||||
int rows_cluster =
|
||||
token_end_cluster - token_start_cluster; // total rows for a cluster
|
||||
// core partition
|
||||
partition(core_id(),
|
||||
core_num(),
|
||||
rows_cluster,
|
||||
1,
|
||||
&token_start_core,
|
||||
&token_end_core);
|
||||
int rows_core = token_end_core - token_start_core; // total rows for a core
|
||||
token_start_core += token_start_cluster;
|
||||
token_end_core += token_start_cluster;
|
||||
|
||||
int read_len;
|
||||
for (int i = token_start_core; i < token_end_core; i += 1) {
|
||||
int token_type, text_image_token_idx;
|
||||
__global_ptr__ T* text_image_input = nullptr;
|
||||
__global_ptr__ int* text_image_index = nullptr;
|
||||
|
||||
GM2LM(token_type_ids + i, &token_type, sizeof(int));
|
||||
if (token_type == 0) {
|
||||
text_image_input = text_input;
|
||||
text_image_index = text_index;
|
||||
} else {
|
||||
text_image_input = image_input;
|
||||
text_image_index = image_index;
|
||||
}
|
||||
int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster
|
||||
// core partition
|
||||
partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core);
|
||||
int rows_core = token_end_core - token_start_core; // total rows for a core
|
||||
token_start_core += token_start_cluster;
|
||||
token_end_core += token_start_cluster;
|
||||
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
|
||||
int input_offset = i * hidden_size;
|
||||
int text_image_offset = text_image_token_idx * hidden_size;
|
||||
|
||||
int read_len;
|
||||
for (int i = token_start_core; i < token_end_core; i += 1) {
|
||||
int token_type, text_image_token_idx;
|
||||
__global_ptr__ T* text_image_input = nullptr;
|
||||
__global_ptr__ int* text_image_index = nullptr;
|
||||
|
||||
GM2LM(token_type_ids + i, &token_type, sizeof(int));
|
||||
if (token_type == 0) {
|
||||
text_image_input = text_input;
|
||||
text_image_index = text_index;
|
||||
} else {
|
||||
text_image_input = image_input;
|
||||
text_image_index = image_index;
|
||||
}
|
||||
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
|
||||
int input_offset = i * hidden_size;
|
||||
int text_image_offset = text_image_token_idx * hidden_size;
|
||||
|
||||
for (int j = 0; j < hidden_size; j += BUFSIZE) {
|
||||
read_len = min(hidden_size - j, BUFSIZE);
|
||||
GM2LM(input + input_offset + j, input_lm, sizeof(T) * read_len);
|
||||
LM2GM(input_lm, text_image_input + text_image_offset + j, sizeof(T) * read_len);
|
||||
}
|
||||
for (int j = 0; j < hidden_size; j += BUFSIZE) {
|
||||
read_len = min(hidden_size - j, BUFSIZE);
|
||||
GM2LM(input + input_offset + j, input_lm, sizeof(T) * read_len);
|
||||
LM2GM(input_lm,
|
||||
text_image_input + text_image_offset + j,
|
||||
sizeof(T) * read_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void text_image_gather_scatter(
|
||||
T* input,
|
||||
T* text_input,
|
||||
T* image_input,
|
||||
int* token_type_ids,
|
||||
int* text_index,
|
||||
int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter) {
|
||||
int cid = core_id();
|
||||
int ncores = core_num();
|
||||
int clusterid = cluster_id();
|
||||
int nclusters = cluster_num();
|
||||
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
|
||||
__simd__ T input_lm[BUFSIZE]; // 2KB for bf16 and fp32
|
||||
if (is_scatter) {
|
||||
text_image_scatter(
|
||||
input, text_input, image_input, token_type_ids, text_index, image_index,
|
||||
token_num, text_token_num, image_token_num, hidden_size, input_lm);
|
||||
} else {
|
||||
text_image_gather(
|
||||
input, text_input, image_input, token_type_ids, text_index, image_index,
|
||||
token_num, text_token_num, image_token_num, hidden_size, input_lm);
|
||||
}
|
||||
__global__ void text_image_gather_scatter(T* input,
|
||||
T* text_input,
|
||||
T* image_input,
|
||||
int* token_type_ids,
|
||||
int* text_index,
|
||||
int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter) {
|
||||
int cid = core_id();
|
||||
int ncores = core_num();
|
||||
int clusterid = cluster_id();
|
||||
int nclusters = cluster_num();
|
||||
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
|
||||
__simd__ T input_lm[BUFSIZE]; // 2KB for bf16 and fp32
|
||||
if (is_scatter) {
|
||||
text_image_scatter(input,
|
||||
text_input,
|
||||
image_input,
|
||||
token_type_ids,
|
||||
text_index,
|
||||
image_index,
|
||||
token_num,
|
||||
text_token_num,
|
||||
image_token_num,
|
||||
hidden_size,
|
||||
input_lm);
|
||||
} else {
|
||||
text_image_gather(input,
|
||||
text_input,
|
||||
image_input,
|
||||
token_type_ids,
|
||||
text_index,
|
||||
image_index,
|
||||
token_num,
|
||||
text_token_num,
|
||||
image_token_num,
|
||||
hidden_size,
|
||||
input_lm);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define _XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(T) \
|
||||
template __global__ void text_image_gather_scatter<T>( \
|
||||
T* input, \
|
||||
T* text_input, \
|
||||
T* image_input, \
|
||||
int* token_type_ids, \
|
||||
int* text_index, \
|
||||
int* image_index, \
|
||||
int64_t token_num, \
|
||||
int64_t text_token_num, \
|
||||
int64_t image_token_num, \
|
||||
int64_t hidden_size, \
|
||||
bool is_scatter);
|
||||
#define _XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(T) \
|
||||
template __global__ void text_image_gather_scatter<T>( \
|
||||
T * input, \
|
||||
T * text_input, \
|
||||
T * image_input, \
|
||||
int* token_type_ids, \
|
||||
int* text_index, \
|
||||
int* image_index, \
|
||||
int64_t token_num, \
|
||||
int64_t text_token_num, \
|
||||
int64_t image_token_num, \
|
||||
int64_t hidden_size, \
|
||||
bool is_scatter);
|
||||
|
||||
_XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(bfloat16);
|
||||
|
||||
|
||||
@@ -23,75 +23,92 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
static __device__ void do_calc(const _shared_ptr_ int* lm_x, int* lm_y1, int* lm_y2, int64_t size, int& text_count, int& images_count) {
|
||||
for (int j = 0; j < size; j++) {
|
||||
if (lm_x[j] == 0) {
|
||||
lm_y1[j] = text_count;
|
||||
text_count += 1;
|
||||
} else {
|
||||
lm_y2[j] = images_count;
|
||||
images_count += 1;
|
||||
}
|
||||
static __device__ void do_calc(const _shared_ptr_ int* lm_x,
|
||||
int* lm_y1,
|
||||
int* lm_y2,
|
||||
int64_t size,
|
||||
int& text_count,
|
||||
int& images_count) {
|
||||
for (int j = 0; j < size; j++) {
|
||||
if (lm_x[j] == 0) {
|
||||
lm_y1[j] = text_count;
|
||||
text_count += 1;
|
||||
} else {
|
||||
lm_y2[j] = images_count;
|
||||
images_count += 1;
|
||||
}
|
||||
mfence_lm_sm();
|
||||
}
|
||||
mfence_lm_sm();
|
||||
}
|
||||
|
||||
__global__ void text_image_index_out_kernel(
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num) {
|
||||
const int cid = core_id();
|
||||
const int tid = core_id() * cluster_num() + cluster_id();
|
||||
const int nthreads = core_num() * cluster_num();
|
||||
if (tid >= 1) return;
|
||||
constexpr int BUFSIZE = 1024;
|
||||
constexpr int READ_MAX_SIZE = BUFSIZE / sizeof(int);
|
||||
const int64_t len = token_num;
|
||||
__global__ void text_image_index_out_kernel(const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num) {
|
||||
const int cid = core_id();
|
||||
const int tid = core_id() * cluster_num() + cluster_id();
|
||||
const int nthreads = core_num() * cluster_num();
|
||||
if (tid >= 1) return;
|
||||
constexpr int BUFSIZE = 1024;
|
||||
constexpr int READ_MAX_SIZE = BUFSIZE / sizeof(int);
|
||||
const int64_t len = token_num;
|
||||
|
||||
__simd__ char buffer0[BUFSIZE * 3];
|
||||
__simd__ char buffer1[BUFSIZE * 3];
|
||||
__simd__ __shared__ char buffer2[64][BUFSIZE * 2];
|
||||
__simd__ char buffer0[BUFSIZE * 3];
|
||||
__simd__ char buffer1[BUFSIZE * 3];
|
||||
__simd__ __shared__ char buffer2[64][BUFSIZE * 2];
|
||||
|
||||
DoublePtr<READ_MAX_SIZE, SmPtr<int>> buffer_ptr_x((SmPtr<int>((_shared_ptr_ int*)buffer2[cid])));
|
||||
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y1((LmPtr<int>((int*)buffer0)));
|
||||
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y2((LmPtr<int>((int*)buffer1)));
|
||||
int64_t buflen = get_1d_buflen(len, nthreads, READ_MAX_SIZE, 64);
|
||||
int64_t i = tid * buflen;
|
||||
int read_size = 0;
|
||||
int offset = nthreads * buflen;
|
||||
DoublePtr<READ_MAX_SIZE, SmPtr<int>> buffer_ptr_x(
|
||||
(SmPtr<int>((_shared_ptr_ int*)buffer2[cid])));
|
||||
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y1(
|
||||
(LmPtr<int>((int*)buffer0)));
|
||||
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y2(
|
||||
(LmPtr<int>((int*)buffer1)));
|
||||
int64_t buflen = get_1d_buflen(len, nthreads, READ_MAX_SIZE, 64);
|
||||
int64_t i = tid * buflen;
|
||||
int read_size = 0;
|
||||
int offset = nthreads * buflen;
|
||||
|
||||
int text_count = 0;
|
||||
int images_count = 0;
|
||||
int text_count = 0;
|
||||
int images_count = 0;
|
||||
|
||||
if (i < len) {
|
||||
read_size = min<int64_t>(buflen, len - i);
|
||||
buffer_ptr_y1.gm_load_async(text_index + tid * buflen, read_size);
|
||||
buffer_ptr_y2.gm_load_async(image_index + tid * buflen, read_size);
|
||||
buffer_ptr_x.gm_load_async(token_type_ids + tid * buflen, read_size);
|
||||
mfence();
|
||||
}
|
||||
while (i < len && i + offset < len) {
|
||||
i = i + offset;
|
||||
int read_size_next = min<int64_t>(buflen, len - i);
|
||||
buffer_ptr_x.next().gm_load_async(token_type_ids + i, read_size_next);
|
||||
buffer_ptr_y1.next().gm_load_async(text_index + i, read_size_next);
|
||||
buffer_ptr_y2.next().gm_load_async(image_index + i, read_size_next);
|
||||
if (i < len) {
|
||||
read_size = min<int64_t>(buflen, len - i);
|
||||
buffer_ptr_y1.gm_load_async(text_index + tid * buflen, read_size);
|
||||
buffer_ptr_y2.gm_load_async(image_index + tid * buflen, read_size);
|
||||
buffer_ptr_x.gm_load_async(token_type_ids + tid * buflen, read_size);
|
||||
mfence();
|
||||
}
|
||||
while (i < len && i + offset < len) {
|
||||
i = i + offset;
|
||||
int read_size_next = min<int64_t>(buflen, len - i);
|
||||
buffer_ptr_x.next().gm_load_async(token_type_ids + i, read_size_next);
|
||||
buffer_ptr_y1.next().gm_load_async(text_index + i, read_size_next);
|
||||
buffer_ptr_y2.next().gm_load_async(image_index + i, read_size_next);
|
||||
|
||||
do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count);
|
||||
do_calc(buffer_ptr_x.ptr,
|
||||
buffer_ptr_y1.ptr,
|
||||
buffer_ptr_y2.ptr,
|
||||
read_size,
|
||||
text_count,
|
||||
images_count);
|
||||
|
||||
buffer_ptr_y1.gm_store_async(text_index + i - offset, read_size);
|
||||
buffer_ptr_y2.gm_store_async(image_index + i - offset, read_size);
|
||||
buffer_ptr_x.toggle();
|
||||
buffer_ptr_y1.toggle();
|
||||
buffer_ptr_y2.toggle();
|
||||
read_size = read_size_next;
|
||||
}
|
||||
if (i < len) {
|
||||
do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count);
|
||||
buffer_ptr_y1.gm_store_async(text_index + i, read_size);
|
||||
buffer_ptr_y2.gm_store(image_index + i, read_size);
|
||||
}
|
||||
buffer_ptr_y1.gm_store_async(text_index + i - offset, read_size);
|
||||
buffer_ptr_y2.gm_store_async(image_index + i - offset, read_size);
|
||||
buffer_ptr_x.toggle();
|
||||
buffer_ptr_y1.toggle();
|
||||
buffer_ptr_y2.toggle();
|
||||
read_size = read_size_next;
|
||||
}
|
||||
if (i < len) {
|
||||
do_calc(buffer_ptr_x.ptr,
|
||||
buffer_ptr_y1.ptr,
|
||||
buffer_ptr_y2.ptr,
|
||||
read_size,
|
||||
text_count,
|
||||
images_count);
|
||||
buffer_ptr_y1.gm_store_async(text_index + i, read_size);
|
||||
buffer_ptr_y2.gm_store(image_index + i, read_size);
|
||||
}
|
||||
}
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
@@ -46,7 +46,8 @@ __global__ void update_inputs(bool *not_need_stop,
|
||||
int seq_len_decoder_update =
|
||||
stop_flag_now
|
||||
? 0
|
||||
: (seq_len_encoder > 0 ? (seq_len_encoder + seq_len_decoder) : seq_len_decoder + 1);
|
||||
: (seq_len_encoder > 0 ? (seq_len_encoder + seq_len_decoder)
|
||||
: seq_len_decoder + 1);
|
||||
int seq_len_this_time_update = !stop_flag_now;
|
||||
int seq_len_encoder_update = 0;
|
||||
mfence_lm();
|
||||
|
||||
@@ -4,32 +4,30 @@
|
||||
// #include <stdio.h>
|
||||
// using namespace std;
|
||||
|
||||
#include "xpu/kernel/xtdk_io.h"
|
||||
#include "xpu/kernel/xtdk.h"
|
||||
#include "xpu/kernel/xtdk_io.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__global__ void update_inputs_v1(bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
|
||||
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
// std::cout << "seq_lens_this_time " << seq_lens_this_time[0] << std::endl;
|
||||
int cid = core_id();
|
||||
int ncores = core_num();
|
||||
@@ -41,74 +39,83 @@ __global__ void update_inputs_v1(bool *not_need_stop,
|
||||
const int max_bs = 1024;
|
||||
__shared__ bool stop_flags_sm[max_bs];
|
||||
__shared__ int stop_flags_int_sm[max_bs];
|
||||
if(cid == 0){
|
||||
if (cid == 0) {
|
||||
GM2SM(stop_flags, stop_flags_sm, sizeof(bool) * bsz);
|
||||
}
|
||||
sync_all();
|
||||
|
||||
for(int i = cid; i < bsz; i+= ncores){
|
||||
if(i < bsz){
|
||||
stop_flags_sm[i] = stop_flags[i];
|
||||
stop_flags_int_sm[i] = static_cast<int64_t>(stop_flags_sm[i]);
|
||||
}else{
|
||||
stop_flags_sm[i] = true;
|
||||
stop_flags_int_sm[i] = 1;
|
||||
for (int i = cid; i < bsz; i += ncores) {
|
||||
if (i < bsz) {
|
||||
stop_flags_sm[i] = stop_flags[i];
|
||||
stop_flags_int_sm[i] = static_cast<int64_t>(stop_flags_sm[i]);
|
||||
} else {
|
||||
stop_flags_sm[i] = true;
|
||||
stop_flags_int_sm[i] = 1;
|
||||
}
|
||||
if(i<bsz){
|
||||
int seq_len_this_time_update = 0;
|
||||
int seq_len_decoder_update = 0;
|
||||
int seq_lens_encoder_update = 0;
|
||||
if(stop_flags_sm[i]){
|
||||
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
|
||||
if (i < bsz) {
|
||||
int seq_len_this_time_update = 0;
|
||||
int seq_len_decoder_update = 0;
|
||||
int seq_lens_encoder_update = 0;
|
||||
if (stop_flags_sm[i]) {
|
||||
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
|
||||
} else {
|
||||
GM2LM(seq_lens_this_time + i, &seq_len_this_time_update, sizeof(int));
|
||||
GM2LM(seq_lens_decoder + i, &seq_len_decoder_update, sizeof(int));
|
||||
GM2LM(seq_lens_encoder + i, &seq_lens_encoder_update, sizeof(int));
|
||||
int sum_of_seq_lens_this_time_and_seq_lens_decoder =
|
||||
seq_len_this_time_update + seq_len_decoder_update;
|
||||
int prompt_lens_update = 0;
|
||||
GM2LM(prompt_lens + i, &prompt_lens_update, sizeof(int64_t));
|
||||
// decoding
|
||||
if (sum_of_seq_lens_this_time_and_seq_lens_decoder >=
|
||||
prompt_lens_update) {
|
||||
seq_len_decoder_update =
|
||||
seq_len_this_time_update + seq_len_decoder_update;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
seq_len_this_time_update = 1;
|
||||
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
|
||||
seq_lens_encoder_update = 0;
|
||||
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
|
||||
int64_t input_ids_update;
|
||||
GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t));
|
||||
LM2GM(&input_ids_update,
|
||||
input_ids + i * input_ids_stride,
|
||||
sizeof(int64_t));
|
||||
// to judge whether block is not enough
|
||||
if (seq_len_this_time_update != 0 &&
|
||||
block_tables[i * block_num_per_seq +
|
||||
seq_len_decoder_update / block_size] == -1) {
|
||||
is_block_step[i] = true;
|
||||
seq_len_this_time_update = 0;
|
||||
LM2GM(
|
||||
&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
|
||||
stop_flags_sm[i] = true;
|
||||
SM2GM(stop_flags_sm + i, stop_flags + i, sizeof(bool));
|
||||
LM2GM(&seq_len_decoder_update,
|
||||
step_seq_lens_decoder + i,
|
||||
sizeof(int));
|
||||
seq_len_decoder_update = 0;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
|
||||
}else{
|
||||
GM2LM(seq_lens_this_time+i, &seq_len_this_time_update, sizeof(int));
|
||||
GM2LM(seq_lens_decoder+i, &seq_len_decoder_update, sizeof(int));
|
||||
GM2LM(seq_lens_encoder+i, &seq_lens_encoder_update, sizeof(int));
|
||||
int sum_of_seq_lens_this_time_and_seq_lens_decoder = seq_len_this_time_update + seq_len_decoder_update;
|
||||
int prompt_lens_update = 0;
|
||||
GM2LM(prompt_lens+i, &prompt_lens_update, sizeof(int64_t));
|
||||
// decoding
|
||||
if(sum_of_seq_lens_this_time_and_seq_lens_decoder >= prompt_lens_update){
|
||||
seq_len_decoder_update = seq_len_this_time_update + seq_len_decoder_update;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder+i, sizeof(int));
|
||||
seq_len_this_time_update = 1;
|
||||
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
|
||||
seq_lens_encoder_update = 0;
|
||||
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
|
||||
int64_t input_ids_update;
|
||||
GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t));
|
||||
LM2GM(&input_ids_update, input_ids + i * input_ids_stride, sizeof(int64_t));
|
||||
// to judge whether block is not enough
|
||||
if(seq_len_this_time_update != 0 && block_tables[i * block_num_per_seq + seq_len_decoder_update/block_size] == -1){
|
||||
is_block_step[i] = true;
|
||||
seq_len_this_time_update = 0;
|
||||
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
|
||||
stop_flags_sm[i] = true;
|
||||
SM2GM(stop_flags_sm+i, stop_flags+i, sizeof(bool));
|
||||
LM2GM(&seq_len_decoder_update, step_seq_lens_decoder+i, sizeof(int));
|
||||
seq_len_decoder_update = 0;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
seq_len_decoder_update = 0;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
stop_flags_int_sm[i] = 1;
|
||||
}
|
||||
}else{
|
||||
stop_flags_sm[i] = true;
|
||||
SM2GM(stop_flags_sm+i, stop_flags+i, sizeof(bool));
|
||||
seq_len_this_time_update = 0;
|
||||
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
|
||||
seq_len_decoder_update = 0;
|
||||
seq_lens_encoder_update = 0;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
|
||||
int64_t topk_ids_update = -1;
|
||||
LM2GM(&topk_ids_update, topk_ids + i, sizeof(int64_t));
|
||||
stop_flags_int_sm[i] = 1;
|
||||
}
|
||||
|
||||
seq_len_decoder_update = 0;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
stop_flags_int_sm[i] = 1;
|
||||
}
|
||||
} else {
|
||||
stop_flags_sm[i] = true;
|
||||
SM2GM(stop_flags_sm + i, stop_flags + i, sizeof(bool));
|
||||
seq_len_this_time_update = 0;
|
||||
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
|
||||
seq_len_decoder_update = 0;
|
||||
seq_lens_encoder_update = 0;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
|
||||
int64_t topk_ids_update = -1;
|
||||
LM2GM(&topk_ids_update, topk_ids + i, sizeof(int64_t));
|
||||
stop_flags_int_sm[i] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
sync_all();
|
||||
|
||||
+42
-27
@@ -6,16 +6,16 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__device__ void do_cast(const int* xlm, float* ylm, int64_t len) {
|
||||
for (int64_t i = 0; i < len; i += 32) {
|
||||
int32x16_t xl = vload_lm_int32x16(xlm + i);
|
||||
int32x16_t xh = vload_lm_int32x16(xlm + i + 16);
|
||||
float32x16_t yl = vfix2float(xl);
|
||||
float32x16_t yh = vfix2float(xh);
|
||||
vstore_lm_float32x16(ylm + i, yl);
|
||||
vstore_lm_float32x16(ylm + i + 16, yh);
|
||||
}
|
||||
mfence_lm();
|
||||
__device__ void do_cast(const int *xlm, float *ylm, int64_t len) {
|
||||
for (int64_t i = 0; i < len; i += 32) {
|
||||
int32x16_t xl = vload_lm_int32x16(xlm + i);
|
||||
int32x16_t xh = vload_lm_int32x16(xlm + i + 16);
|
||||
float32x16_t yl = vfix2float(xl);
|
||||
float32x16_t yh = vfix2float(xh);
|
||||
vstore_lm_float32x16(ylm + i, yl);
|
||||
vstore_lm_float32x16(ylm + i + 16, yh);
|
||||
}
|
||||
mfence_lm();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -124,7 +124,8 @@ __global__ void update_value_by_repeat_times_simd(
|
||||
int nthreads = cluster_num() * ncores;
|
||||
int start = -1;
|
||||
int end = -1;
|
||||
partition(thread_id, nthreads, static_cast<int>(bs * length), 16, &start, &end);
|
||||
partition(
|
||||
thread_id, nthreads, static_cast<int>(bs * length), 16, &start, &end);
|
||||
|
||||
const int param_len = 256;
|
||||
// ncores = 64 for xpu3
|
||||
@@ -178,14 +179,28 @@ __global__ void update_value_by_repeat_times_simd(
|
||||
alpha = alpha_buf[param_idx];
|
||||
beta = beta_buf[param_idx];
|
||||
gamma = gamma_buf[param_idx];
|
||||
time_mask = svneq_float32x16(0.f, time_); // time != 0 mask
|
||||
logit_mask = svle_float32x16(0.f, logits_); // logit >= 0 mask
|
||||
time_ = svmul_float32x16(beta, time_); // time * beta
|
||||
time_ = svadd_float32x16(gamma, time_); // time * beta + gamma
|
||||
logits_ = svmul_float32x16_mh(alpha, logits_, logits_, (time_mask & ~logit_mask)); // when time != 0 && logit < 0, do alpha * logit
|
||||
logits_ = svmul_float32x16_mh(1.0f / alpha, logits_, logits_, (time_mask & logit_mask)); // when time != 0 && >=0, do logit / alpha
|
||||
logits_ = vvsub_float32x16_mh(logits_, time_, logits_, time_mask); // when time != 0, do logit = logit - time * beta - gamma;
|
||||
logits_ = svmul_float32x16(1.0f / temperature, logits_); // logit / temperature
|
||||
time_mask = svneq_float32x16(0.f, time_); // time != 0 mask
|
||||
logit_mask = svle_float32x16(0.f, logits_); // logit >= 0 mask
|
||||
time_ = svmul_float32x16(beta, time_); // time * beta
|
||||
time_ = svadd_float32x16(gamma, time_); // time * beta + gamma
|
||||
logits_ = svmul_float32x16_mh(
|
||||
alpha,
|
||||
logits_,
|
||||
logits_,
|
||||
(time_mask & ~logit_mask)); // when time != 0 && logit < 0, do
|
||||
// alpha * logit
|
||||
logits_ = svmul_float32x16_mh(
|
||||
1.0f / alpha,
|
||||
logits_,
|
||||
logits_,
|
||||
(time_mask & logit_mask)); // when time != 0 && >=0, do logit / alpha
|
||||
logits_ = vvsub_float32x16_mh(logits_,
|
||||
time_,
|
||||
logits_,
|
||||
time_mask); // when time != 0, do logit =
|
||||
// logit - time * beta - gamma;
|
||||
logits_ = svmul_float32x16(1.0f / temperature,
|
||||
logits_); // logit / temperature
|
||||
vstore_lm_float32x16(logits_lm + j, logits_);
|
||||
}
|
||||
mfence_lm();
|
||||
@@ -195,14 +210,14 @@ __global__ void update_value_by_repeat_times_simd(
|
||||
}
|
||||
|
||||
#define _XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(DATA_TYPE) \
|
||||
template __global__ void update_value_by_repeat_times_simd( \
|
||||
const int *repeat_times, \
|
||||
const DATA_TYPE *penalty_scores, \
|
||||
const DATA_TYPE *frequency_score, \
|
||||
const DATA_TYPE *presence_score, \
|
||||
const float *temperatures, \
|
||||
DATA_TYPE *logits, \
|
||||
const int64_t bs, \
|
||||
template __global__ void update_value_by_repeat_times_simd( \
|
||||
const int *repeat_times, \
|
||||
const DATA_TYPE *penalty_scores, \
|
||||
const DATA_TYPE *frequency_score, \
|
||||
const DATA_TYPE *presence_score, \
|
||||
const float *temperatures, \
|
||||
DATA_TYPE *logits, \
|
||||
const int64_t bs, \
|
||||
const int64_t length);
|
||||
_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(float);
|
||||
_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(float16);
|
||||
|
||||
@@ -20,12 +20,16 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
template <typename TX, typename TY>
|
||||
__attribute__((global)) void
|
||||
eb_adjust_batch(TX *src, TY *dst, int *encoder_seqs_lods,
|
||||
int *encoder_batch_map, int *decoder_batch_map, int en_batch,
|
||||
int de_batch, int64_t copy_size);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
__attribute__((global)) void eb_adjust_batch(TX *src,
|
||||
TY *dst,
|
||||
int *encoder_seqs_lods,
|
||||
int *encoder_batch_map,
|
||||
int *decoder_batch_map,
|
||||
int en_batch,
|
||||
int de_batch,
|
||||
int64_t copy_size);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
@@ -33,10 +37,15 @@ namespace api {
|
||||
namespace plugin {
|
||||
|
||||
template <typename TX, typename TY>
|
||||
static int
|
||||
cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods,
|
||||
const int *encoder_batch_map, const int *decoder_batch_map,
|
||||
int en_batch, int de_batch, int64_t hidden_dim) {
|
||||
static int cpu_wrapper(api::Context *ctx,
|
||||
const TX *x,
|
||||
TY *y,
|
||||
const int *encoder_seqs_lods,
|
||||
const int *encoder_batch_map,
|
||||
const int *decoder_batch_map,
|
||||
int en_batch,
|
||||
int de_batch,
|
||||
int64_t hidden_dim) {
|
||||
int ret = 0;
|
||||
int cur_offset = 0;
|
||||
int en_idx = 0;
|
||||
@@ -48,7 +57,8 @@ cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods,
|
||||
int cpy_m = 0;
|
||||
if (de_batch > 0 && decoder_batch_map[de_idx] == i) {
|
||||
cpy_m = 1;
|
||||
ret = api::cast<TX, TY>(ctx, x + cur_offset * hidden_dim,
|
||||
ret = api::cast<TX, TY>(ctx,
|
||||
x + cur_offset * hidden_dim,
|
||||
y + (encoder_len_total + de_idx) * hidden_dim,
|
||||
cpy_m * hidden_dim);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
@@ -56,7 +66,8 @@ cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods,
|
||||
}
|
||||
if (en_batch > 0 && encoder_batch_map[en_idx] == i) {
|
||||
cpy_m = encoder_seqs_lods[en_idx + 1] - encoder_seqs_lods[en_idx];
|
||||
ret = api::cast<TX, TY>(ctx, x + cur_offset * hidden_dim,
|
||||
ret = api::cast<TX, TY>(ctx,
|
||||
x + cur_offset * hidden_dim,
|
||||
y + encoder_seqs_lods[en_idx] * hidden_dim,
|
||||
cpy_m * hidden_dim);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
@@ -69,11 +80,15 @@ cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods,
|
||||
}
|
||||
|
||||
template <typename TX, typename TY>
|
||||
static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y,
|
||||
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int en_batch, int de_batch, int64_t hidden_dim) {
|
||||
static int xpu3_wrapper(api::Context *ctx,
|
||||
const TX *x,
|
||||
TY *y,
|
||||
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int en_batch,
|
||||
int de_batch,
|
||||
int64_t hidden_dim) {
|
||||
using XPU_INDEX_TYPE_TX = typename XPUIndexType<TX>::type;
|
||||
using XPU_INDEX_TYPE_TY = typename XPUIndexType<TY>::type;
|
||||
auto eb_adjust_batch_kernel =
|
||||
@@ -81,17 +96,23 @@ static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y,
|
||||
// NOTE: Don't change 16 to 64, because kernel use gsm
|
||||
eb_adjust_batch_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INDEX_TYPE_TX *>(const_cast<TX *>(x)),
|
||||
reinterpret_cast<XPU_INDEX_TYPE_TY *>(y), encoder_seqs_lods.xpu,
|
||||
encoder_batch_map.xpu, decoder_batch_map.xpu, en_batch, de_batch,
|
||||
reinterpret_cast<XPU_INDEX_TYPE_TY *>(y),
|
||||
encoder_seqs_lods.xpu,
|
||||
encoder_batch_map.xpu,
|
||||
decoder_batch_map.xpu,
|
||||
en_batch,
|
||||
de_batch,
|
||||
hidden_dim);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
template <typename TX, typename TY>
|
||||
int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y,
|
||||
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int eb_adjust_batch(api::Context *ctx,
|
||||
const TX *x,
|
||||
TY *y,
|
||||
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim) {
|
||||
// int dev_id = -1;
|
||||
// xpu_current_device(&dev_id);
|
||||
@@ -101,8 +122,13 @@ int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y,
|
||||
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_adjust_batch", TX, TY);
|
||||
WRAPPER_DUMP_PARAM6(ctx, x, y, encoder_seqs_lods, encoder_batch_map,
|
||||
decoder_batch_map, hidden_dim);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
x,
|
||||
y,
|
||||
encoder_seqs_lods,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
hidden_dim);
|
||||
WRAPPER_DUMP(ctx);
|
||||
int encoder_batch = encoder_batch_map.len;
|
||||
int total_batch = encoder_batch + decoder_batch_map.len;
|
||||
@@ -126,9 +152,14 @@ int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y,
|
||||
WRAPPER_ASSERT_LT(ctx, decoder_batch_map.cpu[i], total_batch)
|
||||
}
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper<TX, TY>(ctx, x, y, encoder_seqs_lods.cpu,
|
||||
encoder_batch_map.cpu, decoder_batch_map.cpu,
|
||||
encoder_batch_map.len, decoder_batch_map.len,
|
||||
return cpu_wrapper<TX, TY>(ctx,
|
||||
x,
|
||||
y,
|
||||
encoder_seqs_lods.cpu,
|
||||
encoder_batch_map.cpu,
|
||||
decoder_batch_map.cpu,
|
||||
encoder_batch_map.len,
|
||||
decoder_batch_map.len,
|
||||
hidden_dim);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
@@ -139,18 +170,27 @@ int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y,
|
||||
encoder_batch_map.to_xpu(RAII_GUARD);
|
||||
api::VectorParam<int32_t> decoder_batch_map_xpu =
|
||||
decoder_batch_map.to_xpu(RAII_GUARD);
|
||||
return xpu3_wrapper<TX, TY>(ctx, x, y, encoder_seqs_lods_xpu,
|
||||
encoder_batch_map_xpu, decoder_batch_map_xpu,
|
||||
encoder_batch_map.len, decoder_batch_map.len,
|
||||
return xpu3_wrapper<TX, TY>(ctx,
|
||||
x,
|
||||
y,
|
||||
encoder_seqs_lods_xpu,
|
||||
encoder_batch_map_xpu,
|
||||
decoder_batch_map_xpu,
|
||||
encoder_batch_map.len,
|
||||
decoder_batch_map.len,
|
||||
hidden_dim);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
#define INSTANTIATION_EB_ADJUST_BATCH(TX, TY) \
|
||||
template int eb_adjust_batch<TX, TY>( \
|
||||
api::Context *, const TX *, TY *, api::VectorParam<int32_t> &, \
|
||||
api::VectorParam<int32_t> &, api::VectorParam<int32_t> &, int64_t);
|
||||
#define INSTANTIATION_EB_ADJUST_BATCH(TX, TY) \
|
||||
template int eb_adjust_batch<TX, TY>(api::Context *, \
|
||||
const TX *, \
|
||||
TY *, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
int64_t);
|
||||
|
||||
INSTANTIATION_EB_ADJUST_BATCH(float16, float16);
|
||||
INSTANTIATION_EB_ADJUST_BATCH(bfloat16, bfloat16);
|
||||
@@ -163,7 +203,7 @@ INSTANTIATION_EB_ADJUST_BATCH(bfloat16, float);
|
||||
INSTANTIATION_EB_ADJUST_BATCH(float, bfloat16);
|
||||
INSTANTIATION_EB_ADJUST_BATCH(int32_t, int32_t);
|
||||
INSTANTIATION_EB_ADJUST_BATCH(int64_t, int64_t);
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
@@ -20,62 +20,92 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
template <typename TX, typename TY>
|
||||
__attribute__((global)) void
|
||||
eb_gather_next_token(TX *src, TY *dst, int *encoder_seqs_lods,
|
||||
int *encoder_batch_map, int *decoder_batch_map,
|
||||
int en_batch, int de_batch, int64_t copy_size);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
__attribute__((global)) void eb_gather_next_token(TX *src,
|
||||
TY *dst,
|
||||
int *encoder_seqs_lods,
|
||||
int *encoder_batch_map,
|
||||
int *decoder_batch_map,
|
||||
int en_batch,
|
||||
int de_batch,
|
||||
int64_t copy_size);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
template <typename TX, typename TY>
|
||||
static int
|
||||
cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods,
|
||||
const int *encoder_batch_map, const int *decoder_batch_map,
|
||||
int en_batch, int de_batch, int64_t hidden_dim) {
|
||||
static int cpu_wrapper(api::Context *ctx,
|
||||
const TX *x,
|
||||
TY *y,
|
||||
const int *encoder_seqs_lods,
|
||||
const int *encoder_batch_map,
|
||||
const int *decoder_batch_map,
|
||||
int en_batch,
|
||||
int de_batch,
|
||||
int64_t hidden_dim) {
|
||||
int ret = 0;
|
||||
int encoder_len_total = encoder_seqs_lods[en_batch];
|
||||
for (int i = 0; i < en_batch; i++) {
|
||||
ret =
|
||||
api::cast<TX, TY>(ctx, x + (encoder_seqs_lods[i + 1] - 1) * hidden_dim,
|
||||
y + encoder_batch_map[i] * hidden_dim, hidden_dim);
|
||||
ret = api::cast<TX, TY>(ctx,
|
||||
x + (encoder_seqs_lods[i + 1] - 1) * hidden_dim,
|
||||
y + encoder_batch_map[i] * hidden_dim,
|
||||
hidden_dim);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
}
|
||||
for (int i = 0; i < de_batch; i++) {
|
||||
ret = api::cast<TX, TY>(ctx, x + (encoder_len_total + i) * hidden_dim,
|
||||
y + decoder_batch_map[i] * hidden_dim, hidden_dim);
|
||||
ret = api::cast<TX, TY>(ctx,
|
||||
x + (encoder_len_total + i) * hidden_dim,
|
||||
y + decoder_batch_map[i] * hidden_dim,
|
||||
hidden_dim);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
}
|
||||
|
||||
return api::SUCCESS;
|
||||
}
|
||||
template <typename TX, typename TY>
|
||||
static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y,
|
||||
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int en_batch, int de_batch, int64_t hidden_dim) {
|
||||
static int xpu3_wrapper(api::Context *ctx,
|
||||
const TX *x,
|
||||
TY *y,
|
||||
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int en_batch,
|
||||
int de_batch,
|
||||
int64_t hidden_dim) {
|
||||
auto eb_gather_next_token_kernel = xpu3::plugin::eb_gather_next_token<TX, TY>;
|
||||
// NOTE: Don't change 16 to 64, because kernel use gsm
|
||||
eb_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
|
||||
const_cast<TX *>(x), y, encoder_seqs_lods.xpu, encoder_batch_map.xpu,
|
||||
decoder_batch_map.xpu, en_batch, de_batch, hidden_dim);
|
||||
const_cast<TX *>(x),
|
||||
y,
|
||||
encoder_seqs_lods.xpu,
|
||||
encoder_batch_map.xpu,
|
||||
decoder_batch_map.xpu,
|
||||
en_batch,
|
||||
de_batch,
|
||||
hidden_dim);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
template <typename TX, typename TY>
|
||||
int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
|
||||
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim) {
|
||||
int eb_gather_next_token(
|
||||
api::Context *ctx,
|
||||
const TX *x,
|
||||
TY *y,
|
||||
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_gather_next_token", TX, TY);
|
||||
WRAPPER_DUMP_PARAM6(ctx, x, y, encoder_seqs_lods, encoder_batch_map,
|
||||
decoder_batch_map, hidden_dim);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
x,
|
||||
y,
|
||||
encoder_seqs_lods,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
hidden_dim);
|
||||
WRAPPER_DUMP(ctx);
|
||||
int encoder_batch = encoder_batch_map.len;
|
||||
int batch = encoder_batch + decoder_batch_map.len;
|
||||
@@ -99,9 +129,14 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
|
||||
WRAPPER_ASSERT_GE(ctx, decoder_batch_map.cpu[i], 0);
|
||||
}
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper<TX, TY>(ctx, x, y, encoder_seqs_lods.cpu,
|
||||
encoder_batch_map.cpu, decoder_batch_map.cpu,
|
||||
encoder_batch_map.len, decoder_batch_map.len,
|
||||
return cpu_wrapper<TX, TY>(ctx,
|
||||
x,
|
||||
y,
|
||||
encoder_seqs_lods.cpu,
|
||||
encoder_batch_map.cpu,
|
||||
decoder_batch_map.cpu,
|
||||
encoder_batch_map.len,
|
||||
decoder_batch_map.len,
|
||||
hidden_dim);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
@@ -112,17 +147,26 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
|
||||
encoder_batch_map.to_xpu(RAII_GUARD);
|
||||
api::VectorParam<int32_t> decoder_batch_map_xpu =
|
||||
decoder_batch_map.to_xpu(RAII_GUARD);
|
||||
return xpu3_wrapper<TX, TY>(ctx, x, y, encoder_seqs_lods_xpu,
|
||||
encoder_batch_map_xpu, decoder_batch_map_xpu,
|
||||
encoder_batch_map.len, decoder_batch_map.len,
|
||||
return xpu3_wrapper<TX, TY>(ctx,
|
||||
x,
|
||||
y,
|
||||
encoder_seqs_lods_xpu,
|
||||
encoder_batch_map_xpu,
|
||||
decoder_batch_map_xpu,
|
||||
encoder_batch_map.len,
|
||||
decoder_batch_map.len,
|
||||
hidden_dim);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
#define INSTANTIATION_EB_GATHER_NEXT_TOKEN(TX, TY) \
|
||||
template int eb_gather_next_token<TX, TY>( \
|
||||
api::Context *, const TX *, TY *, api::VectorParam<int32_t> &, \
|
||||
api::VectorParam<int32_t> &, api::VectorParam<int32_t> &, int64_t);
|
||||
#define INSTANTIATION_EB_GATHER_NEXT_TOKEN(TX, TY) \
|
||||
template int eb_gather_next_token<TX, TY>(api::Context *, \
|
||||
const TX *, \
|
||||
TY *, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
int64_t);
|
||||
|
||||
INSTANTIATION_EB_GATHER_NEXT_TOKEN(float16, float16);
|
||||
INSTANTIATION_EB_GATHER_NEXT_TOKEN(bfloat16, bfloat16);
|
||||
@@ -133,7 +177,7 @@ INSTANTIATION_EB_GATHER_NEXT_TOKEN(bfloat16, float16);
|
||||
INSTANTIATION_EB_GATHER_NEXT_TOKEN(float16, bfloat16);
|
||||
INSTANTIATION_EB_GATHER_NEXT_TOKEN(bfloat16, float);
|
||||
INSTANTIATION_EB_GATHER_NEXT_TOKEN(float, bfloat16);
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
@@ -12,211 +12,304 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__attribute__((global)) void free_and_dispatch_block(
|
||||
bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder,
|
||||
int *block_tables, int *encoder_block_lens, bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len, int *recover_block_list, int *recover_len,
|
||||
int *need_block_list, int *need_block_len, int *used_list_len,
|
||||
int *free_list, int *free_list_len, int64_t *first_token_ids, const int bsz,
|
||||
const int block_size, const int block_num_per_seq,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time,
|
||||
int *seq_lens_decoder, int *block_tables,
|
||||
int *encoder_block_lens, bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len, int *recover_block_list, int *recover_len,
|
||||
int *need_block_list, int *need_block_len,
|
||||
int *used_list_len, int *free_list, int *free_list_len,
|
||||
int64_t *first_token_ids, const int bsz,
|
||||
const int block_size, const int block_num_per_seq,
|
||||
static int cpu_wrapper(Context *ctx,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num) {
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
int *block_table_now = block_tables + i * block_num_per_seq;
|
||||
if (stop_flags[i] && !is_block_step[i]) {
|
||||
// 回收block块
|
||||
const int encoder_block_len = encoder_block_lens[i];
|
||||
const int decoder_used_len = used_list_len[i];
|
||||
if (decoder_used_len > 0) {
|
||||
const int ori_free_list_len = free_list_len[0];
|
||||
free_list_len[0] += decoder_used_len;
|
||||
for (int j = 0; j < decoder_used_len; j++) {
|
||||
free_list[ori_free_list_len + j] =
|
||||
block_table_now[encoder_block_len + j];
|
||||
block_table_now[encoder_block_len + j] = -1;
|
||||
}
|
||||
encoder_block_lens[i] = 0;
|
||||
used_list_len[i] = 0;
|
||||
}
|
||||
} else if (block_table_now[seq_lens_decoder[i] / block_size] == -1) {
|
||||
// 统计需要分配block的位置和总数
|
||||
const int ori_need_block_len = need_block_len[0];
|
||||
need_block_len[0] += 1;
|
||||
need_block_list[ori_need_block_len] = i;
|
||||
}
|
||||
}
|
||||
|
||||
while (need_block_len[0] > free_list_len[0]) {
|
||||
// 调度block,根据used_list_len从大到小回收block,直到满足need_block_len
|
||||
int max_used_list_len_id = 0;
|
||||
int max_used_list_len = 0;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
int *block_table_now = block_tables + i * block_num_per_seq;
|
||||
if (stop_flags[i] && !is_block_step[i]) {
|
||||
// 回收block块
|
||||
const int encoder_block_len = encoder_block_lens[i];
|
||||
const int decoder_used_len = used_list_len[i];
|
||||
if (decoder_used_len > 0) {
|
||||
const int ori_free_list_len = free_list_len[0];
|
||||
free_list_len[0] += decoder_used_len;
|
||||
for (int j = 0; j < decoder_used_len; j++) {
|
||||
free_list[ori_free_list_len + j] =
|
||||
block_table_now[encoder_block_len + j];
|
||||
block_table_now[encoder_block_len + j] = -1;
|
||||
}
|
||||
encoder_block_lens[i] = 0;
|
||||
used_list_len[i] = 0;
|
||||
}
|
||||
} else if (block_table_now[seq_lens_decoder[i] / block_size] == -1) {
|
||||
// 统计需要分配block的位置和总数
|
||||
const int ori_need_block_len = need_block_len[0];
|
||||
need_block_len[0] += 1;
|
||||
need_block_list[ori_need_block_len] = i;
|
||||
}
|
||||
const int used_block_num = !is_block_step[i] ? used_list_len[i] : 0;
|
||||
if (used_block_num > max_used_list_len) {
|
||||
max_used_list_len_id = i;
|
||||
max_used_list_len = used_block_num;
|
||||
}
|
||||
}
|
||||
|
||||
while (need_block_len[0] > free_list_len[0]) {
|
||||
// 调度block,根据used_list_len从大到小回收block,直到满足need_block_len
|
||||
int max_used_list_len_id = 0;
|
||||
int max_used_list_len = 0;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
const int used_block_num = !is_block_step[i] ? used_list_len[i] : 0;
|
||||
if (used_block_num > max_used_list_len) {
|
||||
max_used_list_len_id = i;
|
||||
max_used_list_len = used_block_num;
|
||||
}
|
||||
}
|
||||
|
||||
const int encoder_block_len = encoder_block_lens[max_used_list_len_id];
|
||||
int *block_table_now =
|
||||
block_tables + max_used_list_len_id * block_num_per_seq;
|
||||
for (int i = 0; i < max_used_list_len; i++) {
|
||||
free_list[free_list_len[0] + i] =
|
||||
block_table_now[encoder_block_len + i];
|
||||
block_table_now[encoder_block_len + i] = -1;
|
||||
}
|
||||
step_block_list[step_len[0]] = max_used_list_len_id;
|
||||
step_len[0] += 1;
|
||||
free_list_len[0] += max_used_list_len;
|
||||
stop_flags[max_used_list_len_id] = true;
|
||||
is_block_step[max_used_list_len_id] = true;
|
||||
seq_lens_this_time[max_used_list_len_id] = 0;
|
||||
seq_lens_decoder[max_used_list_len_id] = 0;
|
||||
const int encoder_block_len = encoder_block_lens[max_used_list_len_id];
|
||||
int *block_table_now =
|
||||
block_tables + max_used_list_len_id * block_num_per_seq;
|
||||
for (int i = 0; i < max_used_list_len; i++) {
|
||||
free_list[free_list_len[0] + i] = block_table_now[encoder_block_len + i];
|
||||
block_table_now[encoder_block_len + i] = -1;
|
||||
}
|
||||
step_block_list[step_len[0]] = max_used_list_len_id;
|
||||
step_len[0] += 1;
|
||||
free_list_len[0] += max_used_list_len;
|
||||
stop_flags[max_used_list_len_id] = true;
|
||||
is_block_step[max_used_list_len_id] = true;
|
||||
seq_lens_this_time[max_used_list_len_id] = 0;
|
||||
seq_lens_decoder[max_used_list_len_id] = 0;
|
||||
}
|
||||
|
||||
// 为需要block的位置分配block,每个位置分配一个block
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
if (i < need_block_len[0]) {
|
||||
const int need_block_id = need_block_list[i];
|
||||
if (!stop_flags[need_block_id]) {
|
||||
// 如果需要的位置正好是上一步中被释放的位置,不做处理
|
||||
used_list_len[need_block_id] += 1;
|
||||
const int ori_free_list_len = free_list_len[0];
|
||||
free_list_len[0]--;
|
||||
int *block_table_now =
|
||||
block_tables + need_block_id * block_num_per_seq;
|
||||
block_table_now[seq_lens_decoder[need_block_id] / block_size] =
|
||||
free_list[ori_free_list_len - 1];
|
||||
}
|
||||
need_block_list[i] = -1;
|
||||
}
|
||||
// 为需要block的位置分配block,每个位置分配一个block
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
if (i < need_block_len[0]) {
|
||||
const int need_block_id = need_block_list[i];
|
||||
if (!stop_flags[need_block_id]) {
|
||||
// 如果需要的位置正好是上一步中被释放的位置,不做处理
|
||||
used_list_len[need_block_id] += 1;
|
||||
const int ori_free_list_len = free_list_len[0];
|
||||
free_list_len[0]--;
|
||||
int *block_table_now = block_tables + need_block_id * block_num_per_seq;
|
||||
block_table_now[seq_lens_decoder[need_block_id] / block_size] =
|
||||
free_list[ori_free_list_len - 1];
|
||||
}
|
||||
need_block_list[i] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
// 计算可以复原的query id
|
||||
int ori_step_len = step_len[0];
|
||||
if (ori_step_len > 0) {
|
||||
int ori_free_list_len = free_list_len[0];
|
||||
int ori_step_block_id = step_block_list[ori_step_len - 1];
|
||||
int tmp_used_len = used_list_len[ori_step_block_id];
|
||||
// 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中)
|
||||
int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1
|
||||
: tmp_used_len;
|
||||
while (ori_step_len > 0 && ori_free_list_len >= used_len) {
|
||||
recover_block_list[recover_len[0]] = ori_step_block_id;
|
||||
is_block_step[ori_step_block_id] = false;
|
||||
used_list_len[ori_step_block_id] = used_len;
|
||||
ori_free_list_len -= used_len;
|
||||
step_block_list[ori_step_len - 1] = -1;
|
||||
step_len[0] -= 1;
|
||||
recover_len[0] += 1;
|
||||
ori_step_len = step_len[0];
|
||||
if (ori_step_len > 0) {
|
||||
ori_step_block_id = step_block_list[ori_step_len - 1];
|
||||
tmp_used_len = used_list_len[ori_step_block_id];
|
||||
used_len = tmp_used_len < max_decoder_block_num
|
||||
? tmp_used_len + 1
|
||||
: tmp_used_len;
|
||||
}
|
||||
}
|
||||
need_block_len[0] = 0;
|
||||
// 计算可以复原的query id
|
||||
int ori_step_len = step_len[0];
|
||||
if (ori_step_len > 0) {
|
||||
int ori_free_list_len = free_list_len[0];
|
||||
int ori_step_block_id = step_block_list[ori_step_len - 1];
|
||||
int tmp_used_len = used_list_len[ori_step_block_id];
|
||||
// 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中)
|
||||
int used_len =
|
||||
tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len;
|
||||
while (ori_step_len > 0 && ori_free_list_len >= used_len) {
|
||||
recover_block_list[recover_len[0]] = ori_step_block_id;
|
||||
is_block_step[ori_step_block_id] = false;
|
||||
used_list_len[ori_step_block_id] = used_len;
|
||||
ori_free_list_len -= used_len;
|
||||
step_block_list[ori_step_len - 1] = -1;
|
||||
step_len[0] -= 1;
|
||||
recover_len[0] += 1;
|
||||
ori_step_len = step_len[0];
|
||||
if (ori_step_len > 0) {
|
||||
ori_step_block_id = step_block_list[ori_step_len - 1];
|
||||
tmp_used_len = used_list_len[ori_step_block_id];
|
||||
used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1
|
||||
: tmp_used_len;
|
||||
}
|
||||
}
|
||||
return api::SUCCESS;
|
||||
need_block_len[0] = 0;
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time,
|
||||
int *seq_lens_decoder, int *block_tables,
|
||||
int *encoder_block_lens, bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len, int *recover_block_list,
|
||||
int *recover_len, int *need_block_list,
|
||||
int *need_block_len, int *used_list_len, int *free_list,
|
||||
int *free_list_len, int64_t *first_token_ids,
|
||||
const int bsz, const int block_size,
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto free_and_dispatch_block_kernel = xpu3::plugin::free_and_dispatch_block;
|
||||
free_and_dispatch_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables,
|
||||
encoder_block_lens, is_block_step, step_block_list, step_len,
|
||||
recover_block_list, recover_len, need_block_list, need_block_len,
|
||||
used_list_len, free_list, free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(first_token_ids), bsz, block_size,
|
||||
block_num_per_seq, max_decoder_block_num);
|
||||
return api::SUCCESS;
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto free_and_dispatch_block_kernel = xpu3::plugin::free_and_dispatch_block;
|
||||
free_and_dispatch_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(first_token_ids),
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int free_and_dispatch_block(Context *ctx, bool *stop_flags,
|
||||
int *seq_lens_this_time, int *seq_lens_decoder,
|
||||
int *block_tables, int *encoder_block_lens,
|
||||
int free_and_dispatch_block(Context *ctx,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len, int *recover_block_list,
|
||||
int *recover_len, int *need_block_list,
|
||||
int *need_block_len, int *used_list_len,
|
||||
int *free_list, int *free_list_len,
|
||||
int64_t *first_token_ids, const int bsz,
|
||||
const int block_size, const int block_num_per_seq,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "free_and_dispatch_block", float);
|
||||
WRAPPER_DUMP_PARAM6(ctx, stop_flags, seq_lens_this_time, seq_lens_decoder,
|
||||
block_tables, encoder_block_lens, is_block_step);
|
||||
WRAPPER_DUMP_PARAM6(ctx, step_block_list, step_len, recover_block_list,
|
||||
recover_len, need_block_list, need_block_len);
|
||||
WRAPPER_DUMP_PARAM4(ctx, used_list_len, free_list, free_list_len,
|
||||
first_token_ids);
|
||||
WRAPPER_DUMP_PARAM4(ctx, bsz, block_size, block_num_per_seq,
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "free_and_dispatch_block", float);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len);
|
||||
WRAPPER_DUMP_PARAM4(
|
||||
ctx, used_list_len, free_list, free_list_len, first_token_ids);
|
||||
WRAPPER_DUMP_PARAM4(
|
||||
ctx, bsz, block_size, block_num_per_seq, max_decoder_block_num);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
first_token_ids,
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
first_token_ids,
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(
|
||||
ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables,
|
||||
encoder_block_lens, is_block_step, step_block_list, step_len,
|
||||
recover_block_list, recover_len, need_block_list, need_block_len,
|
||||
used_list_len, free_list, free_list_len, first_token_ids, bsz,
|
||||
block_size, block_num_per_seq, max_decoder_block_num);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(
|
||||
ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables,
|
||||
encoder_block_lens, is_block_step, step_block_list, step_len,
|
||||
recover_block_list, recover_len, need_block_list, need_block_len,
|
||||
used_list_len, free_list, free_list_len, first_token_ids, bsz,
|
||||
block_size, block_num_per_seq, max_decoder_block_num);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
@@ -12,120 +12,178 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__attribute__((global)) void
|
||||
get_padding_offset(int *padding_offset, int *cum_offsets_out, int *cu_seqlens_q,
|
||||
int *cu_seqlens_k, const int *cum_offsets,
|
||||
const int *seq_lens, const int max_seq_len, const int bs);
|
||||
__attribute__((global)) void
|
||||
remove_padding(int64_t *x_remove_padding, const int64_t *input_data,
|
||||
const int *seq_lens, const int *cum_offsets,
|
||||
const int sequence_length, const int bs);
|
||||
__attribute__((global)) void get_padding_offset(int *padding_offset,
|
||||
int *cum_offsets_out,
|
||||
int *cu_seqlens_q,
|
||||
int *cu_seqlens_k,
|
||||
const int *cum_offsets,
|
||||
const int *seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs);
|
||||
__attribute__((global)) void remove_padding(int64_t *x_remove_padding,
|
||||
const int64_t *input_data,
|
||||
const int *seq_lens,
|
||||
const int *cum_offsets,
|
||||
const int sequence_length,
|
||||
const int bs);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int get_padding_offset_cpu(int *padding_offset, int *cum_offsets_out,
|
||||
int *cu_seqlens_q, int *cu_seqlens_k,
|
||||
const int *cum_offsets, const int *seq_lens,
|
||||
const int max_seq_len, const int bs) {
|
||||
for (int i = 0; i < bs; i++) {
|
||||
int cum_offset = i == 0 ? 0 : cum_offsets[i - 1];
|
||||
for (int j = 0; j < seq_lens[i]; j++) {
|
||||
padding_offset[i * max_seq_len - cum_offset + j] = cum_offset;
|
||||
}
|
||||
cum_offsets_out[i] = cum_offset;
|
||||
int cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i];
|
||||
cu_seqlens_q[i + 1] = cum_seq_len;
|
||||
cu_seqlens_k[i + 1] = cum_seq_len;
|
||||
static int get_padding_offset_cpu(int *padding_offset,
|
||||
int *cum_offsets_out,
|
||||
int *cu_seqlens_q,
|
||||
int *cu_seqlens_k,
|
||||
const int *cum_offsets,
|
||||
const int *seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs) {
|
||||
for (int i = 0; i < bs; i++) {
|
||||
int cum_offset = i == 0 ? 0 : cum_offsets[i - 1];
|
||||
for (int j = 0; j < seq_lens[i]; j++) {
|
||||
padding_offset[i * max_seq_len - cum_offset + j] = cum_offset;
|
||||
}
|
||||
return api::SUCCESS;
|
||||
cum_offsets_out[i] = cum_offset;
|
||||
int cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i];
|
||||
cu_seqlens_q[i + 1] = cum_seq_len;
|
||||
cu_seqlens_k[i + 1] = cum_seq_len;
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int remove_padding_cpu(int64_t *x_remove_padding,
|
||||
const int64_t *input_data, const int *seq_lens,
|
||||
const int *cum_offsets, const int sequence_length,
|
||||
const int64_t *input_data,
|
||||
const int *seq_lens,
|
||||
const int *cum_offsets,
|
||||
const int sequence_length,
|
||||
const int bs) {
|
||||
for (int i = 0; i < bs; i++) {
|
||||
for (int j = 0; j < seq_lens[i]; j++) {
|
||||
const int tgt_seq_id = i * sequence_length - cum_offsets[i] + j;
|
||||
const int src_seq_id = i * sequence_length + j;
|
||||
x_remove_padding[tgt_seq_id] = input_data[src_seq_id];
|
||||
}
|
||||
for (int i = 0; i < bs; i++) {
|
||||
for (int j = 0; j < seq_lens[i]; j++) {
|
||||
const int tgt_seq_id = i * sequence_length - cum_offsets[i] + j;
|
||||
const int src_seq_id = i * sequence_length + j;
|
||||
x_remove_padding[tgt_seq_id] = input_data[src_seq_id];
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int cpu_wrapper(Context *ctx, int *padding_offset, int *cum_offsets_out,
|
||||
int *cu_seqlens_q, int *cu_seqlens_k,
|
||||
int64_t *x_remove_padding, const int64_t *input_ids,
|
||||
const int *cum_offsets, const int *seq_lens,
|
||||
const int max_seq_len, const int bs) {
|
||||
get_padding_offset_cpu(padding_offset, cum_offsets_out, cu_seqlens_q,
|
||||
cu_seqlens_k, cum_offsets, seq_lens, max_seq_len,
|
||||
bs);
|
||||
remove_padding_cpu(x_remove_padding, input_ids, seq_lens, cum_offsets_out,
|
||||
max_seq_len, bs);
|
||||
return api::SUCCESS;
|
||||
static int cpu_wrapper(Context *ctx,
|
||||
int *padding_offset,
|
||||
int *cum_offsets_out,
|
||||
int *cu_seqlens_q,
|
||||
int *cu_seqlens_k,
|
||||
int64_t *x_remove_padding,
|
||||
const int64_t *input_ids,
|
||||
const int *cum_offsets,
|
||||
const int *seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs) {
|
||||
get_padding_offset_cpu(padding_offset,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
bs);
|
||||
remove_padding_cpu(
|
||||
x_remove_padding, input_ids, seq_lens, cum_offsets_out, max_seq_len, bs);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context *ctx, int *padding_offset, int *cum_offsets_out,
|
||||
int *cu_seqlens_q, int *cu_seqlens_k,
|
||||
int64_t *x_remove_padding, const int64_t *input_ids,
|
||||
const int *cum_offsets, const int *seq_lens,
|
||||
const int max_seq_len, const int bs) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto get_padding_offset = xpu3::plugin::get_padding_offset;
|
||||
auto remove_padding = xpu3::plugin::remove_padding;
|
||||
get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k,
|
||||
cum_offsets, seq_lens, max_seq_len, bs);
|
||||
remove_padding<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64 *>(x_remove_padding),
|
||||
reinterpret_cast<const XPU_INT64 *>(input_ids), seq_lens,
|
||||
cum_offsets_out, max_seq_len, bs);
|
||||
return api::SUCCESS;
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
int *padding_offset,
|
||||
int *cum_offsets_out,
|
||||
int *cu_seqlens_q,
|
||||
int *cu_seqlens_k,
|
||||
int64_t *x_remove_padding,
|
||||
const int64_t *input_ids,
|
||||
const int *cum_offsets,
|
||||
const int *seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto get_padding_offset = xpu3::plugin::get_padding_offset;
|
||||
auto remove_padding = xpu3::plugin::remove_padding;
|
||||
get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(padding_offset,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
bs);
|
||||
remove_padding<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64 *>(x_remove_padding),
|
||||
reinterpret_cast<const XPU_INT64 *>(input_ids),
|
||||
seq_lens,
|
||||
cum_offsets_out,
|
||||
max_seq_len,
|
||||
bs);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int get_padding_offset(Context *ctx, int *padding_offset, int *cum_offsets_out,
|
||||
int *cu_seqlens_q, int *cu_seqlens_k,
|
||||
int64_t *x_remove_padding, const int64_t *input_ids,
|
||||
const int *cum_offsets, const int *seq_lens,
|
||||
const int max_seq_len, const int bs) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "get_padding_offset", int);
|
||||
WRAPPER_DUMP_PARAM4(ctx, padding_offset, cum_offsets_out, cu_seqlens_q,
|
||||
cu_seqlens_k);
|
||||
WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, cum_offsets,
|
||||
seq_lens);
|
||||
WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bs);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx, padding_offset, cum_offsets_out, cu_seqlens_q,
|
||||
cu_seqlens_k, x_remove_padding, input_ids,
|
||||
cum_offsets, seq_lens, max_seq_len, bs);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx, padding_offset, cum_offsets_out, cu_seqlens_q,
|
||||
cu_seqlens_k, x_remove_padding, input_ids,
|
||||
cum_offsets, seq_lens, max_seq_len, bs);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
int get_padding_offset(Context *ctx,
|
||||
int *padding_offset,
|
||||
int *cum_offsets_out,
|
||||
int *cu_seqlens_q,
|
||||
int *cu_seqlens_k,
|
||||
int64_t *x_remove_padding,
|
||||
const int64_t *input_ids,
|
||||
const int *cum_offsets,
|
||||
const int *seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "get_padding_offset", int);
|
||||
WRAPPER_DUMP_PARAM4(
|
||||
ctx, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k);
|
||||
WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, cum_offsets, seq_lens);
|
||||
WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bs);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx,
|
||||
padding_offset,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
x_remove_padding,
|
||||
input_ids,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
bs);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
padding_offset,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
x_remove_padding,
|
||||
input_ids,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
bs);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
@@ -76,19 +76,20 @@ static int cpu_wrapper(
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context* ctx,
|
||||
const int64_t* base_model_draft_tokens,
|
||||
int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const bool* base_model_stop_flags,
|
||||
int bsz,
|
||||
int base_model_draft_token_len) {
|
||||
xpu3::plugin::draft_model_postprocess<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const xpu3::int64_t*>(base_model_draft_tokens),
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_stop_flags,
|
||||
bsz,
|
||||
base_model_draft_token_len);
|
||||
const int64_t* base_model_draft_tokens,
|
||||
int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const bool* base_model_stop_flags,
|
||||
int bsz,
|
||||
int base_model_draft_token_len) {
|
||||
xpu3::plugin::
|
||||
draft_model_postprocess<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const xpu3::int64_t*>(base_model_draft_tokens),
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_stop_flags,
|
||||
bsz,
|
||||
base_model_draft_token_len);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
@@ -124,12 +125,12 @@ int draft_model_postprocess(Context* ctx,
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
base_model_draft_tokens,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_stop_flags,
|
||||
bsz,
|
||||
base_model_draft_token_len);
|
||||
base_model_draft_tokens,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_stop_flags,
|
||||
bsz,
|
||||
base_model_draft_token_len);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
@@ -21,13 +21,18 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
template <typename T>
|
||||
__attribute__((global)) void
|
||||
set_stop_value_multi_ends(bool *stop_flags, T *topk_ids, T *next_tokens,
|
||||
const T *end_ids, const int *seq_lens, const int bs,
|
||||
const int end_length, const bool beam_search,
|
||||
const bool prefill_one_step_stop);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
__attribute__((global)) void set_stop_value_multi_ends(
|
||||
bool *stop_flags,
|
||||
T *topk_ids,
|
||||
T *next_tokens,
|
||||
const T *end_ids,
|
||||
const int *seq_lens,
|
||||
const int bs,
|
||||
const int end_length,
|
||||
const bool beam_search,
|
||||
const bool prefill_one_step_stop);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
@@ -36,104 +41,143 @@ namespace plugin {
|
||||
|
||||
template <typename T>
|
||||
__inline__ bool is_in_end(const T id, const T *end_ids, int length) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (id == end_ids[i]) {
|
||||
return true;
|
||||
}
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (id == end_ids[i]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int cpu_wrapper(Context *ctx, bool *stop_flags, T *topk_ids,
|
||||
T *next_tokens, const T *end_ids, const int *seq_lens,
|
||||
const int bs, const int end_length,
|
||||
static int cpu_wrapper(Context *ctx,
|
||||
bool *stop_flags,
|
||||
T *topk_ids,
|
||||
T *next_tokens,
|
||||
const T *end_ids,
|
||||
const int *seq_lens,
|
||||
const int bs,
|
||||
const int end_length,
|
||||
const bool beam_search,
|
||||
const bool prefill_one_step_stop) {
|
||||
for (int i = 0; i < bs; i++) {
|
||||
if (prefill_one_step_stop) {
|
||||
stop_flags[i] = true;
|
||||
if (seq_lens[i] == 0) {
|
||||
topk_ids[i] = -1;
|
||||
}
|
||||
next_tokens[i] = topk_ids[i];
|
||||
for (int i = 0; i < bs; i++) {
|
||||
if (prefill_one_step_stop) {
|
||||
stop_flags[i] = true;
|
||||
if (seq_lens[i] == 0) {
|
||||
topk_ids[i] = -1;
|
||||
}
|
||||
next_tokens[i] = topk_ids[i];
|
||||
} else {
|
||||
if (stop_flags[i]) {
|
||||
if (seq_lens[i] == 0) {
|
||||
topk_ids[i] = -1;
|
||||
} else {
|
||||
if (stop_flags[i]) {
|
||||
if (seq_lens[i] == 0) {
|
||||
topk_ids[i] = -1;
|
||||
} else {
|
||||
topk_ids[i] = end_ids[0];
|
||||
next_tokens[i] = end_ids[0];
|
||||
}
|
||||
} else {
|
||||
next_tokens[i] = topk_ids[i];
|
||||
}
|
||||
if (!beam_search && is_in_end(topk_ids[i], end_ids, end_length)) {
|
||||
stop_flags[i] = true;
|
||||
}
|
||||
topk_ids[i] = end_ids[0];
|
||||
next_tokens[i] = end_ids[0];
|
||||
}
|
||||
} else {
|
||||
next_tokens[i] = topk_ids[i];
|
||||
}
|
||||
if (!beam_search && is_in_end(topk_ids[i], end_ids, end_length)) {
|
||||
stop_flags[i] = true;
|
||||
}
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int xpu3_wrapper(Context *ctx, bool *stop_flags, T *topk_ids,
|
||||
T *next_tokens, const T *end_ids, const int *seq_lens,
|
||||
const int bs, const int end_length,
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
bool *stop_flags,
|
||||
T *topk_ids,
|
||||
T *next_tokens,
|
||||
const T *end_ids,
|
||||
const int *seq_lens,
|
||||
const int bs,
|
||||
const int end_length,
|
||||
const bool beam_search,
|
||||
const bool prefill_one_step_stop) {
|
||||
using XPU_TID = typename XPUIndexType<T>::type;
|
||||
auto set_stop_value_multi_ends =
|
||||
xpu3::plugin::set_stop_value_multi_ends<XPU_TID>;
|
||||
set_stop_value_multi_ends<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags, reinterpret_cast<XPU_TID *>(topk_ids),
|
||||
reinterpret_cast<XPU_TID *>(next_tokens),
|
||||
reinterpret_cast<const XPU_TID *>(end_ids), seq_lens, bs, end_length,
|
||||
beam_search, prefill_one_step_stop);
|
||||
return api::SUCCESS;
|
||||
using XPU_TID = typename XPUIndexType<T>::type;
|
||||
auto set_stop_value_multi_ends =
|
||||
xpu3::plugin::set_stop_value_multi_ends<XPU_TID>;
|
||||
set_stop_value_multi_ends<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
reinterpret_cast<XPU_TID *>(topk_ids),
|
||||
reinterpret_cast<XPU_TID *>(next_tokens),
|
||||
reinterpret_cast<const XPU_TID *>(end_ids),
|
||||
seq_lens,
|
||||
bs,
|
||||
end_length,
|
||||
beam_search,
|
||||
prefill_one_step_stop);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int set_stop_value_multi_ends(Context *ctx, bool *stop_flags, T *topk_ids,
|
||||
T *next_tokens, const T *end_ids,
|
||||
const int *seq_lens, const int bs,
|
||||
const int end_length, const bool beam_search) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "set_stop_value_multi_ends", T);
|
||||
WRAPPER_DUMP_PARAM5(ctx, stop_flags, topk_ids, next_tokens, end_ids,
|
||||
seq_lens);
|
||||
WRAPPER_DUMP_PARAM3(ctx, bs, end_length, beam_search);
|
||||
WRAPPER_DUMP(ctx);
|
||||
WRAPPER_CHECK_PTR(ctx, bool, bs, stop_flags);
|
||||
WRAPPER_CHECK_PTR(ctx, T, bs, topk_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, T, end_length, end_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, T, bs, seq_lens);
|
||||
WRAPPER_ASSERT_LE(ctx, end_length, 1024); // assume end_length <= 1024
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
||||
// std::cout << "Your PATH is: " << env_p << '\n';
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
int set_stop_value_multi_ends(Context *ctx,
|
||||
bool *stop_flags,
|
||||
T *topk_ids,
|
||||
T *next_tokens,
|
||||
const T *end_ids,
|
||||
const int *seq_lens,
|
||||
const int bs,
|
||||
const int end_length,
|
||||
const bool beam_search) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "set_stop_value_multi_ends", T);
|
||||
WRAPPER_DUMP_PARAM5(
|
||||
ctx, stop_flags, topk_ids, next_tokens, end_ids, seq_lens);
|
||||
WRAPPER_DUMP_PARAM3(ctx, bs, end_length, beam_search);
|
||||
WRAPPER_DUMP(ctx);
|
||||
WRAPPER_CHECK_PTR(ctx, bool, bs, stop_flags);
|
||||
WRAPPER_CHECK_PTR(ctx, T, bs, topk_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, T, end_length, end_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, T, bs, seq_lens);
|
||||
WRAPPER_ASSERT_LE(ctx, end_length, 1024); // assume end_length <= 1024
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
||||
// std::cout << "Your PATH is: " << env_p << '\n';
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper<T>(ctx, stop_flags, topk_ids, next_tokens, end_ids,
|
||||
seq_lens, bs, end_length, beam_search,
|
||||
prefill_one_step_stop);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper<T>(ctx, stop_flags, topk_ids, next_tokens, end_ids,
|
||||
seq_lens, bs, end_length, beam_search,
|
||||
prefill_one_step_stop);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper<T>(ctx,
|
||||
stop_flags,
|
||||
topk_ids,
|
||||
next_tokens,
|
||||
end_ids,
|
||||
seq_lens,
|
||||
bs,
|
||||
end_length,
|
||||
beam_search,
|
||||
prefill_one_step_stop);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper<T>(ctx,
|
||||
stop_flags,
|
||||
topk_ids,
|
||||
next_tokens,
|
||||
end_ids,
|
||||
seq_lens,
|
||||
bs,
|
||||
end_length,
|
||||
beam_search,
|
||||
prefill_one_step_stop);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
template int set_stop_value_multi_ends<int64_t>(
|
||||
Context *ctx, bool *stop_flags, int64_t *topk_ids, int64_t *next_tokens,
|
||||
const int64_t *end_ids, const int *seq_lens, const int bs,
|
||||
const int end_length, const bool beam_search);
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
template int set_stop_value_multi_ends<int64_t>(Context *ctx,
|
||||
bool *stop_flags,
|
||||
int64_t *topk_ids,
|
||||
int64_t *next_tokens,
|
||||
const int64_t *end_ids,
|
||||
const int *seq_lens,
|
||||
const int bs,
|
||||
const int end_length,
|
||||
const bool beam_search);
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
@@ -12,128 +12,173 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__attribute__((global)) void set_value_by_flags_and_idx(
|
||||
const bool *stop_flags, int64_t *pre_ids_all, const int64_t *input_ids,
|
||||
const int *seq_lens_encoder, const int *seq_lens_decoder,
|
||||
const int64_t *step_idx, int bs, int length, int length_input_ids);
|
||||
const bool *stop_flags,
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *input_ids,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *step_idx,
|
||||
int bs,
|
||||
int length,
|
||||
int length_input_ids);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(Context *ctx, const bool *stop_flags,
|
||||
int64_t *pre_ids_all, const int64_t *pre_ids,
|
||||
const int64_t *step_idx, const int bs,
|
||||
static int cpu_wrapper(Context *ctx,
|
||||
const bool *stop_flags,
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int bs,
|
||||
const int length) {
|
||||
for (int i = 0; i < bs; i++) {
|
||||
int64_t *pre_ids_all_now = pre_ids_all + i * length;
|
||||
if (!stop_flags[i] && step_idx[i] >= 0) {
|
||||
pre_ids_all_now[step_idx[i]] = pre_ids[i];
|
||||
}
|
||||
for (int i = 0; i < bs; i++) {
|
||||
int64_t *pre_ids_all_now = pre_ids_all + i * length;
|
||||
if (!stop_flags[i] && step_idx[i] >= 0) {
|
||||
pre_ids_all_now[step_idx[i]] = pre_ids[i];
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int cpu_wrapper(Context *ctx, const bool *stop_flags,
|
||||
int64_t *pre_ids_all, const int64_t *input_ids,
|
||||
const int *seq_lens_encoder, const int *seq_lens_decoder,
|
||||
const int64_t *step_idx, int bs, int length,
|
||||
static int cpu_wrapper(Context *ctx,
|
||||
const bool *stop_flags,
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *input_ids,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *step_idx,
|
||||
int bs,
|
||||
int length,
|
||||
int length_input_ids) {
|
||||
for (int i = 0; i < bs; i++) {
|
||||
if (!stop_flags[i]) {
|
||||
int64_t *pre_ids_all_now = pre_ids_all + i * length;
|
||||
const int64_t *input_ids_now = input_ids + i * length_input_ids;
|
||||
const int seq_len_dec = seq_lens_decoder[i];
|
||||
const int seq_len_enc = seq_lens_encoder[i];
|
||||
if (seq_len_dec == 0 && seq_len_enc == 0)
|
||||
continue;
|
||||
if (step_idx[i] >= 0) {
|
||||
if (seq_len_enc > 0) {
|
||||
// encoder, get last token accord to seq_lens_encoder
|
||||
pre_ids_all_now[step_idx[i]] =
|
||||
input_ids_now[seq_len_enc - 1];
|
||||
} else {
|
||||
// decoder, get first token
|
||||
pre_ids_all_now[step_idx[i]] = input_ids_now[0];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < bs; i++) {
|
||||
if (!stop_flags[i]) {
|
||||
int64_t *pre_ids_all_now = pre_ids_all + i * length;
|
||||
const int64_t *input_ids_now = input_ids + i * length_input_ids;
|
||||
const int seq_len_dec = seq_lens_decoder[i];
|
||||
const int seq_len_enc = seq_lens_encoder[i];
|
||||
if (seq_len_dec == 0 && seq_len_enc == 0) continue;
|
||||
if (step_idx[i] >= 0) {
|
||||
if (seq_len_enc > 0) {
|
||||
// encoder, get last token accord to seq_lens_encoder
|
||||
pre_ids_all_now[step_idx[i]] = input_ids_now[seq_len_enc - 1];
|
||||
} else {
|
||||
// decoder, get first token
|
||||
pre_ids_all_now[step_idx[i]] = input_ids_now[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context *ctx, const bool *stop_flags,
|
||||
int64_t *pre_ids_all, const int64_t *input_ids,
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
const bool *stop_flags,
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *input_ids,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder, const int64_t *step_idx,
|
||||
int bs, int length, int length_input_ids) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto set_value_by_flags_and_idx_kernel =
|
||||
xpu3::plugin::set_value_by_flags_and_idx;
|
||||
set_value_by_flags_and_idx_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags, reinterpret_cast<XPU_INT64 *>(pre_ids_all),
|
||||
reinterpret_cast<const XPU_INT64 *>(input_ids), seq_lens_encoder,
|
||||
seq_lens_decoder, reinterpret_cast<const XPU_INT64 *>(step_idx), bs,
|
||||
length, length_input_ids);
|
||||
return api::SUCCESS;
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *step_idx,
|
||||
int bs,
|
||||
int length,
|
||||
int length_input_ids) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto set_value_by_flags_and_idx_kernel =
|
||||
xpu3::plugin::set_value_by_flags_and_idx;
|
||||
set_value_by_flags_and_idx_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
reinterpret_cast<XPU_INT64 *>(pre_ids_all),
|
||||
reinterpret_cast<const XPU_INT64 *>(input_ids),
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
reinterpret_cast<const XPU_INT64 *>(step_idx),
|
||||
bs,
|
||||
length,
|
||||
length_input_ids);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int set_value_by_flags_and_idx(Context *ctx, const bool *stop_flags,
|
||||
int64_t *pre_ids_all, const int64_t *input_ids,
|
||||
int set_value_by_flags_and_idx(Context *ctx,
|
||||
const bool *stop_flags,
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *input_ids,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *step_idx, int bs, int length,
|
||||
const int64_t *step_idx,
|
||||
int bs,
|
||||
int length,
|
||||
int length_input_ids) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "set_value_by_flags_and_idx", int64_t);
|
||||
WRAPPER_DUMP_PARAM6(ctx, stop_flags, pre_ids_all, input_ids,
|
||||
seq_lens_encoder, seq_lens_decoder, step_idx);
|
||||
WRAPPER_DUMP_PARAM3(ctx, bs, length, length_input_ids);
|
||||
WRAPPER_DUMP(ctx);
|
||||
int64_t stop_flags_len = -1;
|
||||
int64_t pre_ids_all_len = -1;
|
||||
int64_t input_ids_len = -1;
|
||||
int64_t seq_lens_encoder_len = -1;
|
||||
int64_t seq_lens_decoder_len = -1;
|
||||
int64_t step_idx_len = -1;
|
||||
WRAPPER_CHECK_SHAPE(ctx, &stop_flags_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &pre_ids_all_len, {bs, length});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &input_ids_len, {bs, length_input_ids});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &seq_lens_encoder_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &seq_lens_decoder_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &step_idx_len, {bs});
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, stop_flags_len, stop_flags);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_all_len, pre_ids_all);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, input_ids_len, input_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, int, seq_lens_encoder_len, seq_lens_encoder);
|
||||
WRAPPER_CHECK_PTR(ctx, int, seq_lens_decoder_len, seq_lens_decoder);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, step_idx_len, step_idx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx, stop_flags, pre_ids_all, input_ids,
|
||||
seq_lens_encoder, seq_lens_decoder, step_idx, bs,
|
||||
length, length_input_ids);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx, stop_flags, pre_ids_all, input_ids,
|
||||
seq_lens_encoder, seq_lens_decoder, step_idx, bs,
|
||||
length, length_input_ids);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "set_value_by_flags_and_idx", int64_t);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
stop_flags,
|
||||
pre_ids_all,
|
||||
input_ids,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx);
|
||||
WRAPPER_DUMP_PARAM3(ctx, bs, length, length_input_ids);
|
||||
WRAPPER_DUMP(ctx);
|
||||
int64_t stop_flags_len = -1;
|
||||
int64_t pre_ids_all_len = -1;
|
||||
int64_t input_ids_len = -1;
|
||||
int64_t seq_lens_encoder_len = -1;
|
||||
int64_t seq_lens_decoder_len = -1;
|
||||
int64_t step_idx_len = -1;
|
||||
WRAPPER_CHECK_SHAPE(ctx, &stop_flags_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &pre_ids_all_len, {bs, length});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &input_ids_len, {bs, length_input_ids});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &seq_lens_encoder_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &seq_lens_decoder_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &step_idx_len, {bs});
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, stop_flags_len, stop_flags);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_all_len, pre_ids_all);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, input_ids_len, input_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, int, seq_lens_encoder_len, seq_lens_encoder);
|
||||
WRAPPER_CHECK_PTR(ctx, int, seq_lens_decoder_len, seq_lens_decoder);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, step_idx_len, step_idx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx,
|
||||
stop_flags,
|
||||
pre_ids_all,
|
||||
input_ids,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
bs,
|
||||
length,
|
||||
length_input_ids);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
stop_flags,
|
||||
pre_ids_all,
|
||||
input_ids,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
bs,
|
||||
length,
|
||||
length_input_ids);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
@@ -12,263 +12,367 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
template <typename T>
|
||||
__attribute__((global)) void
|
||||
min_length_logits_process(T *logits, const int64_t *cur_len,
|
||||
const int64_t *min_len, const int64_t *eos_token_id,
|
||||
const int64_t bs, const int64_t length,
|
||||
const int64_t length_id, const int64_t end_length);
|
||||
__attribute__((global)) void
|
||||
update_repeat_times(const int64_t *pre_ids, const int64_t *cur_len,
|
||||
int *repeat_times, const int64_t bs, const int64_t length,
|
||||
const int64_t length_id);
|
||||
__attribute__((global)) void min_length_logits_process(
|
||||
T *logits,
|
||||
const int64_t *cur_len,
|
||||
const int64_t *min_len,
|
||||
const int64_t *eos_token_id,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id,
|
||||
const int64_t end_length);
|
||||
__attribute__((global)) void update_repeat_times(const int64_t *pre_ids,
|
||||
const int64_t *cur_len,
|
||||
int *repeat_times,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id);
|
||||
template <typename T>
|
||||
__attribute__((global)) void
|
||||
update_value_by_repeat_times(const int *repeat_times, const T *penalty_scores,
|
||||
const T *frequency_score, const T *presence_score,
|
||||
const float *temperatures, T *logits,
|
||||
const int64_t bs, const int64_t length);
|
||||
__attribute__((global)) void update_value_by_repeat_times(
|
||||
const int *repeat_times,
|
||||
const T *penalty_scores,
|
||||
const T *frequency_score,
|
||||
const T *presence_score,
|
||||
const float *temperatures,
|
||||
T *logits,
|
||||
const int64_t bs,
|
||||
const int64_t length);
|
||||
template <typename T>
|
||||
__attribute__((global)) void update_value_by_repeat_times_simd(
|
||||
const int *repeat_times, const T *penalty_scores, const T *frequency_score,
|
||||
const T *presence_score, const float *temperatures, T *logits,
|
||||
const int64_t bs, const int64_t length);
|
||||
const int *repeat_times,
|
||||
const T *penalty_scores,
|
||||
const T *frequency_score,
|
||||
const T *presence_score,
|
||||
const float *temperatures,
|
||||
T *logits,
|
||||
const int64_t bs,
|
||||
const int64_t length);
|
||||
template <typename T>
|
||||
__attribute__((global)) void
|
||||
ban_bad_words(T *logits, const int64_t *bad_words_list, const int64_t bs,
|
||||
const int64_t length, const int64_t bad_words_length);
|
||||
__attribute__((global)) void ban_bad_words(T *logits,
|
||||
const int64_t *bad_words_list,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t bad_words_length);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
void update_repeat_times_cpu(const int64_t *pre_ids, const int64_t *cur_len,
|
||||
int *repeat_times, const int64_t bs,
|
||||
const int64_t length, const int64_t length_id) {
|
||||
for (int64_t i = 0; i < bs; i++) {
|
||||
if (cur_len[i] >= 0) {
|
||||
for (int64_t j = 0; j < length_id; j++) {
|
||||
int64_t id = pre_ids[i * length_id + j];
|
||||
if (id < 0 || id >= length)
|
||||
continue;
|
||||
repeat_times[i * length + id] += 1;
|
||||
}
|
||||
}
|
||||
void update_repeat_times_cpu(const int64_t *pre_ids,
|
||||
const int64_t *cur_len,
|
||||
int *repeat_times,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id) {
|
||||
for (int64_t i = 0; i < bs; i++) {
|
||||
if (cur_len[i] >= 0) {
|
||||
for (int64_t j = 0; j < length_id; j++) {
|
||||
int64_t id = pre_ids[i * length_id + j];
|
||||
if (id < 0 || id >= length) continue;
|
||||
repeat_times[i * length + id] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ban_bad_words_cpu(float *logits, const int64_t *bad_words_list,
|
||||
const int64_t bs, const int64_t length,
|
||||
void ban_bad_words_cpu(float *logits,
|
||||
const int64_t *bad_words_list,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t bad_words_length) {
|
||||
for (int64_t i = 0; i < bs; i++) {
|
||||
float *logits_now = logits + i * length;
|
||||
for (int64_t j = 0; j < bad_words_length; j++) {
|
||||
int64_t bad_words_token_id = bad_words_list[j];
|
||||
if (bad_words_token_id >= length || bad_words_token_id < 0)
|
||||
continue;
|
||||
logits_now[bad_words_token_id] = -1e10;
|
||||
}
|
||||
for (int64_t i = 0; i < bs; i++) {
|
||||
float *logits_now = logits + i * length;
|
||||
for (int64_t j = 0; j < bad_words_length; j++) {
|
||||
int64_t bad_words_token_id = bad_words_list[j];
|
||||
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
|
||||
logits_now[bad_words_token_id] = -1e10;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int cpu_wrapper(Context *ctx, const int64_t *pre_ids, T *logits,
|
||||
const T *penalty_scores, const T *frequency_scores,
|
||||
const T *presence_scores, const float *temperatures,
|
||||
const int64_t *cur_len, const int64_t *min_len,
|
||||
const int64_t *eos_token_id, const int64_t *bad_words,
|
||||
const int64_t bs, const int64_t length,
|
||||
const int64_t length_id, const int64_t end_length,
|
||||
static int cpu_wrapper(Context *ctx,
|
||||
const int64_t *pre_ids,
|
||||
T *logits,
|
||||
const T *penalty_scores,
|
||||
const T *frequency_scores,
|
||||
const T *presence_scores,
|
||||
const float *temperatures,
|
||||
const int64_t *cur_len,
|
||||
const int64_t *min_len,
|
||||
const int64_t *eos_token_id,
|
||||
const int64_t *bad_words,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id,
|
||||
const int64_t end_length,
|
||||
const int64_t length_bad_words) {
|
||||
std::vector<float> logitsfp32(bs * length);
|
||||
std::vector<float> penalty_scoresfp32(bs);
|
||||
std::vector<float> frequency_scoresfp32(bs);
|
||||
std::vector<float> presence_scoresfp32(bs);
|
||||
std::vector<int> repeat_times_buffer(bs * length, 0);
|
||||
int ret = api::cast<T, float>(ctx, logits, logitsfp32.data(), bs * length);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
ret =
|
||||
api::cast<T, float>(ctx, penalty_scores, penalty_scoresfp32.data(), bs);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
ret = api::cast<T, float>(ctx, frequency_scores,
|
||||
frequency_scoresfp32.data(), bs);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
ret = api::cast<T, float>(ctx, presence_scores, presence_scoresfp32.data(),
|
||||
bs);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
for (int64_t i = 0; i < bs; i++) {
|
||||
if (cur_len[i] >= 0 && cur_len[i] < min_len[i]) {
|
||||
for (int64_t j = 0; j < end_length; j++) {
|
||||
logitsfp32[i * length + eos_token_id[j]] = -1e4;
|
||||
}
|
||||
}
|
||||
std::vector<float> logitsfp32(bs * length);
|
||||
std::vector<float> penalty_scoresfp32(bs);
|
||||
std::vector<float> frequency_scoresfp32(bs);
|
||||
std::vector<float> presence_scoresfp32(bs);
|
||||
std::vector<int> repeat_times_buffer(bs * length, 0);
|
||||
int ret = api::cast<T, float>(ctx, logits, logitsfp32.data(), bs * length);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
ret = api::cast<T, float>(ctx, penalty_scores, penalty_scoresfp32.data(), bs);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
ret = api::cast<T, float>(
|
||||
ctx, frequency_scores, frequency_scoresfp32.data(), bs);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
ret =
|
||||
api::cast<T, float>(ctx, presence_scores, presence_scoresfp32.data(), bs);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
for (int64_t i = 0; i < bs; i++) {
|
||||
if (cur_len[i] >= 0 && cur_len[i] < min_len[i]) {
|
||||
for (int64_t j = 0; j < end_length; j++) {
|
||||
logitsfp32[i * length + eos_token_id[j]] = -1e4;
|
||||
}
|
||||
}
|
||||
int *repeat_times = &(repeat_times_buffer[0]);
|
||||
update_repeat_times_cpu(pre_ids, cur_len, repeat_times, bs, length,
|
||||
length_id);
|
||||
for (int64_t i = 0; i < bs; i++) {
|
||||
float alpha = penalty_scoresfp32[i];
|
||||
float beta = frequency_scoresfp32[i];
|
||||
float gamma = presence_scoresfp32[i];
|
||||
float temperature = temperatures[i];
|
||||
for (int64_t j = 0; j < length; j++) {
|
||||
int times = repeat_times[i * length + j];
|
||||
float logit_now = logitsfp32[i * length + j];
|
||||
if (times != 0) {
|
||||
logit_now =
|
||||
logit_now < 0 ? logit_now * alpha : logit_now / alpha;
|
||||
logit_now = logit_now - times * beta - gamma;
|
||||
}
|
||||
logitsfp32[i * length + j] = logit_now / temperature;
|
||||
}
|
||||
}
|
||||
int *repeat_times = &(repeat_times_buffer[0]);
|
||||
update_repeat_times_cpu(
|
||||
pre_ids, cur_len, repeat_times, bs, length, length_id);
|
||||
for (int64_t i = 0; i < bs; i++) {
|
||||
float alpha = penalty_scoresfp32[i];
|
||||
float beta = frequency_scoresfp32[i];
|
||||
float gamma = presence_scoresfp32[i];
|
||||
float temperature = temperatures[i];
|
||||
for (int64_t j = 0; j < length; j++) {
|
||||
int times = repeat_times[i * length + j];
|
||||
float logit_now = logitsfp32[i * length + j];
|
||||
if (times != 0) {
|
||||
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
|
||||
logit_now = logit_now - times * beta - gamma;
|
||||
}
|
||||
logitsfp32[i * length + j] = logit_now / temperature;
|
||||
}
|
||||
if (bad_words && length_bad_words > 0) {
|
||||
ban_bad_words_cpu(logitsfp32.data(), bad_words, bs, length,
|
||||
length_bad_words);
|
||||
}
|
||||
ret = api::cast<float, T>(ctx, logitsfp32.data(), logits, bs * length);
|
||||
return ret;
|
||||
}
|
||||
if (bad_words && length_bad_words > 0) {
|
||||
ban_bad_words_cpu(
|
||||
logitsfp32.data(), bad_words, bs, length, length_bad_words);
|
||||
}
|
||||
ret = api::cast<float, T>(ctx, logitsfp32.data(), logits, bs * length);
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int xpu3_wrapper(Context *ctx, const int64_t *pre_ids, T *logits,
|
||||
const T *penalty_scores, const T *frequency_scores,
|
||||
const T *presence_scores, const float *temperatures,
|
||||
const int64_t *cur_len, const int64_t *min_len,
|
||||
const int64_t *eos_token_id, const int64_t *bad_words,
|
||||
const int64_t bs, const int64_t length,
|
||||
const int64_t length_id, const int64_t end_length,
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
const int64_t *pre_ids,
|
||||
T *logits,
|
||||
const T *penalty_scores,
|
||||
const T *frequency_scores,
|
||||
const T *presence_scores,
|
||||
const float *temperatures,
|
||||
const int64_t *cur_len,
|
||||
const int64_t *min_len,
|
||||
const int64_t *eos_token_id,
|
||||
const int64_t *bad_words,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id,
|
||||
const int64_t end_length,
|
||||
const int64_t length_bad_words) {
|
||||
api::ctx_guard RAII_GUARD(ctx);
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto min_length_logits_process_kernel =
|
||||
xpu3::plugin::min_length_logits_process<T>;
|
||||
auto update_repeat_times_kernel = xpu3::plugin::update_repeat_times;
|
||||
auto update_value_by_repeat_times_kernel =
|
||||
xpu3::plugin::update_value_by_repeat_times<T>;
|
||||
if (length % 16 == 0) {
|
||||
update_value_by_repeat_times_kernel =
|
||||
xpu3::plugin::update_value_by_repeat_times_simd<T>;
|
||||
}
|
||||
auto ban_bad_words_kernel = xpu3::plugin::ban_bad_words<T>;
|
||||
api::ctx_guard RAII_GUARD(ctx);
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto min_length_logits_process_kernel =
|
||||
xpu3::plugin::min_length_logits_process<T>;
|
||||
auto update_repeat_times_kernel = xpu3::plugin::update_repeat_times;
|
||||
auto update_value_by_repeat_times_kernel =
|
||||
xpu3::plugin::update_value_by_repeat_times<T>;
|
||||
if (length % 16 == 0) {
|
||||
update_value_by_repeat_times_kernel =
|
||||
xpu3::plugin::update_value_by_repeat_times_simd<T>;
|
||||
}
|
||||
auto ban_bad_words_kernel = xpu3::plugin::ban_bad_words<T>;
|
||||
|
||||
int *repeat_times = RAII_GUARD.alloc_l3_or_gm<int>(bs * length);
|
||||
WRAPPER_ASSERT_WORKSPACE(ctx, repeat_times);
|
||||
int ret = api::constant<int>(ctx, repeat_times, bs * length, 0);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
int *repeat_times = RAII_GUARD.alloc_l3_or_gm<int>(bs * length);
|
||||
WRAPPER_ASSERT_WORKSPACE(ctx, repeat_times);
|
||||
int ret = api::constant<int>(ctx, repeat_times, bs * length, 0);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
|
||||
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const XPU_INT64 *>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(cur_len), repeat_times, bs, length,
|
||||
length_id);
|
||||
min_length_logits_process_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
logits, reinterpret_cast<const XPU_INT64 *>(cur_len),
|
||||
reinterpret_cast<const XPU_INT64 *>(min_len),
|
||||
reinterpret_cast<const XPU_INT64 *>(eos_token_id), bs, length,
|
||||
length_id, end_length);
|
||||
update_value_by_repeat_times_kernel<<<ctx->ncluster(), 64,
|
||||
ctx->xpu_stream>>>(
|
||||
repeat_times, penalty_scores, frequency_scores, presence_scores,
|
||||
temperatures, logits, bs, length);
|
||||
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const XPU_INT64 *>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(cur_len),
|
||||
repeat_times,
|
||||
bs,
|
||||
length,
|
||||
length_id);
|
||||
min_length_logits_process_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
logits,
|
||||
reinterpret_cast<const XPU_INT64 *>(cur_len),
|
||||
reinterpret_cast<const XPU_INT64 *>(min_len),
|
||||
reinterpret_cast<const XPU_INT64 *>(eos_token_id),
|
||||
bs,
|
||||
length,
|
||||
length_id,
|
||||
end_length);
|
||||
update_value_by_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
repeat_times,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
logits,
|
||||
bs,
|
||||
length);
|
||||
|
||||
if (bad_words && length_bad_words > 0) {
|
||||
ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
logits, reinterpret_cast<const XPU_INT64 *>(bad_words), bs, length,
|
||||
length_bad_words);
|
||||
}
|
||||
return api::SUCCESS;
|
||||
if (bad_words && length_bad_words > 0) {
|
||||
ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
logits,
|
||||
reinterpret_cast<const XPU_INT64 *>(bad_words),
|
||||
bs,
|
||||
length,
|
||||
length_bad_words);
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int token_penalty_multi_scores(
|
||||
Context *ctx, const int64_t *pre_ids, T *logits, const T *penalty_scores,
|
||||
const T *frequency_scores, const T *presence_scores,
|
||||
const float *temperatures, const int64_t *cur_len, const int64_t *min_len,
|
||||
const int64_t *eos_token_id, const int64_t *bad_words, const int64_t bs,
|
||||
const int64_t length, const int64_t length_id, const int64_t end_length,
|
||||
const int64_t length_bad_words) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "token_penalty_multi_scores", T);
|
||||
WRAPPER_DUMP_PARAM6(ctx, pre_ids, logits, penalty_scores, frequency_scores,
|
||||
presence_scores, temperatures);
|
||||
WRAPPER_DUMP_PARAM3(ctx, cur_len, min_len, eos_token_id);
|
||||
WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length);
|
||||
WRAPPER_DUMP(ctx);
|
||||
// TODO(mayang02) shape check
|
||||
int64_t pre_ids_len = -1;
|
||||
int64_t logits_len = -1;
|
||||
int64_t penalty_scores_len = -1;
|
||||
int64_t frequency_scores_len = -1;
|
||||
int64_t presence_scores_len = -1;
|
||||
int64_t temperatures_len = -1;
|
||||
int64_t cur_len_len = -1;
|
||||
int64_t min_len_len = -1;
|
||||
int64_t eos_token_id_len = -1;
|
||||
int64_t bad_words_len = -1;
|
||||
WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &logits_len, {bs, length});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &penalty_scores_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &frequency_scores_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &presence_scores_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &temperatures_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &cur_len_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words});
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, T, logits_len, logits);
|
||||
WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores);
|
||||
WRAPPER_CHECK_PTR(ctx, T, frequency_scores_len, frequency_scores);
|
||||
WRAPPER_CHECK_PTR(ctx, T, presence_scores_len, presence_scores);
|
||||
WRAPPER_CHECK_PTR(ctx, float, temperatures_len, temperatures);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, cur_len_len, cur_len);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper<T>(ctx, pre_ids, logits, penalty_scores,
|
||||
frequency_scores, presence_scores, temperatures,
|
||||
cur_len, min_len, eos_token_id, bad_words, bs,
|
||||
length, length_id, end_length, length_bad_words);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper<T>(ctx, pre_ids, logits, penalty_scores,
|
||||
frequency_scores, presence_scores, temperatures,
|
||||
cur_len, min_len, eos_token_id, bad_words, bs,
|
||||
length, length_id, end_length, length_bad_words);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
int token_penalty_multi_scores(Context *ctx,
|
||||
const int64_t *pre_ids,
|
||||
T *logits,
|
||||
const T *penalty_scores,
|
||||
const T *frequency_scores,
|
||||
const T *presence_scores,
|
||||
const float *temperatures,
|
||||
const int64_t *cur_len,
|
||||
const int64_t *min_len,
|
||||
const int64_t *eos_token_id,
|
||||
const int64_t *bad_words,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id,
|
||||
const int64_t end_length,
|
||||
const int64_t length_bad_words) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "token_penalty_multi_scores", T);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures);
|
||||
WRAPPER_DUMP_PARAM3(ctx, cur_len, min_len, eos_token_id);
|
||||
WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length);
|
||||
WRAPPER_DUMP(ctx);
|
||||
// TODO(mayang02) shape check
|
||||
int64_t pre_ids_len = -1;
|
||||
int64_t logits_len = -1;
|
||||
int64_t penalty_scores_len = -1;
|
||||
int64_t frequency_scores_len = -1;
|
||||
int64_t presence_scores_len = -1;
|
||||
int64_t temperatures_len = -1;
|
||||
int64_t cur_len_len = -1;
|
||||
int64_t min_len_len = -1;
|
||||
int64_t eos_token_id_len = -1;
|
||||
int64_t bad_words_len = -1;
|
||||
WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &logits_len, {bs, length});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &penalty_scores_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &frequency_scores_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &presence_scores_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &temperatures_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &cur_len_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length});
|
||||
WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words});
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, T, logits_len, logits);
|
||||
WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores);
|
||||
WRAPPER_CHECK_PTR(ctx, T, frequency_scores_len, frequency_scores);
|
||||
WRAPPER_CHECK_PTR(ctx, T, presence_scores_len, presence_scores);
|
||||
WRAPPER_CHECK_PTR(ctx, float, temperatures_len, temperatures);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, cur_len_len, cur_len);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper<T>(ctx,
|
||||
pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
bad_words,
|
||||
bs,
|
||||
length,
|
||||
length_id,
|
||||
end_length,
|
||||
length_bad_words);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper<T>(ctx,
|
||||
pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
bad_words,
|
||||
bs,
|
||||
length,
|
||||
length_id,
|
||||
end_length,
|
||||
length_bad_words);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
template int token_penalty_multi_scores<float>(
|
||||
Context *ctx, const int64_t *pre_ids, float *logits,
|
||||
const float *penalty_scores, const float *frequency_scores,
|
||||
const float *presence_scores, const float *temperatures,
|
||||
const int64_t *cur_len, const int64_t *min_len, const int64_t *eos_token_id,
|
||||
const int64_t *bad_words, const int64_t bs, const int64_t length,
|
||||
const int64_t length_id, const int64_t end_length,
|
||||
const int64_t length_bad_words);
|
||||
template int token_penalty_multi_scores<float>(Context *ctx,
|
||||
const int64_t *pre_ids,
|
||||
float *logits,
|
||||
const float *penalty_scores,
|
||||
const float *frequency_scores,
|
||||
const float *presence_scores,
|
||||
const float *temperatures,
|
||||
const int64_t *cur_len,
|
||||
const int64_t *min_len,
|
||||
const int64_t *eos_token_id,
|
||||
const int64_t *bad_words,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id,
|
||||
const int64_t end_length,
|
||||
const int64_t length_bad_words);
|
||||
template int token_penalty_multi_scores<float16>(
|
||||
Context *ctx, const int64_t *pre_ids, float16 *logits,
|
||||
const float16 *penalty_scores, const float16 *frequency_scores,
|
||||
const float16 *presence_scores, const float *temperatures,
|
||||
const int64_t *cur_len, const int64_t *min_len, const int64_t *eos_token_id,
|
||||
const int64_t *bad_words, const int64_t bs, const int64_t length,
|
||||
const int64_t length_id, const int64_t end_length,
|
||||
Context *ctx,
|
||||
const int64_t *pre_ids,
|
||||
float16 *logits,
|
||||
const float16 *penalty_scores,
|
||||
const float16 *frequency_scores,
|
||||
const float16 *presence_scores,
|
||||
const float *temperatures,
|
||||
const int64_t *cur_len,
|
||||
const int64_t *min_len,
|
||||
const int64_t *eos_token_id,
|
||||
const int64_t *bad_words,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id,
|
||||
const int64_t end_length,
|
||||
const int64_t length_bad_words);
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
@@ -20,21 +20,18 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
template <typename TX, typename TSCALE, typename TY>
|
||||
__attribute__((global)) void
|
||||
quant2d_per_channel_cluster(const TX *x, const TSCALE *scale, TY *y, int64_t m,
|
||||
int64_t n);
|
||||
__attribute__((global)) void quant2d_per_channel_cluster(
|
||||
const TX *x, const TSCALE *scale, TY *y, int64_t m, int64_t n);
|
||||
|
||||
template <typename TX, typename TSCALE, typename TY, int MAX_N>
|
||||
__attribute__((global)) void
|
||||
quant2d_per_channel_cached(const TX *input, TY *output, TSCALE *scale,
|
||||
int64_t m, int64_t n);
|
||||
__attribute__((global)) void quant2d_per_channel_cached(
|
||||
const TX *input, TY *output, TSCALE *scale, int64_t m, int64_t n);
|
||||
|
||||
template <typename TX, typename TSCALE, typename TY>
|
||||
__attribute__((global)) void quant2d_per_channel_bign(const TX *input,
|
||||
TY *output, TSCALE *scale,
|
||||
int64_t m, int64_t n);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
__attribute__((global)) void quant2d_per_channel_bign(
|
||||
const TX *input, TY *output, TSCALE *scale, int64_t m, int64_t n);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace api = baidu::xpu::api;
|
||||
|
||||
@@ -43,11 +40,17 @@ namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
template <typename TX, typename TSCALE, typename TY,
|
||||
template <typename TX,
|
||||
typename TSCALE,
|
||||
typename TY,
|
||||
typename std::enable_if<std::is_same<TY, int8_t>::value, TY>::type
|
||||
*ptr = nullptr>
|
||||
int cpu_wrapper_input_scale(api::Context *ctx, const TX *x, const TSCALE *scale,
|
||||
TY *y, int64_t m, int64_t n) {
|
||||
int cpu_wrapper_input_scale(api::Context *ctx,
|
||||
const TX *x,
|
||||
const TSCALE *scale,
|
||||
TY *y,
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
float absmax = 1e-30f;
|
||||
for (int i = 0; i < m; ++i) {
|
||||
for (int j = 0; j < n; ++j) {
|
||||
@@ -78,11 +81,13 @@ static float16 quant_int4(float x, float scale) {
|
||||
return (float16)std::min(static_cast<float>(r), 7.f);
|
||||
}
|
||||
|
||||
template <typename TX, typename TSCALE, typename TY,
|
||||
template <typename TX,
|
||||
typename TSCALE,
|
||||
typename TY,
|
||||
typename std::enable_if<std::is_same<TY, int4_t>::value, TY>::type
|
||||
*ptr = nullptr>
|
||||
int cpu_wrapper_input_scale(api::Context *ctx, const TX *x, const TSCALE *scale,
|
||||
TY *y, int m, int n) {
|
||||
int cpu_wrapper_input_scale(
|
||||
api::Context *ctx, const TX *x, const TSCALE *scale, TY *y, int m, int n) {
|
||||
int8_t *y_ptr = reinterpret_cast<int8_t *>(y);
|
||||
float t1, t2;
|
||||
for (int i = 0; i < m; ++i) {
|
||||
@@ -109,11 +114,17 @@ int cpu_wrapper_input_scale(api::Context *ctx, const TX *x, const TSCALE *scale,
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
template <typename TX, typename TSCALE, typename TY,
|
||||
template <typename TX,
|
||||
typename TSCALE,
|
||||
typename TY,
|
||||
typename std::enable_if<!std::is_same<TY, int4_t>::value, TY>::type
|
||||
*ptr = nullptr>
|
||||
int cpu_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
|
||||
TY *y, int64_t m, int64_t n) {
|
||||
int cpu_wrapper_output_scale(api::Context *ctx,
|
||||
const TX *x,
|
||||
TSCALE *scale,
|
||||
TY *y,
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
int64_t i, j;
|
||||
for (j = 0; j < n; ++j) {
|
||||
float absmax = 1e-30f;
|
||||
@@ -129,11 +140,13 @@ int cpu_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
template <typename TX, typename TSCALE, typename TY,
|
||||
template <typename TX,
|
||||
typename TSCALE,
|
||||
typename TY,
|
||||
typename std::enable_if<std::is_same<TY, int4_t>::value, TY>::type
|
||||
*ptr = nullptr>
|
||||
int cpu_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
|
||||
TY *y, int m, int n) {
|
||||
int cpu_wrapper_output_scale(
|
||||
api::Context *ctx, const TX *x, TSCALE *scale, TY *y, int m, int n) {
|
||||
int8_t *y_ptr = reinterpret_cast<int8_t *>(y);
|
||||
float t1, t2, absmax_1, absmax_2, act_scale_1, act_scale_2;
|
||||
for (int j = 0; j < n; j += 2) {
|
||||
@@ -173,18 +186,28 @@ int cpu_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
|
||||
}
|
||||
|
||||
template <typename TX, typename TSCALE, typename TY>
|
||||
int xpu3_wrapper_input_scale(api::Context *ctx, const TX *x,
|
||||
const TSCALE *scale, TY *y, int64_t m, int64_t n) {
|
||||
int xpu3_wrapper_input_scale(api::Context *ctx,
|
||||
const TX *x,
|
||||
const TSCALE *scale,
|
||||
TY *y,
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
auto func = xpu3::plugin::quant2d_per_channel_cluster<TX, TSCALE, TY>;
|
||||
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, scale, y, m, n);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
template <typename TX, typename TSCALE, typename TY,
|
||||
template <typename TX,
|
||||
typename TSCALE,
|
||||
typename TY,
|
||||
typename std::enable_if<!std::is_same<TY, int4_t>::value, TY>::type
|
||||
* = nullptr>
|
||||
int xpu3_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
|
||||
TY *y, int64_t m, int64_t n) {
|
||||
int xpu3_wrapper_output_scale(api::Context *ctx,
|
||||
const TX *x,
|
||||
TSCALE *scale,
|
||||
TY *y,
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
int64_t channel_size = m * sizeof(TX);
|
||||
int64_t cluster_n = (n + ctx->ncluster() - 1) / ctx->ncluster();
|
||||
auto func = xpu3::plugin::quant2d_per_channel_bign<TX, TSCALE, TY>;
|
||||
@@ -210,19 +233,30 @@ int xpu3_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
|
||||
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
template <typename TX, typename TSCALE, typename TY,
|
||||
template <typename TX,
|
||||
typename TSCALE,
|
||||
typename TY,
|
||||
typename std::enable_if<std::is_same<TY, int4_t>::value, TY>::type * =
|
||||
nullptr>
|
||||
int xpu3_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale,
|
||||
TY *y, int64_t m, int64_t n) {
|
||||
int xpu3_wrapper_output_scale(api::Context *ctx,
|
||||
const TX *x,
|
||||
TSCALE *scale,
|
||||
TY *y,
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
auto func = xpu3::plugin::quant2d_per_channel_bign<TX, TSCALE, TY>;
|
||||
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
template <typename TX, typename TSCALE, typename TY>
|
||||
int quant2d_per_channel(api::Context *ctx, const TX *x, const TSCALE *scale_in,
|
||||
TY *y, TSCALE *scale_out, int64_t m, int64_t n) {
|
||||
int quant2d_per_channel(api::Context *ctx,
|
||||
const TX *x,
|
||||
const TSCALE *scale_in,
|
||||
TY *y,
|
||||
TSCALE *scale_out,
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T3(ctx, "quant2d_per_channel", TX, TSCALE, TY);
|
||||
WRAPPER_DUMP_PARAM4(ctx, x, scale_in, y, scale_out);
|
||||
@@ -251,20 +285,24 @@ int quant2d_per_channel(api::Context *ctx, const TX *x, const TSCALE *scale_in,
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
if (scale_in != nullptr) {
|
||||
return xpu3_wrapper_input_scale<TX, TSCALE, TY>(ctx, x, scale_in, y, m,
|
||||
n);
|
||||
return xpu3_wrapper_input_scale<TX, TSCALE, TY>(
|
||||
ctx, x, scale_in, y, m, n);
|
||||
}
|
||||
return xpu3_wrapper_output_scale<TX, TSCALE, TY>(ctx, x, scale_out, y, m,
|
||||
n);
|
||||
return xpu3_wrapper_output_scale<TX, TSCALE, TY>(
|
||||
ctx, x, scale_out, y, m, n);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define INSTANTIATION_QUANT2D_PER_CHANNEL(TX, TSCALE, TY) \
|
||||
template int quant2d_per_channel<TX, TSCALE, TY>( \
|
||||
api::Context *, const TX *, const TSCALE *, TY *, TSCALE *, int64_t, \
|
||||
int64_t);
|
||||
#define INSTANTIATION_QUANT2D_PER_CHANNEL(TX, TSCALE, TY) \
|
||||
template int quant2d_per_channel<TX, TSCALE, TY>(api::Context *, \
|
||||
const TX *, \
|
||||
const TSCALE *, \
|
||||
TY *, \
|
||||
TSCALE *, \
|
||||
int64_t, \
|
||||
int64_t);
|
||||
|
||||
INSTANTIATION_QUANT2D_PER_CHANNEL(float16, float, int8_t);
|
||||
INSTANTIATION_QUANT2D_PER_CHANNEL(bfloat16, float, int8_t);
|
||||
@@ -274,7 +312,7 @@ INSTANTIATION_QUANT2D_PER_CHANNEL(float16, float16, int4_t);
|
||||
INSTANTIATION_QUANT2D_PER_CHANNEL(float16, float, int4_t);
|
||||
INSTANTIATION_QUANT2D_PER_CHANNEL(float, float, int4_t);
|
||||
INSTANTIATION_QUANT2D_PER_CHANNEL(bfloat16, float, int4_t);
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
@@ -12,28 +12,38 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__attribute__((global)) void
|
||||
recover_block(int *recover_block_list, // [bsz]
|
||||
int *recover_len, bool *stop_flags, int *seq_lens_this_time,
|
||||
const int *ori_seq_lens_encoder, int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder, int *block_tables, int *free_list,
|
||||
int *free_list_len, int64_t *input_ids, const int64_t *pre_ids,
|
||||
const int64_t *step_idx, const int *encoder_block_lens,
|
||||
const int *used_list_len, const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids, const int bsz,
|
||||
const int block_num_per_seq, const int length,
|
||||
const int pre_id_length);
|
||||
__attribute__((global)) void recover_block(int *recover_block_list, // [bsz]
|
||||
int *recover_len,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
const int *ori_seq_lens_encoder,
|
||||
int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *input_ids,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int *encoder_block_lens,
|
||||
const int *used_list_len,
|
||||
const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
@@ -41,125 +51,207 @@ namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(Context *ctx,
|
||||
int *recover_block_list, // [bsz]
|
||||
int *recover_len, bool *stop_flags,
|
||||
int *seq_lens_this_time, const int *ori_seq_lens_encoder,
|
||||
int *seq_lens_encoder, const int *seq_lens_decoder,
|
||||
int *block_tables, int *free_list, int *free_list_len,
|
||||
int64_t *input_ids, const int64_t *pre_ids,
|
||||
const int64_t *step_idx, const int *encoder_block_lens,
|
||||
const int *used_list_len, const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids, const int bsz,
|
||||
const int block_num_per_seq, const int length,
|
||||
int *recover_block_list, // [bsz]
|
||||
int *recover_len,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
const int *ori_seq_lens_encoder,
|
||||
int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *input_ids,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int *encoder_block_lens,
|
||||
const int *used_list_len,
|
||||
const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length) {
|
||||
for (int bid = 0; bid < recover_len[0]; bid++) {
|
||||
const int recover_id = recover_block_list[bid];
|
||||
const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id];
|
||||
const int step_idx_now = step_idx[recover_id];
|
||||
const int seq_len = ori_seq_len_encoder + step_idx_now;
|
||||
const int encoder_block_len = encoder_block_lens[recover_id];
|
||||
const int decoder_used_len = used_list_len[recover_id];
|
||||
int *block_table_now = block_tables + recover_id * block_num_per_seq;
|
||||
int64_t *input_ids_now = input_ids + recover_id * length;
|
||||
const int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length;
|
||||
for (int bid = 0; bid < recover_len[0]; bid++) {
|
||||
const int recover_id = recover_block_list[bid];
|
||||
const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id];
|
||||
const int step_idx_now = step_idx[recover_id];
|
||||
const int seq_len = ori_seq_len_encoder + step_idx_now;
|
||||
const int encoder_block_len = encoder_block_lens[recover_id];
|
||||
const int decoder_used_len = used_list_len[recover_id];
|
||||
int *block_table_now = block_tables + recover_id * block_num_per_seq;
|
||||
int64_t *input_ids_now = input_ids + recover_id * length;
|
||||
const int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length;
|
||||
|
||||
seq_lens_this_time[recover_id] = seq_len;
|
||||
seq_lens_encoder[recover_id] = seq_len;
|
||||
stop_flags[recover_id] = false;
|
||||
input_ids_now[seq_len - 1] = next_tokens[recover_id]; // next tokens
|
||||
input_ids_now[0] =
|
||||
first_token_ids[recover_id]; // set first prompt token
|
||||
int ori_free_list_len = free_list_len[0];
|
||||
free_list_len[0] -= decoder_used_len;
|
||||
seq_lens_this_time[recover_id] = seq_len;
|
||||
seq_lens_encoder[recover_id] = seq_len;
|
||||
stop_flags[recover_id] = false;
|
||||
input_ids_now[seq_len - 1] = next_tokens[recover_id]; // next tokens
|
||||
input_ids_now[0] = first_token_ids[recover_id]; // set first prompt token
|
||||
int ori_free_list_len = free_list_len[0];
|
||||
free_list_len[0] -= decoder_used_len;
|
||||
|
||||
// 恢复block table
|
||||
for (int i = 0; i < decoder_used_len; i++) {
|
||||
block_table_now[encoder_block_len + i] =
|
||||
free_list[ori_free_list_len - i - 1];
|
||||
}
|
||||
// 恢复input_ids
|
||||
for (int i = 0; i < step_idx_now - 1; i++) {
|
||||
input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1];
|
||||
}
|
||||
// 恢复block table
|
||||
for (int i = 0; i < decoder_used_len; i++) {
|
||||
block_table_now[encoder_block_len + i] =
|
||||
free_list[ori_free_list_len - i - 1];
|
||||
}
|
||||
recover_len[0] = 0;
|
||||
return api::SUCCESS;
|
||||
// 恢复input_ids
|
||||
for (int i = 0; i < step_idx_now - 1; i++) {
|
||||
input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1];
|
||||
}
|
||||
}
|
||||
recover_len[0] = 0;
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
int *recover_block_list, // [bsz]
|
||||
int *recover_len, bool *stop_flags,
|
||||
int *recover_block_list, // [bsz]
|
||||
int *recover_len,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
const int *ori_seq_lens_encoder, int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder, int *block_tables,
|
||||
int *free_list, int *free_list_len, int64_t *input_ids,
|
||||
const int64_t *pre_ids, const int64_t *step_idx,
|
||||
const int *encoder_block_lens, const int *used_list_len,
|
||||
const int *ori_seq_lens_encoder,
|
||||
int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *input_ids,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int *encoder_block_lens,
|
||||
const int *used_list_len,
|
||||
const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids, const int bsz,
|
||||
const int block_num_per_seq, const int length,
|
||||
const int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto recover_block_kernel = xpu3::plugin::recover_block;
|
||||
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
recover_block_list, // [bsz]
|
||||
recover_len, stop_flags, seq_lens_this_time, ori_seq_lens_encoder,
|
||||
seq_lens_encoder, seq_lens_decoder, block_tables, free_list,
|
||||
free_list_len, reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(step_idx), encoder_block_lens,
|
||||
used_list_len, reinterpret_cast<const XPU_INT64 *>(next_tokens),
|
||||
reinterpret_cast<const XPU_INT64 *>(first_token_ids), bsz,
|
||||
block_num_per_seq, length, pre_id_length);
|
||||
return api::SUCCESS;
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto recover_block_kernel = xpu3::plugin::recover_block;
|
||||
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
recover_block_list, // [bsz]
|
||||
recover_len,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(step_idx),
|
||||
encoder_block_lens,
|
||||
used_list_len,
|
||||
reinterpret_cast<const XPU_INT64 *>(next_tokens),
|
||||
reinterpret_cast<const XPU_INT64 *>(first_token_ids),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int recover_block(Context *ctx,
|
||||
int *recover_block_list, // [bsz]
|
||||
int *recover_len, bool *stop_flags, int *seq_lens_this_time,
|
||||
const int *ori_seq_lens_encoder, int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder, int *block_tables,
|
||||
int *free_list, int *free_list_len, int64_t *input_ids,
|
||||
const int64_t *pre_ids, const int64_t *step_idx,
|
||||
const int *encoder_block_lens, const int *used_list_len,
|
||||
const int64_t *next_tokens, const int64_t *first_token_ids,
|
||||
const int bsz, const int block_num_per_seq, const int length,
|
||||
int *recover_block_list, // [bsz]
|
||||
int *recover_len,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
const int *ori_seq_lens_encoder,
|
||||
int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *input_ids,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int *encoder_block_lens,
|
||||
const int *used_list_len,
|
||||
const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_block", float);
|
||||
WRAPPER_DUMP_PARAM6(ctx, recover_block_list, recover_len, stop_flags,
|
||||
seq_lens_this_time, ori_seq_lens_encoder,
|
||||
seq_lens_encoder);
|
||||
WRAPPER_DUMP_PARAM6(ctx, seq_lens_decoder, block_tables, free_list,
|
||||
free_list_len, input_ids, pre_ids);
|
||||
WRAPPER_DUMP_PARAM5(ctx, step_idx, encoder_block_lens, used_list_len,
|
||||
next_tokens, first_token_ids);
|
||||
WRAPPER_DUMP_PARAM4(ctx, bsz, block_num_per_seq, length, pre_id_length);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(
|
||||
ctx,
|
||||
recover_block_list, // [bsz]
|
||||
recover_len, stop_flags, seq_lens_this_time, ori_seq_lens_encoder,
|
||||
seq_lens_encoder, seq_lens_decoder, block_tables, free_list,
|
||||
free_list_len, input_ids, pre_ids, step_idx, encoder_block_lens,
|
||||
used_list_len, next_tokens, first_token_ids, bsz, block_num_per_seq,
|
||||
length, pre_id_length);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(
|
||||
ctx,
|
||||
recover_block_list, // [bsz]
|
||||
recover_len, stop_flags, seq_lens_this_time, ori_seq_lens_encoder,
|
||||
seq_lens_encoder, seq_lens_decoder, block_tables, free_list,
|
||||
free_list_len, input_ids, pre_ids, step_idx, encoder_block_lens,
|
||||
used_list_len, next_tokens, first_token_ids, bsz, block_num_per_seq,
|
||||
length, pre_id_length);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_block", float);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
seq_lens_encoder);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
free_list,
|
||||
free_list_len,
|
||||
input_ids,
|
||||
pre_ids);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
step_idx,
|
||||
encoder_block_lens,
|
||||
used_list_len,
|
||||
next_tokens,
|
||||
first_token_ids);
|
||||
WRAPPER_DUMP_PARAM4(ctx, bsz, block_num_per_seq, length, pre_id_length);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx,
|
||||
recover_block_list, // [bsz]
|
||||
recover_len,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
free_list,
|
||||
free_list_len,
|
||||
input_ids,
|
||||
pre_ids,
|
||||
step_idx,
|
||||
encoder_block_lens,
|
||||
used_list_len,
|
||||
next_tokens,
|
||||
first_token_ids,
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
recover_block_list, // [bsz]
|
||||
recover_len,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
free_list,
|
||||
free_list_len,
|
||||
input_ids,
|
||||
pre_ids,
|
||||
step_idx,
|
||||
encoder_block_lens,
|
||||
used_list_len,
|
||||
next_tokens,
|
||||
first_token_ids,
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
@@ -12,96 +12,102 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__attribute__((global)) void
|
||||
recover_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size);
|
||||
__attribute__((global)) void recover_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int xpu3_wrapper(Context *ctx, bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto recover_decode_task = xpu3::plugin::recover_decode_task;
|
||||
recover_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
block_tables,
|
||||
is_block_step,
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
return api::SUCCESS;
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto recover_decode_task = xpu3::plugin::recover_decode_task;
|
||||
recover_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
block_tables,
|
||||
is_block_step,
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int recover_decode_task(Context *ctx, bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_decode_task", int);
|
||||
WRAPPER_DUMP_PARAM5(ctx, stop_flags, seq_lens_this_time,
|
||||
seq_lens_encoder, seq_lens_decoder, step_seq_lens_decoder);
|
||||
WRAPPER_DUMP_PARAM2(ctx, block_tables, is_block_step);
|
||||
WRAPPER_DUMP_PARAM3(ctx, bsz, block_num_per_seq, block_size);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
assert(false);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx, stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
block_tables,
|
||||
is_block_step,
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
int recover_decode_task(Context *ctx,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_decode_task", int);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder);
|
||||
WRAPPER_DUMP_PARAM2(ctx, block_tables, is_block_step);
|
||||
WRAPPER_DUMP_PARAM3(ctx, bsz, block_num_per_seq, block_size);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
assert(false);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
block_tables,
|
||||
is_block_step,
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
@@ -18,18 +18,17 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
template <typename T>
|
||||
__attribute__((global)) void text_image_gather_scatter(
|
||||
T* input,
|
||||
T* text_input,
|
||||
T* image_input,
|
||||
int* token_type_ids,
|
||||
int* text_index,
|
||||
int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter);
|
||||
__attribute__((global)) void text_image_gather_scatter(T* input,
|
||||
T* text_input,
|
||||
T* image_input,
|
||||
int* token_type_ids,
|
||||
int* text_index,
|
||||
int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
@@ -41,18 +40,17 @@ namespace plugin {
|
||||
template <typename T>
|
||||
static int cpu_wrapper(
|
||||
Context* ctx,
|
||||
T* input, // shape [token_num, hidden_size]
|
||||
T* text_input, // shape [text_token_num, hidden_size]
|
||||
T* image_input, // shape [image_token_num, hidden_size]
|
||||
int* token_type_ids,// shape [token_num], 0 for text, 1 for image
|
||||
int* text_index, // shape [token_num], mapping from input to text_input
|
||||
int* image_index, // shape [token_num], mapping from input to image_input
|
||||
T* input, // shape [token_num, hidden_size]
|
||||
T* text_input, // shape [text_token_num, hidden_size]
|
||||
T* image_input, // shape [image_token_num, hidden_size]
|
||||
int* token_type_ids, // shape [token_num], 0 for text, 1 for image
|
||||
int* text_index, // shape [token_num], mapping from input to text_input
|
||||
int* image_index, // shape [token_num], mapping from input to image_input
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter) {
|
||||
|
||||
if (is_scatter) {
|
||||
// Scatter mode: input -> text_input/image_input
|
||||
for (int64_t i = 0; i < token_num; i++) {
|
||||
@@ -106,36 +104,42 @@ static int cpu_wrapper(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int xpu3_wrapper(
|
||||
Context* ctx,
|
||||
T* input,
|
||||
T* text_input,
|
||||
T* image_input,
|
||||
int* token_type_ids,
|
||||
int* text_index,
|
||||
int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter) {
|
||||
xpu3::plugin::text_image_gather_scatter<T> <<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
input, text_input, image_input, token_type_ids, text_index, image_index,
|
||||
token_num, text_token_num, image_token_num, hidden_size, is_scatter
|
||||
);
|
||||
static int xpu3_wrapper(Context* ctx,
|
||||
T* input,
|
||||
T* text_input,
|
||||
T* image_input,
|
||||
int* token_type_ids,
|
||||
int* text_index,
|
||||
int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter) {
|
||||
xpu3::plugin::text_image_gather_scatter<T>
|
||||
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(input,
|
||||
text_input,
|
||||
image_input,
|
||||
token_type_ids,
|
||||
text_index,
|
||||
image_index,
|
||||
token_num,
|
||||
text_token_num,
|
||||
image_token_num,
|
||||
hidden_size,
|
||||
is_scatter);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
int text_image_gather_scatter(
|
||||
Context* ctx,
|
||||
T* input, // shape [token_num, hidden_size]
|
||||
T* text_input, // shape [text_token_num, hidden_size]
|
||||
T* image_input, // shape [image_token_num, hidden_size]
|
||||
int* token_type_ids,// shape [token_num], 0 for text, 1 for image
|
||||
int* text_index, // shape [token_num], mapping from input to text_input
|
||||
int* image_index, // shape [token_num], mapping from input to image_input
|
||||
T* input, // shape [token_num, hidden_size]
|
||||
T* text_input, // shape [text_token_num, hidden_size]
|
||||
T* image_input, // shape [image_token_num, hidden_size]
|
||||
int* token_type_ids, // shape [token_num], 0 for text, 1 for image
|
||||
int* text_index, // shape [token_num], mapping from input to text_input
|
||||
int* image_index, // shape [token_num], mapping from input to image_input
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
@@ -143,14 +147,23 @@ int text_image_gather_scatter(
|
||||
bool is_scatter) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_gather_scatter", T);
|
||||
WRAPPER_DUMP_PARAM6(ctx, input, text_input, image_input, token_type_ids, text_index, image_index);
|
||||
WRAPPER_DUMP_PARAM5(ctx, token_num, text_token_num, image_token_num, hidden_size, is_scatter);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
input,
|
||||
text_input,
|
||||
image_input,
|
||||
token_type_ids,
|
||||
text_index,
|
||||
image_index);
|
||||
WRAPPER_DUMP_PARAM5(
|
||||
ctx, token_num, text_token_num, image_token_num, hidden_size, is_scatter);
|
||||
WRAPPER_DUMP(ctx);
|
||||
WRAPPER_CHECK_PTR(ctx, T, token_num * hidden_size, input);
|
||||
if (text_token_num != 0) { // avoiding text_input tensor with shape [0, hidden_size]
|
||||
if (text_token_num !=
|
||||
0) { // avoiding text_input tensor with shape [0, hidden_size]
|
||||
WRAPPER_CHECK_PTR(ctx, T, text_token_num * hidden_size, text_input);
|
||||
}
|
||||
if (image_token_num != 0) { // avoiding image_input tensor with shape [0, hidden_size]
|
||||
if (image_token_num !=
|
||||
0) { // avoiding image_input tensor with shape [0, hidden_size]
|
||||
WRAPPER_CHECK_PTR(ctx, T, image_token_num * hidden_size, image_input);
|
||||
}
|
||||
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
|
||||
@@ -159,23 +172,48 @@ int text_image_gather_scatter(
|
||||
WRAPPER_ASSERT_EQ(ctx, token_num, text_token_num + image_token_num);
|
||||
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper<T>(
|
||||
ctx, input, text_input, image_input, token_type_ids, text_index, image_index,
|
||||
token_num, text_token_num, image_token_num, hidden_size, is_scatter
|
||||
);
|
||||
return cpu_wrapper<T>(ctx,
|
||||
input,
|
||||
text_input,
|
||||
image_input,
|
||||
token_type_ids,
|
||||
text_index,
|
||||
image_index,
|
||||
token_num,
|
||||
text_token_num,
|
||||
image_token_num,
|
||||
hidden_size,
|
||||
is_scatter);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper<T>(
|
||||
ctx, input, text_input, image_input, token_type_ids, text_index, image_index,
|
||||
token_num, text_token_num, image_token_num, hidden_size, is_scatter
|
||||
);
|
||||
return xpu3_wrapper<T>(ctx,
|
||||
input,
|
||||
text_input,
|
||||
image_input,
|
||||
token_type_ids,
|
||||
text_index,
|
||||
image_index,
|
||||
token_num,
|
||||
text_token_num,
|
||||
image_token_num,
|
||||
hidden_size,
|
||||
is_scatter);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
|
||||
template int text_image_gather_scatter(
|
||||
Context*, bfloat16*, bfloat16*, bfloat16*, int*, int*, int*, const int64_t, const int64_t, const int64_t, const int64_t, bool);
|
||||
template int text_image_gather_scatter(Context*,
|
||||
bfloat16*,
|
||||
bfloat16*,
|
||||
bfloat16*,
|
||||
int*,
|
||||
int*,
|
||||
int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
bool);
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
|
||||
@@ -17,10 +17,11 @@
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
__attribute__((global)) void text_image_index_out_kernel(const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num);
|
||||
__attribute__((global)) void text_image_index_out_kernel(
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
@@ -30,69 +31,54 @@ namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(Context* ctx,
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num) {
|
||||
int text_count = 0;
|
||||
int text_count = 0;
|
||||
int image_count = 0;
|
||||
|
||||
for (int64_t i = 0; i < token_num; ++i) {
|
||||
if (token_type_ids[i] == 0) {
|
||||
text_index[i] = text_count;
|
||||
++text_count;
|
||||
} else {
|
||||
image_index[i] = image_count;
|
||||
++image_count;
|
||||
}
|
||||
if (token_type_ids[i] == 0) {
|
||||
text_index[i] = text_count;
|
||||
++text_count;
|
||||
} else {
|
||||
image_index[i] = image_count;
|
||||
++image_count;
|
||||
}
|
||||
}
|
||||
return api::SUCCESS;
|
||||
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context* ctx,
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num) {
|
||||
|
||||
xpu3::plugin::text_image_index_out_kernel<<<1, 1, ctx->xpu_stream>>>(
|
||||
token_type_ids,
|
||||
text_index,
|
||||
image_index,
|
||||
token_num);
|
||||
token_type_ids, text_index, image_index, token_num);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int text_image_index_out(Context* ctx,
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num) {
|
||||
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_index_out", int);
|
||||
WRAPPER_DUMP_PARAM4(
|
||||
ctx, token_type_ids, text_index, image_index, token_num);
|
||||
WRAPPER_DUMP_PARAM4(ctx, token_type_ids, text_index, image_index, token_num);
|
||||
WRAPPER_DUMP(ctx);
|
||||
WRAPPER_ASSERT_GT(ctx, token_num, 0);
|
||||
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, int, token_num, text_index);
|
||||
WRAPPER_CHECK_PTR(ctx, int, token_num, image_index);
|
||||
|
||||
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx,
|
||||
token_type_ids,
|
||||
text_index,
|
||||
image_index,
|
||||
token_num);
|
||||
return cpu_wrapper(ctx, token_type_ids, text_index, image_index, token_num);
|
||||
} else if (ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
token_type_ids,
|
||||
text_index,
|
||||
image_index,
|
||||
token_num);
|
||||
return xpu3_wrapper(
|
||||
ctx, token_type_ids, text_index, image_index, token_num);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
@@ -12,108 +12,162 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__attribute__((global)) void
|
||||
update_inputs(bool *not_need_stop, int *seq_lens_this_time,
|
||||
int *seq_lens_encoder, int *seq_lens_decoder, int64_t *input_ids,
|
||||
const int64_t *stop_nums, const bool *stop_flags,
|
||||
const bool *is_block_step, const int64_t *next_tokens,
|
||||
const int bsz, const int max_bsz, const int input_ids_stride);
|
||||
__attribute__((global)) void update_inputs(bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int64_t *input_ids,
|
||||
const int64_t *stop_nums,
|
||||
const bool *stop_flags,
|
||||
const bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(Context *ctx, bool *not_need_stop,
|
||||
int *seq_lens_this_time, int *seq_lens_encoder,
|
||||
int *seq_lens_decoder, int64_t *input_ids,
|
||||
const int64_t *stop_nums, const bool *stop_flags,
|
||||
const bool *is_block_step, const int64_t *next_tokens,
|
||||
const int bsz, const int max_bsz,
|
||||
static int cpu_wrapper(Context *ctx,
|
||||
bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int64_t *input_ids,
|
||||
const int64_t *stop_nums,
|
||||
const bool *stop_flags,
|
||||
const bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride) {
|
||||
std::vector<int64_t> stop_flag_now_int(max_bsz, 1);
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
bool stop_flags_now = stop_flags[i];
|
||||
stop_flag_now_int[i] = is_block_step[i] ? 0 : stop_flags_now;
|
||||
const int seq_len_encoder = seq_lens_encoder[i];
|
||||
const int seq_len_decoder = seq_lens_decoder[i];
|
||||
std::vector<int64_t> stop_flag_now_int(max_bsz, 1);
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
bool stop_flags_now = stop_flags[i];
|
||||
stop_flag_now_int[i] = is_block_step[i] ? 0 : stop_flags_now;
|
||||
const int seq_len_encoder = seq_lens_encoder[i];
|
||||
const int seq_len_decoder = seq_lens_decoder[i];
|
||||
|
||||
seq_lens_decoder[i] =
|
||||
stop_flags[i] ? 0
|
||||
: (seq_len_decoder == 0 ? seq_len_encoder
|
||||
: seq_len_decoder + 1);
|
||||
seq_lens_decoder[i] =
|
||||
stop_flags[i]
|
||||
? 0
|
||||
: (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1);
|
||||
|
||||
seq_lens_this_time[i] = stop_flags[i] ? 0 : 1;
|
||||
seq_lens_encoder[i] = 0;
|
||||
int64_t *input_ids_now = input_ids + i * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[i];
|
||||
}
|
||||
int64_t stop_sum = 0;
|
||||
for (size_t i = 0; i < stop_flag_now_int.size(); i++) {
|
||||
stop_sum += stop_flag_now_int[i];
|
||||
}
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
return api::SUCCESS;
|
||||
seq_lens_this_time[i] = stop_flags[i] ? 0 : 1;
|
||||
seq_lens_encoder[i] = 0;
|
||||
int64_t *input_ids_now = input_ids + i * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[i];
|
||||
}
|
||||
int64_t stop_sum = 0;
|
||||
for (size_t i = 0; i < stop_flag_now_int.size(); i++) {
|
||||
stop_sum += stop_flag_now_int[i];
|
||||
}
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context *ctx, bool *not_need_stop,
|
||||
int *seq_lens_this_time, int *seq_lens_encoder,
|
||||
int *seq_lens_decoder, int64_t *input_ids,
|
||||
const int64_t *stop_nums, const bool *stop_flags,
|
||||
const bool *is_block_step, const int64_t *next_tokens,
|
||||
const int bsz, const int max_bsz,
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int64_t *input_ids,
|
||||
const int64_t *stop_nums,
|
||||
const bool *stop_flags,
|
||||
const bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto update_inputs = xpu3::plugin::update_inputs;
|
||||
update_inputs<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
not_need_stop, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder,
|
||||
reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(stop_nums), stop_flags,
|
||||
is_block_step, reinterpret_cast<const XPU_INT64 *>(next_tokens), bsz,
|
||||
max_bsz, input_ids_stride);
|
||||
return api::SUCCESS;
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto update_inputs = xpu3::plugin::update_inputs;
|
||||
update_inputs<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(stop_nums),
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
reinterpret_cast<const XPU_INT64 *>(next_tokens),
|
||||
bsz,
|
||||
max_bsz,
|
||||
input_ids_stride);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int update_inputs(Context *ctx, bool *not_need_stop, int *seq_lens_this_time,
|
||||
int *seq_lens_encoder, int *seq_lens_decoder,
|
||||
int64_t *input_ids, const int64_t *stop_nums,
|
||||
const bool *stop_flags, const bool *is_block_step,
|
||||
const int64_t *next_tokens, const int bsz, const int max_bsz,
|
||||
int update_inputs(Context *ctx,
|
||||
bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int64_t *input_ids,
|
||||
const int64_t *stop_nums,
|
||||
const bool *stop_flags,
|
||||
const bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs", int);
|
||||
WRAPPER_DUMP_PARAM5(ctx, not_need_stop, seq_lens_this_time,
|
||||
seq_lens_encoder, seq_lens_decoder, input_ids);
|
||||
WRAPPER_DUMP_PARAM4(ctx, stop_nums, stop_flags, is_block_step, next_tokens);
|
||||
WRAPPER_DUMP_PARAM3(ctx, bsz, max_bsz, input_ids_stride);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx, not_need_stop, seq_lens_this_time,
|
||||
seq_lens_encoder, seq_lens_decoder, input_ids,
|
||||
stop_nums, stop_flags, is_block_step, next_tokens,
|
||||
bsz, max_bsz, input_ids_stride);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx, not_need_stop, seq_lens_this_time,
|
||||
seq_lens_encoder, seq_lens_decoder, input_ids,
|
||||
stop_nums, stop_flags, is_block_step, next_tokens,
|
||||
bsz, max_bsz, input_ids_stride);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs", int);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
input_ids);
|
||||
WRAPPER_DUMP_PARAM4(ctx, stop_nums, stop_flags, is_block_step, next_tokens);
|
||||
WRAPPER_DUMP_PARAM3(ctx, bsz, max_bsz, input_ids_stride);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx,
|
||||
not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
input_ids,
|
||||
stop_nums,
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
next_tokens,
|
||||
bsz,
|
||||
max_bsz,
|
||||
input_ids_stride);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
input_ids,
|
||||
stop_nums,
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
next_tokens,
|
||||
bsz,
|
||||
max_bsz,
|
||||
input_ids_stride);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
@@ -12,138 +12,146 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__attribute__((global)) void
|
||||
update_inputs_v1(bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size);
|
||||
__attribute__((global)) void update_inputs_v1(bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int xpu3_wrapper(Context *ctx, bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto update_inputs_v1 = xpu3::plugin::update_inputs_v1;
|
||||
// kernel 内要做 reduce,只能用 1 个 cluster
|
||||
update_inputs_v1<<<1, 64, ctx->xpu_stream>>>(
|
||||
not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
reinterpret_cast<XPU_INT64 *>(prompt_lens),
|
||||
reinterpret_cast<XPU_INT64 *>(topk_ids),
|
||||
reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
block_tables,
|
||||
reinterpret_cast<const XPU_INT64 *>(stop_nums),
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
reinterpret_cast<const XPU_INT64 *>(next_tokens),
|
||||
bsz,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
return api::SUCCESS;
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto update_inputs_v1 = xpu3::plugin::update_inputs_v1;
|
||||
// kernel 内要做 reduce,只能用 1 个 cluster
|
||||
update_inputs_v1<<<1, 64, ctx->xpu_stream>>>(
|
||||
not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
reinterpret_cast<XPU_INT64 *>(prompt_lens),
|
||||
reinterpret_cast<XPU_INT64 *>(topk_ids),
|
||||
reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
block_tables,
|
||||
reinterpret_cast<const XPU_INT64 *>(stop_nums),
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
reinterpret_cast<const XPU_INT64 *>(next_tokens),
|
||||
bsz,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int update_inputs_v1(Context *ctx, bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs_v1", int);
|
||||
WRAPPER_DUMP_PARAM5(ctx, not_need_stop, seq_lens_this_time,
|
||||
seq_lens_encoder, seq_lens_decoder, step_seq_lens_decoder);
|
||||
WRAPPER_DUMP_PARAM5(ctx, prompt_lens, topk_ids, input_ids, block_tables, stop_nums);
|
||||
WRAPPER_DUMP_PARAM3(ctx, stop_flags, is_block_step, next_tokens);
|
||||
WRAPPER_DUMP_PARAM5(ctx, bsz, max_bsz, input_ids_stride, block_num_per_seq, block_size);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
assert(false);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx, not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
prompt_lens,
|
||||
topk_ids,
|
||||
input_ids,
|
||||
block_tables,
|
||||
stop_nums,
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
next_tokens,
|
||||
bsz,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
int update_inputs_v1(Context *ctx,
|
||||
bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs_v1", int);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder);
|
||||
WRAPPER_DUMP_PARAM5(
|
||||
ctx, prompt_lens, topk_ids, input_ids, block_tables, stop_nums);
|
||||
WRAPPER_DUMP_PARAM3(ctx, stop_flags, is_block_step, next_tokens);
|
||||
WRAPPER_DUMP_PARAM5(
|
||||
ctx, bsz, max_bsz, input_ids_stride, block_num_per_seq, block_size);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
assert(false);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
prompt_lens,
|
||||
topk_ids,
|
||||
input_ids,
|
||||
block_tables,
|
||||
stop_nums,
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
next_tokens,
|
||||
bsz,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
+121
-118
@@ -22,32 +22,25 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <rdma/rdma_cma.h>
|
||||
#include <rdma/rdma_verbs.h>
|
||||
#include <sys/epoll.h>
|
||||
#include <atomic>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <netinet/tcp.h>
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sstream>
|
||||
#include <netdb.h>
|
||||
#include <sstream>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/socket.h>
|
||||
#include <cstring>
|
||||
#include <netdb.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <net/if.h>
|
||||
#include <sys/ioctl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <rdma/rdma_cma.h>
|
||||
#include <rdma/rdma_verbs.h>
|
||||
#include <sys/epoll.h>
|
||||
#include <sys/ioctl.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "kvcache_rdma.h"
|
||||
#include "util.h"
|
||||
@@ -60,115 +53,115 @@
|
||||
|
||||
/// @brief IB device information structure
|
||||
struct IbDeviceInfo {
|
||||
int device;
|
||||
uint64_t guid;
|
||||
enum ibv_mtu mtu;
|
||||
uint64_t busid;
|
||||
uint8_t port;
|
||||
uint8_t link;
|
||||
uint8_t active_mtu;
|
||||
int speed;
|
||||
ibv_context* context;
|
||||
char devName[64];
|
||||
int realPort;
|
||||
int maxQp;
|
||||
int device;
|
||||
uint64_t guid;
|
||||
enum ibv_mtu mtu;
|
||||
uint64_t busid;
|
||||
uint8_t port;
|
||||
uint8_t link;
|
||||
uint8_t active_mtu;
|
||||
int speed;
|
||||
ibv_context* context;
|
||||
char devName[64];
|
||||
int realPort;
|
||||
int maxQp;
|
||||
};
|
||||
|
||||
/// @brief Queue Pair information for RDMA
|
||||
struct QpInfo {
|
||||
uint32_t lid;
|
||||
uint32_t qpn;
|
||||
uint32_t psn;
|
||||
union ibv_gid gid;
|
||||
enum ibv_mtu mtu;
|
||||
uint32_t lid;
|
||||
uint32_t qpn;
|
||||
uint32_t psn;
|
||||
union ibv_gid gid;
|
||||
enum ibv_mtu mtu;
|
||||
|
||||
/// @brief Serialize QP info to buffer
|
||||
void serialize(char* buffer) const {
|
||||
uint32_t* intBuffer = reinterpret_cast<uint32_t*>(buffer);
|
||||
intBuffer[0] = htonl(lid);
|
||||
intBuffer[1] = htonl(qpn);
|
||||
intBuffer[2] = htonl(psn);
|
||||
memcpy(buffer + 12, gid.raw, sizeof(gid.raw));
|
||||
intBuffer[7] = htonl(static_cast<uint32_t>(mtu));
|
||||
}
|
||||
/// @brief Serialize QP info to buffer
|
||||
void serialize(char* buffer) const {
|
||||
uint32_t* intBuffer = reinterpret_cast<uint32_t*>(buffer);
|
||||
intBuffer[0] = htonl(lid);
|
||||
intBuffer[1] = htonl(qpn);
|
||||
intBuffer[2] = htonl(psn);
|
||||
memcpy(buffer + 12, gid.raw, sizeof(gid.raw));
|
||||
intBuffer[7] = htonl(static_cast<uint32_t>(mtu));
|
||||
}
|
||||
|
||||
/// @brief Deserialize QP info from buffer
|
||||
void deserialize(const char* buffer) {
|
||||
const uint32_t* intBuffer = reinterpret_cast<const uint32_t*>(buffer);
|
||||
lid = ntohl(intBuffer[0]);
|
||||
qpn = ntohl(intBuffer[1]);
|
||||
psn = ntohl(intBuffer[2]);
|
||||
memcpy(gid.raw, buffer + 12, sizeof(gid.raw));
|
||||
mtu = static_cast<ibv_mtu>(ntohl(intBuffer[7]));
|
||||
}
|
||||
/// @brief Deserialize QP info from buffer
|
||||
void deserialize(const char* buffer) {
|
||||
const uint32_t* intBuffer = reinterpret_cast<const uint32_t*>(buffer);
|
||||
lid = ntohl(intBuffer[0]);
|
||||
qpn = ntohl(intBuffer[1]);
|
||||
psn = ntohl(intBuffer[2]);
|
||||
memcpy(gid.raw, buffer + 12, sizeof(gid.raw));
|
||||
mtu = static_cast<ibv_mtu>(ntohl(intBuffer[7]));
|
||||
}
|
||||
|
||||
static const size_t size = 12 + sizeof(gid.raw) + 4;
|
||||
static const size_t size = 12 + sizeof(gid.raw) + 4;
|
||||
};
|
||||
|
||||
/// @brief RDMA connection context
|
||||
struct Connection {
|
||||
std::atomic<int> connected;
|
||||
std::atomic<int> connected;
|
||||
|
||||
// Memory regions
|
||||
struct ibv_mr *recv_mr;
|
||||
struct ibv_mr *send_mr;
|
||||
// Memory regions
|
||||
struct ibv_mr* recv_mr;
|
||||
struct ibv_mr* send_mr;
|
||||
|
||||
// Cache pointers
|
||||
std::vector<std::vector<void*>> local_cache_key_ptr_per_layer;
|
||||
std::vector<std::vector<void*>> local_cache_value_ptr_per_layer;
|
||||
// Cache pointers
|
||||
std::vector<std::vector<void*>> local_cache_key_ptr_per_layer;
|
||||
std::vector<std::vector<void*>> local_cache_value_ptr_per_layer;
|
||||
|
||||
// Memory region lists
|
||||
std::vector<ibv_mr*> write_cache_key_server_mr_list;
|
||||
std::vector<ibv_mr*> write_cache_value_server_mr_list;
|
||||
std::vector<std::vector<ibv_mr*>> write_mr_key_list;
|
||||
std::vector<std::vector<ibv_mr*>> write_mr_value_list;
|
||||
// Memory region lists
|
||||
std::vector<ibv_mr*> write_cache_key_server_mr_list;
|
||||
std::vector<ibv_mr*> write_cache_value_server_mr_list;
|
||||
std::vector<std::vector<ibv_mr*>> write_mr_key_list;
|
||||
std::vector<std::vector<ibv_mr*>> write_mr_value_list;
|
||||
|
||||
// Remote access information
|
||||
std::vector<void*> write_cache_key_remote_ptr_list;
|
||||
std::vector<uint32_t> write_cache_key_remote_rkey_list;
|
||||
std::vector<void*> write_cache_value_remote_ptr_list;
|
||||
std::vector<uint32_t> write_cache_value_remote_rkey_list;
|
||||
// Remote access information
|
||||
std::vector<void*> write_cache_key_remote_ptr_list;
|
||||
std::vector<uint32_t> write_cache_key_remote_rkey_list;
|
||||
std::vector<void*> write_cache_value_remote_ptr_list;
|
||||
std::vector<uint32_t> write_cache_value_remote_rkey_list;
|
||||
|
||||
// Received remote memory information
|
||||
std::vector<void*> receive_write_cache_key_remote_ptr_list;
|
||||
std::vector<uint32_t> receive_write_cache_key_remote_rkey_list;
|
||||
std::vector<void*> receive_write_cache_value_remote_ptr_list;
|
||||
std::vector<uint32_t> receive_write_cache_value_remote_rkey_list;
|
||||
// Received remote memory information
|
||||
std::vector<void*> receive_write_cache_key_remote_ptr_list;
|
||||
std::vector<uint32_t> receive_write_cache_key_remote_rkey_list;
|
||||
std::vector<void*> receive_write_cache_value_remote_ptr_list;
|
||||
std::vector<uint32_t> receive_write_cache_value_remote_rkey_list;
|
||||
|
||||
std::vector<void *> send_write_cache_key_remote_ptr_list;
|
||||
std::vector<uint32_t> send_write_cache_key_remote_rkey_list;
|
||||
std::vector<void *> send_write_cache_value_remote_ptr_list;
|
||||
std::vector<uint32_t> send_write_cache_value_remote_rkey_list;
|
||||
std::vector<void*> send_write_cache_key_remote_ptr_list;
|
||||
std::vector<uint32_t> send_write_cache_key_remote_rkey_list;
|
||||
std::vector<void*> send_write_cache_value_remote_ptr_list;
|
||||
std::vector<uint32_t> send_write_cache_value_remote_rkey_list;
|
||||
|
||||
// For rdma read operations
|
||||
std::vector<void*> read_bufs;
|
||||
std::vector<ibv_mr*> read_mrs;
|
||||
// For rdma read operations
|
||||
std::vector<void*> read_bufs;
|
||||
std::vector<ibv_mr*> read_mrs;
|
||||
|
||||
// Work completion tracking
|
||||
int wc_count;
|
||||
int wc_target_count;
|
||||
// Work completion tracking
|
||||
int wc_count;
|
||||
int wc_target_count;
|
||||
|
||||
// Configuration
|
||||
int layer_number;
|
||||
int block_number;
|
||||
int block_byte_size;
|
||||
std::string url;
|
||||
// Configuration
|
||||
int layer_number;
|
||||
int block_number;
|
||||
int block_byte_size;
|
||||
std::string url;
|
||||
|
||||
Connection() = default;
|
||||
~Connection();
|
||||
Connection() = default;
|
||||
~Connection();
|
||||
};
|
||||
|
||||
/// @brief RDMA context structure
|
||||
struct RdmaContext {
|
||||
int sock_fd;
|
||||
struct ibv_context* context;
|
||||
struct ibv_comp_channel* channel;
|
||||
struct ibv_pd* pd;
|
||||
struct ibv_mr* mr;
|
||||
struct ibv_cq* cq;
|
||||
struct ibv_qp* qp;
|
||||
struct ibv_port_attr portinfo;
|
||||
struct Connection conn;
|
||||
int sock_fd;
|
||||
struct ibv_context* context;
|
||||
struct ibv_comp_channel* channel;
|
||||
struct ibv_pd* pd;
|
||||
struct ibv_mr* mr;
|
||||
struct ibv_cq* cq;
|
||||
struct ibv_qp* qp;
|
||||
struct ibv_port_attr portinfo;
|
||||
struct Connection conn;
|
||||
};
|
||||
|
||||
// Global variables
|
||||
@@ -176,36 +169,46 @@ extern std::vector<IbDeviceInfo> g_ib_all_devs;
|
||||
static int g_kvcache_ib_dev_nums = -1;
|
||||
|
||||
// Connection management functions
|
||||
bool client_exchange_destinations(
|
||||
struct RdmaContext* ctx,
|
||||
int ib_port,
|
||||
unsigned int port,
|
||||
int gidx,
|
||||
const std::string& dst_ip);
|
||||
bool client_exchange_destinations(struct RdmaContext* ctx,
|
||||
int ib_port,
|
||||
unsigned int port,
|
||||
int gidx,
|
||||
const std::string& dst_ip);
|
||||
|
||||
int server_exchange_qp_info(int connfd, QpInfo* local_dest, QpInfo* rem_dest);
|
||||
struct RdmaContext* create_qp(struct IbDeviceInfo* ib_dev, struct ibv_pd** g_pd);
|
||||
struct RdmaContext* create_qp(struct IbDeviceInfo* ib_dev,
|
||||
struct ibv_pd** g_pd);
|
||||
bool clear_qp_info(struct RdmaContext* ctx);
|
||||
|
||||
// QP modification functions
|
||||
QpStatus modify_qp_to_rts(struct RdmaContext* ctx, int port, int my_psn,
|
||||
struct QpInfo* dest, int sgid_id);
|
||||
bool poll_cq_with_timeout(struct RdmaContext* ctx, int timeout_seconds, int cqe_count);
|
||||
QpStatus modify_qp_to_rts(struct RdmaContext* ctx,
|
||||
int port,
|
||||
int my_psn,
|
||||
struct QpInfo* dest,
|
||||
int sgid_id);
|
||||
bool poll_cq_with_timeout(struct RdmaContext* ctx,
|
||||
int timeout_seconds,
|
||||
int cqe_count);
|
||||
|
||||
// Utility functions
|
||||
int get_port_info(struct ibv_context* Context, int port,
|
||||
struct ibv_port_attr* attr);
|
||||
int get_port_info(struct ibv_context* Context,
|
||||
int port,
|
||||
struct ibv_port_attr* attr);
|
||||
int parse_port_ib_info();
|
||||
|
||||
// Memory region exchange
|
||||
bool client_exchange_mr(struct RdmaContext* ctx);
|
||||
bool server_exchange_mr(struct RdmaContext* ctx);
|
||||
bool server_send_memory_region(struct RdmaContext *ctx, void *local_mr, int byte_num);
|
||||
bool client_receive_memory_region(struct RdmaContext *ctx, void *remote_mr, int byte_num);
|
||||
bool server_send_memory_region(struct RdmaContext* ctx,
|
||||
void* local_mr,
|
||||
int byte_num);
|
||||
bool client_receive_memory_region(struct RdmaContext* ctx,
|
||||
void* remote_mr,
|
||||
int byte_num);
|
||||
|
||||
// Network setup
|
||||
int setup_listening_socket(int port);
|
||||
int configure_epoll(int sockfd);
|
||||
std::vector<std::string> get_net_ifname();
|
||||
|
||||
#endif // FASTDEPLOY_KVCACHE_CONNECTION_H
|
||||
#endif // FASTDEPLOY_KVCACHE_CONNECTION_H
|
||||
|
||||
+103
-82
@@ -4,77 +4,88 @@
|
||||
#pragma once
|
||||
|
||||
#include <rdma/rdma_cma.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include "util.h" // Contains constant definitions
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "kvcache_connection.h"
|
||||
#include "log.h"
|
||||
|
||||
#include "util.h" // Contains constant definitions
|
||||
|
||||
/**
|
||||
* @brief RDMA communication handler for key-value cache
|
||||
*/
|
||||
class RDMACommunicator {
|
||||
public:
|
||||
// Construction/Destruction
|
||||
RDMACommunicator(std::string &role, int gpu_idx, std::string &port,
|
||||
std::vector<int64_t> local_key_cache,
|
||||
std::vector<int64_t> local_value_cache,
|
||||
int block_number, int block_bytes);
|
||||
~RDMACommunicator();
|
||||
public:
|
||||
// Construction/Destruction
|
||||
RDMACommunicator(std::string& role,
|
||||
int gpu_idx,
|
||||
std::string& port,
|
||||
std::vector<int64_t> local_key_cache,
|
||||
std::vector<int64_t> local_value_cache,
|
||||
int block_number,
|
||||
int block_bytes);
|
||||
~RDMACommunicator();
|
||||
|
||||
// Connection management
|
||||
int connect(const std::string &dst_ip, const std::string &dst_port);
|
||||
bool is_connected(const std::string &dst_ip, const std::string &dst_port);
|
||||
// Connection management
|
||||
int connect(const std::string& dst_ip, const std::string& dst_port);
|
||||
bool is_connected(const std::string& dst_ip, const std::string& dst_port);
|
||||
|
||||
// Core functionality
|
||||
int write_cache(const std::string &ip, const std::string &port,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
const std::vector<int64_t>& remote_block_ids,
|
||||
int32_t layer_idx);
|
||||
// Core functionality
|
||||
int write_cache(const std::string& ip,
|
||||
const std::string& port,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
const std::vector<int64_t>& remote_block_ids,
|
||||
int32_t layer_idx);
|
||||
|
||||
// Server Init
|
||||
int init_server();
|
||||
// Server Init
|
||||
int init_server();
|
||||
|
||||
// get socket nic ip
|
||||
std::string fetch_local_ip();
|
||||
// get socket nic ip
|
||||
std::string fetch_local_ip();
|
||||
|
||||
private:
|
||||
// Server Core functions
|
||||
int start_server(int sport, int sgid_idx, int gpu_index);
|
||||
private:
|
||||
// Server Core functions
|
||||
int start_server(int sport, int sgid_idx, int gpu_index);
|
||||
|
||||
// Internal implementation methods
|
||||
void resize_vectors();
|
||||
void assign_pointers();
|
||||
void validate_addr();
|
||||
bool client_mr_register_per_layer(struct RdmaContext *ctx);
|
||||
bool server_mr_register_per_layer(struct RdmaContext *ctx);
|
||||
struct ibv_mr* register_memory_region(ibv_pd* pd, void* addr, size_t size,
|
||||
const std::string& desc, uint32_t access_flags);
|
||||
bool deregister_memory_regions(struct RdmaContext* ctx);
|
||||
// Internal implementation methods
|
||||
void resize_vectors();
|
||||
void assign_pointers();
|
||||
void validate_addr();
|
||||
bool client_mr_register_per_layer(struct RdmaContext* ctx);
|
||||
bool server_mr_register_per_layer(struct RdmaContext* ctx);
|
||||
struct ibv_mr* register_memory_region(ibv_pd* pd,
|
||||
void* addr,
|
||||
size_t size,
|
||||
const std::string& desc,
|
||||
uint32_t access_flags);
|
||||
bool deregister_memory_regions(struct RdmaContext* ctx);
|
||||
|
||||
bool post_block_send(struct RdmaContext* ctx, int layer_idx,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
bool is_key, std::vector<uint64_t>& remote_addr,
|
||||
uint32_t rkey, const std::string &ip,
|
||||
const std::string &port);
|
||||
bool post_block_send(struct RdmaContext* ctx,
|
||||
int layer_idx,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
bool is_key,
|
||||
std::vector<uint64_t>& remote_addr,
|
||||
uint32_t rkey,
|
||||
const std::string& ip,
|
||||
const std::string& port);
|
||||
|
||||
bool execute_rdma_writes(struct RdmaContext* ctx, int layer_idx,
|
||||
bool execute_rdma_writes(struct RdmaContext* ctx,
|
||||
int layer_idx,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
bool is_key, std::vector<uint64_t>& remote_addr,
|
||||
bool is_key,
|
||||
std::vector<uint64_t>& remote_addr,
|
||||
uint32_t rkey);
|
||||
|
||||
void prepare_write_requests(struct ibv_sge* sge_list,
|
||||
struct ibv_send_wr* send_wr_list,
|
||||
int layer_idx,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
bool is_key,
|
||||
std::vector<uint64_t>& remote_addr,
|
||||
uint32_t rkey);
|
||||
void prepare_write_requests(struct ibv_sge* sge_list,
|
||||
struct ibv_send_wr* send_wr_list,
|
||||
int layer_idx,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
bool is_key,
|
||||
std::vector<uint64_t>& remote_addr,
|
||||
uint32_t rkey);
|
||||
|
||||
bool execute_read_verification(struct RdmaContext* ctx,
|
||||
bool execute_read_verification(struct RdmaContext* ctx,
|
||||
size_t block_idx,
|
||||
uint64_t remote_addr,
|
||||
uint32_t rkey,
|
||||
@@ -82,46 +93,56 @@ private:
|
||||
const std::string& ip,
|
||||
const std::string& port);
|
||||
|
||||
bool post_send_with_retry(struct RdmaContext* ctx,
|
||||
bool post_send_with_retry(struct RdmaContext* ctx,
|
||||
struct ibv_send_wr* wr_list,
|
||||
size_t inflight_wr,
|
||||
bool need_poll);
|
||||
|
||||
// Connection management
|
||||
int client_listener();
|
||||
void close_server_connection(int fd, struct RdmaContext* ctx, int epollfd,
|
||||
std::map<int, struct RdmaContext*>& connectionContexts);
|
||||
void close_client_connection(int fd, struct RdmaContext* ctx, int epollfd);
|
||||
// Connection management
|
||||
int client_listener();
|
||||
void close_server_connection(
|
||||
int fd,
|
||||
struct RdmaContext* ctx,
|
||||
int epollfd,
|
||||
std::map<int, struct RdmaContext*>& connectionContexts);
|
||||
void close_client_connection(int fd, struct RdmaContext* ctx, int epollfd);
|
||||
|
||||
void remove_conn(const std::string& url);
|
||||
struct RdmaContext *get_conn(const std::string &ip,
|
||||
const std::string &port);
|
||||
void remove_conn(const std::string& url);
|
||||
struct RdmaContext* get_conn(const std::string& ip, const std::string& port);
|
||||
|
||||
// Member variables
|
||||
std::string splitwise_role; // Role in distributed system ("decode" or other)
|
||||
int gpu_idx; // GPU device index
|
||||
std::string port; // Communication port
|
||||
std::vector<int64_t> local_cache_key_ptr_layer_head_; // Key cache pointers
|
||||
std::vector<int64_t> local_cache_value_ptr_layer_head_; // Value cache pointers
|
||||
int block_number; // Number of blocks
|
||||
int block_size_byte; // Size of each block in bytes
|
||||
int layer_number; // Number of layers
|
||||
// Member variables
|
||||
std::string splitwise_role; // Role in distributed system ("decode" or other)
|
||||
int gpu_idx; // GPU device index
|
||||
std::string port; // Communication port
|
||||
std::vector<int64_t> local_cache_key_ptr_layer_head_; // Key cache pointers
|
||||
std::vector<int64_t>
|
||||
local_cache_value_ptr_layer_head_; // Value cache pointers
|
||||
int block_number; // Number of blocks
|
||||
int block_size_byte; // Size of each block in bytes
|
||||
int layer_number; // Number of layers
|
||||
|
||||
std::vector<std::vector<void*>> local_cache_key_ptr_per_layer; // Per-layer key pointers
|
||||
std::vector<std::vector<void*>> local_cache_value_ptr_per_layer; // Per-layer value pointers
|
||||
std::vector<std::vector<void*>>
|
||||
local_cache_key_ptr_per_layer; // Per-layer key pointers
|
||||
std::vector<std::vector<void*>>
|
||||
local_cache_value_ptr_per_layer; // Per-layer value pointers
|
||||
|
||||
std::vector<struct ibv_mr*> write_mr_key_list; // Memory regions for key writes
|
||||
std::vector<struct ibv_mr*> write_mr_value_list; // Memory regions for value writes
|
||||
std::vector<struct ibv_mr*> write_cache_key_server_mr_list; // Server-side key memory regions
|
||||
std::vector<struct ibv_mr*> write_cache_value_server_mr_list; // Server-side value memory regions
|
||||
std::vector<struct ibv_mr*>
|
||||
write_mr_key_list; // Memory regions for key writes
|
||||
std::vector<struct ibv_mr*>
|
||||
write_mr_value_list; // Memory regions for value writes
|
||||
std::vector<struct ibv_mr*>
|
||||
write_cache_key_server_mr_list; // Server-side key memory regions
|
||||
std::vector<struct ibv_mr*>
|
||||
write_cache_value_server_mr_list; // Server-side value memory regions
|
||||
|
||||
std::vector<std::string> main_ip_list; // List of local IP addresses
|
||||
std::map<std::string, struct RdmaContext*> conn_map; // Active connections map
|
||||
std::mutex mutex_; // Thread synchronization mutex
|
||||
int rdma_event_channel_epoll_fd; // Epoll file descriptor
|
||||
struct ibv_pd *g_pd = NULL; // fd
|
||||
int RDMACommunicator_status; // Communicator status flag
|
||||
bool start_client_listener = false; // Client listener flag
|
||||
std::vector<std::string> main_ip_list; // List of local IP addresses
|
||||
std::map<std::string, struct RdmaContext*>
|
||||
conn_map; // Active connections map
|
||||
std::mutex mutex_; // Thread synchronization mutex
|
||||
int rdma_event_channel_epoll_fd; // Epoll file descriptor
|
||||
struct ibv_pd* g_pd = NULL; // fd
|
||||
int RDMACommunicator_status; // Communicator status flag
|
||||
bool start_client_listener = false; // Client listener flag
|
||||
};
|
||||
|
||||
#endif // KVCACHE_RDMA_H
|
||||
#endif // KVCACHE_RDMA_H
|
||||
|
||||
@@ -19,99 +19,130 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <pthread.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <sys/time.h>
|
||||
#include <unistd.h> //for gethostname
|
||||
#include <sys/syscall.h>
|
||||
#include <pthread.h>
|
||||
#include <string>
|
||||
#include <ctime>
|
||||
#include <sys/time.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h> //for gethostname
|
||||
#include <chrono>
|
||||
#include <ctime>
|
||||
#include <string>
|
||||
|
||||
#define KV_IS_DEBUG_ENABLED (std::getenv("KVCACHE_DEBUG"))
|
||||
#define FILE_NAME(x) (strrchr(x,'/') ? strrchr(x,'/')+1 : x)
|
||||
#define FILE_NAME(x) (strrchr(x, '/') ? strrchr(x, '/') + 1 : x)
|
||||
|
||||
static thread_local char __attribute__((__unused__)) str[64];
|
||||
|
||||
// for log levels (C++ enum class style in C)
|
||||
typedef enum {
|
||||
KV_LOG_LEVEL_INFO = 0,
|
||||
KV_LOG_LEVEL_DEBUG = 1,
|
||||
KV_LOG_LEVEL_WARN = 2,
|
||||
KV_LOG_LEVEL_ERROR = 3
|
||||
KV_LOG_LEVEL_INFO = 0,
|
||||
KV_LOG_LEVEL_DEBUG = 1,
|
||||
KV_LOG_LEVEL_WARN = 2,
|
||||
KV_LOG_LEVEL_ERROR = 3
|
||||
} KVLogLevel;
|
||||
|
||||
void debug_log(KVLogLevel level, bool enable_to_terminal, const char *filefunc,
|
||||
int line, const char *fmt, ...) __attribute__ ((format (printf, 5, 6)));
|
||||
void debug_log(KVLogLevel level,
|
||||
bool enable_to_terminal,
|
||||
const char *filefunc,
|
||||
int line,
|
||||
const char *fmt,
|
||||
...) __attribute__((format(printf, 5, 6)));
|
||||
|
||||
/**
|
||||
* @brief Unified logging macro to reduce duplication and improve maintainability.
|
||||
* @brief Unified logging macro to reduce duplication and improve
|
||||
* maintainability.
|
||||
*
|
||||
* @param level Log level (e.g., INFO, DEBUG, WARN, ERR).
|
||||
* @param to_terminal If true, the log will be printed to terminal.
|
||||
* @param ... Format string and arguments (like printf).
|
||||
*/
|
||||
#define KV_LOG(level, to_terminal, ...) \
|
||||
debug_log(level, to_terminal, FILE_NAME(__FILE__), __LINE__, __VA_ARGS__)
|
||||
debug_log(level, to_terminal, FILE_NAME(__FILE__), __LINE__, __VA_ARGS__)
|
||||
|
||||
// Public logging macros with terminal output
|
||||
#define WARN(...) KV_LOG(KV_LOG_LEVEL_WARN, true, __VA_ARGS__)
|
||||
#define ERR(...) KV_LOG(KV_LOG_LEVEL_ERROR, true, __VA_ARGS__)
|
||||
#define DEBUG(...) KV_LOG(KV_LOG_LEVEL_DEBUG, true, __VA_ARGS__)
|
||||
#define INFO(...) KV_LOG(KV_LOG_LEVEL_INFO, true, __VA_ARGS__)
|
||||
#define WARN(...) KV_LOG(KV_LOG_LEVEL_WARN, true, __VA_ARGS__)
|
||||
#define ERR(...) KV_LOG(KV_LOG_LEVEL_ERROR, true, __VA_ARGS__)
|
||||
#define DEBUG(...) KV_LOG(KV_LOG_LEVEL_DEBUG, true, __VA_ARGS__)
|
||||
#define INFO(...) KV_LOG(KV_LOG_LEVEL_INFO, true, __VA_ARGS__)
|
||||
|
||||
#define gettid() ((pid_t)syscall(SYS_gettid))
|
||||
#define GET_CURRENT_TIME() do { \
|
||||
time_t timer = time(0); \
|
||||
struct tm* t = localtime(&timer); \
|
||||
char hostname[32]; \
|
||||
gethostname(hostname, 32); \
|
||||
sprintf(str, "%02d:%02d:%02d][%.32s][%d", \
|
||||
t->tm_hour, t->tm_min, t->tm_sec, hostname, gettid()); \
|
||||
} while (0)
|
||||
#define GET_CURRENT_TIME() \
|
||||
do { \
|
||||
time_t timer = time(0); \
|
||||
struct tm *t = localtime(&timer); \
|
||||
char hostname[32]; \
|
||||
gethostname(hostname, 32); \
|
||||
sprintf(str, \
|
||||
"%02d:%02d:%02d][%.32s][%d", \
|
||||
t->tm_hour, \
|
||||
t->tm_min, \
|
||||
t->tm_sec, \
|
||||
hostname, \
|
||||
gettid()); \
|
||||
} while (0)
|
||||
|
||||
#define LOGE(fmt, arg...) do { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stderr, "[%s][ERR][KV_CACHE][%s:%d] " \
|
||||
fmt "\n",str, \
|
||||
FILE_NAME(__FILE__), __LINE__, ## arg); \
|
||||
} while (0)
|
||||
#define LOGE(fmt, arg...) \
|
||||
do { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stderr, \
|
||||
"[%s][ERR][KV_CACHE][%s:%d] " fmt "\n", \
|
||||
str, \
|
||||
FILE_NAME(__FILE__), \
|
||||
__LINE__, \
|
||||
##arg); \
|
||||
} while (0)
|
||||
|
||||
#define LOGW(fmt, arg...) do { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stderr, "[%s][WARN][KV_CACHE][%s:%d] " \
|
||||
fmt "\n",str, \
|
||||
FILE_NAME(__FILE__), __LINE__, ## arg); \
|
||||
} while (0)
|
||||
#define LOGW(fmt, arg...) \
|
||||
do { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stderr, \
|
||||
"[%s][WARN][KV_CACHE][%s:%d] " fmt "\n", \
|
||||
str, \
|
||||
FILE_NAME(__FILE__), \
|
||||
__LINE__, \
|
||||
##arg); \
|
||||
} while (0)
|
||||
|
||||
#define LOGI(fmt, arg...) do { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stdout, "[%s][INFO][KV_CACHE][%s:%d] " \
|
||||
fmt "\n",str, \
|
||||
FILE_NAME(__FILE__), __LINE__, ## arg); \
|
||||
} while (0)
|
||||
#define LOGI(fmt, arg...) \
|
||||
do { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stdout, \
|
||||
"[%s][INFO][KV_CACHE][%s:%d] " fmt "\n", \
|
||||
str, \
|
||||
FILE_NAME(__FILE__), \
|
||||
__LINE__, \
|
||||
##arg); \
|
||||
} while (0)
|
||||
|
||||
#define LOGD(fmt, arg...) do { \
|
||||
if (KV_IS_DEBUG_ENABLED) { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stdout, "[%s][DBG][KV_CACHE][%s:%d] " \
|
||||
fmt "\n", str, \
|
||||
FILE_NAME(__FILE__), __LINE__, ## arg); \
|
||||
} \
|
||||
} while (0)
|
||||
#define LOGD(fmt, arg...) \
|
||||
do { \
|
||||
if (KV_IS_DEBUG_ENABLED) { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stdout, \
|
||||
"[%s][DBG][KV_CACHE][%s:%d] " fmt "\n", \
|
||||
str, \
|
||||
FILE_NAME(__FILE__), \
|
||||
__LINE__, \
|
||||
##arg); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LOGD_IF(cond, fmt, ...) do { \
|
||||
if ((cond)) \
|
||||
LOGD(fmt, __VA_ARGS__); \
|
||||
} while (0)
|
||||
#define LOGD_IF(cond, fmt, ...) \
|
||||
do { \
|
||||
if ((cond)) LOGD(fmt, __VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define LOGD_RAW(fmt, arg...) do { \
|
||||
if (ENV_ENABLE_RAW("KV_IS_DEBUG_ENABLED")) { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stdout, "[%s][DBG][KV_CACHE][%s:%d] " \
|
||||
fmt "\n", str, \
|
||||
FILE_NAME(__FILE__), __LINE__, ## arg); \
|
||||
} \
|
||||
} while (0)
|
||||
#define LOGD_RAW(fmt, arg...) \
|
||||
do { \
|
||||
if (ENV_ENABLE_RAW("KV_IS_DEBUG_ENABLED")) { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stdout, \
|
||||
"[%s][DBG][KV_CACHE][%s:%d] " fmt "\n", \
|
||||
str, \
|
||||
FILE_NAME(__FILE__), \
|
||||
__LINE__, \
|
||||
##arg); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
#ifndef KVCACHE_UTILS_H
|
||||
#define KVCACHE_UTILS_H
|
||||
|
||||
#include <ctime>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <stdexcept>
|
||||
#include <cstdio>
|
||||
#include <arpa/inet.h>
|
||||
#include <ifaddrs.h>
|
||||
#include <net/if.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "log.h"
|
||||
|
||||
#define PATH_MAX 4096 /* # chars in a path name including nul */
|
||||
@@ -28,22 +28,22 @@
|
||||
|
||||
/// @brief Connection status enumeration
|
||||
enum class ConnStatus {
|
||||
kConnected, // Connection is active
|
||||
kDisconnected, // Connection is not active
|
||||
kError, // Connection error occurred
|
||||
kTimeout, // Connection timed out
|
||||
kInvalidParameters // Invalid connection parameters
|
||||
kConnected, // Connection is active
|
||||
kDisconnected, // Connection is not active
|
||||
kError, // Connection error occurred
|
||||
kTimeout, // Connection timed out
|
||||
kInvalidParameters // Invalid connection parameters
|
||||
};
|
||||
|
||||
/// @brief Queue Pair (QP) setup result status
|
||||
enum class QpStatus {
|
||||
kSuccess, // Successfully transitioned QP to RTS
|
||||
kInvalidParameters, // ctx or dest is null
|
||||
kDeviceQueryFailed, // ibv_query_device failed
|
||||
kPortQueryFailed, // ibv_query_port failed
|
||||
kMtuMismatch, // Requested MTU exceeds active MTU
|
||||
kModifyToRTRFailed, // Failed to modify QP to RTR
|
||||
kModifyToRTSFailed // Failed to modify QP to RTS
|
||||
kSuccess, // Successfully transitioned QP to RTS
|
||||
kInvalidParameters, // ctx or dest is null
|
||||
kDeviceQueryFailed, // ibv_query_device failed
|
||||
kPortQueryFailed, // ibv_query_port failed
|
||||
kMtuMismatch, // Requested MTU exceeds active MTU
|
||||
kModifyToRTRFailed, // Failed to modify QP to RTR
|
||||
kModifyToRTSFailed // Failed to modify QP to RTS
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -51,265 +51,281 @@ enum class QpStatus {
|
||||
* @param busId PCI bus ID string (e.g. "0000:3b:00.0")
|
||||
* @param[out] id Converted numeric ID
|
||||
*/
|
||||
inline void busid_to_int64(const char *busId, int64_t *id) {
|
||||
char hexStr[17] = {0};
|
||||
int hexOffset = 0;
|
||||
inline void busid_to_int64(const char* busId, int64_t* id) {
|
||||
char hexStr[17] = {0};
|
||||
int hexOffset = 0;
|
||||
|
||||
// Filter valid hex characters
|
||||
for (int i = 0; hexOffset < sizeof(hexStr) - 1 && busId[i] != '\0'; i++) {
|
||||
char c = busId[i];
|
||||
if (c == '.' || c == ':') continue;
|
||||
// Filter valid hex characters
|
||||
for (int i = 0; hexOffset < sizeof(hexStr) - 1 && busId[i] != '\0'; i++) {
|
||||
char c = busId[i];
|
||||
if (c == '.' || c == ':') continue;
|
||||
|
||||
if ((c >= '0' && c <= '9') ||
|
||||
(c >= 'A' && c <= 'F') ||
|
||||
(c >= 'a' && c <= 'f')) {
|
||||
hexStr[hexOffset++] = c;
|
||||
}
|
||||
if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'F') ||
|
||||
(c >= 'a' && c <= 'f')) {
|
||||
hexStr[hexOffset++] = c;
|
||||
}
|
||||
}
|
||||
|
||||
*id = strtol(hexStr, NULL, 16);
|
||||
*id = strtol(hexStr, NULL, 16);
|
||||
}
|
||||
|
||||
class NetworkInterfaceManager {
|
||||
public:
|
||||
struct InterfaceInfo {
|
||||
std::string name;
|
||||
std::string ip;
|
||||
bool is_up;
|
||||
bool is_running;
|
||||
bool is_loopback;
|
||||
public:
|
||||
struct InterfaceInfo {
|
||||
std::string name;
|
||||
std::string ip;
|
||||
bool is_up;
|
||||
bool is_running;
|
||||
bool is_loopback;
|
||||
|
||||
bool isUsable() const {
|
||||
return is_up && is_running && !is_loopback;
|
||||
}
|
||||
};
|
||||
bool isUsable() const { return is_up && is_running && !is_loopback; }
|
||||
};
|
||||
|
||||
static std::vector<InterfaceInfo> getAllInterfaces() {
|
||||
std::vector<InterfaceInfo> interfaces;
|
||||
struct ifaddrs *ifaddrs_ptr = nullptr;
|
||||
static std::vector<InterfaceInfo> getAllInterfaces() {
|
||||
std::vector<InterfaceInfo> interfaces;
|
||||
struct ifaddrs* ifaddrs_ptr = nullptr;
|
||||
|
||||
if (getifaddrs(&ifaddrs_ptr) == -1) {
|
||||
return interfaces;
|
||||
}
|
||||
|
||||
for (struct ifaddrs *ifa = ifaddrs_ptr; ifa != nullptr; ifa = ifa->ifa_next) {
|
||||
if (ifa->ifa_addr == nullptr) continue;
|
||||
if (ifa->ifa_addr->sa_family != AF_INET) continue;
|
||||
|
||||
InterfaceInfo info;
|
||||
info.name = ifa->ifa_name;
|
||||
info.is_up = (ifa->ifa_flags & IFF_UP) != 0;
|
||||
info.is_running = (ifa->ifa_flags & IFF_RUNNING) != 0;
|
||||
info.is_loopback = (ifa->ifa_flags & IFF_LOOPBACK) != 0;
|
||||
|
||||
struct sockaddr_in* sa = (struct sockaddr_in*)ifa->ifa_addr;
|
||||
char ip_str[INET_ADDRSTRLEN];
|
||||
inet_ntop(AF_INET, &sa->sin_addr, ip_str, INET_ADDRSTRLEN);
|
||||
info.ip = ip_str;
|
||||
|
||||
interfaces.push_back(info);
|
||||
}
|
||||
|
||||
freeifaddrs(ifaddrs_ptr);
|
||||
return interfaces;
|
||||
if (getifaddrs(&ifaddrs_ptr) == -1) {
|
||||
return interfaces;
|
||||
}
|
||||
|
||||
static std::string getFirstUsableInterface() {
|
||||
auto interfaces = getAllInterfaces();
|
||||
for (struct ifaddrs* ifa = ifaddrs_ptr; ifa != nullptr;
|
||||
ifa = ifa->ifa_next) {
|
||||
if (ifa->ifa_addr == nullptr) continue;
|
||||
if (ifa->ifa_addr->sa_family != AF_INET) continue;
|
||||
|
||||
for (const auto& iface : interfaces) {
|
||||
if (iface.isUsable()) {
|
||||
return iface.name;
|
||||
}
|
||||
}
|
||||
return "";
|
||||
InterfaceInfo info;
|
||||
info.name = ifa->ifa_name;
|
||||
info.is_up = (ifa->ifa_flags & IFF_UP) != 0;
|
||||
info.is_running = (ifa->ifa_flags & IFF_RUNNING) != 0;
|
||||
info.is_loopback = (ifa->ifa_flags & IFF_LOOPBACK) != 0;
|
||||
|
||||
struct sockaddr_in* sa = (struct sockaddr_in*)ifa->ifa_addr;
|
||||
char ip_str[INET_ADDRSTRLEN];
|
||||
inet_ntop(AF_INET, &sa->sin_addr, ip_str, INET_ADDRSTRLEN);
|
||||
info.ip = ip_str;
|
||||
|
||||
interfaces.push_back(info);
|
||||
}
|
||||
|
||||
static void displayAllInterfaces() {
|
||||
auto interfaces = getAllInterfaces();
|
||||
freeifaddrs(ifaddrs_ptr);
|
||||
return interfaces;
|
||||
}
|
||||
|
||||
printf("Available network interfaces:\n");
|
||||
for (const auto& iface : interfaces) {
|
||||
printf(" %s: %s [%s%s%s]\n",
|
||||
iface.name.c_str(),
|
||||
iface.ip.c_str(),
|
||||
iface.is_up ? "UP" : "DOWN",
|
||||
iface.is_running ? ",RUNNING" : "",
|
||||
iface.is_loopback ? ",LOOPBACK" : "");
|
||||
}
|
||||
static std::string getFirstUsableInterface() {
|
||||
auto interfaces = getAllInterfaces();
|
||||
|
||||
for (const auto& iface : interfaces) {
|
||||
if (iface.isUsable()) {
|
||||
return iface.name;
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
static void displayAllInterfaces() {
|
||||
auto interfaces = getAllInterfaces();
|
||||
|
||||
printf("Available network interfaces:\n");
|
||||
for (const auto& iface : interfaces) {
|
||||
printf(" %s: %s [%s%s%s]\n",
|
||||
iface.name.c_str(),
|
||||
iface.ip.c_str(),
|
||||
iface.is_up ? "UP" : "DOWN",
|
||||
iface.is_running ? ",RUNNING" : "",
|
||||
iface.is_loopback ? ",LOOPBACK" : "");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class KVCacheConfig {
|
||||
private:
|
||||
// Configuration values
|
||||
int rdma_gid_index_;
|
||||
bool has_rdma_dest_port_override_; // 替代 std::optional
|
||||
int rdma_dest_port_override_;
|
||||
const char* socket_interface_;
|
||||
char* socket_interface_buffer_;
|
||||
bool gdrcopy_flush_enabled_;
|
||||
bool verify_read_enabled_;
|
||||
bool debug_mode_enabled_;
|
||||
bool debug_output_enabled_;
|
||||
const char* debug_file_path_;
|
||||
const char* error_file_path_;
|
||||
bool relax_ordering_enabled_;
|
||||
int ib_timeout_;
|
||||
const char* rdma_nics_;
|
||||
private:
|
||||
// Configuration values
|
||||
int rdma_gid_index_;
|
||||
bool has_rdma_dest_port_override_; // 替代 std::optional
|
||||
int rdma_dest_port_override_;
|
||||
const char* socket_interface_;
|
||||
char* socket_interface_buffer_;
|
||||
bool gdrcopy_flush_enabled_;
|
||||
bool verify_read_enabled_;
|
||||
bool debug_mode_enabled_;
|
||||
bool debug_output_enabled_;
|
||||
const char* debug_file_path_;
|
||||
const char* error_file_path_;
|
||||
bool relax_ordering_enabled_;
|
||||
int ib_timeout_;
|
||||
const char* rdma_nics_;
|
||||
|
||||
// Private constructor for singleton pattern
|
||||
KVCacheConfig() {
|
||||
// Initialize configuration from environment variables
|
||||
rdma_gid_index_ = parse_int_value(
|
||||
std::getenv("KVCACHE_RDMA_GID_INDEX"), 3, "KVCACHE_RDMA_GID_INDEX");
|
||||
// Private constructor for singleton pattern
|
||||
KVCacheConfig() {
|
||||
// Initialize configuration from environment variables
|
||||
rdma_gid_index_ = parse_int_value(
|
||||
std::getenv("KVCACHE_RDMA_GID_INDEX"), 3, "KVCACHE_RDMA_GID_INDEX");
|
||||
|
||||
// Parse optional RDMA port override
|
||||
const char* port_value = std::getenv("SET_RDMA_DEST_PORT");
|
||||
has_rdma_dest_port_override_ = false; // 默认为false
|
||||
if (port_value) {
|
||||
try {
|
||||
rdma_dest_port_override_ = std::stoi(std::string(port_value));
|
||||
has_rdma_dest_port_override_ = true;
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "Invalid SET_RDMA_DEST_PORT value: '%s', ignoring\n", port_value);
|
||||
}
|
||||
}
|
||||
|
||||
const char* env_interface = std::getenv("KVCACHE_SOCKET_IFNAME");
|
||||
|
||||
if (env_interface && env_interface[0] != '\0') {
|
||||
socket_interface_ = env_interface;
|
||||
printf("Using specified interface: %s\n", socket_interface_);
|
||||
} else {
|
||||
std::string iface = NetworkInterfaceManager::getFirstUsableInterface();
|
||||
if (!iface.empty()) {
|
||||
socket_interface_buffer_ = new char[iface.size() + 1];
|
||||
std::strcpy(socket_interface_buffer_, iface.c_str());
|
||||
socket_interface_ = socket_interface_buffer_;
|
||||
printf("Auto-detected interface: %s\n", socket_interface_);
|
||||
} else {
|
||||
fprintf(stderr, "Warning: No usable network interface found\n");
|
||||
socket_interface_ = "";
|
||||
}
|
||||
NetworkInterfaceManager::displayAllInterfaces();
|
||||
}
|
||||
|
||||
socket_interface_ = std::getenv("KVCACHE_SOCKET_IFNAME");
|
||||
debug_file_path_ = std::getenv("KVCACHE_DEBUG_FILE");
|
||||
error_file_path_ = std::getenv("KVCACHE_ERROR_FILE");
|
||||
|
||||
gdrcopy_flush_enabled_ = parse_bool_value(std::getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE"));
|
||||
verify_read_enabled_ = parse_bool_value(std::getenv("KVCACHE_VERIFY_READ"));
|
||||
debug_mode_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG")) ||
|
||||
parse_bool_value(std::getenv("KV_IS_DEBUG_ENABLED"));
|
||||
debug_output_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG_OUTPUT"));
|
||||
|
||||
relax_ordering_enabled_ = parse_bool_value(std::getenv("KVCACHE_RELAX_ORDERING"));
|
||||
|
||||
ib_timeout_ = parse_int_value(
|
||||
std::getenv("KVCACHE_IB_TIMEOUT"),
|
||||
18,
|
||||
"KVCACHE_IB_TIMEOUT"
|
||||
);
|
||||
|
||||
rdma_nics_ = std::getenv("KVCACHE_RDMA_NICS");
|
||||
// Parse optional RDMA port override
|
||||
const char* port_value = std::getenv("SET_RDMA_DEST_PORT");
|
||||
has_rdma_dest_port_override_ = false; // 默认为false
|
||||
if (port_value) {
|
||||
try {
|
||||
rdma_dest_port_override_ = std::stoi(std::string(port_value));
|
||||
has_rdma_dest_port_override_ = true;
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr,
|
||||
"Invalid SET_RDMA_DEST_PORT value: '%s', ignoring\n",
|
||||
port_value);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
bool parse_bool_value(const char* value) {
|
||||
if (!value) return false;
|
||||
const char* env_interface = std::getenv("KVCACHE_SOCKET_IFNAME");
|
||||
|
||||
std::string str_value(value);
|
||||
std::transform(str_value.begin(), str_value.end(), str_value.begin(), ::tolower);
|
||||
|
||||
return (str_value == "1" || str_value == "true" ||
|
||||
str_value == "on" || str_value == "yes");
|
||||
if (env_interface && env_interface[0] != '\0') {
|
||||
socket_interface_ = env_interface;
|
||||
printf("Using specified interface: %s\n", socket_interface_);
|
||||
} else {
|
||||
std::string iface = NetworkInterfaceManager::getFirstUsableInterface();
|
||||
if (!iface.empty()) {
|
||||
socket_interface_buffer_ = new char[iface.size() + 1];
|
||||
std::strcpy(socket_interface_buffer_, iface.c_str());
|
||||
socket_interface_ = socket_interface_buffer_;
|
||||
printf("Auto-detected interface: %s\n", socket_interface_);
|
||||
} else {
|
||||
fprintf(stderr, "Warning: No usable network interface found\n");
|
||||
socket_interface_ = "";
|
||||
}
|
||||
NetworkInterfaceManager::displayAllInterfaces();
|
||||
}
|
||||
|
||||
int parse_int_value(const char* value, int default_value, const char* env_name) {
|
||||
if (!value) return default_value;
|
||||
socket_interface_ = std::getenv("KVCACHE_SOCKET_IFNAME");
|
||||
debug_file_path_ = std::getenv("KVCACHE_DEBUG_FILE");
|
||||
error_file_path_ = std::getenv("KVCACHE_ERROR_FILE");
|
||||
|
||||
try {
|
||||
return std::stoi(std::string(value));
|
||||
} catch (const std::invalid_argument& e) {
|
||||
fprintf(stderr, "Invalid value for %s: '%s', using default: %d\n",
|
||||
env_name, value, default_value);
|
||||
return default_value;
|
||||
} catch (const std::out_of_range& e) {
|
||||
fprintf(stderr, "%s value out of range: '%s', using default: %d\n",
|
||||
env_name, value, default_value);
|
||||
return default_value;
|
||||
}
|
||||
gdrcopy_flush_enabled_ =
|
||||
parse_bool_value(std::getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE"));
|
||||
verify_read_enabled_ = parse_bool_value(std::getenv("KVCACHE_VERIFY_READ"));
|
||||
debug_mode_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG")) ||
|
||||
parse_bool_value(std::getenv("KV_IS_DEBUG_ENABLED"));
|
||||
debug_output_enabled_ =
|
||||
parse_bool_value(std::getenv("KVCACHE_DEBUG_OUTPUT"));
|
||||
|
||||
relax_ordering_enabled_ =
|
||||
parse_bool_value(std::getenv("KVCACHE_RELAX_ORDERING"));
|
||||
|
||||
ib_timeout_ = parse_int_value(
|
||||
std::getenv("KVCACHE_IB_TIMEOUT"), 18, "KVCACHE_IB_TIMEOUT");
|
||||
|
||||
rdma_nics_ = std::getenv("KVCACHE_RDMA_NICS");
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
bool parse_bool_value(const char* value) {
|
||||
if (!value) return false;
|
||||
|
||||
std::string str_value(value);
|
||||
std::transform(
|
||||
str_value.begin(), str_value.end(), str_value.begin(), ::tolower);
|
||||
|
||||
return (str_value == "1" || str_value == "true" || str_value == "on" ||
|
||||
str_value == "yes");
|
||||
}
|
||||
|
||||
int parse_int_value(const char* value,
|
||||
int default_value,
|
||||
const char* env_name) {
|
||||
if (!value) return default_value;
|
||||
|
||||
try {
|
||||
return std::stoi(std::string(value));
|
||||
} catch (const std::invalid_argument& e) {
|
||||
fprintf(stderr,
|
||||
"Invalid value for %s: '%s', using default: %d\n",
|
||||
env_name,
|
||||
value,
|
||||
default_value);
|
||||
return default_value;
|
||||
} catch (const std::out_of_range& e) {
|
||||
fprintf(stderr,
|
||||
"%s value out of range: '%s', using default: %d\n",
|
||||
env_name,
|
||||
value,
|
||||
default_value);
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Prevent copying and assignment
|
||||
KVCacheConfig(const KVCacheConfig&) = delete;
|
||||
KVCacheConfig& operator=(const KVCacheConfig&) = delete;
|
||||
|
||||
// Get singleton instance
|
||||
static KVCacheConfig& getInstance() {
|
||||
static KVCacheConfig instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
int get_ib_timeout() const { return ib_timeout_; }
|
||||
|
||||
// Configuration retrieval methods
|
||||
int get_rdma_gid_index() const { return rdma_gid_index_; }
|
||||
|
||||
int resolve_rdma_dest_port(int default_port) const {
|
||||
return has_rdma_dest_port_override_ ? rdma_dest_port_override_
|
||||
: default_port;
|
||||
}
|
||||
|
||||
int resolve_rdma_dest_port(const std::string& default_port) const {
|
||||
try {
|
||||
return resolve_rdma_dest_port(std::stoi(default_port));
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(
|
||||
stderr, "Invalid default port string: %s\n", default_port.c_str());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
const char* get_socket_interface() const { return socket_interface_; }
|
||||
const char* get_debug_file_path() const { return debug_file_path_; }
|
||||
const char* get_error_file_path() const { return error_file_path_; }
|
||||
const char* get_rdma_nics() const { return rdma_nics_; }
|
||||
|
||||
// Feature check methods
|
||||
bool is_gdrcopy_flush_enabled() const { return gdrcopy_flush_enabled_; }
|
||||
bool is_verify_read_enabled() const { return verify_read_enabled_; }
|
||||
bool is_debug_mode_enabled() const { return debug_mode_enabled_; }
|
||||
bool is_debug_output_enabled() const { return debug_output_enabled_; }
|
||||
bool is_relax_ordering_enabled() const { return relax_ordering_enabled_; }
|
||||
|
||||
// Display configuration
|
||||
void displayConfiguration() const {
|
||||
INFO("KVCache Configuration:\n");
|
||||
INFO("Init KVCacheConfig RDMA GID Index: %d\n", rdma_gid_index_);
|
||||
|
||||
if (has_rdma_dest_port_override_) {
|
||||
INFO("Init KVCacheConfig RDMA Destination Port Override: %d\n",
|
||||
rdma_dest_port_override_);
|
||||
}
|
||||
|
||||
public:
|
||||
// Prevent copying and assignment
|
||||
KVCacheConfig(const KVCacheConfig&) = delete;
|
||||
KVCacheConfig& operator=(const KVCacheConfig&) = delete;
|
||||
|
||||
// Get singleton instance
|
||||
static KVCacheConfig& getInstance() {
|
||||
static KVCacheConfig instance;
|
||||
return instance;
|
||||
if (socket_interface_) {
|
||||
INFO("Init KVCacheConfig Socket Interface: %s\n", socket_interface_);
|
||||
}
|
||||
|
||||
int get_ib_timeout() const { return ib_timeout_; }
|
||||
INFO("Init KVCacheConfig GDRCopy Flush: %s\n",
|
||||
gdrcopy_flush_enabled_ ? "enabled" : "disabled");
|
||||
INFO("Init KVCacheConfig Verify Read: %s\n",
|
||||
verify_read_enabled_ ? "enabled" : "disabled");
|
||||
INFO("Init KVCacheConfig Debug Mode: %s\n",
|
||||
debug_mode_enabled_ ? "enabled" : "disabled");
|
||||
INFO("Init KVCacheConfig Debug Output: %s\n",
|
||||
debug_output_enabled_ ? "enabled" : "disabled");
|
||||
|
||||
// Configuration retrieval methods
|
||||
int get_rdma_gid_index() const { return rdma_gid_index_; }
|
||||
|
||||
int resolve_rdma_dest_port(int default_port) const {
|
||||
return has_rdma_dest_port_override_ ? rdma_dest_port_override_ : default_port;
|
||||
if (debug_file_path_) {
|
||||
INFO("Init KVCacheConfig Debug File: %s\n", debug_file_path_);
|
||||
}
|
||||
|
||||
int resolve_rdma_dest_port(const std::string& default_port) const {
|
||||
try {
|
||||
return resolve_rdma_dest_port(std::stoi(default_port));
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "Invalid default port string: %s\n", default_port.c_str());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
const char* get_socket_interface() const { return socket_interface_; }
|
||||
const char* get_debug_file_path() const { return debug_file_path_; }
|
||||
const char* get_error_file_path() const { return error_file_path_; }
|
||||
const char* get_rdma_nics() const { return rdma_nics_; }
|
||||
|
||||
// Feature check methods
|
||||
bool is_gdrcopy_flush_enabled() const { return gdrcopy_flush_enabled_; }
|
||||
bool is_verify_read_enabled() const { return verify_read_enabled_; }
|
||||
bool is_debug_mode_enabled() const { return debug_mode_enabled_; }
|
||||
bool is_debug_output_enabled() const { return debug_output_enabled_; }
|
||||
bool is_relax_ordering_enabled() const { return relax_ordering_enabled_; }
|
||||
|
||||
// Display configuration
|
||||
void displayConfiguration() const {
|
||||
INFO("KVCache Configuration:\n");
|
||||
INFO("Init KVCacheConfig RDMA GID Index: %d\n", rdma_gid_index_);
|
||||
|
||||
if (has_rdma_dest_port_override_) {
|
||||
INFO("Init KVCacheConfig RDMA Destination Port Override: %d\n", rdma_dest_port_override_);
|
||||
}
|
||||
|
||||
if (socket_interface_) {
|
||||
INFO("Init KVCacheConfig Socket Interface: %s\n", socket_interface_);
|
||||
}
|
||||
|
||||
INFO("Init KVCacheConfig GDRCopy Flush: %s\n", gdrcopy_flush_enabled_ ? "enabled" : "disabled");
|
||||
INFO("Init KVCacheConfig Verify Read: %s\n", verify_read_enabled_ ? "enabled" : "disabled");
|
||||
INFO("Init KVCacheConfig Debug Mode: %s\n", debug_mode_enabled_ ? "enabled" : "disabled");
|
||||
INFO("Init KVCacheConfig Debug Output: %s\n", debug_output_enabled_ ? "enabled" : "disabled");
|
||||
|
||||
if (debug_file_path_) {
|
||||
INFO("Init KVCacheConfig Debug File: %s\n", debug_file_path_);
|
||||
}
|
||||
|
||||
if (error_file_path_) {
|
||||
INFO("Init KVCacheConfig Error File: %s\n", error_file_path_);
|
||||
}
|
||||
if (error_file_path_) {
|
||||
INFO("Init KVCacheConfig Error File: %s\n", error_file_path_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
+841
-794
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -17,14 +17,14 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdarg.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <sys/stat.h>
|
||||
#include <libgen.h>
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
#include "log.h"
|
||||
#include <errno.h>
|
||||
#include <libgen.h>
|
||||
#include <stdarg.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/syscall.h>
|
||||
#include "util.h"
|
||||
|
||||
static int pid = -1;
|
||||
@@ -33,180 +33,237 @@ static char hostname[64];
|
||||
char global_log_last_error[1024] = "";
|
||||
FILE *global_debug_file = stdout;
|
||||
FILE *global_error_file = stdout;
|
||||
static char global_debug_file_name[PATH_MAX+1] = "";
|
||||
static char global_err_file_name[PATH_MAX+1] = "";
|
||||
static char global_debug_file_name[PATH_MAX + 1] = "";
|
||||
static char global_err_file_name[PATH_MAX + 1] = "";
|
||||
int global_debug_level = -1;
|
||||
pthread_mutex_t global_debug_lock = PTHREAD_MUTEX_INITIALIZER;
|
||||
pthread_mutex_t global_log_file_lock = PTHREAD_MUTEX_INITIALIZER;
|
||||
|
||||
void log_file_init(FILE **kv_cache_log_file, const char *kv_cache_log_file_env, char *logFileName) {
|
||||
int c = 0;
|
||||
char *dfn = logFileName;
|
||||
while (c < PATH_MAX && kv_cache_log_file_env[c] != '\0') {
|
||||
if (kv_cache_log_file_env[c++] != '%') {
|
||||
*dfn++ = kv_cache_log_file_env[c - 1];
|
||||
continue;
|
||||
}
|
||||
switch (kv_cache_log_file_env[c++]) {
|
||||
case '%': // Double %
|
||||
*dfn++ = '%';
|
||||
break;
|
||||
case 'h': // %h = hostname
|
||||
dfn += snprintf(dfn, PATH_MAX, "%s", hostname);
|
||||
break;
|
||||
case 'p': // %p = pid
|
||||
dfn += snprintf(dfn, PATH_MAX, "%d", pid);
|
||||
break;
|
||||
default: // Echo everything we don't understand
|
||||
*dfn++ = '%';
|
||||
*dfn++ = kv_cache_log_file_env[c - 1];
|
||||
break;
|
||||
}
|
||||
void log_file_init(FILE **kv_cache_log_file,
|
||||
const char *kv_cache_log_file_env,
|
||||
char *logFileName) {
|
||||
int c = 0;
|
||||
char *dfn = logFileName;
|
||||
while (c < PATH_MAX && kv_cache_log_file_env[c] != '\0') {
|
||||
if (kv_cache_log_file_env[c++] != '%') {
|
||||
*dfn++ = kv_cache_log_file_env[c - 1];
|
||||
continue;
|
||||
}
|
||||
*dfn = '\0';
|
||||
if (logFileName[0] != '\0') {
|
||||
FILE *file = fopen(logFileName, "w");
|
||||
if (file != nullptr) {
|
||||
setbuf(file, nullptr); // disable buffering
|
||||
*kv_cache_log_file = file;
|
||||
}
|
||||
switch (kv_cache_log_file_env[c++]) {
|
||||
case '%': // Double %
|
||||
*dfn++ = '%';
|
||||
break;
|
||||
case 'h': // %h = hostname
|
||||
dfn += snprintf(dfn, PATH_MAX, "%s", hostname);
|
||||
break;
|
||||
case 'p': // %p = pid
|
||||
dfn += snprintf(dfn, PATH_MAX, "%d", pid);
|
||||
break;
|
||||
default: // Echo everything we don't understand
|
||||
*dfn++ = '%';
|
||||
*dfn++ = kv_cache_log_file_env[c - 1];
|
||||
break;
|
||||
}
|
||||
}
|
||||
*dfn = '\0';
|
||||
if (logFileName[0] != '\0') {
|
||||
FILE *file = fopen(logFileName, "w");
|
||||
if (file != nullptr) {
|
||||
setbuf(file, nullptr); // disable buffering
|
||||
*kv_cache_log_file = file;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void recreate_log_file(FILE **kv_cache_log_file, char *logFileName) {
|
||||
if (logFileName[0] != '\0') {
|
||||
pthread_mutex_lock(&global_log_file_lock);
|
||||
FILE *file = fopen(logFileName, "a"); // Use "a" mode to append if file exists, otherwise create it
|
||||
// close the previous log file if it exists
|
||||
if (*kv_cache_log_file != NULL && *kv_cache_log_file != file) {
|
||||
fclose(*kv_cache_log_file);
|
||||
*kv_cache_log_file = NULL;
|
||||
}
|
||||
if (file != NULL) {
|
||||
setbuf(file, NULL); // disable buffering
|
||||
*kv_cache_log_file = file;
|
||||
}
|
||||
pthread_mutex_unlock(&global_log_file_lock);
|
||||
if (logFileName[0] != '\0') {
|
||||
pthread_mutex_lock(&global_log_file_lock);
|
||||
FILE *file = fopen(
|
||||
logFileName,
|
||||
"a"); // Use "a" mode to append if file exists, otherwise create it
|
||||
// close the previous log file if it exists
|
||||
if (*kv_cache_log_file != NULL && *kv_cache_log_file != file) {
|
||||
fclose(*kv_cache_log_file);
|
||||
*kv_cache_log_file = NULL;
|
||||
}
|
||||
if (file != NULL) {
|
||||
setbuf(file, NULL); // disable buffering
|
||||
*kv_cache_log_file = file;
|
||||
}
|
||||
pthread_mutex_unlock(&global_log_file_lock);
|
||||
}
|
||||
}
|
||||
|
||||
void debug_init() {
|
||||
pthread_mutex_lock(&global_debug_lock);
|
||||
if (global_debug_level != -1) {
|
||||
pthread_mutex_unlock(&global_debug_lock);
|
||||
return;
|
||||
}
|
||||
|
||||
const char* kv_cache_debug = std::getenv("KV_IS_DEBUG_ENABLED");
|
||||
int tempg_kv_cache_debug_level = -1;
|
||||
|
||||
if (kv_cache_debug == NULL) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
|
||||
} else if (strcasecmp(kv_cache_debug, "0") == 0) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
|
||||
} else if (strcasecmp(kv_cache_debug, "1") == 0) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_DEBUG;
|
||||
} else if (strcasecmp(kv_cache_debug, "2") == 0) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_WARN;
|
||||
} else if (strcasecmp(kv_cache_debug, "3") == 0) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_ERROR;
|
||||
} else {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
|
||||
}
|
||||
|
||||
gethostname(hostname, 64);
|
||||
pid = getpid();
|
||||
|
||||
const char* g_kv_cache_debug_fileEnv = KVCacheConfig::getInstance().get_debug_file_path();
|
||||
if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO && g_kv_cache_debug_fileEnv != NULL) {
|
||||
log_file_init(&global_debug_file, g_kv_cache_debug_fileEnv, global_debug_file_name);
|
||||
}
|
||||
|
||||
const char* g_kv_cache_error_fileEnv = KVCacheConfig::getInstance().get_error_file_path();
|
||||
if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO && g_kv_cache_error_fileEnv != NULL) {
|
||||
log_file_init(&global_error_file, g_kv_cache_error_fileEnv, global_err_file_name);
|
||||
char buffer[1024];
|
||||
size_t len = 0;
|
||||
char timeBuffer[80]; // Buffer to hold the formatted time
|
||||
std::time_t absoluteTime = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||
std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", std::localtime(&absoluteTime));
|
||||
len = snprintf(buffer, sizeof(buffer), "%s KV_CACHE START ", timeBuffer);
|
||||
buffer[len++] = '\n';
|
||||
if (global_error_file != NULL) {
|
||||
fwrite(buffer, 1, len, global_error_file);
|
||||
}
|
||||
}
|
||||
__atomic_store_n(&global_debug_level, tempg_kv_cache_debug_level, __ATOMIC_RELEASE);
|
||||
pthread_mutex_lock(&global_debug_lock);
|
||||
if (global_debug_level != -1) {
|
||||
pthread_mutex_unlock(&global_debug_lock);
|
||||
return;
|
||||
}
|
||||
|
||||
const char *kv_cache_debug = std::getenv("KV_IS_DEBUG_ENABLED");
|
||||
int tempg_kv_cache_debug_level = -1;
|
||||
|
||||
if (kv_cache_debug == NULL) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
|
||||
} else if (strcasecmp(kv_cache_debug, "0") == 0) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
|
||||
} else if (strcasecmp(kv_cache_debug, "1") == 0) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_DEBUG;
|
||||
} else if (strcasecmp(kv_cache_debug, "2") == 0) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_WARN;
|
||||
} else if (strcasecmp(kv_cache_debug, "3") == 0) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_ERROR;
|
||||
} else {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
|
||||
}
|
||||
|
||||
gethostname(hostname, 64);
|
||||
pid = getpid();
|
||||
|
||||
const char *g_kv_cache_debug_fileEnv =
|
||||
KVCacheConfig::getInstance().get_debug_file_path();
|
||||
if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO &&
|
||||
g_kv_cache_debug_fileEnv != NULL) {
|
||||
log_file_init(
|
||||
&global_debug_file, g_kv_cache_debug_fileEnv, global_debug_file_name);
|
||||
}
|
||||
|
||||
const char *g_kv_cache_error_fileEnv =
|
||||
KVCacheConfig::getInstance().get_error_file_path();
|
||||
if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO &&
|
||||
g_kv_cache_error_fileEnv != NULL) {
|
||||
log_file_init(
|
||||
&global_error_file, g_kv_cache_error_fileEnv, global_err_file_name);
|
||||
char buffer[1024];
|
||||
size_t len = 0;
|
||||
char timeBuffer[80]; // Buffer to hold the formatted time
|
||||
std::time_t absoluteTime =
|
||||
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||
std::strftime(timeBuffer,
|
||||
sizeof(timeBuffer),
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
std::localtime(&absoluteTime));
|
||||
len = snprintf(buffer, sizeof(buffer), "%s KV_CACHE START ", timeBuffer);
|
||||
buffer[len++] = '\n';
|
||||
if (global_error_file != NULL) {
|
||||
fwrite(buffer, 1, len, global_error_file);
|
||||
}
|
||||
}
|
||||
__atomic_store_n(
|
||||
&global_debug_level, tempg_kv_cache_debug_level, __ATOMIC_RELEASE);
|
||||
pthread_mutex_unlock(&global_debug_lock);
|
||||
}
|
||||
|
||||
/* Common logging function used by the INFO, DEBUG and WARN macros
|
||||
* Also exported to the dynamically loadable Net transport modules so
|
||||
* they can share the debugging mechanisms and output files
|
||||
*/
|
||||
void debug_log(KVLogLevel level, bool enable_to_terminal, const char *filefunc, int line, const char *fmt, ...) {
|
||||
if (__atomic_load_n(&global_debug_level, __ATOMIC_ACQUIRE) == -1) {
|
||||
debug_init();
|
||||
}
|
||||
void debug_log(KVLogLevel level,
|
||||
bool enable_to_terminal,
|
||||
const char *filefunc,
|
||||
int line,
|
||||
const char *fmt,
|
||||
...) {
|
||||
if (__atomic_load_n(&global_debug_level, __ATOMIC_ACQUIRE) == -1) {
|
||||
debug_init();
|
||||
}
|
||||
|
||||
// Save the last error (WARN) as a human readable string
|
||||
if (level == KV_LOG_LEVEL_WARN) {
|
||||
pthread_mutex_lock(&global_debug_lock);
|
||||
va_list vargs;
|
||||
va_start(vargs, fmt);
|
||||
(void) vsnprintf(global_log_last_error, sizeof(global_log_last_error), fmt, vargs);
|
||||
va_end(vargs);
|
||||
pthread_mutex_unlock(&global_debug_lock);
|
||||
}
|
||||
// Save the last error (WARN) as a human readable string
|
||||
if (level == KV_LOG_LEVEL_WARN) {
|
||||
pthread_mutex_lock(&global_debug_lock);
|
||||
va_list vargs;
|
||||
va_start(vargs, fmt);
|
||||
(void)vsnprintf(
|
||||
global_log_last_error, sizeof(global_log_last_error), fmt, vargs);
|
||||
va_end(vargs);
|
||||
pthread_mutex_unlock(&global_debug_lock);
|
||||
}
|
||||
|
||||
if (tid == -1) {
|
||||
tid = syscall(SYS_gettid);
|
||||
}
|
||||
if (tid == -1) {
|
||||
tid = syscall(SYS_gettid);
|
||||
}
|
||||
|
||||
char buffer[1024];
|
||||
size_t len = 0;
|
||||
// Convert timestamp to absolute time and directly use it in the snprintf function
|
||||
std::time_t absoluteTime = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||
char timeBuffer[80]; // Buffer to hold the formatted time
|
||||
std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", std::localtime(&absoluteTime));
|
||||
char buffer[1024];
|
||||
size_t len = 0;
|
||||
// Convert timestamp to absolute time and directly use it in the snprintf
|
||||
// function
|
||||
std::time_t absoluteTime =
|
||||
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||
char timeBuffer[80]; // Buffer to hold the formatted time
|
||||
std::strftime(timeBuffer,
|
||||
sizeof(timeBuffer),
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
std::localtime(&absoluteTime));
|
||||
|
||||
if (level == KV_LOG_LEVEL_WARN) {
|
||||
len = snprintf(buffer, sizeof(buffer), "\n%s %s:%d:%d %s:%d KV_CACHE WARN ",
|
||||
timeBuffer, hostname, pid, tid, filefunc, line);
|
||||
} else if (level == KV_LOG_LEVEL_INFO) {
|
||||
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE INFO ", timeBuffer, hostname, pid, tid);
|
||||
} else if (level == KV_LOG_LEVEL_DEBUG) {
|
||||
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE DEBUG ", timeBuffer, hostname, pid, tid);
|
||||
} else if (level == KV_LOG_LEVEL_ERROR) {
|
||||
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE ERROR ", timeBuffer, hostname, pid, tid);
|
||||
} else {
|
||||
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE ", timeBuffer, hostname, pid, tid);
|
||||
}
|
||||
if (level == KV_LOG_LEVEL_WARN) {
|
||||
len = snprintf(buffer,
|
||||
sizeof(buffer),
|
||||
"\n%s %s:%d:%d %s:%d KV_CACHE WARN ",
|
||||
timeBuffer,
|
||||
hostname,
|
||||
pid,
|
||||
tid,
|
||||
filefunc,
|
||||
line);
|
||||
} else if (level == KV_LOG_LEVEL_INFO) {
|
||||
len = snprintf(buffer,
|
||||
sizeof(buffer),
|
||||
"%s %s:%d:%d KV_CACHE INFO ",
|
||||
timeBuffer,
|
||||
hostname,
|
||||
pid,
|
||||
tid);
|
||||
} else if (level == KV_LOG_LEVEL_DEBUG) {
|
||||
len = snprintf(buffer,
|
||||
sizeof(buffer),
|
||||
"%s %s:%d:%d KV_CACHE DEBUG ",
|
||||
timeBuffer,
|
||||
hostname,
|
||||
pid,
|
||||
tid);
|
||||
} else if (level == KV_LOG_LEVEL_ERROR) {
|
||||
len = snprintf(buffer,
|
||||
sizeof(buffer),
|
||||
"%s %s:%d:%d KV_CACHE ERROR ",
|
||||
timeBuffer,
|
||||
hostname,
|
||||
pid,
|
||||
tid);
|
||||
} else {
|
||||
len = snprintf(buffer,
|
||||
sizeof(buffer),
|
||||
"%s %s:%d:%d KV_CACHE ",
|
||||
timeBuffer,
|
||||
hostname,
|
||||
pid,
|
||||
tid);
|
||||
}
|
||||
|
||||
if (len) {
|
||||
va_list vargs;
|
||||
va_start(vargs, fmt);
|
||||
len += vsnprintf(buffer + len, sizeof(buffer) - len, fmt, vargs);
|
||||
va_end(vargs);
|
||||
// vsnprintf may return len > sizeof(buffer) in the case of a truncated output.
|
||||
// Rewind len so that we can replace the final \0 by \n
|
||||
if (len > sizeof(buffer)) {
|
||||
len = sizeof(buffer) - 1;
|
||||
}
|
||||
buffer[len++] = '\n';
|
||||
if (access(global_debug_file_name, F_OK) != 0) {
|
||||
recreate_log_file(&global_debug_file, global_debug_file_name);
|
||||
}
|
||||
if (enable_to_terminal) {
|
||||
fwrite(buffer, 1, len, global_debug_file);
|
||||
}
|
||||
if (level == KV_LOG_LEVEL_WARN && global_error_file != stdout) {
|
||||
if (access(global_err_file_name, F_OK) != 0) {
|
||||
recreate_log_file(&global_error_file, global_err_file_name);
|
||||
}
|
||||
if (global_error_file != NULL) {
|
||||
fwrite(buffer, 1, len, global_error_file);
|
||||
}
|
||||
}
|
||||
if (len) {
|
||||
va_list vargs;
|
||||
va_start(vargs, fmt);
|
||||
len += vsnprintf(buffer + len, sizeof(buffer) - len, fmt, vargs);
|
||||
va_end(vargs);
|
||||
// vsnprintf may return len > sizeof(buffer) in the case of a truncated
|
||||
// output. Rewind len so that we can replace the final \0 by \n
|
||||
if (len > sizeof(buffer)) {
|
||||
len = sizeof(buffer) - 1;
|
||||
}
|
||||
buffer[len++] = '\n';
|
||||
if (access(global_debug_file_name, F_OK) != 0) {
|
||||
recreate_log_file(&global_debug_file, global_debug_file_name);
|
||||
}
|
||||
if (enable_to_terminal) {
|
||||
fwrite(buffer, 1, len, global_debug_file);
|
||||
}
|
||||
if (level == KV_LOG_LEVEL_WARN && global_error_file != stdout) {
|
||||
if (access(global_err_file_name, F_OK) != 0) {
|
||||
recreate_log_file(&global_error_file, global_err_file_name);
|
||||
}
|
||||
if (global_error_file != NULL) {
|
||||
fwrite(buffer, 1, len, global_error_file);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,17 +6,22 @@
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(rdma_comm, m) {
|
||||
m.doc() = R"pbdoc(kv cache messager)pbdoc";
|
||||
py::class_<RDMACommunicator>(m, "RDMACommunicator")
|
||||
.def(py::init<std::string &, int, std::string &, std::vector<int64_t>,
|
||||
std::vector<int64_t>, int, int>())
|
||||
.def("connect", &RDMACommunicator::connect)
|
||||
.def("is_connected", &RDMACommunicator::is_connected)
|
||||
.def("write_cache", &RDMACommunicator::write_cache);
|
||||
m.doc() = R"pbdoc(kv cache messager)pbdoc";
|
||||
py::class_<RDMACommunicator>(m, "RDMACommunicator")
|
||||
.def(py::init<std::string &,
|
||||
int,
|
||||
std::string &,
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
int,
|
||||
int>())
|
||||
.def("connect", &RDMACommunicator::connect)
|
||||
.def("is_connected", &RDMACommunicator::is_connected)
|
||||
.def("write_cache", &RDMACommunicator::write_cache);
|
||||
|
||||
#ifdef VERSION_INFO
|
||||
m.attr("__version__") = VERSION_INFO;
|
||||
m.attr("__version__") = VERSION_INFO;
|
||||
#else
|
||||
m.attr("__version__") = "dev";
|
||||
m.attr("__version__") = "dev";
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -28,6 +28,11 @@ if ! [[ $(python -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $1$2}') -ge 36
|
||||
please change the default python to higher version."
|
||||
exit 1
|
||||
fi
|
||||
if ! [[ $version == *"$VERSION"* ]]; then
|
||||
# low version of pip may not have the source of clang-format whl
|
||||
pip install --upgrade pip
|
||||
pip install clang-format==13.0.0
|
||||
fi
|
||||
|
||||
# Exclude any files under the 'test/ce/server/' directory from code style checks.
|
||||
diff_files=$(git diff --name-only --diff-filter=ACMR ${BRANCH} | grep -v '^tests/ce/server/')
|
||||
|
||||
Reference in New Issue
Block a user