tg

package
v0.0.0-...-27647ab Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Oct 17, 2023 License: BSD-3-Clause, MIT Imports: 10 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

View Source
var NdDot = NdDotFma

var NdDot = NdDotSep

Functions

func NdDotFma

func NdDotFma(a, b nd.NdArray[float32]) nd.NdArray[float32]

func NdDotSep

func NdDotSep(a, b nd.NdArray[float32]) nd.NdArray[float32]

func NdTensorDot

func NdTensorDot(a, b nd.NdArray[float32], axes_a, axes_b nd.Dims) nd.NdArray[float32]

Types

type AddFunc

type AddFunc struct{}

func (AddFunc) Backward

func (a AddFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (AddFunc) Forward

func (a AddFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type Buffer

type Buffer struct {
	// contains filtered or unexported fields
}

func BufferOf

func BufferOf(shape Shape, s []float32) *Buffer

func BufferOfNd

func BufferOfNd(a nd.NdArray[float32]) *Buffer

func MakeConstBuffer

func MakeConstBuffer(c float32, shape Shape) *Buffer

func NewBuffer

func NewBuffer(shape Shape) *Buffer

func (*Buffer) BinaryOp

func (b *Buffer) BinaryOp(op Op, y *Buffer) *Buffer

func (*Buffer) Dot

func (b *Buffer) Dot(y *Buffer, baxes, yaxes []int) *Buffer

dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m]) tensordot, really

func (*Buffer) Expand

func (b *Buffer) Expand(newShape Shape) *Buffer

func (*Buffer) Flip

func (b *Buffer) Flip(axes []Dim) *Buffer

func (*Buffer) Get

func (b *Buffer) Get(idxs ...Dim) float32

func (*Buffer) MovementOp

func (b *Buffer) MovementOp(op Op, arg any) *Buffer

func (*Buffer) Nd

func (b *Buffer) Nd() nd.NdArray[float32]

func (*Buffer) Pad

func (b *Buffer) Pad(padding []SliceBound) *Buffer

func (*Buffer) Permute

func (b *Buffer) Permute(order []Dim) *Buffer

func (*Buffer) ProcessingOp

func (b *Buffer) ProcessingOp(op Op, w *Buffer, arg any) *Buffer

func (*Buffer) ReduceOp

func (b *Buffer) ReduceOp(op Op, nsh Shape) *Buffer

func (*Buffer) Reshape

func (b *Buffer) Reshape(newShape Shape) *Buffer

func (*Buffer) Shape

func (b *Buffer) Shape() Shape

func (*Buffer) Slice

func (b *Buffer) Slice(bounds ...SliceBound) *Buffer

func (*Buffer) Strides

func (b *Buffer) Strides() Strides

func (*Buffer) Transpose

func (b *Buffer) Transpose(order []Dim) *Buffer

func (*Buffer) UnaryOp

func (b *Buffer) UnaryOp(op Op) *Buffer

type Conv2dFunc

type Conv2dFunc struct {
	// contains filtered or unexported fields
}

func (Conv2dFunc) Backward

