nn

package
v0.0.0-...-9afed2f Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Oct 28, 2020 License: MIT Imports: 7 Imported by: 2

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type BatchNorm2dModule

type BatchNorm2dModule struct {
	Module
	NumFeatures       int64
	Eps               float64
	Momentum          float64
	Affine            bool
	TrackRunningStats bool
	Weight            torch.Tensor
	Bias              torch.Tensor
	RunningMean       torch.Tensor `gotorch:"buffer"`
	RunningVar        torch.Tensor `gotorch:"buffer"`
}

BatchNorm2dModule torch.nn.BatchNorm2d

func BatchNorm2d

func BatchNorm2d(numFeatures int64, eps, momentum float64,
	affine, trackRunningStats bool) *BatchNorm2dModule

BatchNorm2d creates a `BatchNorm2dModule` instance

func (*BatchNorm2dModule) Forward

func (b *BatchNorm2dModule) Forward(x torch.Tensor) torch.Tensor

Forward method

type Conv2dModule

type Conv2dModule struct {
	Module
	InChannels  int64
	OutChannels int64
	KernelSize  int64
	Stride      int64
	Padding     int64
	Dilation    int64
	Groups      int64
	PaddingMode string
	Weight      torch.Tensor
	Bias        torch.Tensor
}

Conv2dModule applies convolution over a 2D input.

func Conv2d

func Conv2d(inChannels, outChannels, kernelSize, stride, padding, dilation,
	groups int64, bias bool, paddingMode string) *Conv2dModule

Conv2d creates a `Conv2dModule` instance TODO(qijun): only support zero padding mode only support symmetry kernel/stride/padding/dilation

func (*Conv2dModule) Forward

func (c *Conv2dModule) Forward(x torch.Tensor) torch.Tensor

Forward method

type ConvTranspose2dModule

type ConvTranspose2dModule struct {
	Module
	InChannels  int64
	OutChannels int64
	KernelSize  int64
	Stride      int64
	Padding     int64
	OutPadding  int64
	Groups      int64
	Dilation    int64
	PaddingMode string
	Weight      torch.Tensor
	Bias        torch.Tensor
}

ConvTranspose2dModule corresponds to torch.nn.ConvTranspose2d

func ConvTranspose2d

func ConvTranspose2d(inChannels, outChannels, kernelSize, stride, padding,
	outPadding, groups int64, bias bool, dilation int64, paddingMode string) *ConvTranspose2dModule

ConvTranspose2d torch.nn.conv_transpose2d TODO(qijun): only support zero padding mode only support symmetry kernel/stride/padding/dilation not support output_size when forwarding

func (*ConvTranspose2dModule) Forward

Forward method

type FunctionalModule

type FunctionalModule struct {
	Module
	Function func(torch.Tensor) torch.Tensor
}

FunctionalModule wraps a function in a `Module`.

func Functional

func Functional(f func(torch.Tensor) torch.Tensor) *FunctionalModule

Functional returns a new `Functional` instance.

func (*FunctionalModule) Forward

func (f *FunctionalModule) Forward(input torch.Tensor) torch.Tensor

Forward feeds the `input` tensor to the underlying function.

type IModule

type IModule interface {
	// Train corresponds to torch.nn.Module.train(bool). It effects only
	// certain modules like Dropout and BatchNorm.
	Train(on bool)
	// IsTraining returns true if the module is in training mode
	IsTraining() bool
	// To corresponds to torch.nn.Module.to().  It recursively casts all
	// parameters to the given `dtype` and `device`.
	To(device torch.Device, dtype ...int8)
	// StateDict mimics torch.nn.Module.state_dict()
	StateDict() map[string]torch.Tensor
	// SetStateDict mimics torch.nn.Module.set_state_dict()
	SetStateDict(sd map[string]torch.Tensor) error
	// Apply function recursively to each module
	Apply(f func(IModule))
	// Name returns module type name
	Name() string
}

IModule is the interface of `Module`s

type LinearModule

type LinearModule struct {
	Module
	InFeatures  int64
	OutFeatures int64
	Weight      torch.Tensor
	Bias        torch.Tensor
}

LinearModule applies a linear transformation with optional bias.

func Linear

func Linear(in, out int64, bias bool) *LinearModule

Linear creates a `Linear` instance

func (*LinearModule) Forward

func (l *LinearModule) Forward(x torch.Tensor) torch.Tensor

Forward does a linear transformation to the `input` tensor.

type Module

type Module struct {
	// contains filtered or unexported fields
}

Module contains default implementation of `Module`s

func (*Module) Apply

func (m *Module) Apply(function func(IModule))

Apply function recursively to each module

func (*Module) Buffers

func (m *Module) Buffers() []torch.Tensor

Buffers returns parameters (recursively) that are not trainable

func (*Module) Init

func (m *Module) Init(outer IModule)

Init initializes a `Module`, using a `Module` that's not `Init`ed will panic Example:

type MyModel struct {
	Module
}
func NewMyModule() *MyModel {
	r := &MyModel{}
	r.Init(r)
	return r
}

func (*Module) IsTraining

func (m *Module) IsTraining() bool

IsTraining returns true if the module is in training mode

func (*Module) Name

func (m *Module) Name() string

Name returns module name

func (*Module) NamedBuffers

func (m *Module) NamedBuffers() map[string]torch.Tensor

NamedBuffers returns parameters (recursively) that are not trainable, with their names

func (*Module) NamedParameters

func (m *Module) NamedParameters() map[string]torch.Tensor

NamedParameters returns trainable parameters (recursively) with their names

func (*Module) Outer

func (m *Module) Outer() IModule

Outer returns module outer

func (*Module) Parameters

func (m *Module) Parameters() []torch.Tensor

Parameters returns trainable parameters (recursively)

func (*Module) SetStateDict

func (m *Module) SetStateDict(sd map[string]torch.Tensor) error

SetStateDict sets the module's all tensor fields to values in sd.

func (*Module) StateDict

func (m *Module) StateDict() map[string]torch.Tensor

StateDict mimics torch.Module.state_dict(), which returns parameters and buffers with their (unique) names.

func (*Module) To

func (m *Module) To(device torch.Device, dtype ...int8)

To recursively casts all parameters to the given `dtype` and `device`.

func (*Module) Train

func (m *Module) Train(on bool)

Train enables "training" mode

type SequentialModule

type SequentialModule struct {
	Module
	Modules []IModule
}

SequentialModule is a list of `Module`s that acts as a `Module` itself.

func Sequential

func Sequential(modules ...IModule) *SequentialModule

Sequential returns a new `SequentialModule` instance.

func (*SequentialModule) Forward

func (s *SequentialModule) Forward(inputs ...interface{}) interface{}

Forward feeds `inputs` to the first module and then chains outputs to inputs, returning the last output.

type Visitor

type Visitor func(f reflect.StructField, v reflect.Value, prefix string, noSuffix bool) error

Visitor is a function type supposed to be called by visitTensors. The parameter f and v are a tensor-typed field in a given module. Returning non-nil eror breaks the recursive visiting process.

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL