mirror of
https://github.com/swdee/go-rknnlite.git
synced 2026-04-22 22:57:03 +08:00
Merge pull request #48 from uljjmhn520/model_from_bytes
Added NewRuntimeFromBytes() function to initialize the RKNN Runtime with a model loaded from bytes.
This commit is contained in:
@@ -11,6 +11,24 @@ import (
|
||||
)
|
||||
|
||||
func TestMobileNetTop5(t *testing.T) {
|
||||
testMobileNetTop5Common(t, func(modelFile string) (*Runtime, error) {
|
||||
return NewRuntime(modelFile, NPUCoreAuto)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMobileNetTop5FromBytes(t *testing.T) {
|
||||
testMobileNetTop5Common(t, func(modelFile string) (*Runtime, error) {
|
||||
modelBytes, err := os.ReadFile(modelFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewRuntimeFromBytes(modelBytes, NPUCoreAuto)
|
||||
})
|
||||
}
|
||||
|
||||
func testMobileNetTop5Common(t *testing.T, newRuntime func(modelFile string) (*Runtime, error)) {
|
||||
t.Helper()
|
||||
|
||||
modelFile := os.Getenv("RKNN_MODEL")
|
||||
|
||||
@@ -25,10 +43,10 @@ func TestMobileNetTop5(t *testing.T) {
|
||||
}
|
||||
|
||||
// Initialize runtime
|
||||
rt, err := NewRuntime(modelFile, NPUCoreAuto)
|
||||
rt, err := newRuntime(modelFile)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("NewRuntime failed: %v", err)
|
||||
t.Fatalf("runtime init failed: %v", err)
|
||||
}
|
||||
|
||||
defer rt.Close()
|
||||
@@ -80,8 +98,10 @@ func TestMobileNetTop5(t *testing.T) {
|
||||
}
|
||||
|
||||
if i > 0 && p.Probability > top5[i-1].Probability {
|
||||
t.Errorf("probabilities not descending: index %d has %v > previous %v",
|
||||
i, p.Probability, top5[i-1].Probability)
|
||||
t.Errorf(
|
||||
"probabilities not descending: index %d has %v > previous %v",
|
||||
i, p.Probability, top5[i-1].Probability,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,13 +110,16 @@ func TestMobileNetTop5(t *testing.T) {
|
||||
|
||||
for i, p := range top5 {
|
||||
if int(p.LabelIndex) < 0 || int(p.LabelIndex) >= numClasses {
|
||||
t.Errorf("entry %d: label index %d out of range [0,%d)", i, p.LabelIndex, numClasses)
|
||||
t.Errorf(
|
||||
"entry %d: label index %d out of range [0,%d)",
|
||||
i, p.LabelIndex, numClasses,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Sanity check: at least one probability above a tiny epsilon
|
||||
const eps = 1e-3
|
||||
var found bool
|
||||
found := false
|
||||
|
||||
for _, p := range top5 {
|
||||
if p.Probability > eps {
|
||||
|
||||
+126
-33
@@ -118,6 +118,10 @@ type Runtime struct {
|
||||
// inputTypeFloat32 indicates if we pass the input gocv.Mat's data as float32
|
||||
// to the RKNN backend
|
||||
inputTypeFloat32 bool
|
||||
// C-owned model buffer when initialized FromBytes()
|
||||
// Keep this alive for the lifetime of the RKNN context
|
||||
modelData unsafe.Pointer
|
||||
modelSize C.uint32_t
|
||||
}
|
||||
|
||||
// NewRuntime returns a RKNN run time instance. Provide the full path and
|
||||
@@ -134,32 +138,7 @@ func NewRuntime(modelFile string, core CoreMask) (*Runtime, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// setCoreMask is only supported on RK3588, allow skipping for other Rockchip models
|
||||
// like RK3566
|
||||
if core != NPUSkipSetCore {
|
||||
err = r.setCoreMask(core)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// cache IONumber
|
||||
r.ioNum, err = r.QueryModelIONumber()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// query Input tensors
|
||||
r.inputAttrs, err = r.QueryInputTensors()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// query Output tensors
|
||||
r.outputAttrs, err = r.QueryOutputTensors()
|
||||
err = r.setup(core)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -168,6 +147,61 @@ func NewRuntime(modelFile string, core CoreMask) (*Runtime, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// NewRuntimeFromBytes returns a RKNN run time instance. Provide the model as a byte buffer.
|
||||
func NewRuntimeFromBytes(modelBuffer []byte, core CoreMask) (*Runtime, error) {
|
||||
|
||||
r := &Runtime{
|
||||
wantFloat: true,
|
||||
}
|
||||
|
||||
err := r.initFromBytes(modelBuffer)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = r.setup(core)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// setup performs the common initialization steps for the RKNN runtime
|
||||
func (r *Runtime) setup(core CoreMask) error {
|
||||
var err error
|
||||
|
||||
// setCoreMask is only supported on RK3588, allow skipping for other Rockchip models
|
||||
// like RK3566
|
||||
if core != NPUSkipSetCore {
|
||||
if err = r.setCoreMask(core); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// cache IONumber
|
||||
r.ioNum, err = r.QueryModelIONumber()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// query Input tensors
|
||||
r.inputAttrs, err = r.QueryInputTensors()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// query Output tensors
|
||||
r.outputAttrs, err = r.QueryOutputTensors()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// init wraps C.rknn_init which initializes the RKNN context with the given
|
||||
// model. The modelFile is the full path and filename of the RKNN compiled
|
||||
// model file to run.
|
||||
@@ -200,6 +234,36 @@ func (r *Runtime) init(modelFile string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// initFromBytes copies modelBytes into C-allocated memory and initializes
|
||||
// the RKNN context from that C-owned model buffer so its lifetime is managed
|
||||
// independently of the Go GC.
|
||||
func (r *Runtime) initFromBytes(modelBytes []byte) error {
|
||||
|
||||
if len(modelBytes) == 0 {
|
||||
return fmt.Errorf("model bytes is empty")
|
||||
}
|
||||
|
||||
// Allocate C-owned memory so the model buffer lifetime is under our control.
|
||||
modelData := C.CBytes(modelBytes)
|
||||
if modelData == nil {
|
||||
return fmt.Errorf("failed to allocate C memory for model bytes")
|
||||
}
|
||||
|
||||
size := C.uint32_t(len(modelBytes))
|
||||
ret := C.rknn_init(&r.ctx, modelData, size, 0, nil)
|
||||
|
||||
if ret != C.RKNN_SUCC {
|
||||
C.free(modelData)
|
||||
return fmt.Errorf("C.rknn_init call failed with code %d, error: %s",
|
||||
ret, ErrorCodes(ret).String())
|
||||
}
|
||||
|
||||
r.modelData = modelData
|
||||
r.modelSize = size
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setCoreMark wraps C.rknn_set_core_mask and specifies the NPU core configuration
|
||||
// to run the model on
|
||||
func (r *Runtime) setCoreMask(mask CoreMask) error {
|
||||
@@ -225,9 +289,21 @@ func (r *Runtime) Close() error {
|
||||
ret, ErrorCodes(ret).String())
|
||||
}
|
||||
|
||||
// free any memory with loaded model data
|
||||
r.freeModelData()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// freeModelData frees C memory used to store Model when loaded from bytes
|
||||
func (r *Runtime) freeModelData() {
|
||||
if r.modelData != nil {
|
||||
C.free(r.modelData)
|
||||
r.modelData = nil
|
||||
r.modelSize = 0
|
||||
}
|
||||
}
|
||||
|
||||
// SetWantFloat defines if the Model load requires Output tensors to be converted
|
||||
// to float32 for post processing, or left as quantitized int8
|
||||
func (r *Runtime) SetWantFloat(val bool) {
|
||||
@@ -294,22 +370,39 @@ func (r *Runtime) OutputAttrs() []TensorAttr {
|
||||
// rk3562|rk3566|rk3568|rk3576|rk3582|rk3582|rk3588
|
||||
// Provide the full path and filename of the RKNN compiled model file to run
|
||||
func NewRuntimeByPlatform(platform string, modelFile string) (*Runtime, error) {
|
||||
useCore, err := getCoreMaskByPlatform(platform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewRuntime(modelFile, useCore)
|
||||
}
|
||||
|
||||
// NewRuntimeByPlatformFromBytes returns a RKNN run time instance and automatically
|
||||
// selects the NPU cores to run on the given platform string.
|
||||
func NewRuntimeByPlatformFromBytes(platform string, modelBytes []byte) (*Runtime, error) {
|
||||
useCore, err := getCoreMaskByPlatform(platform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewRuntimeFromBytes(modelBytes, useCore)
|
||||
}
|
||||
|
||||
func getCoreMaskByPlatform(platform string) (CoreMask, error) {
|
||||
|
||||
platform = strings.TrimSpace(platform)
|
||||
platform = strings.ToLower(platform)
|
||||
|
||||
var useCore CoreMask
|
||||
|
||||
switch platform {
|
||||
case "rk3562", "rk3566", "rk3568":
|
||||
useCore = NPUSkipSetCore
|
||||
|
||||
return NPUSkipSetCore, nil
|
||||
|
||||
case "rk3576", "rk3582", "rk3588":
|
||||
useCore = NPUCoreAuto
|
||||
|
||||
return NPUCoreAuto, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown platform: %s", platform)
|
||||
}
|
||||
|
||||
return NewRuntime(modelFile, useCore)
|
||||
return 0, fmt.Errorf("unknown platform: %s", platform)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user