Add more SessionOptions functions

- This change adds three new functions for SessionOptions:
   SetOptimizedModelFilePath, EnableProfiling, and DisableProfiling. All
   three are simple wrappers around their respective C API functions
   with the same names.
This commit is contained in:
yalue
2026-03-04 10:05:52 -05:00
parent 45eb52b8ed
commit 2215308333
4 changed files with 132 additions and 0 deletions
+40
View File
@@ -1930,6 +1930,46 @@ func (o *SessionOptions) AppendExecutionProvider(providerName string,
return nil
}
// Wraps the SetOptimizedModelFilePath API function for these session options.
// Onnxruntime will save the optimized model file to the given path, after
// graph-level transformations.
func (o *SessionOptions) SetOptimizedModelFilePath(path string) error {
ortCharPath, e := createOrtCharString(path)
if e != nil {
return fmt.Errorf("Error encoding optimized model file path: %w", e)
}
defer C.free(unsafe.Pointer(ortCharPath))
status := C.SetOptimizedModelFilePath(o.o, ortCharPath)
if status != nil {
return statusToError(status)
}
return nil
}
// Enables profiling for these session options. The profile will be a JSON file
// with the given path prefix.
func (o *SessionOptions) EnableProfiling(profileFilePrefix string) error {
ortCharPath, e := createOrtCharString(profileFilePrefix)
if e != nil {
return fmt.Errorf("Error encoding profile path prefix: %w", e)
}
defer C.free(unsafe.Pointer(ortCharPath))
status := C.EnableProfiling(o.o, ortCharPath)
if status != nil {
return statusToError(status)
}
return nil
}
// Disables profiling for these session options.
func (o *SessionOptions) DisableProfiling() error {
status := C.DisableProfiling(o.o)
if status != nil {
return statusToError(status)
}
return nil
}
// Initializes and returns a SessionOptions struct, used when setting options
// in new AdvancedSession instances. The caller must call the Destroy()
// function on the returned struct when it's no longer needed.
+69
View File
@@ -2,9 +2,11 @@ package onnxruntime_go
import (
"fmt"
"io/fs"
"math"
"math/rand"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
@@ -1930,6 +1932,73 @@ func TestSessionOptionsConfig(t *testing.T) {
}
}
func TestSessionWithProfilingAndOptimization(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
tmpDir := t.TempDir()
profilePrefix := filepath.Join(tmpDir, "profile")
optimizedModelPath := filepath.Join(tmpDir, "optimized_model.onnx")
options, e := NewSessionOptions()
if e != nil {
t.Fatalf("Error creating session options: %s\n", e)
}
defer options.Destroy()
e = options.SetOptimizedModelFilePath(optimizedModelPath)
if e != nil {
t.Fatalf("Error setting optimized model file path to %s: %s",
optimizedModelPath, e)
}
e = options.EnableProfiling(profilePrefix)
if e != nil {
t.Fatalf("Error setting profile path prefix to %s: %s", profilePrefix,
e)
}
// Set up and run a useless session to make sure our files got created.
originalModelPath := "test_data/example ż 大 김.onnx"
input := newTestTensor[int32](t, NewShape(1, 2))
defer input.Destroy()
output := newTestTensor[int32](t, NewShape(1))
defer output.Destroy()
session, e := NewAdvancedSession(originalModelPath, []string{"in"},
[]string{"out"}, []Value{input}, []Value{output}, options)
if e != nil {
t.Fatalf("Failed creating session: %s\n", e)
}
e = session.Run()
// Destroy the session now so its profile is written.
session.Destroy()
if e != nil {
t.Fatalf("Error running session: %s\n", e)
}
numFiles := 0
e = filepath.Walk(tmpDir, func(f string, n fs.FileInfo, e error) error {
if e != nil {
t.Errorf("Got error traversing temp directory %s: %s\n", tmpDir, e)
return e
}
if f == tmpDir {
// Don't log or count the directory itself; we just want to ensure
// the profile and the model file were created.
return nil
}
t.Logf("Found file %s: %d bytes\n", f, n.Size())
numFiles++
return nil
})
if numFiles < 2 {
t.Errorf("Failed to create both a profile and an optimized model; "+
"only found %d files in temp dir %s\n", numFiles, tmpDir)
}
e = options.DisableProfiling()
if e != nil {
t.Errorf("Error disabling profiling on session options where it "+
"was previously enabled: %s\n", e)
}
}
func TestGraphOptimizationLevel(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
+13
View File
@@ -222,6 +222,19 @@ OrtStatus *AppendExecutionProvider(OrtSessionOptions *o,
keys, values, num_keys);
}
OrtStatus *SetOptimizedModelFilePath(OrtSessionOptions *o,
char *path) {
return ort_api->SetOptimizedModelFilePath(o, (const ORTCHAR_T*) path);
}
OrtStatus *EnableProfiling(OrtSessionOptions *o, char *path) {
return ort_api->EnableProfiling(o, (const ORTCHAR_T*) path);
}
OrtStatus *DisableProfiling(OrtSessionOptions *o) {
return ort_api->DisableProfiling(o);
}
OrtStatus *RegisterExecutionProviderLibrary(OrtEnv *env,
const char *registration_name, char *path) {
return ort_api->RegisterExecutionProviderLibrary(env, registration_name,
+10
View File
@@ -160,6 +160,16 @@ OrtStatus *AppendExecutionProvider(OrtSessionOptions *o,
const char *provider_name, const char **keys, const char **values,
int num_keys);
// Wraps ort_api->SetOptimizedModelFilePath. NOTE: takes an ORTCHAR_T*.
OrtStatus *SetOptimizedModelFilePath(OrtSessionOptions *o,
char *path);
// Wraps ort_api->EnableProfiling. NOTE: takes an ORTCHAR_T*.
OrtStatus *EnableProfiling(OrtSessionOptions *o, char *path);
// Wraps ort_api->DisableProfiling.
OrtStatus *DisableProfiling(OrtSessionOptions *o);
// Wraps ort_api->RegisterExecutionProviderLibrary
OrtStatus *RegisterExecutionProviderLibrary(OrtEnv *env,
const char *registration_name, char *path);