rae

package
v0.6.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 13, 2021 License: BSD-2-Clause Imports: 9 Imported by: 0

Documentation

Overview

Package rae provides an implementation of the recursive auto-encoder strategy described in "Towards Lossless Encoding of Sentences" by Prato et al., 2019. Unlike the method described in the paper above, here I opted to use the positional encoding introduced byVaswani et al. (2017) for the step encoding.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Decoder

type Decoder struct {
	nn.BaseModel
	DecodingFNN1 nn.StandardModel // decoding part 1
	DecodingFFN2 nn.StandardModel // decoding part 2
	DescalingFFN nn.StandardModel
	StepEncoder  *pe.SinusoidalPositionalEncoder
	State        State `spago:"scope:processor"`
}

Decoder contains the serializable parameters.

func NewDefaultDecoder

func NewDefaultDecoder(embeddingSize, outputSize, maxSequenceLength int) *Decoder

NewDefaultDecoder returns a new RAE Decoder.

func (*Decoder) Forward added in v0.2.0

func (p *Decoder) Forward(x ag.Node) []ag.Node

Forward performs the forward step for each input node and returns the result.

func (*Decoder) SetSequenceLength added in v0.2.0

func (p *Decoder) SetSequenceLength(length int)

SetSequenceLength sets the length of the expected sequence.

type Encoder

type Encoder struct {
	nn.BaseModel
	ScalingFFN  nn.StandardModel
	EncodingFFN nn.StandardModel
	StepEncoder *pe.SinusoidalPositionalEncoder
	Recursions  int `spago:"scope:processor"`
}

Encoder contains the serializable parameters.

func NewDefaultEncoder

func NewDefaultEncoder(inputSize, embeddingSize, maxSequenceLength int) *Encoder

NewDefaultEncoder returns a new RAE Encoder.

func (*Encoder) Forward added in v0.2.0

func (p *Encoder) Forward(xs ...ag.Node) []ag.Node

Forward performs the forward step for each input node and returns the result.

func (*Encoder) GetRecursions added in v0.2.0

func (p *Encoder) GetRecursions() int

GetRecursions returns the number of recursions.

type State added in v0.2.0

type State struct {
	SequenceLength int
	MaxRecursions  int
	Recursions     int
}

State contains information used during decoding recursion.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL