multiheadattention

package
v1.1.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Oct 30, 2023 License: BSD-2-Clause Imports: 11 Imported by: 4

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Cache

type Cache []selfattention.Cache

Cache contains the self-attention cache for each head.

func (Cache) At

func (r Cache) At(i int) selfattention.Cache

type Model

type Model struct {
	nn.Module
	Heads       []*selfattention.Model
	OutputMerge *linear.Model
}

Model contains the serializable parameters.

func New

func New[T float.DType](size, numOfHeads int, useCausalMask, isCrossAttention bool) *Model

New returns a new model with parameters initialized to zeros.

func (*Model) Forward

func (m *Model) Forward(cache Cache, q, x []mat.Tensor) ([]mat.Tensor, [][]mat.Tensor, Cache)

Forward performs the forward step for each input node and returns the result.

func (*Model) Init

func (m *Model) Init(rng *rand.LockedRand)

Init initializes the self-attention heads and the merge layer with uniform Xavier random distribution.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL