Documentation ¶
Index ¶
- func PrintTensors(inputs ...*Tensor)
- type DType
- type Error
- type JITModule
- func (m *JITModule) Forward(inputs ...interface{}) (interface{}, error)
- func (m *JITModule) GetMethod(method string) (*JITModuleMethod, error)
- func (m *JITModule) GetMethodNames() []string
- func (m *JITModule) RunMethod(method string, inputs ...interface{}) (interface{}, error)
- func (m *JITModule) Save(path string) error
- type JITModuleMethod
- type JITModuleMethodArgument
- type Tensor
- type Tuple
Examples ¶
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
Types ¶
type DType ¶
type DType C.Torch_DataType
DType tensor scalar data type
const ( // Byte byte tensors (go type uint8) Byte DType = C.Torch_Byte // Char char tensor (go type int8) Char DType = C.Torch_Char // Int int tensor (go type int32) Int DType = C.Torch_Int // Long long tensor (go type int64) Long DType = C.Torch_Long // Float tensor (go type float32) Float DType = C.Torch_Float // Double tensor (go type float64) Double DType = C.Torch_Double )
type Error ¶
type Error struct {
// contains filtered or unexported fields
}
Error errors returned by torch functions
type JITModule ¶
type JITModule struct {
// contains filtered or unexported fields
}
JITModule is a jit compiled PyTorch module
func CompileTorchScript ¶
CompileTorchScript compiles TorchScript and returns a *JITModule
Example ¶
module, _ := torch.CompileTorchScript(` def sum(a, b): return a + b `) a, _ := torch.NewTensor([]float32{1}) b, _ := torch.NewTensor([]float32{2}) result, _ := module.RunMethod("sum", a, b) fmt.Printf("[1] + [2] = %+v\n", result.(*torch.Tensor).Value())
Output: [1] + [2] = [3]
func LoadJITModule ¶
LoadJITModule loads module from file
func (*JITModule) GetMethod ¶
func (m *JITModule) GetMethod(method string) (*JITModuleMethod, error)
GetMethod returns a method from a JITModule
func (*JITModule) GetMethodNames ¶
GetMethodNames returns all method names from the module
type JITModuleMethod ¶
type JITModuleMethod struct { Module *JITModule Name string // contains filtered or unexported fields }
JITModuleMethod is single method from a JITModule
func (*JITModuleMethod) Arguments ¶
func (m *JITModuleMethod) Arguments() []JITModuleMethodArgument
Arguments returns method arguments for the method schema
func (*JITModuleMethod) Returns ¶
func (m *JITModuleMethod) Returns() []JITModuleMethodArgument
Returns returns method return type information for the method schema
func (*JITModuleMethod) Run ¶
func (m *JITModuleMethod) Run(inputs ...interface{}) (interface{}, error)
Run executes given method with tensors as input
type JITModuleMethodArgument ¶
JITModuleMethodArgument contains information of a single method argument
type Tensor ¶
type Tensor struct {
// contains filtered or unexported fields
}
Tensor holds a multi-dimensional array of elements of a single data type.
func NewTensor ¶
NewTensor converts from a Go value to a Tensor. Valid values are scalars, slices, and arrays. Every element of a slice must have the same length so that the resulting Tensor has a valid shape.
func NewTensorWithShape ¶
NewTensorWithShape converts a single dimensional Go array or slice into a Tensor with given shape