func (f Conv2dFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (Conv2dFunc) Forward

func (f Conv2dFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type ConvArgs

type ConvArgs struct {
	// contains filtered or unexported fields
}

func BuildConvArgs

func BuildConvArgs(xsh, wsh Shape, opts ConvOpts) ConvArgs

type ConvOpts

type ConvOpts struct {
	Stride   []Dim
	Groups   Dim
	Padding  []Dim
	Dilation []Dim
	OutShape Shape
}

type Dim

type Dim = int64

type ExpFunc

type ExpFunc struct{}

func (ExpFunc) Backward

func (f ExpFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (ExpFunc) Forward

func (f ExpFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type ExpandFunc

type ExpandFunc struct {
	Shape Shape
}

func (ExpandFunc) Backward

func (f ExpandFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (ExpandFunc) Forward

func (f ExpandFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type Func

type Func interface {
	Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer
	Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer
}

type FuncContext

type FuncContext struct {
	// contains filtered or unexported fields
}

func NewFuncContext

func NewFuncContext(fn Func, parents []*Tensor) *FuncContext

type Lazy

type Lazy interface {
	// contains filtered or unexported methods
}

type LazyBuffer

type LazyBuffer struct {
	bt.NoCopy
	// contains filtered or unexported fields
}

func ElementwiseOp

func ElementwiseOp(op Op, srcs ...*LazyBuffer) *LazyBuffer

func MakeLoadBuffer

func MakeLoadBuffer(data *Buffer) *LazyBuffer

func MakeLoadConstBuffer

func MakeLoadConstBuffer(c float32, shape Shape) *LazyBuffer

func NewLazyBuffer

func NewLazyBuffer(st *ShapeTracker, ot OpType, op *LazyOp) *LazyBuffer

func (*LazyBuffer) BinaryOp

func (b *LazyBuffer) BinaryOp(op Op, y *LazyBuffer) *LazyBuffer

func (*LazyBuffer) MovementOp

func (b *LazyBuffer) MovementOp(op Op, arg any) *LazyBuffer

func (*LazyBuffer) ProcessingOp

func (b *LazyBuffer) ProcessingOp(op Op, w *LazyBuffer, arg any) *LazyBuffer

func (*LazyBuffer) Realize

func (b *LazyBuffer) Realize() *Buffer

func (*LazyBuffer) ReduceOp

func (b *LazyBuffer) ReduceOp(op Op, newShape Shape) *LazyBuffer

func (*LazyBuffer) Shape

func (b *LazyBuffer) Shape() Shape

func (*LazyBuffer) UnaryOp

func (b *LazyBuffer) UnaryOp(op Op) *LazyBuffer

type LazyOp

type LazyOp struct {
	bt.NoCopy
	// contains filtered or unexported fields
}

func (*LazyOp) ForEachBuffer

func (o *LazyOp) ForEachBuffer(fn func(*LazyBuffer) bool) bool

func (*LazyOp) ForEachOp

func (o *LazyOp) ForEachOp(fn func(*LazyOp) bool) bool

func (*LazyOp) GetBuffers

func (o *LazyOp) GetBuffers() []*LazyBuffer

func (*LazyOp) GetOps

func (o *LazyOp) GetOps() []*LazyOp

type LogFunc

type LogFunc struct{}

func (LogFunc) Backward

func (f LogFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (LogFunc) Forward

func (f LogFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type MaxFunc

type MaxFunc struct {
	// contains filtered or unexported fields
}

func (MaxFunc) Backward

func (f MaxFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (MaxFunc) Forward

func (f MaxFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type MulFunc

type MulFunc struct{}

func (MulFunc) Backward

func (a MulFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (MulFunc) Forward

func (a MulFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type Op

type Op uint8
const (
	InvalidOp Op = iota

	NoOp
	NegOp
	ReluOp
	ExpOp
	LogOp
	SignOp

	AddOp
	SubOp
	MulOp
	DivOp
	PowOp
	CmpEqOp

	SumOp
	MaxOp

	ReshapeOp
	PermuteOp
	SliceOp
	ExpandOp
	FlipOp

	ConvOp

	FromCpuOp
)

func (Op) String

func (o Op) String() string

func (Op) Type

func (o Op) Type() OpType

type OpType

type OpType uint8
const (
	InvalidOpType OpType = iota
	UnaryOpType
	BinaryOpType
	ReduceOpType
	MovementOpType
	ProcessingOpType
	LoadOpType
)

func (OpType) String

func (t OpType) String() string

type PermuteFunc

type PermuteFunc struct {
	Order []Dim
}

func (PermuteFunc) Backward

func (f PermuteFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (PermuteFunc) Forward

func (f PermuteFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type PowFunc

type PowFunc struct{}

func (PowFunc) Backward

func (a PowFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (PowFunc) Forward

func (a PowFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type RealizedOp

type RealizedOp struct {
	// contains filtered or unexported fields
}

type ReluFunc

type ReluFunc struct{}

func (ReluFunc) Backward

func (f ReluFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (ReluFunc) Forward

func (f ReluFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type ReshapeFunc

type ReshapeFunc struct {
	Shape Shape
}

func (ReshapeFunc) Backward

func (f ReshapeFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (ReshapeFunc) Forward

func (f ReshapeFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type Shape

type Shape []Dim

func (Shape) At

func (sh Shape) At(i int) Dim

func (Shape) Dim

func (sh Shape) Dim() Dim

func (Shape) Equals

func (sh Shape) Equals(o Shape) bool

type ShapeStride

type ShapeStride struct {
	Shape, Stride Dim
}

func ToShapeStrides

func ToShapeStrides(shape Shape, strides Strides) []ShapeStride

type ShapeTracker

type ShapeTracker struct {
	// contains filtered or unexported fields
}

func NewShapeTracker

func NewShapeTracker(shape Shape) *ShapeTracker

func (*ShapeTracker) Clone

func (st *ShapeTracker) Clone() *ShapeTracker

func (*ShapeTracker) Contiguous

func (st *ShapeTracker) Contiguous() bool

func (*ShapeTracker) Expand

func (st *ShapeTracker) Expand(newShape Shape)

func (*ShapeTracker) Flip

func (st *ShapeTracker) Flip(axes ...Dim)

func (*ShapeTracker) MovementOp

func (st *ShapeTracker) MovementOp(op Op, arg any)

func (*ShapeTracker) Offset

func (st *ShapeTracker) Offset() Dim

func (*ShapeTracker) Permute

func (st *ShapeTracker) Permute(axis ...Dim)

func (*ShapeTracker) Reshape

func (st *ShapeTracker) Reshape(newShape Shape)

func (*ShapeTracker) Shape

func (st *ShapeTracker) Shape() Shape

func (*ShapeTracker) Stride

func (st *ShapeTracker) Stride(mul ...Dim)

func (*ShapeTracker) Strides

func (st *ShapeTracker) Strides() Strides

type SliceBound

type SliceBound struct {
	Start, Stop Dim
}

type Strides

type Strides []Dim

func StridesForShape

func StridesForShape(shape Shape) Strides

func (Strides) Offset

func (st Strides) Offset(idxs ...Dim) Dim

type SubFunc

type SubFunc struct{}

func (SubFunc) Backward

func (a SubFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (SubFunc) Forward

func (a SubFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type SumFunc

type SumFunc struct {
	Axis []int
}

func (SumFunc) Backward

func (f SumFunc) Backward(ctx *FuncContext, g *LazyBuffer) []*LazyBuffer

func (SumFunc) Forward

func (f SumFunc) Forward(ctx *FuncContext, bs []*LazyBuffer) *LazyBuffer

type Tensor

type Tensor struct {
	// contains filtered or unexported fields
}

func Apply

func Apply(fn Func, parents []*Tensor) *Tensor

func BroadcastedTensor

func BroadcastedTensor(fn func(x, y *Tensor) *Tensor, x, y *Tensor) *Tensor

func NewTensor

func NewTensor(data *LazyBuffer, requiresGrad bool) *Tensor

func (*Tensor) Add

func (t *Tensor) Add(y *Tensor) *Tensor

func (*Tensor) Assign

func (t *Tensor) Assign(x *Tensor)

func (*Tensor) Backward

func (t *Tensor) Backward()

func (*Tensor) ClearGrad

func (t *Tensor) ClearGrad()

func (*Tensor) Conv2d

func (t *Tensor) Conv2d(weight *Tensor, bias *Tensor, opts ConvOpts) *Tensor

func (*Tensor) Data

func (t *Tensor) Data() *LazyBuffer

func (*Tensor) Dot

func (t *Tensor) Dot(w *Tensor) *Tensor

func (*Tensor) Exp

func (t *Tensor) Exp() *Tensor

func (*Tensor) Expand

func (t *Tensor) Expand(shape Shape) *Tensor

func (*Tensor) FullSoftmax

func (t *Tensor) FullSoftmax() TensorSoftmax

func (*Tensor) Grad

func (t *Tensor) Grad() *Tensor

func (*Tensor) Log

func (t *Tensor) Log() *Tensor

func (*Tensor) LogSoftmax

func (t *Tensor) LogSoftmax() *Tensor

func (*Tensor) Matmul

func (t *Tensor) Matmul(w *Tensor) *Tensor

func (*Tensor) Max

func (t *Tensor) Max(axis []int, keepDim bool) *Tensor

func (*Tensor) Mean

func (t *Tensor) Mean(axis []int, keepDim bool) *Tensor

func (*Tensor) Mul

func (t *Tensor) Mul(y *Tensor) *Tensor

func (*Tensor) Permute

func (t *Tensor) Permute(order []Dim) *Tensor

func (*Tensor) Relu

func (t *Tensor) Relu() *Tensor

func (*Tensor) RequiresGrad

func (t *Tensor) RequiresGrad() bool

func (*Tensor) Reshape

func (t *Tensor) Reshape(shape Shape) *Tensor

func (*Tensor) Shape

func (t *Tensor) Shape() Shape

func (*Tensor) Sub

func (t *Tensor) Sub(y *Tensor) *Tensor

func (*Tensor) Sum

func (t *Tensor) Sum(axis []int, keepDim bool) *Tensor

func (*Tensor) Transpose

func (t *Tensor) Transpose(order []Dim) *Tensor

type TensorSoftmax

type TensorSoftmax struct {
	M, E, Ss *Tensor
}

type View

type View struct {
	// contains filtered or unexported fields
}

func NewView

func NewView(shape Shape, strides Strides, offset Dim) View

func ViewFromShape

func ViewFromShape(shape Shape) View

func (*View) BaseStrides

func (v *View) BaseStrides() Strides

func (*View) Contiguous

func (v *View) Contiguous() bool

func (View) Offset

func (v View) Offset() Dim

func (View) Shape

func (v View) Shape() Shape

func (View) ShapeStrides

func (v View) ShapeStrides() []ShapeStride

func (View) Strides

func (v View) Strides() Strides

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL