Documentation ¶
Index ¶
- type BatchNorm2dModule
- type Conv2dModule
- type ConvTranspose2dModule
- type FunctionalModule
- type IModule
- type LinearModule
- type Module
- func (m *Module) Apply(function func(IModule))
- func (m *Module) Buffers() []torch.Tensor
- func (m *Module) Init(outer IModule)
- func (m *Module) IsTraining() bool
- func (m *Module) Name() string
- func (m *Module) NamedBuffers() map[string]torch.Tensor
- func (m *Module) NamedParameters() map[string]torch.Tensor
- func (m *Module) Outer() IModule
- func (m *Module) Parameters() []torch.Tensor
- func (m *Module) SetStateDict(sd map[string]torch.Tensor) error
- func (m *Module) StateDict() map[string]torch.Tensor
- func (m *Module) To(device torch.Device, dtype ...int8)
- func (m *Module) Train(on bool)
- type SequentialModule
- type Visitor
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
This section is empty.
Types ¶
type BatchNorm2dModule ¶
type BatchNorm2dModule struct { Module NumFeatures int64 Eps float64 Momentum float64 Affine bool TrackRunningStats bool Weight torch.Tensor Bias torch.Tensor RunningMean torch.Tensor `gotorch:"buffer"` RunningVar torch.Tensor `gotorch:"buffer"` }
BatchNorm2dModule torch.nn.BatchNorm2d
func BatchNorm2d ¶
func BatchNorm2d(numFeatures int64, eps, momentum float64, affine, trackRunningStats bool) *BatchNorm2dModule
BatchNorm2d creates a `BatchNorm2dModule` instance
type Conv2dModule ¶
type Conv2dModule struct { Module InChannels int64 OutChannels int64 KernelSize int64 Stride int64 Padding int64 Dilation int64 Groups int64 PaddingMode string Weight torch.Tensor Bias torch.Tensor }
Conv2dModule applies convolution over a 2D input.
type ConvTranspose2dModule ¶
type ConvTranspose2dModule struct { Module InChannels int64 OutChannels int64 KernelSize int64 Stride int64 Padding int64 OutPadding int64 Groups int64 Dilation int64 PaddingMode string Weight torch.Tensor Bias torch.Tensor }
ConvTranspose2dModule corresponds to torch.nn.ConvTranspose2d
func ConvTranspose2d ¶
func ConvTranspose2d(inChannels, outChannels, kernelSize, stride, padding, outPadding, groups int64, bias bool, dilation int64, paddingMode string) *ConvTranspose2dModule
ConvTranspose2d torch.nn.conv_transpose2d TODO(qijun): only support zero padding mode only support symmetry kernel/stride/padding/dilation not support output_size when forwarding
type FunctionalModule ¶
FunctionalModule wraps a function in a `Module`.
func Functional ¶
func Functional(f func(torch.Tensor) torch.Tensor) *FunctionalModule
Functional returns a new `Functional` instance.
type IModule ¶
type IModule interface { // Train corresponds to torch.nn.Module.train(bool). It effects only // certain modules like Dropout and BatchNorm. Train(on bool) // IsTraining returns true if the module is in training mode IsTraining() bool // To corresponds to torch.nn.Module.to(). It recursively casts all // parameters to the given `dtype` and `device`. To(device torch.Device, dtype ...int8) // StateDict mimics torch.nn.Module.state_dict() StateDict() map[string]torch.Tensor // SetStateDict mimics torch.nn.Module.set_state_dict() SetStateDict(sd map[string]torch.Tensor) error // Apply function recursively to each module Apply(f func(IModule)) // Name returns module type name Name() string }
IModule is the interface of `Module`s
type LinearModule ¶
type LinearModule struct { Module InFeatures int64 OutFeatures int64 Weight torch.Tensor Bias torch.Tensor }
LinearModule applies a linear transformation with optional bias.
func Linear ¶
func Linear(in, out int64, bias bool) *LinearModule
Linear creates a `Linear` instance
type Module ¶
type Module struct {
// contains filtered or unexported fields
}
Module contains default implementation of `Module`s
func (*Module) Init ¶
Init initializes a `Module`, using a `Module` that's not `Init`ed will panic Example:
type MyModel struct { Module }
func NewMyModule() *MyModel { r := &MyModel{} r.Init(r) return r }
func (*Module) IsTraining ¶
IsTraining returns true if the module is in training mode
func (*Module) NamedBuffers ¶
NamedBuffers returns parameters (recursively) that are not trainable, with their names
func (*Module) NamedParameters ¶
NamedParameters returns trainable parameters (recursively) with their names
func (*Module) Parameters ¶
Parameters returns trainable parameters (recursively)
func (*Module) SetStateDict ¶
SetStateDict sets the module's all tensor fields to values in sd.
func (*Module) StateDict ¶
StateDict mimics torch.Module.state_dict(), which returns parameters and buffers with their (unique) names.
type SequentialModule ¶
SequentialModule is a list of `Module`s that acts as a `Module` itself.
func Sequential ¶
func Sequential(modules ...IModule) *SequentialModule
Sequential returns a new `SequentialModule` instance.
func (*SequentialModule) Forward ¶
func (s *SequentialModule) Forward(inputs ...interface{}) interface{}
Forward feeds `inputs` to the first module and then chains outputs to inputs, returning the last output.