Documentation
¶
Overview ¶
Package togomlx contains several conversion utilities from ONNX and GoMLX.
Package onnx provides functionality to parse ONNX models and generate the corresponding GoMLX.
- Parse: converts a serialized ONNX ModelProto to a Model.
- ReadFile: reads a file and calls Parse. It returns a Model.
- Model: object holding information about an ONNX model. It can be used to generate the corresponding GoMLX model graph and executed for inference or used on a training loop for fine-tuning. It can also be used to populate a context with the variables of the ONNX model.
Index ¶
- Constants
- Variables
- func SafeVarName(onnxName string) (gomlxName string)
- func Shape(proto *protos.TensorProto) (shape shapes.Shape, err error)
- func SparseShape(proto *protos.SparseTensorProto) (shape shapes.Shape, err error)
- func TensorValueToONNX(t *tensors.Tensor, proto *protos.TensorProto) (err error)
- type DynamicShape
- type Model
- func (m *Model) CallGraph(ctx *context.Context, g *Graph, inputs map[string]*Node, outputNames ...string) (outputs []*Node)
- func (m *Model) ContextToONNX(ctx *context.Context) error
- func (m *Model) Inputs() (names []string, dshapes []DynamicShape)
- func (m *Model) Name() string
- func (m *Model) NumInputs() int
- func (m *Model) Outputs() (names []string, dshapes []DynamicShape)
- func (m *Model) PrintGraph(writer io.Writer) error
- func (m *Model) PrintGraphviz(writer io.Writer, targets ...string) error
- func (m *Model) PrintVariables(writer io.Writer) error
- func (m *Model) SaveToFile(path string) error
- func (m *Model) String() string
- func (m *Model) ValidateInputs(inputsShapes ...shapes.Shape) error
- func (m *Model) VariablesToContext(ctx *context.Context) error
- func (m *Model) WithInputsAsConstants(inputsAsConstants map[string]any) *Model
- func (m *Model) Write(w io.Writer) error
Constants ¶
const UnnamedDynamicDimension = "?"
UnnamedDynamicDimension is a placeholder name for an unnamed dynamic dimension, that doesn't necessarily match any other (in inputs/outputs).
Variables ¶
var ( GraphvizInputColor = "#FFF59E" GraphvizVarColor = "#E0E0E0" )
var ModelScope = "ONNX"
ModelScope is the default model scope to use when for the ONNX model variables when converting to GoMLX.
Functions ¶
func SafeVarName ¶
SafeVarName converts an ONNX variable name to a GoMLX safe variable name by replacing the scope separator with a "|".
func Shape ¶ added in v0.1.3
func Shape(proto *protos.TensorProto) (shape shapes.Shape, err error)
Shape converts an ONNX data type and shape to GoMLX shapes.Shape (it includes the dtype).
func SparseShape ¶ added in v0.1.3
func SparseShape(proto *protos.SparseTensorProto) (shape shapes.Shape, err error)
SparseShape returns what would be the dense shape of an ONNX SparseTensor.
func TensorValueToONNX ¶ added in v0.1.3
func TensorValueToONNX(t *tensors.Tensor, proto *protos.TensorProto) (err error)
TensorValueToONNX copies the value of a GoMLX tensors.Tensor to the ONNX protos.TensorProto object handling errors and different data types.
Both tensors (GoMLX and ONNX) must already have the same shape.
Types ¶
type DynamicShape ¶
DynamicShape represents a shape for which some of the axes have unknown dimensions.
Similar to GoMLX Shape but some of the dimensions may be -1, denoting an undefined dimension.
Dimensions may also be named, in which case shapes of inputs and outputs with the same name should match.
func (DynamicShape) Rank ¶
func (dshape DynamicShape) Rank() int
Rank returns the DynamicShape's rank.
func (DynamicShape) String ¶
func (dshape DynamicShape) String() string
String implements fmt.Stringer.
type Model ¶
type Model struct { Proto protos.ModelProto InputsNames, OutputsNames []string InputsShapes, OutputsShapes []DynamicShape // contains filtered or unexported fields }
Model represents a parsed ONNX file.
func Parse ¶
Parse parses an ONNX model into an internal representation that can be used to build a GoMLX graph.
func ReadFile ¶
ReadFile parses an ONNX model file into an internal representation that can be used to build a GoMLX graph. Notice any large constant is converted to variables.
func (*Model) CallGraph ¶
func (m *Model) CallGraph(ctx *context.Context, g *Graph, inputs map[string]*Node, outputNames ...string) (outputs []*Node)
CallGraph calls the ONNX graph, and hence building it with GoMLX ops. This can be used for inference or training.
If the model has any variables, call Model.VariablesToContext first (only once) to upload all variable values from the ONNX model to the context -- or load them from a checkpoint if you saved one.
If the model has no variables, the context in ctx can be set to nil.
The inputs (a map of input name to its graph.Node) can be given as normal input parameters to the graph or as static constants -- see WithInputsAsConstants. Set the inputs as constants if they are meant to be interpreted as constants (static) values, that won't change in different inference/training steps.
If outputNames is not given, it will output the model's registered outputs. Alternatively, you can select any list of node outputs to generate. It will return the values for the selected outputs.
The graph being built is given in g.
As in GoMLX graph functions, it panics (throws exceptions) in case of errors.
func (*Model) ContextToONNX ¶ added in v0.1.3
ContextToONNX converts the variables in the context back to the ONNX model. Do this before saving the ONNX model back to disk.
It's the inverse of VariablesToContext, and the context given must be set in the same scope as when VariablesToContext was first called.
Only those variables present in the original ONNX model are converted -- so new variables (e.g.: optimizers (ADAM) moving averages) are converted.
func (*Model) Inputs ¶
func (m *Model) Inputs() (names []string, dshapes []DynamicShape)
Inputs returns the names and DynamicShapes of the inputs.
func (*Model) Outputs ¶
func (m *Model) Outputs() (names []string, dshapes []DynamicShape)
Outputs returns a description of the outputs.
func (*Model) PrintGraph ¶
PrintGraph prints a +/- human-readable (or debuggable) version of the graph to the given writer.
func (*Model) PrintGraphviz ¶
PrintGraphviz outputs the model graph using the "dot" language, starting from the target nodes towards its dependencies.
If targets is left empty, it takes the default graph outputs as targets.
func (*Model) SaveToFile ¶ added in v0.1.3
SaveToFile serializes the ONNX model to the given file.
This is useful, if the model variables were updated (e.g.: fine-tuning in GoMLX) and one wants to save the model. See ContextToONNX to copy over the variables in GoMLX's Context (presumably after some training/update) to the ONNX's model proto.
func (*Model) ValidateInputs ¶
ValidateInputs checks the inputs has a shape that is compatible with the DynamicShapes of the inputs for the model.
func (*Model) VariablesToContext ¶
VariablesToContext will create variables in the context (within scope ModelScope) from all variables present in the model initializer list.
Call this once in your context, before using the model with Model.CallGraph. Alternatively, if you have already checkpoint-ed your model, load the variables from a checkpoint and don't call this.
See also ContextToONNX, if after converting and fine-tuning an ONNX model, you want to update its weights.
func (*Model) WithInputsAsConstants ¶
WithInputsAsConstants marks inputs to be considered as constants, and not vary for different examples in training or inference. Use this just immediately after the creation of the Model. Later changes can cause inconsistencies.
This makes them become constants in the graph, and they shouldn't be passed to CallGraph as inputs.
The value each input maps to will be converted to a tensors.FromAnyValue.
func (*Model) Write ¶ added in v0.1.3
Write will write the ONNX model to the given writer (usually a file).
This is useful, if the model variables were updated (e.g.: fine-tuning in GoMLX) and one wants to save the model. See ContextToONNX to copy over the variables in GoMLX's Context (presumably after some training/update) to the ONNX's model proto.
See also Model.SaveToFile.