samplers

package
v0.0.0-...-83792b2 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Nov 1, 2024 License: Apache-2.0 Imports: 14 Imported by: 0

Documentation

Overview

Package samplers uses a transformer model to generate senteces based on prompts.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func UpdateCacheAttentionMaskGraph

func UpdateCacheAttentionMaskGraph(currentStep *Node, attentionLen int, inputMask *Node) *Node

UpdateCacheAttentionMaskGraph given an inputMask (on the whole batch of example token ids), a currentStep and attentionLen (static).

It's based on _compute_attention_mask in https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py#L32: the inputs and outputs here are very cryptic ... my best guess (also with some help from Gemini) what the original authors meant to generate is a mask that is False to where they can attend, except if it is in the "future" ("future" means positions > currentStep).

- currentStep: scalar with current step. - attentionLen: length of the attention mask (it's ok to be larger than inputMask, it will pad the output accordingly) - inputMask: mask of valid tokens in the input, shaped [batchSize, inputLen]

Types

type Sampler

type Sampler struct {
	Backend backends.Backend
	Vocab   Vocabulary

	// MaxGeneratedTokens default for Sampler.Sample.
	MaxGeneratedTokens int

	// Context with the model weights, used to execute the model.
	Context *context.Context

	// SampleStep graph computation.
	SampleStep *context.Exec

	// Config of the Gemma model, created from the weights.
	Config *transformers.Config

	// CacheTreeStructure holds the structure of the tree used for caching: the tree structure (paths) is stable
	// across different calls to Sample.
	CacheTreeStructure *trees.Tree[struct{}]
}

Sampler has a transformer (LLM) model and a vocabulary (sentencepiece) configured and generates sentences based on prompts.

func New

func New(backend backends.Backend, ctx *context.Context, vocab Vocabulary, maxGeneratedTokens int) (*Sampler, error)

New creates a new sampler with the registered vocabulary and model.

func (*Sampler) Sample

func (s *Sampler) Sample(prompts []string) ([]string, error)

Sample the continuation from the given prompts.

func (*Sampler) SampleMaxTokens

func (s *Sampler) SampleMaxTokens(prompts []string, maxTokens int) ([]string, error)

SampleMaxTokens is like Sample, but instead of using the default MaxGenerateTokens, uses the given maxTokens instead.

type Vocabulary

type Vocabulary interface {
	EncodeAsIDs(text string) []int
	DecodeIDs([]int) string

	BeginningOfSentenceID() int
	EndOfSentenceID() int
	UnknownID() int
	PadID() int
}

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL