conditionalgeneration

package
v0.6.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 13, 2021 License: BSD-2-Clause Imports: 9 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Model

type Model struct {
	nn.BaseModel
	BART       *bart.Model
	Projection *linear.Model
}

Model is a model for conditional generation tasks which embeds a BART pre-trained model.

func New

func New(config config.Config, embeddingsPath string) *Model

New returns a new Model for conditional generation.

func (*Model) Close

func (m *Model) Close()

Close closes the BART model's embeddings DB.

func (*Model) Decode

func (m *Model) Decode(encodedInput []ag.Node, inputIDs []int, pastCache generation.Cache) (ag.Node, generation.Cache)

Decode satisfies pkg/nlp/transformers/generation/Decoder.

func (*Model) Encode

func (m *Model) Encode(InputIDs []int) []ag.Node

Encode satisfies pkg/nlp/transformers/generation/Encoder.

func (*Model) Generate

func (m *Model) Generate(inputIDs []int) []int

Generate generates sequences using generation-search decoding.

func (*Model) PredictNext

func (m *Model) PredictNext(
	encoderOutLastHiddenState []ag.Node,
	decoderInputIDs []int,
	pastKeyValues decoder.KeysValuesPairs,
) (ag.Node, decoder.KeysValuesPairs)

PredictNext returns the logits for the next possible tokens.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL