mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-05-08 08:23:25 +08:00
43cabceb6d
* [Backend] TRT backend & PP-Infer backend support pinned memory (#403) * TRT backend use pinned memory * refine fd tensor pinned memory logic * TRT enable pinned memory configurable * paddle inference support pinned memory * pinned memory pybindings Co-authored-by: Jason <jiangjiajun@baidu.com> * [Bug Fix] release task scripts (#411) * Update py_run.bat * Update cpp_run.bat * Update compare_with_gt.py Increase score_diff and boxes_diff_ratio threshold * Update cpp_run.bat * Update release task scripts according to diffrent platforms * Delete CMAKE_CXX_COMPILER in cpp_run.bat * [Doc] add contributor for js application (#413) add contributor * [Other] Refactor js submodule (#415) * Refactor js submodule * Remove change-log * Update ocr module * Update ocr-detection module * Update ocr-detection module * Remove change-log * [Doc] Add PicoDet & PaddleClas Android demo docs (#412) * [Backend] Add override flag to lite backend * [Docs] Add Android C++ SDK build docs * [Doc] fix android_build_docs typos * Update CMakeLists.txt * Update android.md * [Doc] Add PicoDet Android demo docs * [Doc] Update PicoDet Andorid demo docs * [Doc] Update PaddleClasModel Android demo docs * [Doc] Update fastdeploy android jni docs * [Doc] Update fastdeploy android jni usage docs Co-authored-by: Jason <jiangjiajun@baidu.com> * Update README.md * Update README_CN.md * Update README_CN.md * Update README_EN.md * [Doc] Add tutorial of supporting new models (#418) * first commit for yolov7 * pybind for yolov7 * CPP README.md * CPP README.md * modified yolov7.cc * README.md * python file modify * delete license in fastdeploy/ * repush the conflict part * README.md modified * README.md modified * file path modified * file path modified * file path modified * file path modified * file path modified * README modified * README modified * move some helpers to private * add examples for yolov7 * api.md modified * api.md modified * api.md modified * YOLOv7 * yolov7 release link * yolov7 release link * yolov7 release link * copyright * change some helpers to private * change variables to const and fix documents. * gitignore * Transfer some funtions to private member of class * Transfer some funtions to private member of class * Merge from develop (#9) * Fix compile problem in different python version (#26) * fix some usage problem in linux * Fix compile problem Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> * Add PaddleDetetion/PPYOLOE model support (#22) * add ppdet/ppyoloe * Add demo code and documents * add convert processor to vision (#27) * update .gitignore * Added checking for cmake include dir * fixed missing trt_backend option bug when init from trt * remove un-need data layout and add pre-check for dtype * changed RGB2BRG to BGR2RGB in ppcls model * add model_zoo yolov6 c++/python demo * fixed CMakeLists.txt typos * update yolov6 cpp/README.md * add yolox c++/pybind and model_zoo demo * move some helpers to private * fixed CMakeLists.txt typos * add normalize with alpha and beta * add version notes for yolov5/yolov6/yolox * add copyright to yolov5.cc * revert normalize * fixed some bugs in yolox * fixed examples/CMakeLists.txt to avoid conflicts * add convert processor to vision * format examples/CMakeLists summary * Fix bug while the inference result is empty with YOLOv5 (#29) * Add multi-label function for yolov5 * Update README.md Update doc * Update fastdeploy_runtime.cc fix variable option.trt_max_shape wrong name * Update runtime_option.md Update resnet model dynamic shape setting name from images to x * Fix bug when inference result boxes are empty * Delete detection.py Co-authored-by: Jason <jiangjiajun@baidu.com> Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> * first commit for yolor * for merge * Develop (#11) * Fix compile problem in different python version (#26) * fix some usage problem in linux * Fix compile problem Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> * Add PaddleDetetion/PPYOLOE model support (#22) * add ppdet/ppyoloe * Add demo code and documents * add convert processor to vision (#27) * update .gitignore * Added checking for cmake include dir * fixed missing trt_backend option bug when init from trt * remove un-need data layout and add pre-check for dtype * changed RGB2BRG to BGR2RGB in ppcls model * add model_zoo yolov6 c++/python demo * fixed CMakeLists.txt typos * update yolov6 cpp/README.md * add yolox c++/pybind and model_zoo demo * move some helpers to private * fixed CMakeLists.txt typos * add normalize with alpha and beta * add version notes for yolov5/yolov6/yolox * add copyright to yolov5.cc * revert normalize * fixed some bugs in yolox * fixed examples/CMakeLists.txt to avoid conflicts * add convert processor to vision * format examples/CMakeLists summary * Fix bug while the inference result is empty with YOLOv5 (#29) * Add multi-label function for yolov5 * Update README.md Update doc * Update fastdeploy_runtime.cc fix variable option.trt_max_shape wrong name * Update runtime_option.md Update resnet model dynamic shape setting name from images to x * Fix bug when inference result boxes are empty * Delete detection.py Co-authored-by: Jason <jiangjiajun@baidu.com> Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> * Yolor (#16) * Develop (#11) (#12) * Fix compile problem in different python version (#26) * fix some usage problem in linux * Fix compile problem Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> * Add PaddleDetetion/PPYOLOE model support (#22) * add ppdet/ppyoloe * Add demo code and documents * add convert processor to vision (#27) * update .gitignore * Added checking for cmake include dir * fixed missing trt_backend option bug when init from trt * remove un-need data layout and add pre-check for dtype * changed RGB2BRG to BGR2RGB in ppcls model * add model_zoo yolov6 c++/python demo * fixed CMakeLists.txt typos * update yolov6 cpp/README.md * add yolox c++/pybind and model_zoo demo * move some helpers to private * fixed CMakeLists.txt typos * add normalize with alpha and beta * add version notes for yolov5/yolov6/yolox * add copyright to yolov5.cc * revert normalize * fixed some bugs in yolox * fixed examples/CMakeLists.txt to avoid conflicts * add convert processor to vision * format examples/CMakeLists summary * Fix bug while the inference result is empty with YOLOv5 (#29) * Add multi-label function for yolov5 * Update README.md Update doc * Update fastdeploy_runtime.cc fix variable option.trt_max_shape wrong name * Update runtime_option.md Update resnet model dynamic shape setting name from images to x * Fix bug when inference result boxes are empty * Delete detection.py Co-authored-by: Jason <jiangjiajun@baidu.com> Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> Co-authored-by: Jason <jiangjiajun@baidu.com> Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> * Develop (#13) * Fix compile problem in different python version (#26) * fix some usage problem in linux * Fix compile problem Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> * Add PaddleDetetion/PPYOLOE model support (#22) * add ppdet/ppyoloe * Add demo code and documents * add convert processor to vision (#27) * update .gitignore * Added checking for cmake include dir * fixed missing trt_backend option bug when init from trt * remove un-need data layout and add pre-check for dtype * changed RGB2BRG to BGR2RGB in ppcls model * add model_zoo yolov6 c++/python demo * fixed CMakeLists.txt typos * update yolov6 cpp/README.md * add yolox c++/pybind and model_zoo demo * move some helpers to private * fixed CMakeLists.txt typos * add normalize with alpha and beta * add version notes for yolov5/yolov6/yolox * add copyright to yolov5.cc * revert normalize * fixed some bugs in yolox * fixed examples/CMakeLists.txt to avoid conflicts * add convert processor to vision * format examples/CMakeLists summary * Fix bug while the inference result is empty with YOLOv5 (#29) * Add multi-label function for yolov5 * Update README.md Update doc * Update fastdeploy_runtime.cc fix variable option.trt_max_shape wrong name * Update runtime_option.md Update resnet model dynamic shape setting name from images to x * Fix bug when inference result boxes are empty * Delete detection.py Co-authored-by: Jason <jiangjiajun@baidu.com> Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> * documents * documents * documents * documents * documents * documents * documents * documents * documents * documents * documents * documents * Develop (#14) * Fix compile problem in different python version (#26) * fix some usage problem in linux * Fix compile problem Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> * Add PaddleDetetion/PPYOLOE model support (#22) * add ppdet/ppyoloe * Add demo code and documents * add convert processor to vision (#27) * update .gitignore * Added checking for cmake include dir * fixed missing trt_backend option bug when init from trt * remove un-need data layout and add pre-check for dtype * changed RGB2BRG to BGR2RGB in ppcls model * add model_zoo yolov6 c++/python demo * fixed CMakeLists.txt typos * update yolov6 cpp/README.md * add yolox c++/pybind and model_zoo demo * move some helpers to private * fixed CMakeLists.txt typos * add normalize with alpha and beta * add version notes for yolov5/yolov6/yolox * add copyright to yolov5.cc * revert normalize * fixed some bugs in yolox * fixed examples/CMakeLists.txt to avoid conflicts * add convert processor to vision * format examples/CMakeLists summary * Fix bug while the inference result is empty with YOLOv5 (#29) * Add multi-label function for yolov5 * Update README.md Update doc * Update fastdeploy_runtime.cc fix variable option.trt_max_shape wrong name * Update runtime_option.md Update resnet model dynamic shape setting name from images to x * Fix bug when inference result boxes are empty * Delete detection.py Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> Co-authored-by: Jason <jiangjiajun@baidu.com> Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> Co-authored-by: Jason <928090362@qq.com> * add is_dynamic for YOLO series (#22) * modify ppmatting backend and docs * modify ppmatting docs * fix the PPMatting size problem * fix LimitShort's log * retrigger ci * modify PPMatting docs * modify the way for dealing with LimitShort * change develop_a_new_model.md dir Co-authored-by: Jason <jiangjiajun@baidu.com> Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> Co-authored-by: Jason <928090362@qq.com> * [Doc] add readme for js packages (#421) * add contributor * add package readme * refine ocr readme * refine ocr readme Co-authored-by: Wang Xinyu <wangxinyu_es@163.com> Co-authored-by: huangjianhui <852142024@qq.com> Co-authored-by: Double_V <liuvv0203@163.com> Co-authored-by: chenqianhe <54462604+chenqianhe@users.noreply.github.com> Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: leiqing <54695910+leiqing1@users.noreply.github.com> Co-authored-by: ziqi-jin <67993288+ziqi-jin@users.noreply.github.com> Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
214 lines
10 KiB
C++
Executable File
214 lines
10 KiB
C++
Executable File
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "fastdeploy/pybind/main.h"
|
|
|
|
namespace fastdeploy {
|
|
|
|
void BindRuntime(pybind11::module& m) {
|
|
pybind11::class_<RuntimeOption>(m, "RuntimeOption")
|
|
.def(pybind11::init())
|
|
.def("set_model_path", &RuntimeOption::SetModelPath)
|
|
.def("use_gpu", &RuntimeOption::UseGpu)
|
|
.def("use_cpu", &RuntimeOption::UseCpu)
|
|
.def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum)
|
|
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
|
|
.def("use_poros_backend", &RuntimeOption::UsePorosBackend)
|
|
.def("use_ort_backend", &RuntimeOption::UseOrtBackend)
|
|
.def("set_ort_graph_opt_level", &RuntimeOption::SetOrtGraphOptLevel)
|
|
.def("use_trt_backend", &RuntimeOption::UseTrtBackend)
|
|
.def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend)
|
|
.def("use_lite_backend", &RuntimeOption::UseLiteBackend)
|
|
.def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN)
|
|
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)
|
|
.def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo)
|
|
.def("set_paddle_mkldnn_cache_size",
|
|
&RuntimeOption::SetPaddleMKLDNNCacheSize)
|
|
.def("enable_lite_fp16", &RuntimeOption::EnableLiteFP16)
|
|
.def("disable_lite_fp16", &RuntimeOption::DisableLiteFP16)
|
|
.def("set_lite_power_mode", &RuntimeOption::SetLitePowerMode)
|
|
.def("set_trt_input_shape", &RuntimeOption::SetTrtInputShape)
|
|
.def("set_trt_max_workspace_size", &RuntimeOption::SetTrtMaxWorkspaceSize)
|
|
.def("enable_paddle_to_trt", &RuntimeOption::EnablePaddleToTrt)
|
|
.def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16)
|
|
.def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16)
|
|
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
|
|
.def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory)
|
|
.def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory)
|
|
.def("enable_paddle_trt_collect_shape", &RuntimeOption::EnablePaddleTrtCollectShape)
|
|
.def("disable_paddle_trt_collect_shape", &RuntimeOption::DisablePaddleTrtCollectShape)
|
|
.def_readwrite("model_file", &RuntimeOption::model_file)
|
|
.def_readwrite("params_file", &RuntimeOption::params_file)
|
|
.def_readwrite("model_format", &RuntimeOption::model_format)
|
|
.def_readwrite("backend", &RuntimeOption::backend)
|
|
.def_readwrite("cpu_thread_num", &RuntimeOption::cpu_thread_num)
|
|
.def_readwrite("device_id", &RuntimeOption::device_id)
|
|
.def_readwrite("device", &RuntimeOption::device)
|
|
.def_readwrite("ort_graph_opt_level", &RuntimeOption::ort_graph_opt_level)
|
|
.def_readwrite("ort_inter_op_num_threads",
|
|
&RuntimeOption::ort_inter_op_num_threads)
|
|
.def_readwrite("ort_execution_mode", &RuntimeOption::ort_execution_mode)
|
|
.def_readwrite("trt_max_shape", &RuntimeOption::trt_max_shape)
|
|
.def_readwrite("trt_opt_shape", &RuntimeOption::trt_opt_shape)
|
|
.def_readwrite("trt_min_shape", &RuntimeOption::trt_min_shape)
|
|
.def_readwrite("trt_serialize_file", &RuntimeOption::trt_serialize_file)
|
|
.def_readwrite("trt_enable_fp16", &RuntimeOption::trt_enable_fp16)
|
|
.def_readwrite("trt_enable_int8", &RuntimeOption::trt_enable_int8)
|
|
.def_readwrite("trt_max_batch_size", &RuntimeOption::trt_max_batch_size)
|
|
.def_readwrite("trt_max_workspace_size",
|
|
&RuntimeOption::trt_max_workspace_size)
|
|
.def_readwrite("is_dynamic", &RuntimeOption::is_dynamic)
|
|
.def_readwrite("long_to_int", &RuntimeOption::long_to_int)
|
|
.def_readwrite("use_nvidia_tf32", &RuntimeOption::use_nvidia_tf32)
|
|
.def_readwrite("unconst_ops_thres", &RuntimeOption::unconst_ops_thres)
|
|
.def_readwrite("poros_file", &RuntimeOption::poros_file);
|
|
|
|
pybind11::class_<TensorInfo>(m, "TensorInfo")
|
|
.def_readwrite("name", &TensorInfo::name)
|
|
.def_readwrite("shape", &TensorInfo::shape)
|
|
.def_readwrite("dtype", &TensorInfo::dtype);
|
|
|
|
pybind11::class_<Runtime>(m, "Runtime")
|
|
.def(pybind11::init())
|
|
.def("init", &Runtime::Init)
|
|
.def("compile",
|
|
[](Runtime& self,
|
|
std::vector<std::vector<pybind11::array>>& warm_datas,
|
|
const RuntimeOption& _option) {
|
|
size_t rows = warm_datas.size();
|
|
size_t columns = warm_datas[0].size();
|
|
std::vector<std::vector<FDTensor>> warm_tensors(
|
|
rows, std::vector<FDTensor>(columns));
|
|
for (size_t i = 0; i < rows; ++i) {
|
|
for (size_t j = 0; j < columns; ++j) {
|
|
auto dtype =
|
|
NumpyDataTypeToFDDataType(warm_datas[i][j].dtype());
|
|
std::vector<int64_t> data_shape;
|
|
data_shape.insert(
|
|
data_shape.begin(), warm_datas[i][j].shape(),
|
|
warm_datas[i][j].shape() + warm_datas[i][j].ndim());
|
|
warm_tensors[i][j].Resize(data_shape, dtype);
|
|
memcpy(warm_tensors[i][j].MutableData(),
|
|
warm_datas[i][j].mutable_data(),
|
|
warm_datas[i][j].nbytes());
|
|
}
|
|
}
|
|
return self.Compile(warm_tensors, _option);
|
|
})
|
|
.def("infer",
|
|
[](Runtime& self, std::vector<FDTensor>& inputs) {
|
|
std::vector<FDTensor> outputs(self.NumOutputs());
|
|
self.Infer(inputs, &outputs);
|
|
return outputs;
|
|
})
|
|
.def("infer",
|
|
[](Runtime& self, std::map<std::string, pybind11::array>& data) {
|
|
std::vector<FDTensor> inputs(data.size());
|
|
int index = 0;
|
|
for (auto iter = data.begin(); iter != data.end(); ++iter) {
|
|
std::vector<int64_t> data_shape;
|
|
data_shape.insert(data_shape.begin(), iter->second.shape(),
|
|
iter->second.shape() + iter->second.ndim());
|
|
auto dtype = NumpyDataTypeToFDDataType(iter->second.dtype());
|
|
// TODO(jiangjiajun) Maybe skip memory copy is a better choice
|
|
// use SetExternalData
|
|
inputs[index].Resize(data_shape, dtype);
|
|
memcpy(inputs[index].MutableData(), iter->second.mutable_data(),
|
|
iter->second.nbytes());
|
|
inputs[index].name = iter->first;
|
|
index += 1;
|
|
}
|
|
|
|
std::vector<FDTensor> outputs(self.NumOutputs());
|
|
self.Infer(inputs, &outputs);
|
|
|
|
std::vector<pybind11::array> results;
|
|
results.reserve(outputs.size());
|
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
|
auto numpy_dtype = FDDataTypeToNumpyDataType(outputs[i].dtype);
|
|
results.emplace_back(
|
|
pybind11::array(numpy_dtype, outputs[i].shape));
|
|
memcpy(results[i].mutable_data(), outputs[i].Data(),
|
|
outputs[i].Numel() * FDDataTypeSize(outputs[i].dtype));
|
|
}
|
|
return results;
|
|
})
|
|
.def("num_inputs", &Runtime::NumInputs)
|
|
.def("num_outputs", &Runtime::NumOutputs)
|
|
.def("get_input_info", &Runtime::GetInputInfo)
|
|
.def("get_output_info", &Runtime::GetOutputInfo)
|
|
.def_readonly("option", &Runtime::option);
|
|
|
|
pybind11::enum_<Backend>(m, "Backend", pybind11::arithmetic(),
|
|
"Backend for inference.")
|
|
.value("UNKOWN", Backend::UNKNOWN)
|
|
.value("ORT", Backend::ORT)
|
|
.value("TRT", Backend::TRT)
|
|
.value("POROS", Backend::POROS)
|
|
.value("PDINFER", Backend::PDINFER)
|
|
.value("LITE", Backend::LITE);
|
|
pybind11::enum_<ModelFormat>(m, "ModelFormat", pybind11::arithmetic(),
|
|
"ModelFormat for inference.")
|
|
.value("PADDLE", ModelFormat::PADDLE)
|
|
.value("TORCHSCRIPT", ModelFormat::TORCHSCRIPT)
|
|
.value("ONNX", ModelFormat::ONNX);
|
|
pybind11::enum_<Device>(m, "Device", pybind11::arithmetic(),
|
|
"Device for inference.")
|
|
.value("CPU", Device::CPU)
|
|
.value("GPU", Device::GPU);
|
|
|
|
pybind11::enum_<FDDataType>(m, "FDDataType", pybind11::arithmetic(),
|
|
"Data type of FastDeploy.")
|
|
.value("BOOL", FDDataType::BOOL)
|
|
.value("INT8", FDDataType::INT8)
|
|
.value("INT16", FDDataType::INT16)
|
|
.value("INT32", FDDataType::INT32)
|
|
.value("INT64", FDDataType::INT64)
|
|
.value("FP16", FDDataType::FP16)
|
|
.value("FP32", FDDataType::FP32)
|
|
.value("FP64", FDDataType::FP64)
|
|
.value("UINT8", FDDataType::UINT8);
|
|
|
|
pybind11::class_<FDTensor>(m, "FDTensor", pybind11::buffer_protocol())
|
|
.def(pybind11::init())
|
|
.def("cpu_data",
|
|
[](FDTensor& self) {
|
|
auto ptr = self.CpuData();
|
|
auto numel = self.Numel();
|
|
auto dtype = FDDataTypeToNumpyDataType(self.dtype);
|
|
auto base = pybind11::array(dtype, self.shape);
|
|
return pybind11::array(dtype, self.shape, ptr, base);
|
|
})
|
|
.def("resize", static_cast<void (FDTensor::*)(size_t)>(&FDTensor::Resize))
|
|
.def("resize",
|
|
static_cast<void (FDTensor::*)(const std::vector<int64_t>&)>(
|
|
&FDTensor::Resize))
|
|
.def(
|
|
"resize",
|
|
[](FDTensor& self, const std::vector<int64_t>& shape,
|
|
const FDDataType& dtype, const std::string& name,
|
|
const Device& device) { self.Resize(shape, dtype, name, device); })
|
|
.def("numel", &FDTensor::Numel)
|
|
.def("nbytes", &FDTensor::Nbytes)
|
|
.def_readwrite("name", &FDTensor::name)
|
|
.def_readwrite("is_pinned_memory", &FDTensor::is_pinned_memory)
|
|
.def_readonly("shape", &FDTensor::shape)
|
|
.def_readonly("dtype", &FDTensor::dtype)
|
|
.def_readonly("device", &FDTensor::device);
|
|
|
|
m.def("get_available_backends", []() { return GetAvailableBackends(); });
|
|
}
|
|
|
|
} // namespace fastdeploy
|