rwkv

package module
v0.0.0-...-6a6eeea Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Feb 12, 2023 License: BSD-2-Clause Imports: 6 Imported by: 3

README

RWKV

RWKV (Receptance Weighted Key Value) is a RNN with Transformer-level performance without the quadratic attention mechanism: only the hidden state at the current position is needed to calculate the state at the next position.

RWKV is designed to perform inference efficiently, even on CPUs, so it is well-suited to run LLM (Large Language Model) on normal consumer hardware at decent speed.

This implementation is written in Go and utilizes the Spago machine learning framework.

How it works

Currently, there are no research papers that describe this neural architecture. The majority of the information can be found in the original codebase of RWKV's author, PENG Bo (BlinkDL on GitHub).

Roughly speaking,

  • it uses a method similar to an "exponential moving average" to gather contextual information by alternating time-mix and channel-mix layers. The layers decay at different rates, which helps the network remember important information for longer periods of time as it processes the input sequence.
  • the time-mix is inspired by Apple's AFT. The channel-mix is inspired by GeGLU.
  • it uses careful parameters initialization to get fast convergence (orthogonal matrices with proper scaling and special time curves).

Installation

Requirements:

Clone this repo or get the library:

go get -u github.com/nlpodyssey/rwkv

The library is optimized to run in x86-64 CPUs. If you want to run it on a different architecture, you can use the GOARCH=amd64 environment variable.

Roadmap

  • Parameters initialization (essential)
  • Unit tests
  • Documentation
  • Gob serialization for large models
  • Model optimization

Credits

References

@software{peng_bo_2021_5196578,
  author       = {PENG Bo},
  title        = {BlinkDL/RWKV-LM: 0.01},
  month        = aug,
  year         = 2021,
  publisher    = {Zenodo},
  version      = {0.01},
  doi          = {10.5281/zenodo.5196577},
  url          = {https://doi.org/10.5281/zenodo.5196577}
}

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type ChannelMix

type ChannelMix struct {
	nn.Module

	Key        nn.Param `spago:"type:weights"`
	Value      nn.Param `spago:"type:weights"`
	Receptance nn.Param `spago:"type:weights"`

	TimeMixK nn.Param `spago:"type:weights"`
	TimeMixR nn.Param `spago:"type:weights"`
}

ChannelMix implements the channel mix module.

func NewChannelMix

func NewChannelMix[T float.DType](c Config, _ int) *ChannelMix

func (*ChannelMix) ForwardSequence

func (m *ChannelMix) ForwardSequence(x []ag.Node, state *LayerState) (rkv []ag.Node)

ForwardSequence performs the forward step for a sequence of nodes. The state is updated with the last node of the sequence.

func (*ChannelMix) ForwardSingle

func (m *ChannelMix) ForwardSingle(x ag.Node, state *LayerState) (rkv ag.Node)

ForwardSingle performs the forward step for a single node.

type Config

type Config struct {
	DModel       int
	NumLayers    int
	RescaleLayer int
}

Config is the configuration of the RWKV model.

type Layer

type Layer struct {
	nn.Module

	LN0 *layernorm.Model
	LN1 *layernorm.Model
	LN2 *layernorm.Model

	ChanMix *ChannelMix
	TimeMix *TimeMix

	ID int
}

Layer is a single block of the RWKV model.

func NewLayer

func NewLayer[T float.DType](c Config, id int) *Layer

NewLayer returns a new RWKV layer.

func (*Layer) ForwardSequence

func (m *Layer) ForwardSequence(x []ag.Node, state *LayerState) []ag.Node

func (*Layer) ForwardSingle

func (m *Layer) ForwardSingle(x ag.Node, state *LayerState) ag.Node

type LayerState

type LayerState struct {
	FfnXX ag.Node
	AttXX ag.Node
	AttAA ag.Node
	AttBB ag.Node
	AttPP ag.Node
}

LayerState is the RWKV state for a single layer.

type Model

type Model struct {
	nn.Module
	Layers []*Layer
	Config Config
}

Model implements the RWKV neural network.

func New

func New[T float.DType](c Config) *Model

New returns a new RWKV model.

func (*Model) ForwardSequence

func (m *Model) ForwardSequence(x []ag.Node, state State) ([]ag.Node, State)

ForwardSequence performs the forward step for the entire sequence, just a bit more optimized. It is equivalent to calling ForwardSingle for each element of the sequence, for example:

var x ag.Node
for _, e := range encoded {
	x, s = m.ForwardSingle(e, s)
}
return x, s

It returns the last computed state.

func (*Model) ForwardSingle

func (m *Model) ForwardSingle(x ag.Node, state State) (ag.Node, State)

ForwardSingle performs the forward step for a single element of the sequence.

type State

type State []*LayerState

func NewState

func NewState(c Config) State

NewState returns a new RWKV state.

type TimeMix

type TimeMix struct {
	nn.Module

	Key        nn.Param `spago:"type:weights"`
	Value      nn.Param `spago:"type:weights"`
	Receptance nn.Param `spago:"type:weights"`
	Output     nn.Param `spago:"type:weights"`

	TimeDecay nn.Param `spago:"type:weights"`
	TimeFirst nn.Param `spago:"type:weights"`
	TimeMixK  nn.Param `spago:"type:weights"`
	TimeMixV  nn.Param `spago:"type:weights"`
	TimeMixR  nn.Param `spago:"type:weights"`

	Config Config
}

TimeMix is a model that implements the TimeMix component.

func NewTimeMix

func NewTimeMix[T float.DType](c Config, _ int) *TimeMix

func (*TimeMix) ForwardSequence

func (m *TimeMix) ForwardSequence(x []ag.Node, state *LayerState) []ag.Node

ForwardSequence performs the forward step for a sequence of inputs. The state is updated at the end of the sequence.

func (*TimeMix) ForwardSingle

func (m *TimeMix) ForwardSingle(x ag.Node, state *LayerState) ag.Node

ForwardSingle performs the forward step for a single input.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL