mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2026-04-22 16:17:24 +08:00
Add more SessionOptions functions
- This change adds three new functions for SessionOptions: SetOptimizedModelFilePath, EnableProfiling, and DisableProfiling. All three are simple wrappers around their respective C API functions with the same names.
This commit is contained in:
@@ -1930,6 +1930,46 @@ func (o *SessionOptions) AppendExecutionProvider(providerName string,
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wraps the SetOptimizedModelFilePath API function for these session options.
|
||||
// Onnxruntime will save the optimized model file to the given path, after
|
||||
// graph-level transformations.
|
||||
func (o *SessionOptions) SetOptimizedModelFilePath(path string) error {
|
||||
ortCharPath, e := createOrtCharString(path)
|
||||
if e != nil {
|
||||
return fmt.Errorf("Error encoding optimized model file path: %w", e)
|
||||
}
|
||||
defer C.free(unsafe.Pointer(ortCharPath))
|
||||
status := C.SetOptimizedModelFilePath(o.o, ortCharPath)
|
||||
if status != nil {
|
||||
return statusToError(status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Enables profiling for these session options. The profile will be a JSON file
|
||||
// with the given path prefix.
|
||||
func (o *SessionOptions) EnableProfiling(profileFilePrefix string) error {
|
||||
ortCharPath, e := createOrtCharString(profileFilePrefix)
|
||||
if e != nil {
|
||||
return fmt.Errorf("Error encoding profile path prefix: %w", e)
|
||||
}
|
||||
defer C.free(unsafe.Pointer(ortCharPath))
|
||||
status := C.EnableProfiling(o.o, ortCharPath)
|
||||
if status != nil {
|
||||
return statusToError(status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disables profiling for these session options.
|
||||
func (o *SessionOptions) DisableProfiling() error {
|
||||
status := C.DisableProfiling(o.o)
|
||||
if status != nil {
|
||||
return statusToError(status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initializes and returns a SessionOptions struct, used when setting options
|
||||
// in new AdvancedSession instances. The caller must call the Destroy()
|
||||
// function on the returned struct when it's no longer needed.
|
||||
|
||||
@@ -2,9 +2,11 @@ package onnxruntime_go
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -1930,6 +1932,73 @@ func TestSessionOptionsConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionWithProfilingAndOptimization(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
tmpDir := t.TempDir()
|
||||
profilePrefix := filepath.Join(tmpDir, "profile")
|
||||
optimizedModelPath := filepath.Join(tmpDir, "optimized_model.onnx")
|
||||
options, e := NewSessionOptions()
|
||||
if e != nil {
|
||||
t.Fatalf("Error creating session options: %s\n", e)
|
||||
}
|
||||
defer options.Destroy()
|
||||
e = options.SetOptimizedModelFilePath(optimizedModelPath)
|
||||
if e != nil {
|
||||
t.Fatalf("Error setting optimized model file path to %s: %s",
|
||||
optimizedModelPath, e)
|
||||
}
|
||||
e = options.EnableProfiling(profilePrefix)
|
||||
if e != nil {
|
||||
t.Fatalf("Error setting profile path prefix to %s: %s", profilePrefix,
|
||||
e)
|
||||
}
|
||||
|
||||
// Set up and run a useless session to make sure our files got created.
|
||||
originalModelPath := "test_data/example ż 大 김.onnx"
|
||||
input := newTestTensor[int32](t, NewShape(1, 2))
|
||||
defer input.Destroy()
|
||||
output := newTestTensor[int32](t, NewShape(1))
|
||||
defer output.Destroy()
|
||||
session, e := NewAdvancedSession(originalModelPath, []string{"in"},
|
||||
[]string{"out"}, []Value{input}, []Value{output}, options)
|
||||
if e != nil {
|
||||
t.Fatalf("Failed creating session: %s\n", e)
|
||||
}
|
||||
e = session.Run()
|
||||
// Destroy the session now so its profile is written.
|
||||
session.Destroy()
|
||||
if e != nil {
|
||||
t.Fatalf("Error running session: %s\n", e)
|
||||
}
|
||||
|
||||
numFiles := 0
|
||||
e = filepath.Walk(tmpDir, func(f string, n fs.FileInfo, e error) error {
|
||||
if e != nil {
|
||||
t.Errorf("Got error traversing temp directory %s: %s\n", tmpDir, e)
|
||||
return e
|
||||
}
|
||||
if f == tmpDir {
|
||||
// Don't log or count the directory itself; we just want to ensure
|
||||
// the profile and the model file were created.
|
||||
return nil
|
||||
}
|
||||
t.Logf("Found file %s: %d bytes\n", f, n.Size())
|
||||
numFiles++
|
||||
return nil
|
||||
})
|
||||
if numFiles < 2 {
|
||||
t.Errorf("Failed to create both a profile and an optimized model; "+
|
||||
"only found %d files in temp dir %s\n", numFiles, tmpDir)
|
||||
}
|
||||
|
||||
e = options.DisableProfiling()
|
||||
if e != nil {
|
||||
t.Errorf("Error disabling profiling on session options where it "+
|
||||
"was previously enabled: %s\n", e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphOptimizationLevel(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
|
||||
@@ -222,6 +222,19 @@ OrtStatus *AppendExecutionProvider(OrtSessionOptions *o,
|
||||
keys, values, num_keys);
|
||||
}
|
||||
|
||||
OrtStatus *SetOptimizedModelFilePath(OrtSessionOptions *o,
|
||||
char *path) {
|
||||
return ort_api->SetOptimizedModelFilePath(o, (const ORTCHAR_T*) path);
|
||||
}
|
||||
|
||||
OrtStatus *EnableProfiling(OrtSessionOptions *o, char *path) {
|
||||
return ort_api->EnableProfiling(o, (const ORTCHAR_T*) path);
|
||||
}
|
||||
|
||||
OrtStatus *DisableProfiling(OrtSessionOptions *o) {
|
||||
return ort_api->DisableProfiling(o);
|
||||
}
|
||||
|
||||
OrtStatus *RegisterExecutionProviderLibrary(OrtEnv *env,
|
||||
const char *registration_name, char *path) {
|
||||
return ort_api->RegisterExecutionProviderLibrary(env, registration_name,
|
||||
|
||||
@@ -160,6 +160,16 @@ OrtStatus *AppendExecutionProvider(OrtSessionOptions *o,
|
||||
const char *provider_name, const char **keys, const char **values,
|
||||
int num_keys);
|
||||
|
||||
// Wraps ort_api->SetOptimizedModelFilePath. NOTE: takes an ORTCHAR_T*.
|
||||
OrtStatus *SetOptimizedModelFilePath(OrtSessionOptions *o,
|
||||
char *path);
|
||||
|
||||
// Wraps ort_api->EnableProfiling. NOTE: takes an ORTCHAR_T*.
|
||||
OrtStatus *EnableProfiling(OrtSessionOptions *o, char *path);
|
||||
|
||||
// Wraps ort_api->DisableProfiling.
|
||||
OrtStatus *DisableProfiling(OrtSessionOptions *o);
|
||||
|
||||
// Wraps ort_api->RegisterExecutionProviderLibrary
|
||||
OrtStatus *RegisterExecutionProviderLibrary(OrtEnv *env,
|
||||
const char *registration_name, char *path);
|
||||
|
||||
Reference in New Issue
Block a user