network

package
v0.0.0-...-f96ffc0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jul 11, 2015 License: GPL-3.0 Imports: 10 Imported by: 0

Documentation

Overview

Package network implements feedforward neural nets with training using stochastic gradient descent.

Index

Constants

This section is empty.

Variables

View Source
var SamplerNames = []string{"uniform", "random"}

Functions

func DataSets

func DataSets() (s []string)

Data sets function lists all the registered models.

func Init

func Init(imp blas.Impl)

Init function initialises the package and set the matrix implementation.

func Load

func Load(name string, samples int) (cfg *Config, net *Network, d *Dataset, err error)

Load function loads a data set, creates the network and returns the default config

func Register

func Register(name string, l Loader)

Register function is called on initialisation to make a new dataset available.

func StopCriteria

func StopCriteria(cfg *Config) func(*Stats) (done, failed bool)

Stop criteria function returns a function to check if training is complete.

Types

type Activation

type Activation struct {
	Func  blas.UnaryFunction
	Deriv blas.UnaryFunction
}

Activation type represents the activation function and derivative

var (
	Linear  = Activation{linear{}, nil}
	Sigmoid Activation
	Tanh    Activation
	Relu    Activation
	Softmax Activation
)

Standard activation functions

type Config

type Config struct {
	MaxRuns     int     // number of runs: required
	MaxEpoch    int     // maximum epoch: required
	LearnRate   float32 // learning rate eta: required
	WeightDecay float32 // weight decay epsilon
	Momentum    float32 // momentum term used in weight updates
	Threshold   float32 // target cost threshold
	BatchSize   int     // minibatch size
	StopAfter   int     // stop after n epochs with no improvement
	LogEvery    int     // log stats every n epochs
	Sampler     string  // sampler to use
	Distortion  float32 // distortion severity
}

func (*Config) Print

func (c *Config) Print()

type Data

type Data struct {
	Input      blas.Matrix
	Output     blas.Matrix
	Classes    blas.Matrix
	NumSamples int
}

Data type represents a set of test or training data.

func LoadFile

func LoadFile(filename string, samples int, out2class blas.UnaryFunction) (d *Data, nin, nout int, err error)

LoadFile function reads a dataset from a text file. samples is maxiumum number of records to load from each dataset if non-zero.

func (*Data) Release

func (d *Data) Release()

func (*Data) String

func (d *Data) String() string

type Dataset

type Dataset struct {
	Load          Loader
	OutputToClass blas.UnaryFunction
	Test          *Data
	Train         *Data
	Valid         *Data
	NumInputs     int
	NumOutputs    int
	MaxSamples    int
}

Dataset type represents a set of test, training and validation data

func (*Dataset) Release

func (d *Dataset) Release()

Release method frees any allocated resources

type Distortion

type Distortion struct {
	Mask int
	Name string
}

Distortion type is used for a bitmask of supported distortions

type Layer

type Layer interface {
	Dims() []int
	Values() blas.Matrix
	FeedForward(in blas.Matrix) blas.Matrix
	BackProp(err blas.Matrix, momentum float32) blas.Matrix
	Weights() blas.Matrix
	Gradient() blas.Matrix
	Cost(t blas.Matrix) blas.Matrix
	Release()
}

Layer interface type represents one layer in the network.

type Loader

type Loader interface {
	Load(samples int) (*Dataset, error)
	Config() *Config
	CreateNetwork(cfg *Config, d *Dataset) *Network
	DistortTypes() []Distortion
	Distort(in, out blas.Matrix, mask int, severity float32)
	Release()
	Debug(on bool)
}

Loader interface is used to load a new dataset

func GetLoader

func GetLoader(name string) (l Loader, ok bool)

GetLoader function looks up a loader by name.

type Network

type Network struct {
	Nodes     []Layer
	Layers    int
	BatchSize int
	Verbose   bool
	// contains filtered or unexported fields
}

Neural network type is an array of layers.

func New

func New(samples int, out2class blas.UnaryFunction) *Network

New function initialises a new network, samples is the maximum number of samples, i.e. minibatch size.

func (*Network) AddCrossEntropyOutput

func (n *Network) AddCrossEntropyOutput(nodes int)

AddCrossEntropyOutput method appends a cross entropy output layer with softmax activation to the network.

func (*Network) AddLayer

func (n *Network) AddLayer(dims []int, nout int, a Activation)

AddLayer method adds a new input or hidden layer to the network.

func (*Network) AddQuadraticOutput

func (n *Network) AddQuadraticOutput(nodes int, a Activation)

AddQuadraticOutput method appends a quadratic cost output layer to the network.

func (*Network) CheckGradient

func (n *Network) CheckGradient(nepochs int, maxError float32, samples int, scale float32)

CheckGradient method enables gradient check every nepochs runs with max error of maxError. if samples is > 0 then limit to this number of samples per layer. scale fudge factor shouldn't be needed!

func (*Network) Classify

func (n *Network) Classify(output blas.Matrix) blas.Matrix

Classify method returns a column vector with classified output

func (*Network) FeedForward

func (n *Network) FeedForward(m blas.Matrix) blas.Matrix

FeedForward method calculates output from the network given input

func (*Network) GetError

func (n *Network) GetError(samples int, d *Data, hist *vec.Vector, hmax float32) (totalErr, classErr float32)

GetError method calculates the error and classification error given a set of inputs and target outputs. samples parameter is the maximum number of samples to check.

func (*Network) Release

func (n *Network) Release()

Release method frees up any resources used by the network.

func (*Network) SetRandomWeights

func (n *Network) SetRandomWeights()

SetRandomWeights method initalises the weights to random values and sets the gradients to zero. Uses a normal distribution with mean zero and std dev 1/sqrt(num_inputs) for the weights. Bias weights are left at zero.

func (*Network) String

func (n *Network) String() string

String method returns a printable representation of the network.

func (*Network) Train

func (n *Network) Train(s *Stats, d *Dataset, cfg *Config)

Train method trains the network on the given training set for one epoch.

func (*Network) TrainStep

func (n *Network) TrainStep(epoch, batch, samples int, eta, lambda, momentum float32)

Train step method performs one training step. eta is the learning rate, lambda is the weight decay.

type Sampler

type Sampler interface {
	Init(samples, batchSize int) Sampler
	Next() bool
	Sample(in, out blas.Matrix)
	Release()
}

Sampler interface is used to split the data set into batches.

func NewSampler

func NewSampler(typ string) Sampler

NewSampler function creates a new sampler of the given type.

type Stats

type Stats struct {
	Epoch      int
	Runs       int
	RunSuccess int
	StartEpoch time.Time
	EpochTime  time.Duration
	TotalTime  time.Duration
	Test       *StatsData
	Train      *StatsData
	Valid      *StatsData
	RunTime    *vec.RunningStat
	RegError   *vec.RunningStat
	ClsError   *vec.RunningStat
}

Stats struct has matrix with error on each set over time

func NewStats

func NewStats() *Stats

NewStats function returns a new stats struct.

func (*Stats) EndRun

func (s *Stats) EndRun(failed bool) string

EndRun method updates per run statistics and returns the stats.

func (*Stats) History

func (s *Stats) History() string

History method returns historical statistics

func (*Stats) Reset

func (s *Stats) Reset()

Reset method resets all the stats

func (*Stats) StartRun

func (s *Stats) StartRun()

StartRun method resets the stats vectors for this run and starts the timer.

func (*Stats) String

func (s *Stats) String() string

String method prints the stats for logging.

func (*Stats) Update

func (s *Stats) Update(n *Network, d *Dataset)

Update method calculates the error and updates the stats.

type StatsData

type StatsData struct {
	Error      *vec.Vector
	ClassError *vec.Vector
	ErrorHist  *vec.Vector
	HistMax    float32
}

StatsData stores vectors with the errors and classification errors

func (*StatsData) String

func (d *StatsData) String() string

Directories

Path Synopsis
Package iris loads the iris dataset from file.
Package iris loads the iris dataset from file.
Package mnist loads the MNist dataset of handwritten digits.
Package mnist loads the MNist dataset of handwritten digits.
Package xor loads the dataset for an xor gate.
Package xor loads the dataset for an xor gate.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL