class OnnxRuntime::TrainingSession

Overview

TrainingSession class provides high-level API for training ONNX models. This is a placeholder for future implementation when ONNX Runtime training API is supported.

Defined in:

onnxruntime/training_session.cr

Constructors

Instance Method Summary

Constructor Detail

def self.new(env : OrtEnvironment, model_path : String, **options) #

Creates a new TrainingSession instance.


[View source]

Instance Method Detail

def eval_step(input_feed) #

Evaluate the model


[View source]
def learning_rate=(rate : Float64) #

Set learning rate


[View source]
def load_checkpoint(checkpoint_path : String) #

Load a checkpoint


[View source]
def optimizer_state #

Get optimizer state


[View source]
def save_checkpoint(checkpoint_path : String) #

Save the trained model


[View source]
def train_step(input_feed) #

Train the model for one step


[View source]