Documentation ¶
Overview ¶
Package inference allows users to do inference through tflite (tf, pytorch, etc in the future)
Index ¶
Constants ¶
const ( UInt8 = InTensorType("UInt8") Float32 = InTensorType("Float32") )
UInt8 and Float32 are the currently supported input tensor types.
Variables ¶
This section is empty.
Functions ¶
func FailedToGetError ¶
FailedToGetError is the default error message for when expected information will be fetched fails.
func FailedToLoadError ¶
FailedToLoadError is the default error message for when expected resources for inference fail to load.
func MetadataDoesNotExistError ¶
func MetadataDoesNotExistError() error
MetadataDoesNotExistError returns a metadata does not exist error.
Types ¶
type InTensorType ¶
type InTensorType string
InTensorType is a wrapper around a string that details the allowed input tensor types.
type Interpreter ¶
type Interpreter interface { AllocateTensors() tflite.Status Invoke() tflite.Status GetOutputTensorCount() int GetInputTensorCount() int GetInputTensor(i int) *tflite.Tensor GetOutputTensor(i int) *tflite.Tensor Delete() }
Interpreter interface holds methods used by a tflite interpreter.
type MLModel ¶
type MLModel interface { // Infer takes an already ordered input tensor as an array, // and makes an inference on the model, returning an output tensor map Infer(inputTensor interface{}) (config.AttributeMap, error) // GetMetadata gets the entire model metadata structure from file GetMetadata() (interface{}, error) // Close closes the model and interpreter that allows inferences to be made, opens up space in memory. // All models must be closed when done using Close() error }
MLModel represents a trained machine learning model.
type TFLiteInfo ¶
type TFLiteInfo struct { InputHeight int InputWidth int InputChannels int InputShape []int InputTensorType InTensorType InputTensorCount int OutputTensorCount int OutputTensorTypes []string }
TFLiteInfo holds information about a model that are useful for creating input tensors bytes.
type TFLiteModelLoader ¶
type TFLiteModelLoader struct {
// contains filtered or unexported fields
}
TFLiteModelLoader holds functions that sets up a tflite model to be used.
func NewDefaultTFLiteModelLoader ¶
func NewDefaultTFLiteModelLoader() (*TFLiteModelLoader, error)
NewDefaultTFLiteModelLoader returns the default loader when using tflite.
func NewTFLiteModelLoader ¶
func NewTFLiteModelLoader(numThreads int) (*TFLiteModelLoader, error)
NewTFLiteModelLoader returns a loader that allows you to set threads when using tflite.
func (TFLiteModelLoader) Load ¶
func (loader TFLiteModelLoader) Load(modelPath string) (*TFLiteStruct, error)
Load returns a TFLite struct that is ready to be used for inferences.
type TFLiteStruct ¶
type TFLiteStruct struct { Info *TFLiteInfo // contains filtered or unexported fields }
TFLiteStruct holds information, model and interpreter of a tflite model in go.
func (*TFLiteStruct) Close ¶
func (model *TFLiteStruct) Close() error
Close should be called at the end of using the interpreter to delete related models and interpreters.
func (*TFLiteStruct) GetMetadata ¶
func (model *TFLiteStruct) GetMetadata() (*metadata.ModelMetadataT, error)
GetMetadata provides the metadata information based on the model flatbuffer file.
func (*TFLiteStruct) Infer ¶
func (model *TFLiteStruct) Infer(inputTensor interface{}) ([]interface{}, error)
Infer takes an input array in desired type and returns an array of the output tensors.