nes

package
v0.1.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 13, 2023 License: Apache-2.0 Imports: 13 Imported by: 0

README

Natural Evolution Strategies

An implementation of Natural Evolution Strategies for black box optimization.

How it works

Natural Evolution Strategies optimizes a black box function parameterized with a set of weights by mutating a baseline set of weights with noise for each member of the population, then biasing the noise provided to the rewards achieved and summing that accross the entire population. The baseline set of weights is then updated at each generation based on the Natural Gradient of the entire population.

eq

Examples

See the experiments folder for example implementations.

Roadmap

  • More enviornments
  • K8s
  • Support multiple weights.

References

Documentation

Overview

Package nes is an agent implementation of the Natural Evolution Strategies algorithm.

Index

Constants

This section is empty.

Variables

View Source
var DefaultAgentConfig = &AgentConfig{
	PolicyConfig: DefaultPolicyConfig,
}

DefaultAgentConfig is the default config for a dqn agent.

View Source
var DefaultEvolverHyperparameters = &EvolverHyperparameters{
	NPop:  50,
	NGen:  300,
	Sigma: 0.1,
	Alpha: 0.001,
}

DefaultEvolverHyperparameters are the default hyperparams for the evolver.

View Source
var DefaultFCLayerBuilder = func(x, y *modelv1.Input) []layer.Config {
	return []layer.Config{
		layer.FC{Input: x.Squeeze()[0], Output: y.Squeeze()[0], Activation: layer.Linear, NoBias: true},
	}
}

DefaultFCLayerBuilder is a default fully connected layer builder.

View Source
var DefaultPolicyConfig = &PolicyConfig{
	LayerBuilder: DefaultFCLayerBuilder,
	Track:        false,
}

DefaultPolicyConfig are the default hyperparameters for a policy.

View Source
var DefaultSphereBlackBoxConfig = &SphereBlackBoxConfig{
	NumEpisodes: 100,
	EnvName:     "CartPole-v0",
	AgentConfig: DefaultAgentConfig,
	Logger:      log.DefaultLogger,
	SolvedChecker: func(reward float32) bool {
		if reward >= 195 {
			return true
		}
		return false
	},
}

DefaultSphereBlackBoxConfig is the default config for a sphere black box.

Functions

func MakePolicy

func MakePolicy(config *PolicyConfig, base *agentv1.Base, env *envv1.Env) (modelv1.Model, error)

MakePolicy makes a model.

Types

type Agent

type Agent struct {
	// Base for the agent.
	*agentv1.Base

	// Policy for the agent.
	Policy model.Model
	// contains filtered or unexported fields
}

Agent is a dqn agent.

func NewAgent

func NewAgent(c *AgentConfig, env *envv1.Env, base *agentv1.Base) (*Agent, error)

NewAgent returns a new dqn agent.

func (*Agent) Action

func (a *Agent) Action(state *tensor.Dense) (action int, err error)

Action selects the best known action for the given state and weights.

func (*Agent) SetWeights

func (a *Agent) SetWeights(weights *tensor.Dense) error

SetWeights sets the weights. TODO: support multiple weights.

type AgentConfig

type AgentConfig struct {
	// PolicyConfig for the agent.
	PolicyConfig *PolicyConfig
}

AgentConfig is the config for a dqn agent.

type BlackBox

type BlackBox interface {
	// Run the black box.
	Run(weights *t.Dense) (reward float32, err error)

	// RunAsync the black box.
	RunAsync(populationID int, weights *tensor.Dense, results chan BlackBoxResult, wg *sync.WaitGroup)

	// Initialize the weights.
	InitWeights() *t.Dense
}

BlackBox function we wish to optimize.

type BlackBoxResult

type BlackBoxResult struct {
	// Reward from black box run.
	Reward float32

	// Error from run.
	Err error

	// PopulationID is the ID of the population.
	PopulationID int

	// Solved tells if the agent solved the problem.
	Solved bool

	// Weights if solved.
	Weights *tensor.Dense
}

BlackBoxResult is the result of a black box run.

type Evolver

type Evolver struct {
	*EvolverHyperparameters
	*agent.Base
	// contains filtered or unexported fields
}

Evolver of agents.

func NewEvolver

func NewEvolver(c *EvolverConfig) *Evolver

NewEvolver returns a new evolver.

func (*Evolver) Evolve

func (e *Evolver) Evolve() (weights *tensor.Dense, err error)

Evolve the agents.

type EvolverConfig

type EvolverConfig struct {
	// Hyperparameters for the evolver.
	*EvolverHyperparameters

	// BlackBox function to be optimized.
	BlackBox BlackBox

	// Base agent.
	Base *agent.Base
}

EvolverConfig is the config for the evolver.

type EvolverHyperparameters

type EvolverHyperparameters struct {
	// NPop is the population size.
	NPop int

	// NGen is the number of generations.
	NGen int

	// Sigma is the noise standard deviation.
	Sigma float32

	// Alpha is the learning rate.
	Alpha float32
}

EvolverHyperparameters are the hyperparameters for the evolver.

type LayerBuilder

type LayerBuilder func(x, y *modelv1.Input) []layer.Config

LayerBuilder builds layers.

type PolicyConfig

type PolicyConfig struct {
	// LayerBuilder is a builder of layer.
	LayerBuilder LayerBuilder

	// Track is whether to track the model.
	Track bool
}

PolicyConfig are the hyperparameters for a policy.

type SolvedChecker

type SolvedChecker func(reward float32) bool

SolvedChecker checks if the environment is solved.

type SphereBlackBox

type SphereBlackBox struct {
	// contains filtered or unexported fields
}

SphereBlackBox is a sphere environment runner.

func NewSphereBlackBox

func NewSphereBlackBox(config *SphereBlackBoxConfig, server *envv1.Server) (*SphereBlackBox, error)

NewSphereBlackBox returns a new sphere black box.

func (*SphereBlackBox) InitWeights

func (s *SphereBlackBox) InitWeights() *tensor.Dense

InitWeights for the test.

func (*SphereBlackBox) Run

func (s *SphereBlackBox) Run(weights *tensor.Dense) (reward float32, err error)

Run env.

func (*SphereBlackBox) RunAsync

func (s *SphereBlackBox) RunAsync(populationID int, weights *tensor.Dense, results chan BlackBoxResult, wg *sync.WaitGroup)

RunAsync runs the black box async

type SphereBlackBoxConfig

type SphereBlackBoxConfig struct {
	// NumEpisodes is the number of episodes.
	NumEpisodes int

	// EnvName is the environment name.
	EnvName string

	// AgentConfig is the agent config.
	AgentConfig *AgentConfig

	// Logger for the box.
	Logger *log.Logger

	// SolvedChecker checks if the environment is solved.
	SolvedChecker SolvedChecker
}

SphereBlackBoxConfig is the sphere black box config.

Directories

Path Synopsis
experiments

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL