add more ppdet models

This commit is contained in:
jiangjiajun
2022-08-03 08:15:23 +00:00
parent 379c58cae3
commit 56bdbe5b0d
23 changed files with 883 additions and 63 deletions
@@ -0,0 +1,141 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/common/processors/pad_to_size.h"
namespace fastdeploy {
namespace vision {
bool PadToSize::CpuRun(Mat* mat) {
if (mat->layout != Layout::HWC) {
FDERROR << "PadToSize: The input data must be Layout::HWC format!"
<< std::endl;
return false;
}
if (mat->Channels() > 4) {
FDERROR << "PadToSize: Only support channels <= 4." << std::endl;
return false;
}
if (mat->Channels() != value_.size()) {
FDERROR
<< "PadToSize: Require input channels equals to size of padding value, "
"but now channels = "
<< mat->Channels() << ", the size of padding values = " << value_.size()
<< "." << std::endl;
return false;
}
int origin_w = mat->Width();
int origin_h = mat->Height();
if (origin_w > width_) {
FDERROR << "PadToSize: the input width:" << origin_w
<< " is greater than the target width: " << width_ << "."
<< std::endl;
return false;
}
if (origin_h > height_) {
FDERROR << "PadToSize: the input height:" << origin_h
<< " is greater than the target height: " << height_ << "."
<< std::endl;
return false;
}
if (origin_w == width_ && origin_h == height_) {
return true;
}
cv::Mat* im = mat->GetCpuMat();
cv::Scalar value;
if (value_.size() == 1) {
value = cv::Scalar(value_[0]);
} else if (value_.size() == 2) {
value = cv::Scalar(value_[0], value_[1]);
} else if (value_.size() == 3) {
value = cv::Scalar(value_[0], value_[1], value_[2]);
} else {
value = cv::Scalar(value_[0], value_[1], value_[2], value_[3]);
}
// top, bottom, left, right
cv::copyMakeBorder(*im, *im, 0, height_ - origin_h, 0, width_ - origin_w,
cv::BORDER_CONSTANT, value);
mat->SetHeight(height_);
mat->SetWidth(width_);
return true;
}
#ifdef ENABLE_OPENCV_CUDA
bool PadToSize::GpuRun(Mat* mat) {
if (mat->layout != Layout::HWC) {
FDERROR << "PadToSize: The input data must be Layout::HWC format!"
<< std::endl;
return false;
}
if (mat->Channels() > 4) {
FDERROR << "PadToSize: Only support channels <= 4." << std::endl;
return false;
}
if (mat->Channels() != value_.size()) {
FDERROR
<< "PadToSize: Require input channels equals to size of padding value, "
"but now channels = "
<< mat->Channels() << ", the size of padding values = " << value_.size()
<< "." << std::endl;
return false;
}
int origin_w = mat->Width();
int origin_h = mat->Height();
if (origin_w > width_) {
FDERROR << "PadToSize: the input width:" << origin_w
<< " is greater than the target width: " << width_ << "."
<< std::endl;
return false;
}
if (origin_h > height_) {
FDERROR << "PadToSize: the input height:" << origin_h
<< " is greater than the target height: " << height_ << "."
<< std::endl;
return false;
}
if (origin_w == width_ && origin_h == height_) {
return true;
}
cv::cuda::GpuMat* im = mat->GetGpuMat();
cv::Scalar value;
if (value_.size() == 1) {
value = cv::Scalar(value_[0]);
} else if (value_.size() == 2) {
value = cv::Scalar(value_[0], value_[1]);
} else if (value_.size() == 3) {
value = cv::Scalar(value_[0], value_[1], value_[2]);
} else {
value = cv::Scalar(value_[0], value_[1], value_[2], value_[3]);
}
// top, bottom, left, right
cv::cuda::copyMakeBorder(*im, *im, 0, height_ - origin_h, 0,
width_ - origin_w, cv::BORDER_CONSTANT, value);
mat->SetHeight(height_);
mat->SetWidth(width_);
return true;
}
#endif
bool PadToSize::Run(Mat* mat, int width, int height,
const std::vector<float>& value, ProcLib lib) {
auto p = PadToSize(width, height, value);
return p(mat, lib);
}
} // namespace vision
} // namespace fastdeploy
@@ -0,0 +1,46 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/base.h"
namespace fastdeploy {
namespace vision {
class PadToSize : public Processor {
public:
// only support pad with left-top padding mode
PadToSize(int width, int height, const std::vector<float>& value) {
width_ = width;
height_ = height;
value_ = value;
}
bool CpuRun(Mat* mat);
#ifdef ENABLE_OPENCV_CUDA
bool GpuRun(Mat* mat);
#endif
std::string Name() { return "PadToSize"; }
static bool Run(Mat* mat, int width, int height,
const std::vector<float>& value,
ProcLib lib = ProcLib::OPENCV_CPU);
private:
int width_;
int height_;
std::vector<float> value_;
};
} // namespace vision
} // namespace fastdeploy
@@ -0,0 +1,124 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/common/processors/stride_pad.h"
namespace fastdeploy {
namespace vision {
bool StridePad::CpuRun(Mat* mat) {
if (mat->layout != Layout::HWC) {
FDERROR << "StridePad: The input data must be Layout::HWC format!"
<< std::endl;
return false;
}
if (mat->Channels() > 4) {
FDERROR << "StridePad: Only support channels <= 4." << std::endl;
return false;
}
if (mat->Channels() != value_.size()) {
FDERROR
<< "StridePad: Require input channels equals to size of padding value, "
"but now channels = "
<< mat->Channels() << ", the size of padding values = " << value_.size()
<< "." << std::endl;
return false;
}
int origin_w = mat->Width();
int origin_h = mat->Height();
int pad_h = (mat->Height() / stride_) * stride_ +
(mat->Height() % stride_ != 0) * stride_ - mat->Height();
int pad_w = (mat->Width() / stride_) * stride_ +
(mat->Width() % stride_ != 0) * stride_ - mat->Width();
if (pad_h == 0 && pad_w == 0) {
return true;
}
cv::Mat* im = mat->GetCpuMat();
cv::Scalar value;
if (value_.size() == 1) {
value = cv::Scalar(value_[0]);
} else if (value_.size() == 2) {
value = cv::Scalar(value_[0], value_[1]);
} else if (value_.size() == 3) {
value = cv::Scalar(value_[0], value_[1], value_[2]);
} else {
value = cv::Scalar(value_[0], value_[1], value_[2], value_[3]);
}
// top, bottom, left, right
cv::copyMakeBorder(*im, *im, 0, pad_h, 0, pad_w, cv::BORDER_CONSTANT, value);
mat->SetHeight(origin_h + pad_h);
mat->SetWidth(origin_w + pad_w);
return true;
}
#ifdef ENABLE_OPENCV_CUDA
bool StridePad::GpuRun(Mat* mat) {
if (mat->layout != Layout::HWC) {
FDERROR << "StridePad: The input data must be Layout::HWC format!"
<< std::endl;
return false;
}
if (mat->Channels() > 4) {
FDERROR << "StridePad: Only support channels <= 4." << std::endl;
return false;
}
if (mat->Channels() != value_.size()) {
FDERROR
<< "StridePad: Require input channels equals to size of padding value, "
"but now channels = "
<< mat->Channels() << ", the size of padding values = " << value_.size()
<< "." << std::endl;
return false;
}
int origin_w = mat->Width();
int origin_h = mat->Height();
int pad_h = (mat->Height() / stride_) * stride_ +
(mat->Height() % stride_ != 0) * stride_;
int pad_w = (mat->Width() / stride_) * stride_ +
(mat->Width() % stride_ != 0) * stride_;
if (pad_h == 0 && pad_w == 0) {
return true;
}
cv::cuda::GpuMat* im = mat->GetGpuMat();
cv::Scalar value;
if (value_.size() == 1) {
value = cv::Scalar(value_[0]);
} else if (value_.size() == 2) {
value = cv::Scalar(value_[0], value_[1]);
} else if (value_.size() == 3) {
value = cv::Scalar(value_[0], value_[1], value_[2]);
} else {
value = cv::Scalar(value_[0], value_[1], value_[2], value_[3]);
}
// top, bottom, left, right
cv::cuda::copyMakeBorder(*im, *im, 0, pad_h, 0, pad_w, cv::BORDER_CONSTANT,
value);
mat->SetHeight(origin_h + pad_h);
mat->SetWidth(origin_w + pad_w);
return true;
}
#endif
bool StridePad::Run(Mat* mat, int stride, const std::vector<float>& value,
ProcLib lib) {
auto p = StridePad(stride, value);
return p(mat, lib);
}
} // namespace vision
} // namespace fastdeploy
@@ -0,0 +1,44 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/base.h"
namespace fastdeploy {
namespace vision {
class StridePad : public Processor {
public:
// only support pad with left-top padding mode
StridePad(int stride, const std::vector<float>& value) {
stride_ = stride;
value_ = value;
}
bool CpuRun(Mat* mat);
#ifdef ENABLE_OPENCV_CUDA
bool GpuRun(Mat* mat);
#endif
std::string Name() { return "StridePad"; }
static bool Run(Mat* mat, int stride,
const std::vector<float>& value = std::vector<float>(),
ProcLib lib = ProcLib::OPENCV_CPU);
private:
int stride_ = 32;
std::vector<float> value_;
};
} // namespace vision
} // namespace fastdeploy
@@ -21,5 +21,7 @@
#include "fastdeploy/vision/common/processors/hwc2chw.h"
#include "fastdeploy/vision/common/processors/normalize.h"
#include "fastdeploy/vision/common/processors/pad.h"
#include "fastdeploy/vision/common/processors/pad_to_size.h"
#include "fastdeploy/vision/common/processors/resize.h"
#include "fastdeploy/vision/common/processors/resize_by_short.h"
#include "fastdeploy/vision/common/processors/stride_pad.h"
@@ -0,0 +1,78 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ppdet/ppyoloe.h"
bool BuildPreprocessPipelineFromConfig(
std::vector<std::shared_ptr<Processor>>* processors,
const std::string& config_file) {
processors->clear();
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false;
}
processors->push_back(std::make_shared<BGR2RGB>());
for (const auto& op : cfg["Preprocess"]) {
std::string op_name = op["type"].as<std::string>();
if (op_name == "NormalizeImage") {
auto mean = op["mean"].as<std::vector<float>>();
auto std = op["std"].as<std::vector<float>>();
bool is_scale = op["is_scale"].as<bool>();
processors->push_back(std::make_shared<Normalize>(mean, std, is_scale));
} else if (op_name == "Resize") {
bool keep_ratio = op["keep_ratio"].as<bool>();
auto target_size = op["target_size"].as<std::vector<int>>();
int interp = op["interp"].as<int>();
FDASSERT(target_size.size(),
"Require size of target_size be 2, but now it's " +
std::to_string(target_size.size()) + ".");
if (!keep_ratio) {
int width = target_size[1];
int height = target_size[0];
processors->push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
} else {
int min_target_size = std::min(target_size[0], target_size[1]);
int max_target_size = std::max(target_size[0], target_size[1]);
processors->push_back(std::make_shared<ResizeByShort>(
min_target_size, interp, true, max_target_size));
}
} else if (op_name == "Permute") {
// Do nothing, do permute as the last operation
continue;
} else if (op_name == "Pad") {
auto size = op["size"].as<std::vector<int>>();
auto value = op["fill_value"].as<std::vector<float>>();
processors->push_back(std::make_shared<Cast>("float"));
processors->push_back(
std::make_shared<PadToSize>(size[1], size[0], value));
} else if (op_name == "PadStride") {
auto stride = op["stride"].as<int>();
processors->push_back(
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
} else {
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
processors->push_back(std::make_shared<HWC2CHW>());
return true;
}
@@ -1,25 +0,0 @@
#include "fastdeploy/vision/ppdet/centernet.h"
namespace fastdeploy {
namespace vision {
namespace ppdet {
CenterNet::CenterNet(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const Frontend& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::PDINFER};
valid_gpu_backends = {Backend::PDINFER};
has_nms_ = true;
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
} // namespace ppdet
} // namespace vision
} // namespace fastdeploy
-19
View File
@@ -1,19 +0,0 @@
#pragma once
#include "fastdeploy/vision/ppdet/ppyolo.h"
namespace fastdeploy {
namespace vision {
namespace ppdet {
class FASTDEPLOY_DECL CenterNet : public PPYOLO {
public:
CenterNet(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::PADDLE);
virtual std::string ModelName() const { return "PaddleDetection/CenterNet"; }
};
} // namespace ppdet
} // namespace vision
} // namespace fastdeploy
+16
View File
@@ -1,6 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/ppdet/centernet.h"
#include "fastdeploy/vision/ppdet/picodet.h"
#include "fastdeploy/vision/ppdet/ppyolo.h"
#include "fastdeploy/vision/ppdet/ppyoloe.h"
#include "fastdeploy/vision/ppdet/rcnn.h"
#include "fastdeploy/vision/ppdet/yolov3.h"
#include "fastdeploy/vision/ppdet/yolox.h"
+14
View File
@@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ppdet/picodet.h"
#include "yaml-cpp/yaml.h"
+14
View File
@@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/ppdet/ppyoloe.h"
@@ -27,5 +27,21 @@ void BindPPDet(pybind11::module& m) {
self.Predict(&mat, &res);
return res;
});
pybind11::class_<vision::ppdet::PPYOLO, vision::ppdet::PPYOLOE>(ppdet_module,
"PPYOLO")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
Frontend>());
pybind11::class_<vision::ppdet::PicoDet, vision::ppdet::PPYOLOE>(ppdet_module,
"PicoDet")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
Frontend>());
pybind11::class_<vision::ppdet::YOLOX, vision::ppdet::PPYOLOE>(ppdet_module,
"YOLOX")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
Frontend>());
pybind11::class_<vision::ppdet::FasterRCNN, vision::ppdet::PPYOLOE>(
ppdet_module, "FasterRCNN")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
Frontend>());
}
} // namespace fastdeploy
+14 -4
View File
@@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ppdet/ppyolo.h"
namespace fastdeploy {
@@ -48,20 +62,16 @@ bool PPYOLO::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
outputs->resize(3);
(*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape");
(*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor");
std::cout << "111111111" << std::endl;
float* ptr0 = static_cast<float*>((*outputs)[0].MutableData());
ptr0[0] = mat->Height();
ptr0[1] = mat->Width();
std::cout << "090909" << std::endl;
float* ptr2 = static_cast<float*>((*outputs)[2].MutableData());
ptr2[0] = mat->Height() * 1.0 / origin_h;
ptr2[1] = mat->Width() * 1.0 / origin_w;
std::cout << "88888" << std::endl;
(*outputs)[1].name = "image";
mat->ShareWithTensor(&((*outputs)[1]));
// reshape to [1, c, h, w]
(*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1);
std::cout << "??????" << std::endl;
return true;
}
+25 -9
View File
@@ -101,21 +101,38 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() {
FDASSERT(target_size.size(),
"Require size of target_size be 2, but now it's " +
std::to_string(target_size.size()) + ".");
FDASSERT(!keep_ratio,
"Only support keep_ratio is false while deploy "
"PaddleDetection model.");
int width = target_size[1];
int height = target_size[0];
processors_.push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
if (!keep_ratio) {
int width = target_size[1];
int height = target_size[0];
processors_.push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
} else {
int min_target_size = std::min(target_size[0], target_size[1]);
int max_target_size = std::max(target_size[0], target_size[1]);
processors_.push_back(std::make_shared<ResizeByShort>(
min_target_size, interp, true, max_target_size));
}
} else if (op_name == "Permute") {
processors_.push_back(std::make_shared<HWC2CHW>());
// Do nothing, do permute as the last operation
continue;
// processors_.push_back(std::make_shared<HWC2CHW>());
} else if (op_name == "Pad") {
auto size = op["size"].as<std::vector<int>>();
auto value = op["fill_value"].as<std::vector<float>>();
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(
std::make_shared<PadToSize>(size[1], size[0], value));
} else if (op_name == "PadStride") {
auto stride = op["stride"].as<int>();
processors_.push_back(
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
} else {
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
processors_.push_back(std::make_shared<HWC2CHW>());
return true;
}
@@ -224,7 +241,6 @@ bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) {
processed_data[0].PrintInfo("Before infer");
float* tmp = static_cast<float*>(processed_data[1].Data());
std::cout << "==== " << tmp[0] << " " << tmp[1] << std::endl;
std::vector<FDTensor> infer_result;
if (!Infer(processed_data, &infer_result)) {
FDERROR << "Failed to inference while using model:" << ModelName() << "."
+19
View File
@@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
@@ -47,6 +61,11 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel {
bool normalized = true;
bool has_nms_ = false;
};
// Read configuration and build pipeline to process input image
bool BuildPreprocessPipelineFromConfig(
std::vector<std::shared_ptr<Processor>>* processors,
const std::string& config_file);
} // namespace ppdet
} // namespace vision
} // namespace fastdeploy
+89
View File
@@ -0,0 +1,89 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ppdet/rcnn.h"
namespace fastdeploy {
namespace vision {
namespace ppdet {
FasterRCNN::FasterRCNN(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const Frontend& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::PDINFER};
valid_gpu_backends = {Backend::PDINFER};
has_nms_ = true;
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool FasterRCNN::Initialize() {
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
return false;
}
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool FasterRCNN::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
int origin_w = mat->Width();
int origin_h = mat->Height();
mat->PrintInfo("Origin");
float scale[2] = {1.0, 1.0};
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
if (processors_[i]->Name().find("Resize") != std::string::npos) {
scale[0] = mat->Height() * 1.0 / origin_h;
scale[1] = mat->Width() * 1.0 / origin_w;
}
mat->PrintInfo(processors_[i]->Name());
}
outputs->resize(3);
(*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape");
(*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor");
float* ptr0 = static_cast<float*>((*outputs)[0].MutableData());
ptr0[0] = mat->Height();
ptr0[1] = mat->Width();
float* ptr2 = static_cast<float*>((*outputs)[2].MutableData());
ptr2[0] = scale[0];
ptr2[1] = scale[1];
(*outputs)[1].name = "image";
mat->ShareWithTensor(&((*outputs)[1]));
// reshape to [1, c, h, w]
(*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1);
(*outputs)[0].PrintInfo("im_shape");
(*outputs)[1].PrintInfo("image");
(*outputs)[2].PrintInfo("scale_factor");
return true;
}
} // namespace ppdet
} // namespace vision
} // namespace fastdeploy
+39
View File
@@ -0,0 +1,39 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/ppdet/ppyoloe.h"
namespace fastdeploy {
namespace vision {
namespace ppdet {
class FASTDEPLOY_DECL FasterRCNN : public PPYOLOE {
public:
FasterRCNN(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::PADDLE);
virtual std::string ModelName() const { return "PaddleDetection/FasterRCNN"; }
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
virtual bool Initialize();
protected:
FasterRCNN() {}
};
} // namespace ppdet
} // namespace vision
} // namespace fastdeploy
+14 -4
View File
@@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ppdet/yolov3.h"
namespace fastdeploy {
@@ -34,20 +48,16 @@ bool YOLOv3::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
outputs->resize(3);
(*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape");
(*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor");
std::cout << "111111111" << std::endl;
float* ptr0 = static_cast<float*>((*outputs)[0].MutableData());
ptr0[0] = mat->Height();
ptr0[1] = mat->Width();
std::cout << "090909" << std::endl;
float* ptr2 = static_cast<float*>((*outputs)[2].MutableData());
ptr2[0] = mat->Height() * 1.0 / origin_h;
ptr2[1] = mat->Width() * 1.0 / origin_w;
std::cout << "88888" << std::endl;
(*outputs)[1].name = "image";
mat->ShareWithTensor(&((*outputs)[1]));
// reshape to [1, c, h, w]
(*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1);
std::cout << "??????" << std::endl;
return true;
}
+14
View File
@@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/ppdet/ppyoloe.h"
+74
View File
@@ -0,0 +1,74 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ppdet/yolox.h"
namespace fastdeploy {
namespace vision {
namespace ppdet {
YOLOX::YOLOX(const std::string& model_file, const std::string& params_file,
const std::string& config_file, const RuntimeOption& custom_option,
const Frontend& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
background_label = -1;
keep_top_k = 1000;
nms_eta = 1;
nms_threshold = 0.65;
nms_top_k = 10000;
normalized = true;
score_threshold = 0.001;
initialized = Initialize();
}
bool YOLOX::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
int origin_w = mat->Width();
int origin_h = mat->Height();
float scale[2] = {1.0, 1.0};
mat->PrintInfo("Origin");
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
mat->PrintInfo(processors_[i]->Name());
if (processors_[i]->Name().find("Resize") != std::string::npos) {
scale[0] = mat->Height() * 1.0 / origin_h;
scale[1] = mat->Width() * 1.0 / origin_w;
}
}
outputs->resize(2);
(*outputs)[0].name = InputInfoOfRuntime(0).name;
mat->ShareWithTensor(&((*outputs)[0]));
// reshape to [1, c, h, w]
(*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1);
(*outputs)[1].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(1).name);
float* ptr = static_cast<float*>((*outputs)[1].MutableData());
ptr[0] = scale[0];
ptr[1] = scale[1];
return true;
}
} // namespace ppdet
} // namespace vision
} // namespace fastdeploy
+35
View File
@@ -0,0 +1,35 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/ppdet/ppyoloe.h"
namespace fastdeploy {
namespace vision {
namespace ppdet {
class FASTDEPLOY_DECL YOLOX : public PPYOLOE {
public:
YOLOX(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::PADDLE);
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
virtual std::string ModelName() const { return "PaddleDetection/YOLOX"; }
};
} // namespace ppdet
} // namespace vision
} // namespace fastdeploy
@@ -328,7 +328,6 @@ bool YOLOv5Lite::Predict(cv::Mat* im, DetectionResult* result,
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_START(0)
#endif
std::cout << nms_iou_threshold << nms_iou_threshold << std::endl;
Mat mat(*im);
std::vector<FDTensor> input_tensors(1);
+65 -1
View File
@@ -27,7 +27,7 @@ class PPYOLOE(FastDeployModel):
model_format=Frontend.PADDLE):
super(PPYOLOE, self).__init__(runtime_option)
assert model_format == Frontend.PADDLE, "PPYOLOE only support model format of Frontend.Paddle now."
assert model_format == Frontend.PADDLE, "PPYOLOE model only support model format of Frontend.Paddle now."
self._model = C.vision.ppdet.PPYOLOE(model_file, params_file,
config_file, self._runtime_option,
model_format)
@@ -36,3 +36,67 @@ class PPYOLOE(FastDeployModel):
def predict(self, input_image):
assert input_image is not None, "The input image data is None."
return self._model.predict(input_image)
class PPYOLO(PPYOLOE):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=Frontend.PADDLE):
super(PPYOLO, self).__init__(runtime_option)
assert model_format == Frontend.PADDLE, "PPYOLO model only support model format of Frontend.Paddle now."
self._model = C.vision.ppdet.PPYOLO(model_file, params_file,
config_file, self._runtime_option,
model_format)
assert self.initialized, "PPYOLO model initialize failed."
class YOLOX(PPYOLOE):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=Frontend.PADDLE):
super(YOLOX, self).__init__(runtime_option)
assert model_format == Frontend.PADDLE, "YOLOX model only support model format of Frontend.Paddle now."
self._model = C.vision.ppdet.YOLOX(model_file, params_file,
config_file, self._runtime_option,
model_format)
assert self.initialized, "YOLOX model initialize failed."
class PicoDet(PPYOLOE):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=Frontend.PADDLE):
super(PicoDet, self).__init__(runtime_option)
assert model_format == Frontend.PADDLE, "PicoDet model only support model format of Frontend.Paddle now."
self._model = C.vision.ppdet.PicoDet(model_file, params_file,
config_file, self._runtime_option,
model_format)
assert self.initialized, "PicoDet model initialize failed."
class FasterRCNN(PPYOLOE):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=Frontend.PADDLE):
super(FasterRCNN, self).__init__(runtime_option)
assert model_format == Frontend.PADDLE, "FasterRCNN model only support model format of Frontend.Paddle now."
self._model = C.vision.ppdet.FasterRCNN(
model_file, params_file, config_file, self._runtime_option,
model_format)
assert self.initialized, "FasterRCNN model initialize failed."