rwkvlm

package
v0.0.0-...-4d121eb Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Feb 22, 2023 License: BSD-2-Clause Imports: 23 Imported by: 0

Documentation

Index

Constants

View Source
const (
	DefaultPyModelFilename   = "pytorch_model.pt"
	DefaultOutputFilename    = "spago_model.bin"
	DefaultEmbeddingRepoPath = "embeddings"

	DefaultLayerNormEps = 1e-5
)

Variables

This section is empty.

Functions

func ConvertPickledModelToRWKVLM

func ConvertPickledModelToRWKVLM[T float.DType](config ConverterConfig) error

ConvertPickledModelToRWKVLM converts a PyTorch model to a RWKVLM model. It expects a configuration file "config.json" in the same directory as the model file containing the model configuration.

func Dump

func Dump(obj *Model, filename string) error

Dump saves the Model to a file. See gobEncode for further details.

Types

type Config

type Config struct {
	// DModel primarily corresponds to the embedding size.
	//
	// When converting a torch model, it can be left zero, letting the
	// process deduce the value automatically.
	DModel int `json:"d_model"`
	// NumHiddenLayers is the number of hidden layers.
	//
	// When converting a torch model, it can be left zero, letting the
	// process deduce the value automatically.
	NumHiddenLayers int `json:"num_hidden_layers"`
	// VocabSize is the vocabulary size.
	//
	// When converting a torch model, it can be left zero, letting the
	// process deduce the value automatically.
	VocabSize           int    `json:"vocab_size"`
	RescaleLayer        int    `json:"rescale_layer"`
	EmbeddingsStoreName string `json:"embeddings_store_name"`
}

func LoadConfig

func LoadConfig(filePath string) (Config, error)

type ConverterConfig

type ConverterConfig struct {
	// The path to the directory where the models will be read from and written to.
	ModelDir string
	// The path to the input model file (default "pytorch_model.pt")
	PyModelFilename string
	// The path to the output model file (default "spago_model.bin")
	GoModelFilename string
	// The path to the embedding repository (default "embeddings")
	EmbeddingRepoPath string
	// If true, overwrite the model file if it already exists (default "false")
	OverwriteIfExist bool
}

type Embeddings

type Embeddings struct {
	nn.Module
	Tokens *emb.Model[int]
	Config Config
}

Embeddings embeds the token embeddings.

func NewEmbeddings

func NewEmbeddings[T float.DType](c emb.Config, repo store.Repository) *Embeddings

NewEmbeddings returns a new embedding module.

func (*Embeddings) Encode

func (m *Embeddings) Encode(tokens []int) []ag.Node

Encode performs the input encoding.

type Model

type Model struct {
	nn.Module
	Embeddings *Embeddings
	Encoder    *rwkv.Model
	LN         *layernorm.Model
	Linear     nn.Param `spago:"type:weights"`
	Config     Config
}

func Load

func Load(dir string) (*Model, error)

Load loads a pre-trained model from the given path.

func New

func New[T float.DType](c Config, repo store.Repository) *Model

func (*Model) ApplyEmbeddings

func (m *Model) ApplyEmbeddings(repo *diskstore.Repository) (err error)

ApplyEmbeddings sets the embeddings of the model.

func (*Model) Encode

func (m *Model) Encode(ctx context.Context, s rwkv.State, tokens ...int) (ag.Node, rwkv.State)

Encode performs EncodeTokens and EncodeEmbeddings.

func (*Model) EncodeEmbeddings

func (m *Model) EncodeEmbeddings(_ context.Context, s rwkv.State, xs []ag.Node) (ag.Node, rwkv.State)

EncodeEmbeddings returns the encoding of the given input considering the last state. At least one token is required, otherwise can panic. If the input is a sequence, the last state is returned.

func (*Model) EncodeTokens

func (m *Model) EncodeTokens(_ context.Context, tokens ...int) []ag.Node

EncodeTokens returns the embeddings of the given tokens.

func (*Model) Predict

func (m *Model) Predict(x ag.Node) ag.Node

Predict returns the prediction logits of the next token.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL