gonet

package module
v0.0.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Apr 23, 2024 License: Apache-2.0 Imports: 3 Imported by: 0

README

Neural Networks

GoDoc Go Report Card

Examples

Feedforward

Examples

How to Build & Train
// Start by defining shapes of your data and construct the network.

nn := feedforward.New(
    // We want to model 3 variables XOR function, that's the first shape.
    // We want 4 nodes in first hidden layer, that's the second shape.
    // There will be 1 output (0 or 1), that's the third shape.
    feedforward.Shapes([]int{3, 4, 1}),

    // This is chosen based on prior knowledge of training data.
    // Sigmoid function squashes the output to the range [0, 1],
    // that makes it suitable for binary classification.
    feedforward.Activation(fns.Sigmoid),

    // backpropagation
    feedforward.ActivationDerivative(fns.SigmoidDerivative),

    // Learning rate, choose wisely
    feedforward.LearningRate(0.1),
)

// Print shapes

fmt.Println(nn.String())
// Shapes: [3 4 1]
// Hidden Layers: 1

// Define training data

inputs := [][]float64{
    {0, 0, 0},
    {0, 0, 1},
    {0, 1, 0},
    {0, 1, 1},
    {1, 0, 0},
    {1, 0, 1},
    {1, 1, 0},
    {1, 1, 1},
}

targets := [][]float64{
    {0},
    {1},
    {1},
    {1},
    {1},
    {1},
    {1},
    {0},
}

// Train/fit the function

help.Train(context.TODO(), nn, 100000, inputs, targets)
// Epoch 0000, Loss: 0.214975
// Epoch:(0) Inputs:(8) Duration:(72.166µs)
// Epoch 10000, Loss: 0.004950
// Epoch:(10000) Inputs:(8) Duration:(2.542µs)
// Epoch 20000, Loss: 0.001701
// Epoch:(20000) Inputs:(8) Duration:(2µs)
// Epoch 30000, Loss: 0.001014
// Epoch:(30000) Inputs:(8) Duration:(1.834µs)
// ...
// ...

// Let's predict

fmt.Println(nn.Predict([]float64{0, 0, 0})) // [0.02768982884099321]
fmt.Println(nn.Predict([]float64{1, 1, 0})) // [0.9965961389721709]
fmt.Println(nn.Predict([]float64{1, 1, 1})) // [0.012035375150277857]
How to Save & Resume
// Saving a network to disk allows for later loading to resume training and/or prediction.

// To save the Network, you would invoke Save() on it and provide an io.Writer as an argument.
if err := nn.Save(w); err != nil {
    // do something with error
}

// Alternatively, you can call help.Save, specifying a file name (which can include a file path).
if err := help.Save("bin/my-model", nn); err != nil {
    // do something with error
}
// A network can be loaded from disk to resume training and/or to predict.

// To load a previously saved network, you would call feedforward.Load and pass in an io.Reader. 
nn, err := feedforward.Load(
    r, // io.Reader
    feedforward.Activation(fns.Sigmoid), // set for your network
    // REQUIRED if resuming network training
    feedforward.ActivationDerivative(fns.SigmoidDerivative), // set for your network
    feedforward.LearningRate(0.01), // set for your network
)
if err != nil {
    // do something with error
}

// Alternatively, you can call help.LoadFeedforward, providing a file name (which can be a file path).
nn, err := help.LoadFeedforward(
    "bin/my-model",
    feedforward.Activation(fns.Sigmoid), // set for your network
    // REQUIRED if resuming network training
    feedforward.ActivationDerivative(fns.SigmoidDerivative), // set for your network
    feedforward.LearningRate(0.01), // set for your network
)
if err != nil {
    // do something with error
}

Wish List

  • Define activation function for each network layer
  • Persist activation function with saved network
  • Draw network

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Network

type Network interface {
	Train(epochs int, inputs, outputs [][]float64, callback func(int) bool)
	Predict(inputs []float64) []float64
	TrainingDuration() time.Duration
	EpochStats(epoch int) stats.Epoch
	Save(w io.Writer) error
	String() string
}

type NetworkOpt

type NetworkOpt func(*networkOpts)

func Activation

func Activation(v func(float64) float64) NetworkOpt

func ActivationDerivative

func ActivationDerivative(v func(float64) float64) NetworkOpt

func HiddenSize

func HiddenSize(v int) NetworkOpt

func InputSize

func InputSize(v int) NetworkOpt

func LearningRate

func LearningRate(v float64) NetworkOpt

func LoadFrom

func LoadFrom(v []byte) NetworkOpt

func OutputSize

func OutputSize(v int) NetworkOpt

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL