cnn

package
v0.0.0-...-da9aff4 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Oct 7, 2020 License: MIT Imports: 5 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Network

type Network struct {
	// contains filtered or unexported fields
}

func New

func New(inputDims []int, learningRate float64, loss metrics.LossFunction) *Network

func (*Network) AddConvolutionLayer

func (n *Network) AddConvolutionLayer(filterDimensions []int, filterCount int) *Network

func (*Network) AddFullyConnectedLayer

func (n *Network) AddFullyConnectedLayer(outputLength int) *Network

func (*Network) AddMaxPoolingLayer

func (n *Network) AddMaxPoolingLayer(stride int, dimensions []int) *Network

func (*Network) AddReLULayer

func (n *Network) AddReLULayer() *Network

func (*Network) AddSoftmaxLayer

func (n *Network) AddSoftmaxLayer() *Network

func (*Network) Fit

func (n *Network) Fit(inputs, labels, valInputs, valLabels []maths.Tensor, epochs int, batchSize int, verbose bool, logRate int, onEpochDone func())

Fit will train the CNN. inputs are the inputs, labels are the labels. epochs are the amount of times the network is fitted if valInputs and valLabels != nil a validation step is ran on that data after each epoch batchSize is the size of every propagation batch. if verbose then logging is enabled and is written to to stdout with fmt every 'logRate' of iterations a message is written when verbose == true onBatchDone is a callback that is called every time a batch is done. This can be used to reduce the learning rate for example

func (*Network) LearningRate

func (n *Network) LearningRate() float64

func (*Network) Predict

func (n *Network) Predict(input maths.Tensor) []float64

Returns a slice of probabilities

func (*Network) PredictIndex

func (n *Network) PredictIndex(input maths.Tensor) int

Returns the highest index from the prediction

func (*Network) SetLearningRate

func (n *Network) SetLearningRate(rate float64)

func (*Network) Validate

func (n *Network) Validate(inputs []maths.Tensor, labels []maths.Tensor)

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL