Documentation ¶
Overview ¶
Package samplers uses a transformer model to generate senteces based on prompts.
Index ¶
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func UpdateCacheAttentionMaskGraph ¶
func UpdateCacheAttentionMaskGraph(currentStep *Node, attentionLen int, inputMask *Node) *Node
UpdateCacheAttentionMaskGraph given an inputMask (on the whole batch of example token ids), a currentStep and attentionLen (static).
It's based on _compute_attention_mask in https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py#L32: the inputs and outputs here are very cryptic ... my best guess (also with some help from Gemini) what the original authors meant to generate is a mask that is False to where they can attend, except if it is in the "future" ("future" means positions > currentStep).
- currentStep: scalar with current step. - attentionLen: length of the attention mask (it's ok to be larger than inputMask, it will pad the output accordingly) - inputMask: mask of valid tokens in the input, shaped [batchSize, inputLen]
Types ¶
type Sampler ¶
type Sampler struct { Backend backends.Backend Vocab Vocabulary // MaxGeneratedTokens default for Sampler.Sample. MaxGeneratedTokens int // Context with the model weights, used to execute the model. Context *context.Context // SampleStep graph computation. SampleStep *context.Exec // Config of the Gemma model, created from the weights. Config *transformers.Config // CacheTreeStructure holds the structure of the tree used for caching: the tree structure (paths) is stable // across different calls to Sample. CacheTreeStructure *trees.Tree[struct{}] }
Sampler has a transformer (LLM) model and a vocabulary (sentencepiece) configured and generates sentences based on prompts.
func New ¶
func New(backend backends.Backend, ctx *context.Context, vocab Vocabulary, maxGeneratedTokens int) (*Sampler, error)
New creates a new sampler with the registered vocabulary and model.