ad

package
v1.0.8 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jun 20, 2022 License: MIT Imports: 21 Imported by: 6

Documentation

Overview

Package ad implements automatic differentiation of a model. A model is defined in it's own package. The model must implement interface model.Model. In the model's source code:

  1. Methods on the type implementing model.Model returning a single float64 or nothing are differentiated.
  2. Within the methods, the following is differentiated: a) assignments to float64 (including parallel assignments if all values are of type float64); b) returns of float64; c) standalone calls to methods on the type implementing model.Model (apparently called for side effects on the model).
  3. Imported package name "ad" is reserved.
  4. Non-dummy identifiers starting with the prefix for generated identifiers ("_" by default) are reserved.

Functions are considered elementals (and must have a registered derivative) if their signature is of kind

func (float64, float64*) float64

that is, one or more non-variadic float64 argument and float64 return value. For example, function

func (float64, float64, float64) float64

is considered elemental, while functions

func (...float64) float64
func ([]float64) float64
func (int, float64) float64

are not.

Derivatives do not propagate through a function that is not an elemental or a call to a model method. If a derivative is not registered for an elemental, calling the elemental in a Observe will cause a run-time error.

The differentiated model is put into subpackage "ad" of the model's package, with the same name as the original package.

Index

Constants

View Source
const (
	OpNeg = iota
	OpAdd
	OpSub
	OpMul
	OpDiv
)

Arithmetic operation codes.

Variables

View Source
var (
	// When Fold is true, folded values are substituted instead
	// of constant expressions.
	Fold = true
)

Functions

func Arithmetic

func Arithmetic(op int, px ...*float64) *float64

Arithmetic encodes an arithmetic operation and returns the location of the result.

func Assignment

func Assignment(p *float64, px *float64)

Assignment encodes a single-value assingment.

func Call

func Call(
	f func(_vararg []float64),
	narg int,
	px ...*float64,
) *float64

Call wraps a call to a differentiated subfunction. narg is the number of non-variadic arguments.

func Called

func Called() bool

True iff the last record on the tape is a Call record. A call record is added before a call to a differentiated method from another differentiated method.

func Deriv

func Deriv(mpath string, prefix string) (err error)

Deriv differentiates a model. The original model is in the package located at mpath. The differentiated model is written to mpath/ad. When a variable is generated by the autodiff code, the variable name has the specified prefix.

func DropAllTapes added in v0.4.0

func DropAllTapes()

DropAllTapes discards all tapes. Intended for use with third party inference algorithms which run in multiple goroutines.

func DropTape added in v0.4.0

func DropTape()

DropTape discards the goroutine's tape.

func Elemental

func Elemental(f interface{}, px ...*float64) *float64

Elemental encodes a call to the elemental f. To call gradient without allocation on backward pass, argument values are copied to the tape memory. Elemental returns the location of the result.

func Enter

func Enter(px ...*float64)

Enter copies the actual parameters to the formal parameters.

func Gradient

func Gradient() []float64

Gradient performs the backward pass on the tape and returns the gradient. It should be called immediately after the call to an automatically differentiated function, and can be called only once per call to an automatically differentiated function.

func IsMTSafe added in v0.4.0

func IsMTSafe() bool

IsMTSafe returns true if multithreading support is turned on, and multiple differentiations may run concurrently.

func MTSafeOn added in v0.4.0

func MTSafeOn() bool

MTSafeOn makes differentiation thread safe at the expense of a loss in performance. There is no corresponding MTSafeOff, as once things are safe they cannot safely become unsafe again.

MTSafeOn enables multithreading support on some versions and architectures only. The caller should check the return value (true if succeeded) or call IsMTSafe if the code depends on the tape being thread-safe.

func ParallelAssignment

func ParallelAssignment(ppx ...*float64)

ParallelAssigment encodes a parallel assignment.

func Pop

func Pop()

Pop deallocates current tape fragment from the tape. Gradient calls Pop; when the gradient is not needed, Pop can be called directly to skip gradient computation.

func RegisterElemental

func RegisterElemental(f interface{}, g ElementalGradientFunc)

RegisterElemental registers the gradient for an elemental function.

func Return

func Return(px *float64) float64

Return returns the result of the differentiated function.

func Setup

func Setup(x []float64)

Setup set ups the tape for the forward pass.

func Value

func Value(v float64) *float64

Value adds value v to the memory and returns the location of the value.

func Vlemental added in v0.7.0

func Vlemental(f func([]float64) float64, x []float64) *float64

Vlemental encodes a call to the vector elemental f. To call gradient without allocation on backward pass, argument values are copied to the tape memory. Vlemental returns the location of the result.

Types

type ElementalGradientFunc

type ElementalGradientFunc func(value float64, params ...float64) []float64

ElementalGradientFunc accepts the function value and the parameters and returns a vector of partial gradients. Depending on the function, either the value or the parameters may be ignored in the computation of the gradient.

func ElementalGradient

func ElementalGradient(f interface{}) (ElementalGradientFunc, bool)

ElementalGradient returns the gradient for a function. If the function is not registered as elemental, the second returned value is false. Intended to be called from the backward pass of gradient computation. Exported for testing.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL