Documentation
¶
Overview ¶
Package generation implements a generation search algorithm for conditional generation.
Index ¶
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
This section is empty.
Types ¶
type Decoder ¶
type Decoder interface { // Decode returns the log probabilities for each possible next element of a sequence. Decode(encodedInput []ag.Node, decodingInputIDs []int, pastCache Cache) (ag.Node, Cache) }
Decoder is a model able to encode.
type Encoder ¶
type Encoder interface { // Encode transforms a sequence of input IDs in a sequence of nodes. Encode(InputIDs []int) []ag.Node }
Encoder is a model able to encode each input of a sequence into a vector representation.
type EncoderDecoder ¶
EncoderDecoder is a model able to perform encoder-decoder conditional generation.
type Generator ¶
type Generator struct {
// contains filtered or unexported fields
}
Generator is an implementation of a generation search algorithm for conditional generation.
func NewGenerator ¶
func NewGenerator(config GeneratorConfig, model EncoderDecoder) *Generator
NewGenerator creates a new Generator object.
type GeneratorConfig ¶
type GeneratorConfig struct { // NumBeams is the number of beams for generation search. NumBeams int // MinLength is the minimum length of the sequence to be generated. MinLength int // MaxLength is the maximum length of the sequence to be generated. MaxLength int // IsEncoderDecoder reports whether the model is used as an encoder/decoder. IsEncoderDecoder bool // BOSTokenID is the ID of the Beginning-Of-Sequence token. BOSTokenID int // EOSTokenID is the ID of the End-Of-Sequence token. EOSTokenID int // PadTokenID is the id of the padding token. PadTokenID int // DecoderStartTokenID is the ID of the start token for the decoder of an // encoder-decoder model. DecoderStartTokenID int // LengthPenalty is the exponential penalty to the length. // 1.0 means no penalty. Set to values < 1.0 in order to encourage the // model to generate shorter sequences, to a value > 1.0 in order to // encourage the model to produce longer sequences. LengthPenalty mat.Float // EarlyStopping reports whether to stop the generation search when at least // NumBeams sentences are finished per batch or not. EarlyStopping bool // BadWordsIDs is a list of token IDs that are not allowed to be generated. BadWordsIDs [][]int // MaxConcurrentComputations is the maximum number of concurrent computations // handled by the generation search algorithm. MaxConcurrentComputations int // IncrementalForward indicates the graph usage mode. IncrementalForward bool }
GeneratorConfig provides configuration options for the generation search algorithm.
type Hypotheses ¶
type Hypotheses struct {
// contains filtered or unexported fields
}
Hypotheses provides hypotheses data for a generation Scorer.
func NewHypotheses ¶
func NewHypotheses(config GeneratorConfig) *Hypotheses
NewHypotheses returns a new Hypotheses.
func (*Hypotheses) Add ¶
func (h *Hypotheses) Add(hypVector []int, sumLogProbs mat.Float)
Add adds a new hypothesis to the list.
func (*Hypotheses) Beams ¶
func (h *Hypotheses) Beams() []Hypothesis
Beams returns the hypothesis beams.
func (*Hypotheses) IsDone ¶
func (h *Hypotheses) IsDone(bestSumLogProbs mat.Float, curLen int) bool
IsDone reports whether there are enough hypotheses and none of the hypotheses being generated can become better than the worst one in the heap.
func (*Hypotheses) Len ¶
func (h *Hypotheses) Len() int
Len returns the number of hypotheses in the list.
type Hypothesis ¶
Hypothesis represents a single generation hypothesis, which is a sequence of Token IDs paired with a score.
type ScoredToken ¶
ScoredToken associates a score to a token identified by its (generation-index, token-index) position.
type ScoredTokens ¶
type ScoredTokens []ScoredToken
ScoredTokens is a slice of ScoredToken.
func (ScoredTokens) TopK ¶
func (st ScoredTokens) TopK(k int) ScoredTokens
TopK returns the top K indices and values of the provided float array in decreasing order
type Scorer ¶
type Scorer struct {
// contains filtered or unexported fields
}
Scorer is a generation scorer implementing standard generation search decoding.
func (*Scorer) Finalize ¶
Finalize finalizes the generation hypotheses and returns the best sequence.
func (*Scorer) IsDone ¶
IsDone reports whether there are enough hypotheses and none of the hypotheses being generated can become better than the worst one in the heap.
func (*Scorer) Process ¶
func (s *Scorer) Process(inputIDs [][]int, scoredTokens ScoredTokens) ScorerProcessOutput
Process processes a new set of scored tokens.
type ScorerProcessOutput ¶
type ScorerProcessOutput struct {
// contains filtered or unexported fields
}
ScorerProcessOutput is the output value of Scorer.Process.