reinforce

package
v0.1.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 13, 2023 License: Apache-2.0 Imports: 13 Imported by: 0

README

REINFORCE

An implementation of the REINFORCE algorithm.

How it works

REINFORCE is form of policy gradient that uses a Monte Carlo rollout to compute rewards. It accumulates the rewards for the entire episode and then discounts them weighting earier rewards heavier using the equation: reward discount

The gradient is computed by the softmax loss of the discounted rewards with respect to the episode states. Actions are then sampled from the softmax distribution.

full equation

Examples

See the experiments folder for example implementations.

Roadmap

  • more environments

References

Documentation

Overview

Package reinforce is an agent implementation of the REINFORCE algorithm.

Index

Constants

This section is empty.

Variables

View Source
var DefaultAgentConfig = &AgentConfig{
	Hyperparameters: DefaultHyperparameters,
	PolicyConfig:    DefaultPolicyConfig,
	Base:            agentv1.NewBase("REINFORCE"),
}

DefaultAgentConfig is the default config for a dqn agent.

View Source
var DefaultFCLayerBuilder = func(x, y *modelv1.Input) []layer.Config {
	return []layer.Config{
		layer.FC{Input: x.Squeeze()[0], Output: 24},
		layer.FC{Input: 24, Output: 24},
		layer.FC{Input: 24, Output: y.Squeeze()[0], Activation: layer.Softmax},
	}
}

DefaultFCLayerBuilder is a default fully connected layer builder.

View Source
var DefaultHyperparameters = &Hyperparameters{
	Gamma: 0.99,
}

DefaultHyperparameters are the default hyperparameters.

View Source
var DefaultPolicyConfig = &PolicyConfig{
	Optimizer:    g.NewAdamSolver(),
	LayerBuilder: DefaultFCLayerBuilder,
	Track:        true,
}

DefaultPolicyConfig are the default hyperparameters for a policy.

Functions

func MakePolicy

func MakePolicy(config *PolicyConfig, base *agentv1.Base, env *envv1.Env) (modelv1.Model, error)

MakePolicy makes a model.

Types

type Agent

type Agent struct {
	// Base for the agent.
	*agentv1.Base

	// Hyperparameters for the dqn agent.
	*Hyperparameters

	// Policy by which the agent acts.
	Policy model.Model

	// Memory of the agent.
	Memory *Memory
	// contains filtered or unexported fields
}

Agent is a dqn agent.

func NewAgent

func NewAgent(c *AgentConfig, env *envv1.Env) (*Agent, error)

NewAgent returns a new dqn agent.

func (*Agent) Action

func (a *Agent) Action(state *tensor.Dense) (action int, err error)

Action selects the best known action for the given state.

func (*Agent) Learn

func (a *Agent) Learn() error

Learn the agent.

type AgentConfig

type AgentConfig struct {
	// Base for the agent.
	Base *agentv1.Base

	// Hyperparameters for the agent.
	*Hyperparameters

	// PolicyConfig for the agent.
	PolicyConfig *PolicyConfig
}

AgentConfig is the config for a dqn agent.

type Hyperparameters

type Hyperparameters struct {
	// Gamma is the discount factor (0≤γ≤1). It determines how much importance we want to give to future
	// rewards. A high value for the discount factor (close to 1) captures the long-term effective award, whereas,
	// a discount factor of 0 makes our agent consider only immediate reward, hence making it greedy.
	Gamma float32
}

Hyperparameters for the dqn agent.

type LayerBuilder

type LayerBuilder func(x, y *modelv1.Input) []layer.Config

LayerBuilder builds layers.

type Memory

type Memory struct {
	States  []*tensor.Dense
	Actions []float32
	Rewards []float32
}

Memory for the agent.

func NewMemory

func NewMemory() *Memory

NewMemory returns a new Memory store.

func (*Memory) Clear

func (m *Memory) Clear()

Clear the memory.

func (*Memory) Pop

func (m *Memory) Pop() (states []*tensor.Dense, actions, rewards []float32)

Pop the actions and rewards from memory.

func (*Memory) Store

func (m *Memory) Store(state *tensor.Dense, action int, reward float32)

Store an action reward pair.

type PolicyConfig

type PolicyConfig struct {
	// Optimizer to optimize the weights with regards to the error.
	Optimizer g.Solver

	// LayerBuilder is a builder of layer.
	LayerBuilder LayerBuilder

	// Track is whether to track the model.
	Track bool
}

PolicyConfig are the hyperparameters for a policy.

Directories

Path Synopsis
experiments

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL