mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] use quant2d_per_token for weight quant int8 && fix some XPU Kernel check (#6869)
This commit is contained in:
@@ -51,9 +51,6 @@ std::vector<paddle::Tensor> WeightOnlyLinearKernel(
|
||||
int64_t n = w_shape[0];
|
||||
int64_t k = w_shape[1];
|
||||
int64_t m = x.numel() / k;
|
||||
if (weight_dtype == "int4_t") {
|
||||
n = n * 2;
|
||||
}
|
||||
paddle::Tensor out = paddle::empty({m, n}, x.dtype(), x.place());
|
||||
if (m == 0) {
|
||||
return {out};
|
||||
@@ -148,7 +145,7 @@ std::vector<paddle::Tensor> WeightOnlyLinear(
|
||||
#define APPLY_FFN_KERNEL(TX, TW) \
|
||||
return WeightOnlyLinearKernel<TX, TW>( \
|
||||
x, weight, weight_scale, bias, weight_dtype);
|
||||
|
||||
PD_CHECK(weight_dtype != "int4_t", "WeightOnlyLinear not support wint4");
|
||||
if (x_type == paddle::DataType::BFLOAT16 &&
|
||||
w_type == paddle::DataType::INT8) {
|
||||
APPLY_FFN_KERNEL(paddle::bfloat16, int8_t);
|
||||
|
||||
@@ -30,18 +30,36 @@ std::vector<paddle::Tensor> WeightQuantizeKernel(const paddle::Tensor &x,
|
||||
int64_t n = x.shape()[1];
|
||||
|
||||
paddle::Tensor scale =
|
||||
paddle::full({n}, 0, paddle::DataType::FLOAT32, x.place());
|
||||
paddle::empty({n}, paddle::DataType::FLOAT32, x.place());
|
||||
if (algo == "weight_only_int8") {
|
||||
paddle::Tensor out =
|
||||
paddle::full({k, n}, 0, paddle::DataType::INT8, x.place());
|
||||
int ret = fastdeploy::plugin::quant2d_per_channel<XPUType, float, int8_t>(
|
||||
paddle::empty({k, n}, paddle::DataType::INT8, x.place());
|
||||
paddle::Tensor x_trans = paddle::empty({k, n}, x.dtype(), x.place());
|
||||
paddle::Tensor out_trans =
|
||||
paddle::empty({k, n}, paddle::DataType::INT8, x.place());
|
||||
XPUType *x_trans_ptr = const_cast<XPUType *>(
|
||||
reinterpret_cast<const XPUType *>(x_trans.data<T>()));
|
||||
int ret = baidu::xpu::api::transpose<XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPUType *>(x.template data<T>()),
|
||||
reinterpret_cast<const XPUType *>(x.data<T>()),
|
||||
x_trans_ptr,
|
||||
{k, n},
|
||||
{1, 0});
|
||||
PD_CHECK(ret == 0);
|
||||
ret = infer_ops::quant2d_per_token<XPUType, float, int8_t>(
|
||||
xpu_ctx->x_context(),
|
||||
x_trans_ptr,
|
||||
nullptr,
|
||||
out.data<int8_t>(),
|
||||
out_trans.data<int8_t>(),
|
||||
scale.data<float>(),
|
||||
k,
|
||||
n);
|
||||
n,
|
||||
k);
|
||||
PD_CHECK(ret == 0);
|
||||
ret = baidu::xpu::api::transpose<int8_t>(xpu_ctx->x_context(),
|
||||
out_trans.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
{n, k},
|
||||
{1, 0});
|
||||
PD_CHECK(ret == 0);
|
||||
return {out, scale};
|
||||
} else if (algo == "weight_only_int4") {
|
||||
@@ -49,30 +67,30 @@ std::vector<paddle::Tensor> WeightQuantizeKernel(const paddle::Tensor &x,
|
||||
// quant2d_per_token + transpose at now
|
||||
PD_CHECK(k % 2 == 0);
|
||||
paddle::Tensor out =
|
||||
paddle::full({(k + 1) / 2, n}, 0, paddle::DataType::INT8, x.place());
|
||||
xpu::ctx_guard RAII_GUARD(xpu_ctx->x_context());
|
||||
XPUType *x_trans = RAII_GUARD.alloc<XPUType>(k * n);
|
||||
int8_t *out_trans = RAII_GUARD.alloc<int8_t>(k * n / 2);
|
||||
PD_CHECK(x_trans != nullptr);
|
||||
PD_CHECK(out_trans != nullptr);
|
||||
paddle::empty({(k + 1) / 2, n}, paddle::DataType::INT8, x.place());
|
||||
paddle::Tensor x_trans = paddle::empty({k, n}, x.dtype(), x.place());
|
||||
paddle::Tensor out_trans =
|
||||
paddle::empty({(k + 1) / 2, n}, paddle::DataType::INT8, x.place());
|
||||
XPUType *x_trans_ptr = const_cast<XPUType *>(
|
||||
reinterpret_cast<const XPUType *>(x_trans.data<T>()));
|
||||
int ret = baidu::xpu::api::transpose<XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPUType *>(x.data<T>()),
|
||||
x_trans,
|
||||
x_trans_ptr,
|
||||
{k, n},
|
||||
{1, 0});
|
||||
PD_CHECK(ret == 0);
|
||||
ret = infer_ops::quant2d_per_token<XPUType, float, int4_t>(
|
||||
xpu_ctx->x_context(),
|
||||
x_trans,
|
||||
x_trans_ptr,
|
||||
nullptr,
|
||||
reinterpret_cast<int4_t *>(out_trans),
|
||||
reinterpret_cast<int4_t *>(out_trans.data<int8_t>()),
|
||||
scale.data<float>(),
|
||||
n,
|
||||
k);
|
||||
PD_CHECK(ret == 0);
|
||||
ret = baidu::xpu::api::transpose<int8_t>(xpu_ctx->x_context(),
|
||||
out_trans,
|
||||
out_trans.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
{n, k / 2},
|
||||
{1, 0});
|
||||
|
||||
@@ -217,7 +217,7 @@ int speculate_recover_block(api::Context *ctx,
|
||||
WRAPPER_CHECK_PTR(ctx, bool, bsz, stop_flags);
|
||||
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time);
|
||||
WRAPPER_CHECK_PTR(ctx, int, bsz, ori_seq_lens_encoder);
|
||||
WRAPPER_CHECK_PTR(ctx, int, bsz, ori_seq_lens_decoder);
|
||||
WRAPPER_CHECK_PTR_OR_NULL(ctx, int, bsz, ori_seq_lens_decoder);
|
||||
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder);
|
||||
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_decoder);
|
||||
WRAPPER_CHECK_PTR(ctx, int, bsz *block_num_per_seq, block_tables);
|
||||
|
||||
@@ -36,7 +36,6 @@ for i in range(bs):
|
||||
seq_lens_decoder[i] = i
|
||||
seq_lens_this_time[i] = 1
|
||||
input_ids_np = np.random.randint(1, 10, [max_bs, max_input_length], "int64")
|
||||
stop_nums = np.array([max_bs], "int64")
|
||||
next_tokens = np.random.randint(1, 10, [max_bs], "int64")
|
||||
is_block_step = np.random.randint(0, 2, [max_bs]).astype("bool")
|
||||
|
||||
@@ -46,7 +45,6 @@ seq_lens_this_time = paddle.to_tensor(seq_lens_this_time)
|
||||
seq_lens_encoder = paddle.to_tensor(seq_lens_encoder)
|
||||
seq_lens_decoder = paddle.to_tensor(seq_lens_decoder)
|
||||
input_ids = paddle.to_tensor(input_ids_np)
|
||||
stop_nums = paddle.to_tensor(stop_nums)
|
||||
next_tokens = paddle.to_tensor(next_tokens)
|
||||
is_block_step = paddle.to_tensor(is_block_step)
|
||||
|
||||
@@ -56,7 +54,6 @@ print("seq_lens_this_time:\n", seq_lens_this_time)
|
||||
print("seq_lens_encoder:\n", seq_lens_encoder)
|
||||
print("seq_lens_decoder:\n", seq_lens_decoder)
|
||||
print("input_ids:\n", input_ids)
|
||||
print("stop_nums:\n", stop_nums)
|
||||
print("next_tokens:\n", next_tokens)
|
||||
print("is_block_step:\n", is_block_step)
|
||||
|
||||
@@ -67,7 +64,6 @@ update_inputs(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
input_ids,
|
||||
stop_nums,
|
||||
next_tokens,
|
||||
is_block_step,
|
||||
)
|
||||
@@ -79,7 +75,6 @@ print("seq_lens_this_time:\n", seq_lens_this_time)
|
||||
print("seq_lens_encoder:\n", seq_lens_encoder)
|
||||
print("seq_lens_decoder:\n", seq_lens_decoder)
|
||||
print("input_ids:\n", input_ids)
|
||||
print("stop_nums:\n", stop_nums)
|
||||
print("next_tokens:\n", next_tokens)
|
||||
|
||||
ref_not_need_stop_out = np.array([True])
|
||||
|
||||
Reference in New Issue
Block a user