[XPU] use quant2d_per_token for weight quant int8 && fix some XPU Kernel check (#6869)

This commit is contained in:
lizan1999
2026-03-17 19:44:48 +08:00
committed by GitHub
parent aa9deb6ad4
commit 148eee84c6
4 changed files with 37 additions and 27 deletions
@@ -51,9 +51,6 @@ std::vector<paddle::Tensor> WeightOnlyLinearKernel(
int64_t n = w_shape[0];
int64_t k = w_shape[1];
int64_t m = x.numel() / k;
if (weight_dtype == "int4_t") {
n = n * 2;
}
paddle::Tensor out = paddle::empty({m, n}, x.dtype(), x.place());
if (m == 0) {
return {out};
@@ -148,7 +145,7 @@ std::vector<paddle::Tensor> WeightOnlyLinear(
#define APPLY_FFN_KERNEL(TX, TW) \
return WeightOnlyLinearKernel<TX, TW>( \
x, weight, weight_scale, bias, weight_dtype);
PD_CHECK(weight_dtype != "int4_t", "WeightOnlyLinear not support wint4");
if (x_type == paddle::DataType::BFLOAT16 &&
w_type == paddle::DataType::INT8) {
APPLY_FFN_KERNEL(paddle::bfloat16, int8_t);
@@ -30,18 +30,36 @@ std::vector<paddle::Tensor> WeightQuantizeKernel(const paddle::Tensor &x,
int64_t n = x.shape()[1];
paddle::Tensor scale =
paddle::full({n}, 0, paddle::DataType::FLOAT32, x.place());
paddle::empty({n}, paddle::DataType::FLOAT32, x.place());
if (algo == "weight_only_int8") {
paddle::Tensor out =
paddle::full({k, n}, 0, paddle::DataType::INT8, x.place());
int ret = fastdeploy::plugin::quant2d_per_channel<XPUType, float, int8_t>(
paddle::empty({k, n}, paddle::DataType::INT8, x.place());
paddle::Tensor x_trans = paddle::empty({k, n}, x.dtype(), x.place());
paddle::Tensor out_trans =
paddle::empty({k, n}, paddle::DataType::INT8, x.place());
XPUType *x_trans_ptr = const_cast<XPUType *>(
reinterpret_cast<const XPUType *>(x_trans.data<T>()));
int ret = baidu::xpu::api::transpose<XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType *>(x.template data<T>()),
reinterpret_cast<const XPUType *>(x.data<T>()),
x_trans_ptr,
{k, n},
{1, 0});
PD_CHECK(ret == 0);
ret = infer_ops::quant2d_per_token<XPUType, float, int8_t>(
xpu_ctx->x_context(),
x_trans_ptr,
nullptr,
out.data<int8_t>(),
out_trans.data<int8_t>(),
scale.data<float>(),
k,
n);
n,
k);
PD_CHECK(ret == 0);
ret = baidu::xpu::api::transpose<int8_t>(xpu_ctx->x_context(),
out_trans.data<int8_t>(),
out.data<int8_t>(),
{n, k},
{1, 0});
PD_CHECK(ret == 0);
return {out, scale};
} else if (algo == "weight_only_int4") {
@@ -49,30 +67,30 @@ std::vector<paddle::Tensor> WeightQuantizeKernel(const paddle::Tensor &x,
// quant2d_per_token + transpose at now
PD_CHECK(k % 2 == 0);
paddle::Tensor out =
paddle::full({(k + 1) / 2, n}, 0, paddle::DataType::INT8, x.place());
xpu::ctx_guard RAII_GUARD(xpu_ctx->x_context());
XPUType *x_trans = RAII_GUARD.alloc<XPUType>(k * n);
int8_t *out_trans = RAII_GUARD.alloc<int8_t>(k * n / 2);
PD_CHECK(x_trans != nullptr);
PD_CHECK(out_trans != nullptr);
paddle::empty({(k + 1) / 2, n}, paddle::DataType::INT8, x.place());
paddle::Tensor x_trans = paddle::empty({k, n}, x.dtype(), x.place());
paddle::Tensor out_trans =
paddle::empty({(k + 1) / 2, n}, paddle::DataType::INT8, x.place());
XPUType *x_trans_ptr = const_cast<XPUType *>(
reinterpret_cast<const XPUType *>(x_trans.data<T>()));
int ret = baidu::xpu::api::transpose<XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType *>(x.data<T>()),
x_trans,
x_trans_ptr,
{k, n},
{1, 0});
PD_CHECK(ret == 0);
ret = infer_ops::quant2d_per_token<XPUType, float, int4_t>(
xpu_ctx->x_context(),
x_trans,
x_trans_ptr,
nullptr,
reinterpret_cast<int4_t *>(out_trans),
reinterpret_cast<int4_t *>(out_trans.data<int8_t>()),
scale.data<float>(),
n,
k);
PD_CHECK(ret == 0);
ret = baidu::xpu::api::transpose<int8_t>(xpu_ctx->x_context(),
out_trans,
out_trans.data<int8_t>(),
out.data<int8_t>(),
{n, k / 2},
{1, 0});
@@ -217,7 +217,7 @@ int speculate_recover_block(api::Context *ctx,
WRAPPER_CHECK_PTR(ctx, bool, bsz, stop_flags);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int, bsz, ori_seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, bsz, ori_seq_lens_decoder);
WRAPPER_CHECK_PTR_OR_NULL(ctx, int, bsz, ori_seq_lens_decoder);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_decoder);
WRAPPER_CHECK_PTR(ctx, int, bsz *block_num_per_seq, block_tables);
@@ -36,7 +36,6 @@ for i in range(bs):
seq_lens_decoder[i] = i
seq_lens_this_time[i] = 1
input_ids_np = np.random.randint(1, 10, [max_bs, max_input_length], "int64")
stop_nums = np.array([max_bs], "int64")
next_tokens = np.random.randint(1, 10, [max_bs], "int64")
is_block_step = np.random.randint(0, 2, [max_bs]).astype("bool")
@@ -46,7 +45,6 @@ seq_lens_this_time = paddle.to_tensor(seq_lens_this_time)
seq_lens_encoder = paddle.to_tensor(seq_lens_encoder)
seq_lens_decoder = paddle.to_tensor(seq_lens_decoder)
input_ids = paddle.to_tensor(input_ids_np)
stop_nums = paddle.to_tensor(stop_nums)
next_tokens = paddle.to_tensor(next_tokens)
is_block_step = paddle.to_tensor(is_block_step)
@@ -56,7 +54,6 @@ print("seq_lens_this_time:\n", seq_lens_this_time)
print("seq_lens_encoder:\n", seq_lens_encoder)
print("seq_lens_decoder:\n", seq_lens_decoder)
print("input_ids:\n", input_ids)
print("stop_nums:\n", stop_nums)
print("next_tokens:\n", next_tokens)
print("is_block_step:\n", is_block_step)
@@ -67,7 +64,6 @@ update_inputs(
seq_lens_encoder,
seq_lens_decoder,
input_ids,
stop_nums,
next_tokens,
is_block_step,
)
@@ -79,7 +75,6 @@ print("seq_lens_this_time:\n", seq_lens_this_time)
print("seq_lens_encoder:\n", seq_lens_encoder)
print("seq_lens_decoder:\n", seq_lens_decoder)
print("input_ids:\n", input_ids)
print("stop_nums:\n", stop_nums)
print("next_tokens:\n", next_tokens)
ref_not_need_stop_out = np.array([True])