agent

package
v0.0.0-...-25ab4ef Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Apr 27, 2024 License: MIT Imports: 8 Imported by: 0

Documentation

Overview

Example (Rand)
package main

import (
	"fmt"
	randv2 "math/rand/v2"

	"github.com/itsubaki/neu/math/rand"
)

func main() {
	for i := 0; i < 5; i++ {
		r := randv2.New(rand.Const(1))
		fmt.Println(r.Float64())
	}

	s := rand.Const(1)
	for i := 0; i < 5; i++ {
		r := randv2.New(s)
		fmt.Println(r.Float64())
	}

}
Output:

0.23842319087387442
0.23842319087387442
0.23842319087387442
0.23842319087387442
0.23842319087387442
0.23842319087387442
0.50092138792625
0.04999911180706662
0.4894631469238666
0.7500167893718852
Example (Target)
package main

import (
	"fmt"

	"github.com/itsubaki/neu/agent"
)

func main() {
	fmt.Println(agent.Target(
		[]float64{1, 2, 3},
		[]bool{false, false, true},
		0.98,
		[]float64{1, 2, 3},
	))

}
Output:

[[1.98] [3.96] [3]]

Index

Examples

Constants

This section is empty.

Variables

This section is empty.

Functions

func SortedKeys

func SortedKeys(m map[string]float64) []string

Types

type Agent

type Agent struct {
	Epsilon float64
	Qs      []float64
	Ns      []float64
	Source  randv2.Source
}
Example
package main

import (
	"fmt"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/math/rand"
)

func main() {
	a := &agent.Agent{
		Epsilon: 0.5,
		Qs:      []float64{0, 0, 0, 0, 0},
		Ns:      []float64{0, 0, 0, 0, 0},
		Source:  rand.Const(1),
	}

	for i := 0; i < 10; i++ {
		action := a.GetAction()
		a.Update(action, 1.0)
		fmt.Printf("%v: %v\n", action, a.Qs)
	}

}
Output:

0: [1 0 0 0 0]
0: [1 0 0 0 0]
0: [1 0 0 0 0]
0: [1 0 0 0 0]
2: [1 0 1 0 0]
0: [1 0 1 0 0]
2: [1 0 1 0 0]
4: [1 0 1 0 1]
4: [1 0 1 0 1]
3: [1 0 1 1 1]
Example (Bandit)
package main

import (
	"fmt"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/agent/env"
	"github.com/itsubaki/neu/math/rand"
	"github.com/itsubaki/neu/math/vector"
)

func main() {
	arms, steps, runs, eps := 10, 1000, 200, 0.1
	s := rand.Const(1)

	all := make([][]float64, runs)
	for r := 0; r < runs; r++ {
		bandit := env.NewNonStatBandit(arms, s)
		agent := &agent.Agent{Epsilon: eps, Qs: make([]float64, arms), Ns: make([]float64, arms), Source: s}

		var total float64
		rates := make([]float64, steps)
		for i := 0; i < steps; i++ {
			action := agent.GetAction()
			reward := bandit.Play(action)
			agent.Update(action, reward)

			total += reward
			rates[i] = total / float64(i+1)
		}

		all[r] = rates
	}

	for _, i := range []int{190, 191, 192, 193, 194, 195, 196, 197, 198, 199} {
		fmt.Printf("step=%3v: mean(rate)=%.4f\n", i, vector.Mean(all[i]))
	}

}
Output:

step=190: mean(rate)=0.7875
step=191: mean(rate)=0.9418
step=192: mean(rate)=0.9185
step=193: mean(rate)=0.6062
step=194: mean(rate)=0.8260
step=195: mean(rate)=0.8314
step=196: mean(rate)=0.8781
step=197: mean(rate)=0.8273
step=198: mean(rate)=0.8844
step=199: mean(rate)=0.8903

func (*Agent) GetAction

func (a *Agent) GetAction() int

func (*Agent) Update

func (a *Agent) Update(action int, reward float64)

type AlphaAgent

type AlphaAgent struct {
	Epsilon float64
	Alpha   float64
	Qs      []float64
	Source  randv2.Source
}
Example
package main

import (
	"fmt"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/math/rand"
)

func main() {
	a := &agent.AlphaAgent{
		Epsilon: 0.5,
		Alpha:   0.8,
		Qs:      []float64{0, 0, 0, 0, 0},
		Source:  rand.Const(1),
	}

	for i := 0; i < 10; i++ {
		action := a.GetAction()
		a.Update(action, 1.0)
		fmt.Printf("%v: %.4f\n", action, a.Qs)
	}

}
Output:

0: [0.8000 0.0000 0.0000 0.0000 0.0000]
0: [0.9600 0.0000 0.0000 0.0000 0.0000]
0: [0.9920 0.0000 0.0000 0.0000 0.0000]
0: [0.9984 0.0000 0.0000 0.0000 0.0000]
2: [0.9984 0.0000 0.8000 0.0000 0.0000]
0: [0.9997 0.0000 0.8000 0.0000 0.0000]
2: [0.9997 0.0000 0.9600 0.0000 0.0000]
4: [0.9997 0.0000 0.9600 0.0000 0.8000]
4: [0.9997 0.0000 0.9600 0.0000 0.9600]
3: [0.9997 0.0000 0.9600 0.8000 0.9600]
Example (Bandit)
package main

import (
	"fmt"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/agent/env"
	"github.com/itsubaki/neu/math/rand"
	"github.com/itsubaki/neu/math/vector"
)

func main() {
	arms, steps, runs := 10, 1000, 200
	eps, alpha := 0.1, 0.8
	s := rand.Const(1)

	all := make([][]float64, runs)
	for r := 0; r < runs; r++ {
		bandit := env.NewNonStatBandit(arms, s)
		agent := &agent.AlphaAgent{Epsilon: eps, Alpha: alpha, Qs: make([]float64, arms), Source: s}

		var total float64
		rates := make([]float64, steps)
		for i := 0; i < steps; i++ {
			action := agent.GetAction()
			reward := bandit.Play(action)
			agent.Update(action, reward)

			total += reward
			rates[i] = total / float64(i+1)
		}

		all[r] = rates
	}

	for _, i := range []int{190, 191, 192, 193, 194, 195, 196, 197, 198, 199} {
		fmt.Printf("step=%3v: mean(rate)=%.4f\n", i, vector.Mean(all[i]))
	}

}
Output:

step=190: mean(rate)=0.8010
step=191: mean(rate)=0.9436
step=192: mean(rate)=0.9249
step=193: mean(rate)=0.6519
step=194: mean(rate)=0.8373
step=195: mean(rate)=0.8367
step=196: mean(rate)=0.8759
step=197: mean(rate)=0.9196
step=198: mean(rate)=0.8844
step=199: mean(rate)=0.8871

func (*AlphaAgent) GetAction

func (a *AlphaAgent) GetAction() int

func (*AlphaAgent) Update

func (a *AlphaAgent) Update(action int, reward float64)

type Buffer

type Buffer struct {
	State     []float64
	Action    int
	Reward    float64
	NextState []float64
	Done      bool
}

type DQNAgent

type DQNAgent struct {
	Gamma        float64
	Epsilon      float64
	ReplayBuffer *ReplayBuffer
	ActionSize   int
	Q            *model.QNet
	QTarget      *model.QNet
	Optimizer    *optimizer.Adam
	Source       randv2.Source
}
Example
package main

import (
	"fmt"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/agent/env"
	"github.com/itsubaki/neu/math/matrix"
	"github.com/itsubaki/neu/math/rand"
	"github.com/itsubaki/neu/model"
	"github.com/itsubaki/neu/optimizer"
	"github.com/itsubaki/neu/weight"
)

func main() {
	e := env.NewGridWorld()
	s := rand.Const(1)
	a := &agent.DQNAgent{
		Gamma:        0.98,
		Epsilon:      0.1,
		ActionSize:   4,
		ReplayBuffer: agent.NewReplayBuffer(10000, 32, s),
		Q: model.NewQNet(&model.QNetConfig{
			InputSize:  12,
			OutputSize: 4,
			HiddenSize: []int{128, 128},
			WeightInit: weight.Xavier,
		}, s),
		QTarget: model.NewQNet(&model.QNetConfig{
			InputSize:  12,
			OutputSize: 4,
			HiddenSize: []int{128, 128},
			WeightInit: weight.Xavier,
		}, s),
		Optimizer: &optimizer.Adam{
			Alpha: 0.0005,
			Beta1: 0.9,
			Beta2: 0.999,
		},
		Source: s,
	}

	episodes, syncInterval := 1, 1
	for i := 0; i < episodes; i++ {
		state := e.OneHot(e.Reset())
		var totalLoss, totalReward float64
		var count int

		for {
			action := a.GetAction(state)
			next, reward, done := e.Step(action)
			nextoh := e.OneHot(next)
			loss := a.Update(state, action, reward, nextoh, done)
			state = nextoh

			totalLoss += loss[0][0]
			totalReward += reward
			count++

			if done {
				break
			}
		}

		if (i+1)%syncInterval == 0 {
			a.Sync()
		}

		fmt.Printf("%d: %.4f, %.4f\n", i, totalLoss/float64(count), totalReward/float64(count))
	}

	for _, s := range e.State {
		if s.Equals(e.GoalState) || s.Equals(e.WallState) {
			continue
		}

		q := a.Q.Predict(matrix.New(e.OneHot(&s)))
		for _, a := range e.Actions() {
			fmt.Printf("%s %-6s: %.4f\n", s, e.ActionMeaning[a], q[0][a])
		}
	}

}
Output:

0: 0.0143, -0.0046
(0, 0) UP    : 0.1057
(0, 0) DOWN  : 0.0411
(0, 0) LEFT  : -0.1044
(0, 0) RIGHT : -0.1153
(0, 1) UP    : 0.2575
(0, 1) DOWN  : 0.0565
(0, 1) LEFT  : -0.0090
(0, 1) RIGHT : -0.1412
(0, 2) UP    : -0.1237
(0, 2) DOWN  : 0.3362
(0, 2) LEFT  : 0.0124
(0, 2) RIGHT : -0.0446
(1, 0) UP    : 0.0993
(1, 0) DOWN  : 0.0425
(1, 0) LEFT  : -0.1653
(1, 0) RIGHT : -0.0591
(1, 2) UP    : -0.4625
(1, 2) DOWN  : 0.0474
(1, 2) LEFT  : -0.2263
(1, 2) RIGHT : -0.1712
(1, 3) UP    : 0.7964
(1, 3) DOWN  : 0.0965
(1, 3) LEFT  : 0.0643
(1, 3) RIGHT : -0.1828
(2, 0) UP    : -0.4854
(2, 0) DOWN  : 0.2162
(2, 0) LEFT  : -0.2302
(2, 0) RIGHT : -0.1094
(2, 1) UP    : 0.2301
(2, 1) DOWN  : 0.0680
(2, 1) LEFT  : -0.0531
(2, 1) RIGHT : -0.0764
(2, 2) UP    : -1.2781
(2, 2) DOWN  : 0.2185
(2, 2) LEFT  : -0.6493
(2, 2) RIGHT : -0.4158
(2, 3) UP    : -0.7843
(2, 3) DOWN  : 0.1689
(2, 3) LEFT  : -0.4043
(2, 3) RIGHT : -0.2855

func (*DQNAgent) GetAction

func (a *DQNAgent) GetAction(state []float64) int

func (*DQNAgent) Sync

func (a *DQNAgent) Sync()

func (*DQNAgent) Update

func (a *DQNAgent) Update(state []float64, action int, reward float64, next []float64, done bool) matrix.Matrix

type DefaultMap

type DefaultMap[T any] map[string]T
Example
package main

import (
	"fmt"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/agent/env"
)

func main() {
	m := agent.DefaultMap[agent.RandomActions]{}

	fmt.Println(m.Get(env.GridState{Height: 1, Width: 1}, agent.RandomActions{0: 0.25, 1: 0.25, 2: 0.25, 3: 0.25}).Probs())
	for k, v := range m {
		fmt.Println(k, v)
	}

}
Output:

[0.25 0.25 0.25 0.25]
(1, 1) map[0:0.25 1:0.25 2:0.25 3:0.25]

func (DefaultMap[T]) Get

func (m DefaultMap[T]) Get(key fmt.Stringer, defaultValue T) T

type Deque

type Deque[T any] struct {
	// contains filtered or unexported fields
}
Example
package main

import (
	"fmt"

	"github.com/itsubaki/neu/agent"
)

func main() {
	q := agent.NewDeque[agent.Memory](2)
	q.Add(agent.Memory{State: "a", Action: 1, Reward: 1, Done: false})
	q.Add(agent.Memory{State: "b", Action: 2, Reward: 2, Done: false})
	q.Add(agent.Memory{State: "c", Action: 3, Reward: 3, Done: true})

	fmt.Println(q.Get(0))
	fmt.Println(q.Get(1))
	fmt.Println(q.Len())
	fmt.Println(q.Size())

}
Output:

{b 2 2 false}
{c 3 3 true}
2
2

func NewDeque

func NewDeque[T any](size int) *Deque[T]

func (*Deque[T]) Add

func (q *Deque[T]) Add(m T)

func (*Deque[T]) Get

func (q *Deque[T]) Get(i int) T

func (*Deque[T]) Len

func (q *Deque[T]) Len() int

func (*Deque[T]) Size

func (q *Deque[T]) Size() int

type Memory

type Memory struct {
	State  string
	Action int
	Reward float64
	Done   bool
}

func NewMemory

func NewMemory(state fmt.Stringer, action int, reward float64, done bool) Memory

type MonteCarloAgent

type MonteCarloAgent struct {
	Gamma          float64
	Epsilon        float64
	Alpha          float64
	ActionSize     int
	DefaultActions RandomActions
	Pi             DefaultMap[RandomActions]
	Q              DefaultMap[float64]
	Memory         []Memory
	Source         randv2.Source
}
Example
package main

import (
	"fmt"
	"strconv"
	"strings"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/agent/env"
	"github.com/itsubaki/neu/math/rand"
)

func main() {
	e := env.NewGridWorld()
	a := &agent.MonteCarloAgent{
		Gamma:          0.9,
		Epsilon:        0.1,
		Alpha:          0.1,
		ActionSize:     4,
		DefaultActions: agent.RandomActions{0: 0.25, 1: 0.25, 2: 0.25, 3: 0.25},
		Pi:             make(map[string]agent.RandomActions),
		Q:              make(map[string]float64),
		Memory:         make([]agent.Memory, 0),
		Source:         rand.Const(1),
	}

	episodes := 10000
	for i := 0; i < episodes; i++ {
		state := e.Reset()
		a.Reset()

		for {
			action := a.GetAction(state)
			next, reward, done := e.Step(action)
			a.Add(state, action, reward)

			if done {
				a.Update()
				break
			}

			state = next
		}
	}

	for _, k := range agent.SortedKeys(a.Q) {
		s := strings.Split(k, ": ")
		move, _ := strconv.Atoi(s[1])
		fmt.Printf("%s %-6s: %.2f\n", s[0], e.ActionMeaning[move], a.Q[k])
	}

}
Output:

(0, 0) UP    : 0.70
(0, 0) DOWN  : 0.63
(0, 0) LEFT  : 0.73
(0, 0) RIGHT : 0.75
(0, 1) UP    : 0.81
(0, 1) DOWN  : 0.80
(0, 1) LEFT  : 0.69
(0, 1) RIGHT : 0.86
(0, 2) UP    : 0.89
(0, 2) DOWN  : 0.77
(0, 2) LEFT  : 0.80
(0, 2) RIGHT : 1.00
(1, 0) UP    : 0.68
(1, 0) DOWN  : 0.57
(1, 0) LEFT  : 0.61
(1, 0) RIGHT : 0.64
(1, 2) UP    : 0.88
(1, 2) DOWN  : 0.63
(1, 2) LEFT  : 0.78
(1, 2) RIGHT : -0.11
(1, 3) UP    : 1.00
(1, 3) DOWN  : -0.14
(1, 3) LEFT  : 0.30
(1, 3) RIGHT : -0.10
(2, 0) UP    : 0.61
(2, 0) DOWN  : 0.56
(2, 0) LEFT  : 0.54
(2, 0) RIGHT : 0.56
(2, 1) UP    : 0.51
(2, 1) DOWN  : 0.21
(2, 1) LEFT  : 0.45
(2, 1) RIGHT : 0.64
(2, 2) UP    : 0.71
(2, 2) DOWN  : 0.48
(2, 2) LEFT  : 0.42
(2, 2) RIGHT : -0.09
(2, 3) UP    : -0.20
(2, 3) DOWN  : -0.20
(2, 3) LEFT  : -0.04
(2, 3) RIGHT : -0.23

func (*MonteCarloAgent) Add

func (a *MonteCarloAgent) Add(state fmt.Stringer, action int, reward float64)

func (*MonteCarloAgent) GetAction

func (a *MonteCarloAgent) GetAction(state fmt.Stringer) int

func (*MonteCarloAgent) Reset

func (a *MonteCarloAgent) Reset()

func (*MonteCarloAgent) Update

func (a *MonteCarloAgent) Update()

type QLearningAgent

type QLearningAgent struct {
	Gamma      float64
	Alpha      float64
	Epsilon    float64
	ActionSize int
	Q          DefaultMap[float64]
	Source     randv2.Source
}
Example
package main

import (
	"fmt"
	"strconv"
	"strings"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/agent/env"
	"github.com/itsubaki/neu/math/rand"
)

func main() {
	e := env.NewGridWorld()
	a := &agent.QLearningAgent{
		Gamma:      0.9,
		Alpha:      0.8,
		Epsilon:    0.1,
		ActionSize: 4,
		Q:          make(map[string]float64),
		Source:     rand.Const(1),
	}

	episodes := 10000
	for i := 0; i < episodes; i++ {
		state := e.Reset()

		for {
			action := a.GetAction(state)
			next, reward, done := e.Step(action)
			a.Update(state, action, reward, next, done)

			if done {
				break
			}

			state = next
		}
	}

	for _, k := range agent.SortedKeys(a.Q) {
		s := strings.Split(k, ": ")
		move, _ := strconv.Atoi(s[1])
		fmt.Printf("%s %-6s: %.4f\n", s[0], e.ActionMeaning[move], a.Q[k])
	}

}
Output:

(0, 0) UP    : 0.7290
(0, 0) DOWN  : 0.6561
(0, 0) LEFT  : 0.7290
(0, 0) RIGHT : 0.8100
(0, 1) UP    : 0.8100
(0, 1) DOWN  : 0.8100
(0, 1) LEFT  : 0.7290
(0, 1) RIGHT : 0.9000
(0, 2) UP    : 0.9000
(0, 2) DOWN  : 0.8100
(0, 2) LEFT  : 0.8100
(0, 2) RIGHT : 1.0000
(1, 0) UP    : 0.7290
(1, 0) DOWN  : 0.5905
(1, 0) LEFT  : 0.6561
(1, 0) RIGHT : 0.6561
(1, 2) UP    : 0.9000
(1, 2) DOWN  : 0.7290
(1, 2) LEFT  : 0.8100
(1, 2) RIGHT : -0.1001
(1, 3) UP    : 1.0000
(1, 3) DOWN  : 0.0000
(1, 3) LEFT  : 0.0000
(1, 3) RIGHT : -0.0812
(2, 0) UP    : 0.6561
(2, 0) DOWN  : 0.5905
(2, 0) LEFT  : 0.5905
(2, 0) RIGHT : 0.6561
(2, 1) UP    : 0.6559
(2, 1) DOWN  : 0.6559
(2, 1) LEFT  : 0.5905
(2, 1) RIGHT : 0.7290
(2, 2) UP    : 0.8100
(2, 2) DOWN  : 0.7290
(2, 2) LEFT  : 0.6561
(2, 2) RIGHT : 0.0000
(2, 3) UP    : -0.1000
(2, 3) DOWN  : 0.0000
(2, 3) LEFT  : 0.0000
(2, 3) RIGHT : 0.0000

func (*QLearningAgent) GetAction

func (a *QLearningAgent) GetAction(state fmt.Stringer) int

func (*QLearningAgent) Update

func (a *QLearningAgent) Update(state fmt.Stringer, action int, reward float64, next fmt.Stringer, done bool)

type RandomActions

type RandomActions map[int]float64

func (RandomActions) Probs

func (a RandomActions) Probs() []float64

type RandomAgent

type RandomAgent struct {
	Gamma          float64
	ActionSize     int
	DefaultActions RandomActions
	Pi             DefaultMap[RandomActions]
	V              map[string]float64
	Counts         map[string]int
	Memory         []Memory
	Source         randv2.Source
}
Example
package main

import (
	"fmt"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/agent/env"
	"github.com/itsubaki/neu/math/rand"
)

func main() {
	e := env.NewGridWorld()
	a := &agent.RandomAgent{
		Gamma:          0.9,
		ActionSize:     4,
		DefaultActions: agent.RandomActions{0: 0.25, 1: 0.25, 2: 0.25, 3: 0.25},
		Pi:             make(map[string]agent.RandomActions),
		V:              make(map[string]float64),
		Counts:         make(map[string]int),
		Memory:         make([]agent.Memory, 0),
		Source:         rand.Const(1),
	}

	episodes := 1000
	for i := 0; i < episodes; i++ {
		state := e.Reset()
		a.Reset()

		for {
			action := a.GetAction(state)
			next, reward, done := e.Step(action)
			a.Add(state, action, reward)

			if done {
				a.Eval()
				break
			}

			state = next
		}
	}

	for _, k := range agent.SortedKeys(a.V) {
		fmt.Println(k, a.V[k])
	}

}
Output:

(0, 0) 0.019763603081288356
(0, 1) 0.07803773078469887
(0, 2) 0.1844500459964045
(1, 0) -0.023319112315680166
(1, 2) -0.484607672112952
(1, 3) -0.3464933845764332
(2, 0) -0.10313829730561434
(2, 1) -0.22467848343570673
(2, 2) -0.4547704293279075
(2, 3) -0.760130052613597

func (*RandomAgent) Add

func (a *RandomAgent) Add(state fmt.Stringer, action int, reward float64)

func (*RandomAgent) Eval

func (a *RandomAgent) Eval()

func (*RandomAgent) GetAction

func (a *RandomAgent) GetAction(state fmt.Stringer) int

func (*RandomAgent) Reset

func (a *RandomAgent) Reset()

type ReplayBuffer

type ReplayBuffer struct {
	Buffer    *Deque[Buffer]
	BatchSize int
	Source    randv2.Source
}
Example
package main

import (
	"fmt"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/math/rand"
)

func main() {
	buf := agent.NewReplayBuffer(10, 3, rand.Const(1))
	for i := 0; i < 10; i++ {
		buf.Add([]float64{float64(i), float64(i)}, i, float64(i), []float64{float64(i * 10), float64(i * 10)}, false)
	}
	fmt.Println(buf.Len())

	state, action, reward, next, done := buf.Batch()
	for i := range state {
		fmt.Println(state[i], action[i], reward[i], next[i], done[i])
	}

}
Output:

10
[0 0] 0 0 [0 0] false
[7 7] 7 7 [70 70] false
[5 5] 5 5 [50 50] false
Example (Rand)
package main

import (
	"fmt"

	"github.com/itsubaki/neu/agent"
)

func main() {
	buf := agent.NewReplayBuffer(10, 3)
	for i := 0; i < 10; i++ {
		buf.Add([]float64{float64(i)}, i, float64(i), []float64{float64(i * 10)}, false)
	}
	fmt.Println(buf.Len())

}
Output:

10

func NewReplayBuffer

func NewReplayBuffer(bufferSize, batchSize int, s ...randv2.Source) *ReplayBuffer

func (*ReplayBuffer) Add

func (b *ReplayBuffer) Add(state []float64, action int, reward float64, next []float64, done bool)

func (*ReplayBuffer) Batch

func (b *ReplayBuffer) Batch() ([][]float64, []int, []float64, [][]float64, []bool)

func (*ReplayBuffer) Len

func (b *ReplayBuffer) Len() int

type SarsaAgent

type SarsaAgent struct {
	Gamma          float64
	Alpha          float64
	Epsilon        float64
	ActionSize     int
	DefaultActions RandomActions
	Pi             DefaultMap[RandomActions]
	Q              DefaultMap[float64]
	Memory         *Deque[Memory]
	Source         randv2.Source
}
Example
package main

import (
	"fmt"
	"strconv"
	"strings"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/agent/env"
	"github.com/itsubaki/neu/math/rand"
)

func main() {
	e := env.NewGridWorld()
	a := &agent.SarsaAgent{
		Gamma:          0.9,
		Alpha:          0.8,
		Epsilon:        0.1,
		ActionSize:     4,
		DefaultActions: agent.RandomActions{0: 0.25, 1: 0.25, 2: 0.25, 3: 0.25},
		Pi:             make(map[string]agent.RandomActions),
		Q:              make(map[string]float64),
		Memory:         agent.NewDeque[agent.Memory](2),
		Source:         rand.Const(1),
	}

	episodes := 10000
	for i := 0; i < episodes; i++ {
		state := e.Reset()
		a.Reset()

		for {
			action := a.GetAction(state)
			next, reward, done := e.Step(action)
			a.Update(state, action, reward, done)

			if done {
				a.Update(state, -1, 0, false)
				break
			}

			state = next
		}
	}

	for _, k := range agent.SortedKeys(a.Q) {
		s := strings.Split(k, ": ")
		move, _ := strconv.Atoi(s[1])
		fmt.Printf("%s %-6s: %.4f\n", s[0], e.ActionMeaning[move], a.Q[k])
	}

}
Output:

(0, 0) UP    : 0.4435
(0, 0) DOWN  : 0.4794
(0, 0) LEFT  : 0.5081
(0, 0) RIGHT : 0.8100
(0, 1) UP    : 0.6910
(0, 1) DOWN  : 0.7147
(0, 1) LEFT  : 0.5726
(0, 1) RIGHT : 0.9000
(0, 2) UP    : 0.8856
(0, 2) DOWN  : 0.8097
(0, 2) LEFT  : 0.6566
(0, 2) RIGHT : 1.0000
(1, 0) UP    : 0.6862
(1, 0) DOWN  : 0.4138
(1, 0) LEFT  : 0.4232
(1, 0) RIGHT : 0.4689
(1, 2) UP    : 0.9000
(1, 2) DOWN  : 0.0828
(1, 2) LEFT  : 0.7719
(1, 2) RIGHT : -0.1973
(1, 3) UP    : 0.9920
(1, 3) DOWN  : 0.0000
(1, 3) LEFT  : 0.6116
(1, 3) RIGHT : -0.8000
(2, 0) UP    : 0.4778
(2, 0) DOWN  : 0.4032
(2, 0) LEFT  : 0.4687
(2, 0) RIGHT : 0.3780
(2, 1) UP    : 0.4054
(2, 1) DOWN  : 0.3012
(2, 1) LEFT  : 0.4845
(2, 1) RIGHT : 0.0498
(2, 2) UP    : -0.1938
(2, 2) DOWN  : 0.0000
(2, 2) LEFT  : 0.3481
(2, 2) RIGHT : -0.1597
(2, 3) UP    : -0.2669
(2, 3) DOWN  : 0.0000
(2, 3) LEFT  : 0.0000
(2, 3) RIGHT : 0.0000

func (*SarsaAgent) GetAction

func (a *SarsaAgent) GetAction(state fmt.Stringer) int

func (*SarsaAgent) Reset

func (a *SarsaAgent) Reset()

func (*SarsaAgent) Update

func (a *SarsaAgent) Update(state fmt.Stringer, action int, reward float64, done bool)

type SarsaOffPolicyAgent

type SarsaOffPolicyAgent struct {
	Gamma          float64
	Alpha          float64
	Epsilon        float64
	ActionSize     int
	DefaultActions RandomActions
	Pi             DefaultMap[RandomActions]
	B              DefaultMap[RandomActions]
	Q              DefaultMap[float64]
	Memory         *Deque[Memory]
	Source         randv2.Source
}
Example
package main

import (
	"fmt"
	"strconv"
	"strings"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/agent/env"
	"github.com/itsubaki/neu/math/rand"
)

func main() {
	e := env.NewGridWorld()
	a := &agent.SarsaOffPolicyAgent{
		Gamma:          0.9,
		Alpha:          0.8,
		Epsilon:        0.1,
		ActionSize:     4,
		DefaultActions: agent.RandomActions{0: 0.25, 1: 0.25, 2: 0.25, 3: 0.25},
		Pi:             make(map[string]agent.RandomActions),
		B:              make(map[string]agent.RandomActions),
		Q:              make(map[string]float64),
		Memory:         agent.NewDeque[agent.Memory](2),
		Source:         rand.Const(1),
	}

	episodes := 10000
	for i := 0; i < episodes; i++ {
		state := e.Reset()
		a.Reset()

		for {
			action := a.GetAction(state)
			next, reward, done := e.Step(action)
			a.Update(state, action, reward, done)

			if done {
				a.Update(state, -1, 0, false)
				break
			}

			state = next
		}
	}

	for _, k := range agent.SortedKeys(a.Q) {
		s := strings.Split(k, ": ")
		move, _ := strconv.Atoi(s[1])
		fmt.Printf("%s %-6s: %.4f\n", s[0], e.ActionMeaning[move], a.Q[k])
	}

}
Output:

(0, 0) UP    : 0.0392
(0, 0) DOWN  : 0.0294
(0, 0) LEFT  : 0.0825
(0, 0) RIGHT : 0.9415
(0, 1) UP    : 0.1644
(0, 1) DOWN  : 0.0989
(0, 1) LEFT  : 0.7034
(0, 1) RIGHT : 0.9727
(0, 2) UP    : 0.9730
(0, 2) DOWN  : 0.7868
(0, 2) LEFT  : 0.3601
(0, 2) RIGHT : 1.0000
(1, 0) UP    : 0.8772
(1, 0) DOWN  : 0.0449
(1, 0) LEFT  : 0.0249
(1, 0) RIGHT : 0.1231
(1, 2) UP    : 0.9717
(1, 2) DOWN  : 0.0294
(1, 2) LEFT  : 0.0286
(1, 2) RIGHT : -0.1081
(1, 3) UP    : 1.0000
(1, 3) DOWN  : 0.2980
(1, 3) LEFT  : 0.9221
(1, 3) RIGHT : -0.1081
(2, 0) UP    : 0.7060
(2, 0) DOWN  : 0.1899
(2, 0) LEFT  : 0.0973
(2, 0) RIGHT : 0.0888
(2, 1) UP    : 0.0137
(2, 1) DOWN  : 0.0042
(2, 1) LEFT  : 0.4360
(2, 1) RIGHT : 0.6891
(2, 2) UP    : 0.9112
(2, 2) DOWN  : 0.0142
(2, 2) LEFT  : 0.0029
(2, 2) RIGHT : 0.0530
(2, 3) UP    : -0.0182
(2, 3) DOWN  : 0.0171
(2, 3) LEFT  : 0.0591
(2, 3) RIGHT : 0.0065

func (*SarsaOffPolicyAgent) GetAction

func (a *SarsaOffPolicyAgent) GetAction(state fmt.Stringer) int

func (*SarsaOffPolicyAgent) Reset

func (a *SarsaOffPolicyAgent) Reset()

func (*SarsaOffPolicyAgent) Update

func (a *SarsaOffPolicyAgent) Update(state fmt.Stringer, action int, reward float64, done bool)

type StateAction

type StateAction struct {
	State  string
	Action int
}

func (StateAction) String

func (s StateAction) String() string

type TemporalDiffAgent

type TemporalDiffAgent struct {
	Gamma          float64
	Alpha          float64
	ActionSize     int
	DefaultActions RandomActions
	Pi             DefaultMap[RandomActions]
	V              map[string]float64
	Source         randv2.Source
}
Example
package main

import (
	"fmt"

	"github.com/itsubaki/neu/agent"
	"github.com/itsubaki/neu/agent/env"
	"github.com/itsubaki/neu/math/rand"
)

func main() {
	e := env.NewGridWorld()
	a := &agent.TemporalDiffAgent{
		Gamma:          0.9,
		Alpha:          0.1,
		ActionSize:     4,
		DefaultActions: agent.RandomActions{0: 0.25, 1: 0.25, 2: 0.25, 3: 0.25},
		Pi:             make(map[string]agent.RandomActions),
		V:              make(map[string]float64),
		Source:         rand.Const(1),
	}

	episodes := 1000
	for i := 0; i < episodes; i++ {
		state := e.Reset()

		for {
			action := a.GetAction(state)
			next, reward, done := e.Step(action)
			a.Eval(state, reward, next, done)

			if done {
				break
			}

			state = next
		}
	}

	for _, k := range agent.SortedKeys(a.V) {
		fmt.Printf("%s: %.2f\n", k, a.V[k])
	}

}
Output:

(0, 0): 0.07
(0, 1): 0.13
(0, 2): 0.05
(1, 0): -0.01
(1, 2): -0.68
(1, 3): -0.60
(2, 0): -0.08
(2, 1): -0.15
(2, 2): -0.43
(2, 3): -0.60

func (*TemporalDiffAgent) Eval

func (a *TemporalDiffAgent) Eval(state fmt.Stringer, reward float64, next fmt.Stringer, done bool)

func (*TemporalDiffAgent) GetAction

func (a *TemporalDiffAgent) GetAction(state fmt.Stringer) int

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL