Documentation ¶
Overview ¶
Package network implements feedforward neural nets with training using stochastic gradient descent.
Index ¶
- Variables
- func DataSets() (s []string)
- func Init(imp blas.Impl)
- func Load(name string, samples int) (cfg *Config, net *Network, d *Dataset, err error)
- func Register(name string, l Loader)
- func StopCriteria(cfg *Config) func(*Stats) (done, failed bool)
- type Activation
- type Config
- type Data
- type Dataset
- type Distortion
- type Layer
- type Loader
- type Network
- func (n *Network) AddCrossEntropyOutput(nodes int)
- func (n *Network) AddLayer(dims []int, nout int, a Activation)
- func (n *Network) AddQuadraticOutput(nodes int, a Activation)
- func (n *Network) CheckGradient(nepochs int, maxError float32, samples int, scale float32)
- func (n *Network) Classify(output blas.Matrix) blas.Matrix
- func (n *Network) FeedForward(m blas.Matrix) blas.Matrix
- func (n *Network) GetError(samples int, d *Data, hist *vec.Vector, hmax float32) (totalErr, classErr float32)
- func (n *Network) Release()
- func (n *Network) SetRandomWeights()
- func (n *Network) String() string
- func (n *Network) Train(s *Stats, d *Dataset, cfg *Config)
- func (n *Network) TrainStep(epoch, batch, samples int, eta, lambda, momentum float32)
- type Sampler
- type Stats
- type StatsData
Constants ¶
This section is empty.
Variables ¶
var SamplerNames = []string{"uniform", "random"}
Functions ¶
func StopCriteria ¶
Stop criteria function returns a function to check if training is complete.
Types ¶
type Activation ¶
type Activation struct { Func blas.UnaryFunction Deriv blas.UnaryFunction }
Activation type represents the activation function and derivative
var ( Linear = Activation{linear{}, nil} Sigmoid Activation Tanh Activation Relu Activation Softmax Activation )
Standard activation functions
type Config ¶
type Config struct { MaxRuns int // number of runs: required MaxEpoch int // maximum epoch: required LearnRate float32 // learning rate eta: required WeightDecay float32 // weight decay epsilon Momentum float32 // momentum term used in weight updates Threshold float32 // target cost threshold BatchSize int // minibatch size StopAfter int // stop after n epochs with no improvement LogEvery int // log stats every n epochs Sampler string // sampler to use Distortion float32 // distortion severity }
type Data ¶
Data type represents a set of test or training data.
type Dataset ¶
type Dataset struct { Load Loader OutputToClass blas.UnaryFunction Test *Data Train *Data Valid *Data NumInputs int NumOutputs int MaxSamples int }
Dataset type represents a set of test, training and validation data
type Distortion ¶
Distortion type is used for a bitmask of supported distortions
type Layer ¶
type Layer interface { Dims() []int Values() blas.Matrix FeedForward(in blas.Matrix) blas.Matrix BackProp(err blas.Matrix, momentum float32) blas.Matrix Weights() blas.Matrix Gradient() blas.Matrix Cost(t blas.Matrix) blas.Matrix Release() }
Layer interface type represents one layer in the network.
type Loader ¶
type Loader interface { Load(samples int) (*Dataset, error) Config() *Config CreateNetwork(cfg *Config, d *Dataset) *Network DistortTypes() []Distortion Distort(in, out blas.Matrix, mask int, severity float32) Release() Debug(on bool) }
Loader interface is used to load a new dataset
type Network ¶
type Network struct { Nodes []Layer Layers int BatchSize int Verbose bool // contains filtered or unexported fields }
Neural network type is an array of layers.
func New ¶
func New(samples int, out2class blas.UnaryFunction) *Network
New function initialises a new network, samples is the maximum number of samples, i.e. minibatch size.
func (*Network) AddCrossEntropyOutput ¶
AddCrossEntropyOutput method appends a cross entropy output layer with softmax activation to the network.
func (*Network) AddLayer ¶
func (n *Network) AddLayer(dims []int, nout int, a Activation)
AddLayer method adds a new input or hidden layer to the network.
func (*Network) AddQuadraticOutput ¶
func (n *Network) AddQuadraticOutput(nodes int, a Activation)
AddQuadraticOutput method appends a quadratic cost output layer to the network.
func (*Network) CheckGradient ¶
CheckGradient method enables gradient check every nepochs runs with max error of maxError. if samples is > 0 then limit to this number of samples per layer. scale fudge factor shouldn't be needed!
func (*Network) FeedForward ¶
FeedForward method calculates output from the network given input
func (*Network) GetError ¶
func (n *Network) GetError(samples int, d *Data, hist *vec.Vector, hmax float32) (totalErr, classErr float32)
GetError method calculates the error and classification error given a set of inputs and target outputs. samples parameter is the maximum number of samples to check.
func (*Network) Release ¶
func (n *Network) Release()
Release method frees up any resources used by the network.
func (*Network) SetRandomWeights ¶
func (n *Network) SetRandomWeights()
SetRandomWeights method initalises the weights to random values and sets the gradients to zero. Uses a normal distribution with mean zero and std dev 1/sqrt(num_inputs) for the weights. Bias weights are left at zero.
type Sampler ¶
type Sampler interface { Init(samples, batchSize int) Sampler Next() bool Sample(in, out blas.Matrix) Release() }
Sampler interface is used to split the data set into batches.
func NewSampler ¶
NewSampler function creates a new sampler of the given type.
type Stats ¶
type Stats struct { Epoch int Runs int RunSuccess int StartEpoch time.Time EpochTime time.Duration TotalTime time.Duration Test *StatsData Train *StatsData Valid *StatsData RunTime *vec.RunningStat RegError *vec.RunningStat ClsError *vec.RunningStat }
Stats struct has matrix with error on each set over time
Directories ¶
Path | Synopsis |
---|---|
Package iris loads the iris dataset from file.
|
Package iris loads the iris dataset from file. |
Package mnist loads the MNist dataset of handwritten digits.
|
Package mnist loads the MNist dataset of handwritten digits. |
Package xor loads the dataset for an xor gate.
|
Package xor loads the dataset for an xor gate. |