torch

package module
v0.0.0-...-5e92bb2 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Feb 25, 2019 License: Apache-2.0 Imports: 7 Imported by: 0

README

WORK IN PROGRESS... USE AT OWN RISK :-)

Build Status GoDoc

go-torch

LibTorch (PyTorch) bindings for Golang. Library is first and foremost designed for running inference against serialized models exported from Python version of PyTorch. Library can also be used to compile TorchScript applications directly from Go.

Installing

$ go get github.com/orktes/go-torch

Usage

go-torch depends on the LibTorch shared library to be available. For more information refer to https://pytorch.org/cppdocs/. The is also an example Dockerfile which is used for executing tests for the library.

import (
    "github.com/orktes/go-torch"
)
Creating Tensors

Supported scalar types:

  • torch.Byte uint8
  • torch.Char int8
  • torch.Int int32
  • torch.Long int64
  • torch.Float float32
  • torch.Double float64

matrix := []float32{
    []float32{1,2,3},
    []float32{4,5,6},
}
tensor, _ := torch.NewTensor(matrix)
tensor.Shape() // [2, 3]
tensor.DType() // torch.Float
Using serialized PyTorch models

For instructions on how to export models for PyTorch refer to the PyTorch documentation

// Load model
module, _ := torch.LoadJITModule("model.pt")

// Create an input tensor
inputTensor, _ := torch.NewTensor([][]float32{
    []float32{1, 2, 3},
})

// Forward propagation
res, _ := module.Forward(inputTensor)

Using TorchScript

TorchScript documentation

Currently supported input and output types

  • Tensor
  • Tuple (of Tensor and/or nested Tuples)
sumScript = `
def sum(a, b):
    return a + b
`

// Compile TorchScript
module, _ := torch.CompileTorchScript(sumScript)

// Create inputs
a, _ := torch.NewTensor([]float32{1})
b, _ := torch.NewTensor([]float32{2})

res, _ := module.RunMethod("sum", a, b)
fmt.Printf("[1] + [2] = %+v\n", res.(*torch.Tensor).Value())
// output: [1] + [2] = [3]

Acknowledgements

Lots of the functionality related to converting Golang types to PyTorch Tensors are a shameless copy on what Google is doing with their Go Tensorflow bindings. Therefore big part of the credit definetely goes to The TensorFlow Authors.

LICENSE

See here

Documentation

Index

Examples

Constants

This section is empty.

Variables

This section is empty.

Functions

func PrintTensors

func PrintTensors(inputs ...*Tensor)

PrintTensors prints tensors contents

Types

type DType

type DType C.Torch_DataType

DType tensor scalar data type

const (
	// Byte byte tensors (go type uint8)
	Byte DType = C.Torch_Byte
	// Char char tensor (go type int8)
	Char DType = C.Torch_Char
	// Int int tensor (go type int32)
	Int DType = C.Torch_Int
	// Long long tensor (go type int64)
	Long DType = C.Torch_Long
	// Float tensor (go type float32)
	Float DType = C.Torch_Float
	// Double tensor  (go type float64)
	Double DType = C.Torch_Double
)

type Error

type Error struct {
	// contains filtered or unexported fields
}

Error errors returned by torch functions

func (*Error) Error

func (te *Error) Error() string

type JITModule

type JITModule struct {
	// contains filtered or unexported fields
}

JITModule is a jit compiled PyTorch module

func CompileTorchScript

func CompileTorchScript(torchScript string) (*JITModule, error)

CompileTorchScript compiles TorchScript and returns a *JITModule

Example
module, _ := torch.CompileTorchScript(`
		def sum(a, b):
			return a + b
	`)

a, _ := torch.NewTensor([]float32{1})
b, _ := torch.NewTensor([]float32{2})

result, _ := module.RunMethod("sum", a, b)
fmt.Printf("[1] + [2] = %+v\n", result.(*torch.Tensor).Value())
Output:

[1] + [2] = [3]

func LoadJITModule

func LoadJITModule(path string) (*JITModule, error)

LoadJITModule loads module from file

func (*JITModule) Forward

func (m *JITModule) Forward(inputs ...interface{}) (interface{}, error)

Forward exectures forward method of the module (forward propagation)

func (*JITModule) GetMethod

func (m *JITModule) GetMethod(method string) (*JITModuleMethod, error)

GetMethod returns a method from a JITModule

func (*JITModule) GetMethodNames

func (m *JITModule) GetMethodNames() []string

GetMethodNames returns all method names from the module

func (*JITModule) RunMethod

func (m *JITModule) RunMethod(method string, inputs ...interface{}) (interface{}, error)

RunMethod executes given method with tensors or tuples as input

func (*JITModule) Save

func (m *JITModule) Save(path string) error

Save saves Module to given path

type JITModuleMethod

type JITModuleMethod struct {
	Module *JITModule
	Name   string
	// contains filtered or unexported fields
}

JITModuleMethod is single method from a JITModule

func (*JITModuleMethod) Arguments

func (m *JITModuleMethod) Arguments() []JITModuleMethodArgument

Arguments returns method arguments for the method schema

func (*JITModuleMethod) Returns

Returns returns method return type information for the method schema

func (*JITModuleMethod) Run

func (m *JITModuleMethod) Run(inputs ...interface{}) (interface{}, error)

Run executes given method with tensors as input

type JITModuleMethodArgument

type JITModuleMethodArgument struct {
	Name string
	Type string
}

JITModuleMethodArgument contains information of a single method argument

type Tensor

type Tensor struct {
	// contains filtered or unexported fields
}

Tensor holds a multi-dimensional array of elements of a single data type.

func NewTensor

func NewTensor(value interface{}) (*Tensor, error)

NewTensor converts from a Go value to a Tensor. Valid values are scalars, slices, and arrays. Every element of a slice must have the same length so that the resulting Tensor has a valid shape.

func NewTensorWithShape

func NewTensorWithShape(value interface{}, shape []int64, dt DType) (*Tensor, error)

NewTensorWithShape converts a single dimensional Go array or slice into a Tensor with given shape

func (*Tensor) DType

func (t *Tensor) DType() DType

DType returns tensors datatype

func (*Tensor) Shape

func (t *Tensor) Shape() []int64

Shape returns tensors shape

func (*Tensor) Value

func (t *Tensor) Value() interface{}

Value returns tensors value as a go type

type Tuple

type Tuple []interface{}

Tuple a tuple type

func NewTuple

func NewTuple(vals ...interface{}) (Tuple, error)

NewTuple returns a new tuple for given values (go types, torch.Tensor, torch.Tuple)

func (Tuple) Get

func (t Tuple) Get(index int) interface{}

Get returns a type in specific tuple index (otherwise returns nil)

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL