mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Iluvatar] Support wi4a16 group_gemm (#7078)
This commit is contained in:
@@ -15,11 +15,12 @@
|
|||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "iluvatar_context.h"
|
#include "iluvatar_context.h"
|
||||||
|
|
||||||
std::vector<paddle::Tensor> GroupGemm(const paddle::Tensor& x,
|
std::vector<paddle::Tensor> W8A16GroupGemm(const paddle::Tensor& x,
|
||||||
const paddle::Tensor& weight,
|
const paddle::Tensor& weight,
|
||||||
const paddle::Tensor& weight_scale,
|
const paddle::Tensor& weight_scale,
|
||||||
const paddle::Tensor& prefix_sum,
|
const paddle::Tensor& weight_zeros,
|
||||||
const int32_t group_size) {
|
const paddle::Tensor& prefix_sum,
|
||||||
|
const int32_t group_size) {
|
||||||
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||||
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
|
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
|
||||||
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||||
@@ -174,28 +175,22 @@ std::vector<paddle::Tensor> GroupGemm(const paddle::Tensor& x,
|
|||||||
return {output};
|
return {output};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> GroupGemmInferShape(
|
std::vector<std::vector<int64_t>> W8A16GroupGemmInferShape(
|
||||||
const std::vector<int64_t>& x_shape,
|
const std::vector<int64_t>& x_shape,
|
||||||
const std::vector<int64_t>& weight_shape,
|
const std::vector<int64_t>& weight_shape) {
|
||||||
const std::vector<int64_t>& weight_scale_shape,
|
|
||||||
const std::vector<int64_t>& prefix_sum_shape) {
|
|
||||||
return {{x_shape[0], weight_shape[1]}};
|
return {{x_shape[0], weight_shape[1]}};
|
||||||
}
|
}
|
||||||
std::vector<paddle::DataType> GroupGemmInferDtype(
|
std::vector<paddle::DataType> W8A16GroupGemmInferDtype(
|
||||||
const paddle::DataType& input_dtype,
|
const paddle::DataType& input_dtype) {
|
||||||
const paddle::DataType& weight_output_dtype,
|
|
||||||
const paddle::DataType& weight_scale_dtype,
|
|
||||||
const paddle::DataType& prefix_sum_dtype,
|
|
||||||
const int moe_topk) {
|
|
||||||
return {input_dtype};
|
return {input_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(w8a16_group_gemm)
|
PD_BUILD_STATIC_OP(w8a16_group_gemm)
|
||||||
.Inputs({"x", "weight", "weight_scale", "prefix_sum"})
|
.Inputs({"x", "weight", "weight_scale", "weight_zeros", "prefix_sum"})
|
||||||
.Outputs({"output"})
|
.Outputs({"output"})
|
||||||
.Attrs({
|
.Attrs({
|
||||||
"group_size:int",
|
"group_size:int",
|
||||||
})
|
})
|
||||||
.SetKernelFn(PD_KERNEL(GroupGemm))
|
.SetKernelFn(PD_KERNEL(W8A16GroupGemm))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(GroupGemmInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(W8A16GroupGemmInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(GroupGemmInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(W8A16GroupGemmInferDtype));
|
||||||
|
|||||||
@@ -15,18 +15,20 @@
|
|||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "iluvatar_context.h"
|
#include "iluvatar_context.h"
|
||||||
|
|
||||||
std::vector<paddle::Tensor> GroupGemv(const paddle::Tensor& x,
|
std::vector<paddle::Tensor> W8A16GroupGemv(
|
||||||
const paddle::Tensor& weight,
|
const paddle::Tensor& x,
|
||||||
const paddle::Tensor& weight_scale,
|
const paddle::Tensor& weight,
|
||||||
const paddle::Tensor& prefix_sum,
|
const paddle::Tensor& weight_scale,
|
||||||
const int32_t group_size) {
|
const paddle::Tensor& weight_zeros,
|
||||||
|
const paddle::Tensor& tokens_per_expert,
|
||||||
|
const int32_t group_size) {
|
||||||
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||||
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
|
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
|
||||||
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||||
const auto& x_dims = x.dims();
|
const auto& x_dims = x.dims();
|
||||||
const auto& w_dims = weight.dims();
|
const auto& w_dims = weight.dims();
|
||||||
const auto& ws_dims = weight_scale.dims();
|
const auto& ws_dims = weight_scale.dims();
|
||||||
const auto& prefix_sum_dims = prefix_sum.dims();
|
const auto& tokens_per_expert_dims = tokens_per_expert.dims();
|
||||||
// [m, k]
|
// [m, k]
|
||||||
PD_CHECK(x_dims.size() == 2, "x should be 2D");
|
PD_CHECK(x_dims.size() == 2, "x should be 2D");
|
||||||
// [n_experts, n, k]
|
// [n_experts, n, k]
|
||||||
@@ -34,7 +36,8 @@ std::vector<paddle::Tensor> GroupGemv(const paddle::Tensor& x,
|
|||||||
// [n_experts, n]
|
// [n_experts, n]
|
||||||
PD_CHECK(ws_dims.size() == 2, "weight_scale should be 2D");
|
PD_CHECK(ws_dims.size() == 2, "weight_scale should be 2D");
|
||||||
// [n_experts]
|
// [n_experts]
|
||||||
PD_CHECK(prefix_sum_dims.size() == 1, "prefix_sum should be 1D");
|
PD_CHECK(tokens_per_expert_dims.size() == 1,
|
||||||
|
"tokens_per_expert should be 1D");
|
||||||
PD_CHECK(group_size == -1);
|
PD_CHECK(group_size == -1);
|
||||||
auto m = x_dims[0];
|
auto m = x_dims[0];
|
||||||
auto k = x_dims[1];
|
auto k = x_dims[1];
|
||||||
@@ -43,9 +46,9 @@ std::vector<paddle::Tensor> GroupGemv(const paddle::Tensor& x,
|
|||||||
PD_CHECK(w_dims[2] == k);
|
PD_CHECK(w_dims[2] == k);
|
||||||
PD_CHECK(ws_dims[0] == n_experts);
|
PD_CHECK(ws_dims[0] == n_experts);
|
||||||
PD_CHECK(ws_dims[1] == n);
|
PD_CHECK(ws_dims[1] == n);
|
||||||
PD_CHECK(prefix_sum_dims[0] == n_experts);
|
PD_CHECK(tokens_per_expert_dims[0] == n_experts);
|
||||||
|
|
||||||
PD_CHECK(prefix_sum.dtype() == paddle::DataType::INT32);
|
PD_CHECK(tokens_per_expert.dtype() == paddle::DataType::INT32);
|
||||||
PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 ||
|
PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 ||
|
||||||
x.dtype() == paddle::DataType::FLOAT16);
|
x.dtype() == paddle::DataType::FLOAT16);
|
||||||
PD_CHECK(weight.dtype() == paddle::DataType::INT8);
|
PD_CHECK(weight.dtype() == paddle::DataType::INT8);
|
||||||
@@ -87,7 +90,7 @@ std::vector<paddle::Tensor> GroupGemv(const paddle::Tensor& x,
|
|||||||
cust_device_param.sortedId = nullptr;
|
cust_device_param.sortedId = nullptr;
|
||||||
cust_device_param.bias = nullptr;
|
cust_device_param.bias = nullptr;
|
||||||
cust_device_param.scale = weight_scale.data();
|
cust_device_param.scale = weight_scale.data();
|
||||||
cust_device_param.nSize = prefix_sum.data<int32_t>();
|
cust_device_param.nSize = tokens_per_expert.data<int32_t>();
|
||||||
|
|
||||||
int lda = k;
|
int lda = k;
|
||||||
int ldb = k;
|
int ldb = k;
|
||||||
@@ -152,28 +155,23 @@ std::vector<paddle::Tensor> GroupGemv(const paddle::Tensor& x,
|
|||||||
return {output};
|
return {output};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> GroupGemvInferShape(
|
std::vector<std::vector<int64_t>> W8A16GroupGemvInferShape(
|
||||||
const std::vector<int64_t>& x_shape,
|
const std::vector<int64_t>& x_shape,
|
||||||
const std::vector<int64_t>& weight_shape,
|
const std::vector<int64_t>& weight_shape) {
|
||||||
const std::vector<int64_t>& weight_scale_shape,
|
|
||||||
const std::vector<int64_t>& prefix_sum_shape) {
|
|
||||||
return {{x_shape[0], weight_shape[1]}};
|
return {{x_shape[0], weight_shape[1]}};
|
||||||
}
|
}
|
||||||
std::vector<paddle::DataType> GroupGemvInferDtype(
|
std::vector<paddle::DataType> W8A16GroupGemvInferDtype(
|
||||||
const paddle::DataType& input_dtype,
|
const paddle::DataType& input_dtype) {
|
||||||
const paddle::DataType& weight_output_dtype,
|
|
||||||
const paddle::DataType& weight_scale_dtype,
|
|
||||||
const paddle::DataType& prefix_sum_dtype,
|
|
||||||
const int moe_topk) {
|
|
||||||
return {input_dtype};
|
return {input_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(w8a16_group_gemv)
|
PD_BUILD_STATIC_OP(w8a16_group_gemv)
|
||||||
.Inputs({"x", "weight", "weight_scale", "prefix_sum"})
|
.Inputs(
|
||||||
|
{"x", "weight", "weight_scale", "weight_zeros", "tokens_per_expert"})
|
||||||
.Outputs({"output"})
|
.Outputs({"output"})
|
||||||
.Attrs({
|
.Attrs({
|
||||||
"group_size:int",
|
"group_size:int",
|
||||||
})
|
})
|
||||||
.SetKernelFn(PD_KERNEL(GroupGemv))
|
.SetKernelFn(PD_KERNEL(W8A16GroupGemv))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(GroupGemvInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(W8A16GroupGemvInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(GroupGemvInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(W8A16GroupGemvInferDtype));
|
||||||
|
|||||||
@@ -0,0 +1,241 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
#include "iluvatar_context.h"
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> WI4A16GroupGemm(const paddle::Tensor& x,
|
||||||
|
const paddle::Tensor& weight,
|
||||||
|
const paddle::Tensor& weight_scale,
|
||||||
|
const paddle::Tensor& weight_zeros,
|
||||||
|
const paddle::Tensor& prefix_sum,
|
||||||
|
const int32_t group_size) {
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||||
|
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
|
||||||
|
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||||
|
auto prefix_sum_cpu = prefix_sum.copy_to(paddle::CPUPlace(), false);
|
||||||
|
|
||||||
|
const auto& x_dims = x.dims();
|
||||||
|
const auto& w_dims = weight.dims();
|
||||||
|
const auto& ws_dims = weight_scale.dims();
|
||||||
|
const auto& prefix_sum_dims = prefix_sum.dims();
|
||||||
|
const auto& zeros_dims = weight_zeros.dims();
|
||||||
|
// [m, k]
|
||||||
|
PD_CHECK(x_dims.size() == 2, "x should be 2D");
|
||||||
|
// [n_experts, n // 2, k]
|
||||||
|
PD_CHECK(w_dims.size() == 3, "weight should be 3D");
|
||||||
|
// [n_experts, k // group_size, n]
|
||||||
|
PD_CHECK(ws_dims.size() == 3, "weight_scale should be 3D");
|
||||||
|
// [n_experts, k // group_size, n]
|
||||||
|
PD_CHECK(zeros_dims.size() == 3, "weight_zeros should be 3D");
|
||||||
|
// [n_experts]
|
||||||
|
PD_CHECK(prefix_sum_dims.size() == 1, "prefix_sum should be 1D");
|
||||||
|
PD_CHECK(group_size == 128);
|
||||||
|
auto m = x_dims[0];
|
||||||
|
auto k = x_dims[1];
|
||||||
|
auto n_experts = w_dims[0];
|
||||||
|
auto n = w_dims[1] * 2;
|
||||||
|
PD_CHECK(w_dims[2] == k);
|
||||||
|
PD_CHECK(ws_dims[0] == n_experts);
|
||||||
|
PD_CHECK(ws_dims[1] == k / group_size);
|
||||||
|
PD_CHECK(ws_dims[2] == n);
|
||||||
|
PD_CHECK(zeros_dims[0] == n_experts);
|
||||||
|
PD_CHECK(zeros_dims[1] == k / group_size);
|
||||||
|
PD_CHECK(zeros_dims[2] == n);
|
||||||
|
PD_CHECK(prefix_sum_dims[0] == n_experts);
|
||||||
|
|
||||||
|
PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 ||
|
||||||
|
x.dtype() == paddle::DataType::FLOAT16);
|
||||||
|
PD_CHECK(weight.dtype() == paddle::DataType::INT8);
|
||||||
|
PD_CHECK(weight_scale.dtype() == x.dtype());
|
||||||
|
PD_CHECK(weight_zeros.dtype() == x.dtype());
|
||||||
|
PD_CHECK(prefix_sum.dtype() == paddle::DataType::INT64);
|
||||||
|
|
||||||
|
PD_CHECK(x.is_contiguous());
|
||||||
|
PD_CHECK(weight.is_contiguous());
|
||||||
|
PD_CHECK(weight_scale.is_contiguous());
|
||||||
|
PD_CHECK(weight_zeros.is_contiguous());
|
||||||
|
PD_CHECK(prefix_sum.is_contiguous());
|
||||||
|
|
||||||
|
const int64_t* prefix_sum_cpu_ptr = prefix_sum_cpu.data<int64_t>();
|
||||||
|
auto output = GetEmptyTensor({m, n}, x.dtype(), x.place());
|
||||||
|
int16_t* out_data = static_cast<int16_t*>(output.data());
|
||||||
|
const int16_t* x_data = static_cast<const int16_t*>(x.data());
|
||||||
|
const int8_t* weight_data = weight.data<int8_t>();
|
||||||
|
const int16_t* weight_scale_data =
|
||||||
|
static_cast<const int16_t*>(weight_scale.data());
|
||||||
|
const int16_t* weight_zeros_data =
|
||||||
|
static_cast<const int16_t*>(weight_zeros.data());
|
||||||
|
|
||||||
|
cuinferHandle_t handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||||
|
cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST;
|
||||||
|
cuinferOperation_t transa = CUINFER_OP_T;
|
||||||
|
cuinferOperation_t transb = CUINFER_OP_N;
|
||||||
|
cudaDataType_t Atype = CUDA_R_4I;
|
||||||
|
cudaDataType_t Btype;
|
||||||
|
if (x.dtype() == paddle::DataType::FLOAT16) {
|
||||||
|
Btype = CUDA_R_16F;
|
||||||
|
} else if (x.dtype() == paddle::DataType::BFLOAT16) {
|
||||||
|
Btype = CUDA_R_16BF;
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW(common::errors::Unimplemented("Unsupported input dtype."));
|
||||||
|
}
|
||||||
|
cudaDataType_t Ctype = Btype;
|
||||||
|
cudaDataType_t computeType = CUDA_R_32F;
|
||||||
|
cudaDataType_t scaleType = CUDA_R_32F;
|
||||||
|
cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE;
|
||||||
|
|
||||||
|
cuinferQuantGEMMHostParam cust_host_param;
|
||||||
|
cuinferCustomGemmHostParamInit(&cust_host_param);
|
||||||
|
cust_host_param.size = sizeof(cuinferQuantGEMMHostParam);
|
||||||
|
cust_host_param.persistent = 0;
|
||||||
|
cust_host_param.groupSize = group_size;
|
||||||
|
|
||||||
|
cuinferQuantGEMMDeviceParam cust_device_param;
|
||||||
|
cust_device_param.size = sizeof(cuinferQuantGEMMDeviceParam);
|
||||||
|
cust_device_param.bias = nullptr;
|
||||||
|
|
||||||
|
int lda = k;
|
||||||
|
int ldb = k;
|
||||||
|
int ldc = n;
|
||||||
|
float beta = 0.f;
|
||||||
|
float alpha = 1.f;
|
||||||
|
int batch_count = 1;
|
||||||
|
size_t pre = 0;
|
||||||
|
|
||||||
|
auto* allocator = paddle::GetAllocator(x.place());
|
||||||
|
phi::Allocator::AllocationPtr tmp_workspace;
|
||||||
|
for (int i = 0; i < n_experts; i++) {
|
||||||
|
size_t expert_i_end = prefix_sum_cpu_ptr[i];
|
||||||
|
size_t cur_len = expert_i_end - pre;
|
||||||
|
pre = expert_i_end;
|
||||||
|
if (cur_len != 0) {
|
||||||
|
cust_device_param.scale = weight_scale_data;
|
||||||
|
cust_device_param.zero = weight_zeros_data;
|
||||||
|
|
||||||
|
size_t workspace_size = 0;
|
||||||
|
CUINFER_CHECK(cuinferGetCustomGemmWorkspace(transa,
|
||||||
|
transb,
|
||||||
|
n,
|
||||||
|
cur_len,
|
||||||
|
k,
|
||||||
|
Atype,
|
||||||
|
lda,
|
||||||
|
lda,
|
||||||
|
Btype,
|
||||||
|
ldb,
|
||||||
|
ldb,
|
||||||
|
Ctype,
|
||||||
|
ldc,
|
||||||
|
ldc,
|
||||||
|
batch_count,
|
||||||
|
computeType,
|
||||||
|
scaleType,
|
||||||
|
&workspace_size));
|
||||||
|
if (workspace_size > 0) {
|
||||||
|
tmp_workspace = allocator->Allocate(workspace_size);
|
||||||
|
cust_device_param.workspace = tmp_workspace->ptr();
|
||||||
|
} else {
|
||||||
|
cust_device_param.workspace = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cur_len <= 1) {
|
||||||
|
CUINFER_CHECK(cuinferCustomGemmEx(handle,
|
||||||
|
stream,
|
||||||
|
cuinfer_ptr_mode,
|
||||||
|
transa,
|
||||||
|
transb,
|
||||||
|
n,
|
||||||
|
cur_len,
|
||||||
|
k,
|
||||||
|
&alpha,
|
||||||
|
weight_data,
|
||||||
|
Atype,
|
||||||
|
lda,
|
||||||
|
lda,
|
||||||
|
x_data,
|
||||||
|
Btype,
|
||||||
|
ldb,
|
||||||
|
ldb,
|
||||||
|
&beta,
|
||||||
|
out_data,
|
||||||
|
Ctype,
|
||||||
|
ldc,
|
||||||
|
ldc,
|
||||||
|
batch_count,
|
||||||
|
computeType,
|
||||||
|
scaleType,
|
||||||
|
&cust_host_param,
|
||||||
|
&cust_device_param,
|
||||||
|
customOption,
|
||||||
|
cust_device_param.workspace));
|
||||||
|
} else {
|
||||||
|
CUINFER_CHECK(cuinferCustomGemm(handle,
|
||||||
|
stream,
|
||||||
|
cuinfer_ptr_mode,
|
||||||
|
transa,
|
||||||
|
transb,
|
||||||
|
n,
|
||||||
|
cur_len,
|
||||||
|
k,
|
||||||
|
&alpha,
|
||||||
|
weight_data,
|
||||||
|
Atype,
|
||||||
|
lda,
|
||||||
|
lda,
|
||||||
|
x_data,
|
||||||
|
Btype,
|
||||||
|
ldb,
|
||||||
|
ldb,
|
||||||
|
&beta,
|
||||||
|
out_data,
|
||||||
|
Ctype,
|
||||||
|
ldc,
|
||||||
|
ldc,
|
||||||
|
batch_count,
|
||||||
|
computeType,
|
||||||
|
scaleType,
|
||||||
|
&cust_host_param,
|
||||||
|
&cust_device_param,
|
||||||
|
customOption));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
x_data += cur_len * k;
|
||||||
|
weight_data += k * n / 2;
|
||||||
|
weight_scale_data += k * n / group_size;
|
||||||
|
weight_zeros_data += k * n / group_size;
|
||||||
|
out_data += cur_len * n;
|
||||||
|
}
|
||||||
|
return {output};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> WI4A16GroupGemmInferShape(
|
||||||
|
const std::vector<int64_t>& x_shape,
|
||||||
|
const std::vector<int64_t>& weight_shape) {
|
||||||
|
return {{x_shape[0], weight_shape[1] * 2}};
|
||||||
|
}
|
||||||
|
std::vector<paddle::DataType> WI4A16GroupGemmInferDtype(
|
||||||
|
const paddle::DataType& input_dtype) {
|
||||||
|
return {input_dtype};
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(wi4a16_group_gemm)
|
||||||
|
.Inputs({"x", "weight", "weight_scale", "weight_zeros", "prefix_sum"})
|
||||||
|
.Outputs({"output"})
|
||||||
|
.Attrs({
|
||||||
|
"group_size:int",
|
||||||
|
})
|
||||||
|
.SetKernelFn(PD_KERNEL(WI4A16GroupGemm))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(WI4A16GroupGemmInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(WI4A16GroupGemmInferDtype));
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Wi4A16 weight quantization: per-group symmetric int4 with scale = max|w|/7,
|
||||||
|
// packed two int4 per int8 along the output dimension, matching Python
|
||||||
|
// fastdeploy.model_executor.ops.iluvatar.utils.wi4a16_weight_quantize_cuda.
|
||||||
|
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
__device__ __forceinline__ float ToFloat(__half v) { return __half2float(v); }
|
||||||
|
|
||||||
|
__device__ __forceinline__ float ToFloat(__nv_bfloat16 v) {
|
||||||
|
return __bfloat162float(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void WriteScale(__half* scales, int idx, float s) {
|
||||||
|
scales[idx] = __float2half(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void WriteScale(__nv_bfloat16* scales,
|
||||||
|
int idx,
|
||||||
|
float s) {
|
||||||
|
scales[idx] = __float2bfloat16_rn(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Wi4A16QuantizeGroupsKernel(const T* __restrict__ w,
|
||||||
|
int8_t* __restrict__ q,
|
||||||
|
T* __restrict__ scales,
|
||||||
|
int k,
|
||||||
|
int n,
|
||||||
|
int group_size,
|
||||||
|
int num_groups_per_row) {
|
||||||
|
const int gid = blockIdx.x;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
if (tid >= group_size) return;
|
||||||
|
|
||||||
|
const int nn = gid / num_groups_per_row;
|
||||||
|
const int gj = gid % num_groups_per_row;
|
||||||
|
const int kk = gj * group_size + tid;
|
||||||
|
|
||||||
|
extern __shared__ float sdata[];
|
||||||
|
const float v = fabsf(ToFloat(w[kk * n + nn]));
|
||||||
|
sdata[tid] = v;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int s = group_size >> 1; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
__shared__ float s_scale;
|
||||||
|
if (tid == 0) {
|
||||||
|
const float max_abs = sdata[0];
|
||||||
|
s_scale = max_abs / 7.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float scale = s_scale;
|
||||||
|
if (tid == 0) {
|
||||||
|
WriteScale(scales, gj * n + nn, scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float wval = ToFloat(w[kk * n + nn]);
|
||||||
|
float qf = roundf(wval / scale);
|
||||||
|
if (qf > 7.f) qf = 7.f;
|
||||||
|
if (qf < -8.f) qf = -8.f;
|
||||||
|
q[kk * n + nn] = static_cast<int8_t>(static_cast<int>(qf));
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void Wi4A16PackInt4Kernel(const int8_t* __restrict__ q,
|
||||||
|
int8_t* __restrict__ packed,
|
||||||
|
int k,
|
||||||
|
int n) {
|
||||||
|
const int nn = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int kk = blockIdx.y;
|
||||||
|
const int nhalf = n >> 1;
|
||||||
|
if (nn >= nhalf || kk >= k) return;
|
||||||
|
|
||||||
|
const int8_t q0 = q[kk * n + (nn << 1)];
|
||||||
|
const int8_t q1 = q[kk * n + (nn << 1) + 1];
|
||||||
|
const uint32_t b0 = static_cast<uint32_t>(static_cast<uint8_t>(q0)) & 0xFU;
|
||||||
|
const uint32_t b1 = static_cast<uint32_t>(static_cast<uint8_t>(q1)) & 0xFU;
|
||||||
|
packed[nn * k + kk] = static_cast<int8_t>(b0 | (b1 << 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> Wi4A16Quantize(const paddle::Tensor& w,
|
||||||
|
int32_t group_size) {
|
||||||
|
PD_CHECK(w.dims().size() == 2,
|
||||||
|
"wi4a16_weight_quantize: weight must be 2D [k, n]");
|
||||||
|
PD_CHECK(group_size == 128,
|
||||||
|
"wi4a16_weight_quantize CUDA: group_size must be 128");
|
||||||
|
const int64_t k = w.dims()[0];
|
||||||
|
const int64_t n = w.dims()[1];
|
||||||
|
PD_CHECK(n % 2 == 0, "wi4a16_weight_quantize: n (dim 1) must be even");
|
||||||
|
PD_CHECK(k % group_size == 0,
|
||||||
|
"wi4a16_weight_quantize: k must be divisible by group_size");
|
||||||
|
|
||||||
|
PD_CHECK(w.dtype() == paddle::DataType::FLOAT16 ||
|
||||||
|
w.dtype() == paddle::DataType::BFLOAT16,
|
||||||
|
"wi4a16_weight_quantize: weight dtype must be float16 or bfloat16");
|
||||||
|
PD_CHECK(w.is_contiguous(),
|
||||||
|
"wi4a16_weight_quantize: weight must be contiguous");
|
||||||
|
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||||
|
paddle::experimental::DeviceContextPool::Instance().Get(w.place()));
|
||||||
|
auto stream = static_cast<cudaStream_t>(dev_ctx->stream());
|
||||||
|
|
||||||
|
auto packed = GetEmptyTensor({n / 2, k}, paddle::DataType::INT8, w.place());
|
||||||
|
auto scales = GetEmptyTensor({k / group_size, n}, w.dtype(), w.place());
|
||||||
|
auto zeros = GetEmptyTensor({k / group_size, n}, w.dtype(), w.place());
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaMemsetAsync(
|
||||||
|
zeros.data(),
|
||||||
|
0,
|
||||||
|
static_cast<size_t>(zeros.numel()) * phi::SizeOf(zeros.dtype()),
|
||||||
|
stream));
|
||||||
|
|
||||||
|
auto q_tmp = GetEmptyTensor({k, n}, paddle::DataType::INT8, w.place());
|
||||||
|
int8_t* q_ptr = q_tmp.data<int8_t>();
|
||||||
|
int8_t* packed_ptr = packed.data<int8_t>();
|
||||||
|
|
||||||
|
const int num_groups_per_row = static_cast<int>(k / group_size);
|
||||||
|
const int total_groups = static_cast<int>(n * num_groups_per_row);
|
||||||
|
const int threads = group_size;
|
||||||
|
const size_t shmem = static_cast<size_t>(group_size) * sizeof(float);
|
||||||
|
|
||||||
|
if (w.dtype() == paddle::DataType::FLOAT16) {
|
||||||
|
Wi4A16QuantizeGroupsKernel<__half>
|
||||||
|
<<<total_groups, threads, shmem, stream>>>(
|
||||||
|
reinterpret_cast<const __half*>(w.data()),
|
||||||
|
q_ptr,
|
||||||
|
reinterpret_cast<__half*>(scales.data()),
|
||||||
|
static_cast<int>(k),
|
||||||
|
static_cast<int>(n),
|
||||||
|
group_size,
|
||||||
|
num_groups_per_row);
|
||||||
|
} else {
|
||||||
|
Wi4A16QuantizeGroupsKernel<__nv_bfloat16>
|
||||||
|
<<<total_groups, threads, shmem, stream>>>(
|
||||||
|
reinterpret_cast<const __nv_bfloat16*>(w.data()),
|
||||||
|
q_ptr,
|
||||||
|
reinterpret_cast<__nv_bfloat16*>(scales.data()),
|
||||||
|
static_cast<int>(k),
|
||||||
|
static_cast<int>(n),
|
||||||
|
group_size,
|
||||||
|
num_groups_per_row);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int nhalf = static_cast<int>(n >> 1);
|
||||||
|
dim3 block(256);
|
||||||
|
dim3 grid((nhalf + block.x - 1) / block.x, static_cast<unsigned>(k));
|
||||||
|
Wi4A16PackInt4Kernel<<<grid, block, 0, stream>>>(
|
||||||
|
q_ptr, packed_ptr, static_cast<int>(k), static_cast<int>(n));
|
||||||
|
|
||||||
|
return {packed, scales, zeros};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> Wi4A16QuantizeInferShape(
|
||||||
|
const std::vector<int64_t>& w_shape, int32_t group_size) {
|
||||||
|
const int64_t k = w_shape[0];
|
||||||
|
const int64_t n = w_shape[1];
|
||||||
|
const int64_t k_groups = k / group_size;
|
||||||
|
return {{n / 2, k}, {k_groups, n}, {k_groups, n}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::DataType> Wi4A16QuantizeInferDtype(
|
||||||
|
const paddle::DataType& w_dtype, int32_t group_size) {
|
||||||
|
return {paddle::DataType::INT8, w_dtype, w_dtype};
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(wi4a16_weight_quantize_cuda)
|
||||||
|
.Inputs({"w"})
|
||||||
|
.Outputs({"quant_weight", "scales", "zeros"})
|
||||||
|
.Attrs({"group_size: int"})
|
||||||
|
.SetKernelFn(PD_KERNEL(Wi4A16Quantize))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(Wi4A16QuantizeInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(Wi4A16QuantizeInferDtype));
|
||||||
@@ -584,14 +584,13 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
elif paddle.is_compiled_with_xpu():
|
elif paddle.is_compiled_with_xpu():
|
||||||
assert False, "For XPU, please use setup_ops.py in the xpu_ops directory to compile custom ops."
|
assert False, "For XPU, please use setup_ops.py in the xpu_ops directory to compile custom ops."
|
||||||
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||||
|
_iluvatar_clang_cuda_flags = ["-Wno-non-pod-varargs", "-DPADDLE_DEV", "-DPADDLE_WITH_CUSTOM_DEVICE"]
|
||||||
setup(
|
setup(
|
||||||
name="fastdeploy_ops",
|
name="fastdeploy_ops",
|
||||||
ext_modules=CUDAExtension(
|
ext_modules=CUDAExtension(
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
"nvcc": [
|
"cxx": _iluvatar_clang_cuda_flags,
|
||||||
"-DPADDLE_DEV",
|
"nvcc": _iluvatar_clang_cuda_flags,
|
||||||
"-DPADDLE_WITH_CUSTOM_DEVICE",
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
sources=[
|
sources=[
|
||||||
"gpu_ops/save_with_output_msg.cc",
|
"gpu_ops/save_with_output_msg.cc",
|
||||||
@@ -625,6 +624,8 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
|||||||
"iluvatar_ops/mixed_fused_attn.cu",
|
"iluvatar_ops/mixed_fused_attn.cu",
|
||||||
"iluvatar_ops/w8a16_group_gemm.cu",
|
"iluvatar_ops/w8a16_group_gemm.cu",
|
||||||
"iluvatar_ops/w8a16_group_gemv.cu",
|
"iluvatar_ops/w8a16_group_gemv.cu",
|
||||||
|
"iluvatar_ops/wi4a16_group_gemm.cu",
|
||||||
|
"iluvatar_ops/wi4a16_weight_quantize.cu",
|
||||||
"iluvatar_ops/restore_tokens_per_expert.cu",
|
"iluvatar_ops/restore_tokens_per_expert.cu",
|
||||||
"iluvatar_ops/runtime/iluvatar_context.cc",
|
"iluvatar_ops/runtime/iluvatar_context.cc",
|
||||||
"iluvatar_ops/cpp_extensions.cc",
|
"iluvatar_ops/cpp_extensions.cc",
|
||||||
|
|||||||
@@ -34,8 +34,7 @@ Note: Because the 4.3.8 SDK in the image is incompatible with KMD, paddle cannot
|
|||||||
### 3.2 Install paddle
|
### 3.2 Install paddle
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install paddlepaddle==3.4.0.dev20260226 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
|
pip3 install paddlepaddle-iluvatar==3.4.0.dev20260326 -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/
|
||||||
pip3 install paddle-iluvatar-gpu==3.0.0.dev20260226 -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3.3 Install or build FastDeploy
|
### 3.3 Install or build FastDeploy
|
||||||
@@ -59,7 +58,6 @@ script list bellow:
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 run_demo.py
|
python3 run_demo.py
|
||||||
@@ -136,7 +134,6 @@ server:
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 -m fastdeploy.entrypoints.openai.api_server \
|
python3 -m fastdeploy.entrypoints.openai.api_server \
|
||||||
@@ -193,7 +190,6 @@ server:
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 -m fastdeploy.entrypoints.openai.api_server \
|
python3 -m fastdeploy.entrypoints.openai.api_server \
|
||||||
@@ -230,7 +226,6 @@ server:
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 -m fastdeploy.entrypoints.openai.api_server \
|
python3 -m fastdeploy.entrypoints.openai.api_server \
|
||||||
@@ -285,7 +280,6 @@ The script as bellow:
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 run_demo_vl.py
|
python3 run_demo_vl.py
|
||||||
@@ -384,7 +378,6 @@ server:
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 -m fastdeploy.entrypoints.openai.api_server \
|
python3 -m fastdeploy.entrypoints.openai.api_server \
|
||||||
@@ -424,7 +417,6 @@ server:
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 -m fastdeploy.entrypoints.openai.api_server \
|
python3 -m fastdeploy.entrypoints.openai.api_server \
|
||||||
@@ -475,7 +467,6 @@ server:
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
export CUDA_VISIBLE_DEVICES=1
|
export CUDA_VISIBLE_DEVICES=1
|
||||||
@@ -541,3 +532,7 @@ python3 infer_ocr_vl_benchmark.py
|
|||||||
```
|
```
|
||||||
|
|
||||||
After each image is inferred, a corresponding `md` file will be generated in the `output` path. Running the entire benchmark (1355 images) takes approximately 1.8 hours.
|
After each image is inferred, a corresponding `md` file will be generated in the `output` path. Running the entire benchmark (1355 images) takes approximately 1.8 hours.
|
||||||
|
|
||||||
|
## 5. Quantization Format Support
|
||||||
|
- `W8A16`: `--quantization wint8`
|
||||||
|
- `W4A16`: `--quantization wint4`
|
||||||
|
|||||||
@@ -34,8 +34,7 @@ docker exec -it paddle_infer bash
|
|||||||
### 3.2 安装paddle
|
### 3.2 安装paddle
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install paddlepaddle==3.4.0.dev20260226 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
|
pip3 install paddlepaddle-iluvatar==3.4.0.dev20260326 -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/
|
||||||
pip3 install paddle-iluvatar-gpu==3.0.0.dev20260226 -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3.3 安装fastdeploy
|
### 3.3 安装fastdeploy
|
||||||
@@ -59,7 +58,6 @@ bash build.sh
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 run_demo.py
|
python3 run_demo.py
|
||||||
@@ -136,7 +134,6 @@ The largest ocean is the Pacific Ocean, covering an area of approximately â¦
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 -m fastdeploy.entrypoints.openai.api_server \
|
python3 -m fastdeploy.entrypoints.openai.api_server \
|
||||||
@@ -193,7 +190,6 @@ Latency: 1539.625 s
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 -m fastdeploy.entrypoints.openai.api_server \
|
python3 -m fastdeploy.entrypoints.openai.api_server \
|
||||||
@@ -230,7 +226,6 @@ curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 -m fastdeploy.entrypoints.openai.api_server \
|
python3 -m fastdeploy.entrypoints.openai.api_server \
|
||||||
@@ -285,7 +280,6 @@ python3 -u bench_gsm8k.py --port 8180 --num-questions 1319 --num-shots 5 --paral
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 run_demo_vl.py
|
python3 run_demo_vl.py
|
||||||
@@ -384,7 +378,6 @@ generated_text=
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 -m fastdeploy.entrypoints.openai.api_server \
|
python3 -m fastdeploy.entrypoints.openai.api_server \
|
||||||
@@ -424,7 +417,6 @@ curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python3 -m fastdeploy.entrypoints.openai.api_server \
|
python3 -m fastdeploy.entrypoints.openai.api_server \
|
||||||
@@ -475,7 +467,6 @@ pip3 install -e ".[doc-parser]"
|
|||||||
```bash
|
```bash
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
export FD_SAMPLING_CLASS=rejection
|
export FD_SAMPLING_CLASS=rejection
|
||||||
export CUDA_VISIBLE_DEVICES=1
|
export CUDA_VISIBLE_DEVICES=1
|
||||||
@@ -538,3 +529,7 @@ python3 infer_ocr_vl_benchmark.py
|
|||||||
```
|
```
|
||||||
|
|
||||||
每推理完一张图片,会在`output`路径下生成一个对应的`md`文件,跑完整个benchmark(1355张图片)大概需要1.8个小时。
|
每推理完一张图片,会在`output`路径下生成一个对应的`md`文件,跑完整个benchmark(1355张图片)大概需要1.8个小时。
|
||||||
|
|
||||||
|
## 5. 支持的量化策略
|
||||||
|
- `W8A16`: `--quantization wint8`
|
||||||
|
- `W4A16`: `--quantization wint4`
|
||||||
|
|||||||
+47
-16
@@ -30,6 +30,7 @@ from fastdeploy.model_executor.ops.iluvatar import (
|
|||||||
moe_expert_ffn,
|
moe_expert_ffn,
|
||||||
moe_expert_reduce,
|
moe_expert_reduce,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.ops.iluvatar.utils import wi4a16_weight_quantize
|
||||||
from fastdeploy.model_executor.utils import (
|
from fastdeploy.model_executor.utils import (
|
||||||
TensorTracker,
|
TensorTracker,
|
||||||
free_tensor,
|
free_tensor,
|
||||||
@@ -80,8 +81,11 @@ class IluvatarCutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
getattr(layer, self.added_weight_attrs[1]),
|
getattr(layer, self.added_weight_attrs[1]),
|
||||||
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
|
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
|
||||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||||
|
(layer.up_gate_proj_weight_zeros if hasattr(layer, "up_gate_proj_weight_zeros") else None),
|
||||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||||
|
(layer.down_proj_weight_zeros if hasattr(layer, "down_proj_weight_zeros") else None),
|
||||||
self.moe_quant_type,
|
self.moe_quant_type,
|
||||||
|
self.quant_config.group_size,
|
||||||
layer.fd_config.model_config.moe_phase.phase,
|
layer.fd_config.model_config.moe_phase.phase,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -214,7 +218,12 @@ class IluvatarCutlassWeightOnlyMoEMethod(IluvatarCutlassMoEMethod):
|
|||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.moe_quant_type = self.quant_config.algo
|
self.moe_quant_type = self.quant_config.algo
|
||||||
self.pack_num = 1
|
if self.moe_quant_type == "weight_only_int8":
|
||||||
|
self.quant_config.group_size = -1
|
||||||
|
elif self.moe_quant_type == "weight_only_int4":
|
||||||
|
self.quant_config.group_size = 128
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Iluvarar only support wint8 nand wint4 yet.")
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||||
"""
|
"""
|
||||||
@@ -427,24 +436,33 @@ class IluvatarCutlassWeightOnlyMoEMethod(IluvatarCutlassMoEMethod):
|
|||||||
# quantized_weight_name
|
# quantized_weight_name
|
||||||
weight_name = self.added_weight_attrs[weight_idx]
|
weight_name = self.added_weight_attrs[weight_idx]
|
||||||
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
|
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
|
||||||
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
|
|
||||||
weight_dtype = "int8"
|
|
||||||
# scale
|
# scale
|
||||||
scale_name = self.added_scale_attrs[weight_idx]
|
scale_name = self.added_scale_attrs[weight_idx]
|
||||||
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
|
|
||||||
scale_dtype = self.default_dtype
|
if self.moe_quant_type == "weight_only_int4":
|
||||||
|
# zeros for int4
|
||||||
|
zeros = []
|
||||||
|
zeros_name = scale_name.replace("weight_scale", "weight_zeros")
|
||||||
|
|
||||||
# 2.crate tmp tensor
|
# 2.crate tmp tensor
|
||||||
|
weight, scale = [], []
|
||||||
weight = paddle.empty(weight_shape, dtype=weight_dtype)
|
|
||||||
scale = paddle.empty(scale_shape, dtype=scale_dtype)
|
|
||||||
|
|
||||||
# 3.quantize weight
|
# 3.quantize weight
|
||||||
|
|
||||||
for expert_id in range(layer.num_local_experts):
|
for expert_id in range(layer.num_local_experts):
|
||||||
weight[expert_id], scale[expert_id] = weight_quantize(
|
unquantized_weight = getattr(layer, unquantized_weight_name)[expert_id]
|
||||||
getattr(layer, unquantized_weight_name)[expert_id], algo=self.moe_quant_type
|
if self.moe_quant_type == "weight_only_int8":
|
||||||
)
|
w, s = weight_quantize(unquantized_weight, algo=self.moe_quant_type)
|
||||||
|
else:
|
||||||
|
w, s, z = wi4a16_weight_quantize(unquantized_weight)
|
||||||
|
zeros.append(z)
|
||||||
|
weight.append(w)
|
||||||
|
scale.append(s)
|
||||||
|
|
||||||
|
weight = paddle.stack(weight, axis=0)
|
||||||
|
scale = paddle.stack(scale, axis=0)
|
||||||
|
if self.moe_quant_type == "weight_only_int4":
|
||||||
|
zeros = paddle.stack(zeros, axis=0)
|
||||||
|
|
||||||
free_tensor(getattr(layer, unquantized_weight_name))
|
free_tensor(getattr(layer, unquantized_weight_name))
|
||||||
|
|
||||||
@@ -453,8 +471,8 @@ class IluvatarCutlassWeightOnlyMoEMethod(IluvatarCutlassMoEMethod):
|
|||||||
layer,
|
layer,
|
||||||
weight_name,
|
weight_name,
|
||||||
layer.create_parameter(
|
layer.create_parameter(
|
||||||
shape=weight_shape,
|
shape=weight.shape,
|
||||||
dtype=weight_dtype,
|
dtype=weight.dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -463,14 +481,27 @@ class IluvatarCutlassWeightOnlyMoEMethod(IluvatarCutlassMoEMethod):
|
|||||||
layer,
|
layer,
|
||||||
scale_name,
|
scale_name,
|
||||||
layer.create_parameter(
|
layer.create_parameter(
|
||||||
shape=scale_shape,
|
shape=scale.shape,
|
||||||
dtype=scale_dtype,
|
dtype=scale.dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
getattr(layer, weight_name).copy_(weight, False)
|
getattr(layer, weight_name).copy_(weight, False)
|
||||||
getattr(layer, scale_name).copy_(scale, False)
|
getattr(layer, scale_name).copy_(scale, False)
|
||||||
|
|
||||||
|
if self.moe_quant_type == "weight_only_int4":
|
||||||
|
# create zeros
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
zeros_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=zeros.shape,
|
||||||
|
dtype=zeros.dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
getattr(layer, zeros_name).copy_(zeros, False)
|
||||||
|
|
||||||
if self.quant_config.is_checkpoint_bf16:
|
if self.quant_config.is_checkpoint_bf16:
|
||||||
weight_id_map = {"gate_up": 0, "down": 1}
|
weight_id_map = {"gate_up": 0, "down": 1}
|
||||||
if weight_fully_copied(layer.up_gate_proj_weight):
|
if weight_fully_copied(layer.up_gate_proj_weight):
|
||||||
|
|||||||
@@ -18,76 +18,35 @@ from typing import Optional
|
|||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.nn.functional import swiglu
|
from paddle.nn.functional import swiglu
|
||||||
from paddle.nn.quant import weight_only_linear
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from fastdeploy.model_executor.ops.iluvatar import (
|
from fastdeploy.model_executor.ops.iluvatar import (
|
||||||
restore_tokens_per_expert,
|
restore_tokens_per_expert,
|
||||||
w8a16_group_gemm,
|
w8a16_group_gemm,
|
||||||
w8a16_group_gemv,
|
w8a16_group_gemv,
|
||||||
|
wi4a16_group_gemm,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except:
|
||||||
w8a16_group_gemm = None
|
w8a16_group_gemm = None
|
||||||
w8a16_group_gemv = None
|
w8a16_group_gemv = None
|
||||||
|
wi4a16_group_gemm = None
|
||||||
restore_tokens_per_expert = None
|
restore_tokens_per_expert = None
|
||||||
|
|
||||||
|
|
||||||
def group_gemm(
|
def _pre_process_expert_ffn(
|
||||||
input: paddle.Tensor,
|
moe_phase: str,
|
||||||
|
quant_method: str,
|
||||||
tokens_expert_prefix_sum: paddle.Tensor,
|
tokens_expert_prefix_sum: paddle.Tensor,
|
||||||
weight: paddle.Tensor,
|
|
||||||
scale: paddle.Tensor,
|
|
||||||
output: paddle.Tensor,
|
|
||||||
):
|
):
|
||||||
assert (
|
if quant_method == "weight_only_int8":
|
||||||
input.dim() == 2
|
if moe_phase == "decode":
|
||||||
and tokens_expert_prefix_sum.dim() == 1
|
group_gemm_func = w8a16_group_gemv
|
||||||
and weight.dim() == 3
|
tokens_per_expert = restore_tokens_per_expert(tokens_expert_prefix_sum).to("int32")
|
||||||
and scale.dim() == 2
|
else:
|
||||||
and output.dim() == 2
|
group_gemm_func = w8a16_group_gemm
|
||||||
)
|
tokens_per_expert = tokens_expert_prefix_sum
|
||||||
num_tokens = input.shape[0]
|
|
||||||
dim_in = input.shape[1]
|
|
||||||
dim_out = weight.shape[1]
|
|
||||||
num_experts = weight.shape[0]
|
|
||||||
|
|
||||||
# check shape
|
|
||||||
assert tokens_expert_prefix_sum.shape == [
|
|
||||||
num_experts,
|
|
||||||
]
|
|
||||||
assert weight.shape == [num_experts, dim_out, dim_in]
|
|
||||||
assert scale.shape == [num_experts, dim_out]
|
|
||||||
assert output.shape == [num_tokens, dim_out]
|
|
||||||
|
|
||||||
# check dtype
|
|
||||||
assert input.dtype in (paddle.float16, paddle.bfloat16)
|
|
||||||
assert scale.dtype == input.dtype and output.dtype == input.dtype
|
|
||||||
assert tokens_expert_prefix_sum.dtype == paddle.int64
|
|
||||||
assert weight.dtype == paddle.int8
|
|
||||||
|
|
||||||
# check others
|
|
||||||
assert tokens_expert_prefix_sum.place.is_cpu_place()
|
|
||||||
assert tokens_expert_prefix_sum[-1] == num_tokens
|
|
||||||
for i in range(num_experts):
|
|
||||||
expert_start = 0 if i == 0 else tokens_expert_prefix_sum[i - 1]
|
|
||||||
expert_end = tokens_expert_prefix_sum[i]
|
|
||||||
if expert_start == expert_end:
|
|
||||||
continue
|
|
||||||
input_i = input[expert_start:expert_end]
|
|
||||||
weight_i = weight[i]
|
|
||||||
scale_i = scale[i]
|
|
||||||
# avoid d2d?
|
|
||||||
output[expert_start:expert_end] = weight_only_linear(
|
|
||||||
input_i, weight_i, weight_scale=scale_i, weight_dtype="int8", group_size=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _pre_process_expert_ffn(moe_phase: str, tokens_expert_prefix_sum: paddle.Tensor):
|
|
||||||
if moe_phase == "decode":
|
|
||||||
group_gemm_func = w8a16_group_gemv
|
|
||||||
tokens_per_expert = restore_tokens_per_expert(tokens_expert_prefix_sum).to("int32")
|
|
||||||
else:
|
else:
|
||||||
group_gemm_func = w8a16_group_gemm
|
group_gemm_func = wi4a16_group_gemm
|
||||||
tokens_per_expert = tokens_expert_prefix_sum
|
tokens_per_expert = tokens_expert_prefix_sum
|
||||||
return group_gemm_func, tokens_per_expert
|
return group_gemm_func, tokens_per_expert
|
||||||
|
|
||||||
@@ -99,16 +58,25 @@ def iluvatar_moe_expert_ffn(
|
|||||||
down_proj_weight: paddle.Tensor,
|
down_proj_weight: paddle.Tensor,
|
||||||
up_gate_proj_bias: Optional[paddle.Tensor],
|
up_gate_proj_bias: Optional[paddle.Tensor],
|
||||||
up_gate_proj_scale: Optional[paddle.Tensor],
|
up_gate_proj_scale: Optional[paddle.Tensor],
|
||||||
|
up_gate_proj_zeros: Optional[paddle.Tensor],
|
||||||
down_proj_scale: Optional[paddle.Tensor],
|
down_proj_scale: Optional[paddle.Tensor],
|
||||||
|
down_proj_zeros: Optional[paddle.Tensor],
|
||||||
quant_method: str,
|
quant_method: str,
|
||||||
|
group_size: int,
|
||||||
moe_phase: str,
|
moe_phase: str,
|
||||||
):
|
):
|
||||||
assert up_gate_proj_bias is None
|
assert up_gate_proj_bias is None
|
||||||
assert up_gate_proj_scale is not None
|
assert up_gate_proj_scale is not None
|
||||||
assert down_proj_scale is not None
|
assert down_proj_scale is not None
|
||||||
assert quant_method in ("weight_only_int8")
|
if quant_method == "weight_only_int4":
|
||||||
group_gemm_func, tokens_per_expert = _pre_process_expert_ffn(moe_phase, tokens_expert_prefix_sum)
|
assert up_gate_proj_zeros is not None
|
||||||
ffn1_output = group_gemm_func(permute_input, up_gate_proj_weight, up_gate_proj_scale, tokens_per_expert, -1)
|
assert down_proj_zeros is not None
|
||||||
|
group_gemm_func, tokens_per_expert = _pre_process_expert_ffn(moe_phase, quant_method, tokens_expert_prefix_sum)
|
||||||
|
ffn1_output = group_gemm_func(
|
||||||
|
permute_input, up_gate_proj_weight, up_gate_proj_scale, up_gate_proj_zeros, tokens_per_expert, group_size
|
||||||
|
)
|
||||||
act_out = swiglu(ffn1_output)
|
act_out = swiglu(ffn1_output)
|
||||||
output = group_gemm_func(act_out, down_proj_weight, down_proj_scale, tokens_per_expert, -1)
|
output = group_gemm_func(
|
||||||
|
act_out, down_proj_weight, down_proj_scale, down_proj_zeros, tokens_per_expert, group_size
|
||||||
|
)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
import paddle
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastdeploy.model_executor.ops.iluvatar import wi4a16_weight_quantize_cuda
|
||||||
|
except:
|
||||||
|
wi4a16_weight_quantize_cuda = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_weight_by_group_size(w, group_size):
|
||||||
|
assert w.dim() == 2
|
||||||
|
assert group_size in (-1, 32, 64, 128)
|
||||||
|
if group_size == -1:
|
||||||
|
quant_weight = w
|
||||||
|
else:
|
||||||
|
assert w.shape[-1] % group_size == 0
|
||||||
|
quant_weight = w.reshape(-1, group_size)
|
||||||
|
assert paddle.isnan(quant_weight).sum() == 0
|
||||||
|
return quant_weight
|
||||||
|
|
||||||
|
|
||||||
|
def _pack_int4_to_int8(weight):
|
||||||
|
return ((weight[:, 1::2] & 0xF) << 4) | (weight[:, 0::2] & 0xF)
|
||||||
|
|
||||||
|
|
||||||
|
def wi4a16_weight_quantize(w, group_size=128):
|
||||||
|
"""Quantize [k, n] weight to packed int4, scales, zeros (MoE wi4a16)."""
|
||||||
|
k, n = w.shape
|
||||||
|
assert k % group_size == 0 and n % 2 == 0
|
||||||
|
if wi4a16_weight_quantize_cuda is not None:
|
||||||
|
return wi4a16_weight_quantize_cuda(w.contiguous(), group_size)
|
||||||
|
else:
|
||||||
|
# [k, n] -> [n, k]
|
||||||
|
w = w.T.contiguous()
|
||||||
|
quant_weight = _get_weight_by_group_size(w, group_size)
|
||||||
|
|
||||||
|
wmax = quant_weight.abs().max(axis=1, keepdim=True)
|
||||||
|
scales = wmax / 7
|
||||||
|
out = paddle.round(quant_weight.to(paddle.float32) / scales).clamp(-8, 7).to(paddle.int8)
|
||||||
|
|
||||||
|
out = _pack_int4_to_int8(
|
||||||
|
# NOTE: conver to numpy since paddle cannot support &
|
||||||
|
out.view(w.shape[0], -1)
|
||||||
|
.T.contiguous()
|
||||||
|
.cpu()
|
||||||
|
.numpy(),
|
||||||
|
)
|
||||||
|
out = paddle.from_numpy(out).T.contiguous()
|
||||||
|
|
||||||
|
scales = scales.view(w.shape[0], -1).T.contiguous()
|
||||||
|
zeros = paddle.zeros_like(scales)
|
||||||
|
return out, scales, zeros
|
||||||
@@ -66,6 +66,9 @@ class IluvatarModelRunner(GPUModelRunner):
|
|||||||
not self.cache_config.enable_chunked_prefill
|
not self.cache_config.enable_chunked_prefill
|
||||||
), "Iluvatar does not support chunked prefill for VL model"
|
), "Iluvatar does not support chunked prefill for VL model"
|
||||||
|
|
||||||
|
if hasattr(self.quant_config, "moe_quant_type") and self.quant_config.moe_quant_type == "wint4":
|
||||||
|
assert not self.use_cudagraph, "Iluvatar does not support cuda graph for weight_only_int4"
|
||||||
|
|
||||||
print(f"self.use_cudagraph={self.use_cudagraph}")
|
print(f"self.use_cudagraph={self.use_cudagraph}")
|
||||||
# VL neox style = True
|
# VL neox style = True
|
||||||
emb_shape = self.share_inputs["rope_emb"].shape
|
emb_shape = self.share_inputs["rope_emb"].shape
|
||||||
|
|||||||
@@ -34,8 +34,7 @@ function pip_install_with_retry() {
|
|||||||
echo "pip requirements"
|
echo "pip requirements"
|
||||||
pip_install_with_retry -r requirements_iluvatar.txt
|
pip_install_with_retry -r requirements_iluvatar.txt
|
||||||
echo "install paddle cpu and custom device"
|
echo "install paddle cpu and custom device"
|
||||||
pip_install_with_retry --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
|
pip_install_with_retry --pre paddlepaddle-iluvatar -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/
|
||||||
pip_install_with_retry --pre paddle-iluvatar-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/
|
|
||||||
|
|
||||||
echo "Run paddle.utils.run_check()"
|
echo "Run paddle.utils.run_check()"
|
||||||
python -c "import paddle; paddle.utils.run_check()"
|
python -c "import paddle; paddle.utils.run_check()"
|
||||||
@@ -101,6 +100,7 @@ export FD_SAMPLING_CLASS=rejection
|
|||||||
offline_ci_list=(
|
offline_ci_list=(
|
||||||
${CI_PATH}/run_ernie_21b.py
|
${CI_PATH}/run_ernie_21b.py
|
||||||
${CI_PATH}/run_ernie_vl_28B.py
|
${CI_PATH}/run_ernie_vl_28B.py
|
||||||
|
${CI_PATH}/run_ernie_vl_28B_wint4.py
|
||||||
)
|
)
|
||||||
echo "test offline ci files: ${offline_ci_list[@]}"
|
echo "test offline ci files: ${offline_ci_list[@]}"
|
||||||
for cur_test_file in ${offline_ci_list[@]}
|
for cur_test_file in ${offline_ci_list[@]}
|
||||||
|
|||||||
@@ -0,0 +1,102 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from fastdeploy import LLM, SamplingParams
|
||||||
|
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||||
|
from fastdeploy.utils import set_random_seed
|
||||||
|
|
||||||
|
tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
sys.path.insert(0, tests_dir)
|
||||||
|
|
||||||
|
from ci_use.iluvatar_UT.utils import timeout
|
||||||
|
|
||||||
|
|
||||||
|
@timeout(240)
|
||||||
|
def offline_infer_check():
|
||||||
|
set_random_seed(123)
|
||||||
|
|
||||||
|
PATH = "/model_data/ERNIE-4.5-VL-28B-A3B-Paddle"
|
||||||
|
tokenizer = Ernie4_5Tokenizer.from_pretrained(PATH)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "图中的文物属于哪个年代"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
images, videos = [], []
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
if not isinstance(content, list):
|
||||||
|
continue
|
||||||
|
for part in content:
|
||||||
|
if part["type"] == "image_url":
|
||||||
|
url = part["image_url"]["url"]
|
||||||
|
image_bytes = requests.get(url).content
|
||||||
|
img = Image.open(io.BytesIO(image_bytes))
|
||||||
|
images.append(img)
|
||||||
|
elif part["type"] == "video_url":
|
||||||
|
url = part["video_url"]["url"]
|
||||||
|
video_bytes = requests.get(url).content
|
||||||
|
videos.append({"video": video_bytes, "max_frames": 30})
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.1, max_tokens=128)
|
||||||
|
graph_optimization_config = {"use_cudagraph": False}
|
||||||
|
llm = LLM(
|
||||||
|
model=PATH,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
max_model_len=32768,
|
||||||
|
block_size=16,
|
||||||
|
quantization="wint4",
|
||||||
|
limit_mm_per_prompt={"image": 100},
|
||||||
|
reasoning_parser="ernie-45-vl",
|
||||||
|
graph_optimization_config=graph_optimization_config,
|
||||||
|
)
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts={"prompt": prompt, "multimodal_data": {"image": images, "video": videos}},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
generated_text = output.outputs.text
|
||||||
|
print(f"generated_text={generated_text}")
|
||||||
|
assert any(keyword in generated_text for keyword in ["北魏", "北齐", "释迦牟尼", "北朝"])
|
||||||
|
|
||||||
|
print("PASSED")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
result = offline_infer_check()
|
||||||
|
sys.exit(0)
|
||||||
|
except TimeoutError:
|
||||||
|
sys.exit(124)
|
||||||
|
except Exception:
|
||||||
|
sys.exit(1)
|
||||||
Reference in New Issue
Block a user