tabnet

package
v0.0.0-...-61db93c Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 14, 2022 License: Apache-2.0 Imports: 8 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type AttentiveTransformerModule

type AttentiveTransformerModule struct {
	// contains filtered or unexported fields
}

func AttentiveTransformer

func AttentiveTransformer(nn *godl.Model, opts AttentiveTransformerOpts) *AttentiveTransformerModule

AttentiveTransformer implements an attetion transformer layer

func (*AttentiveTransformerModule) Forward

func (m *AttentiveTransformerModule) Forward(inputs ...*godl.Node) godl.Nodes

type AttentiveTransformerOpts

type AttentiveTransformerOpts struct {
	InputDimension                   int
	OutputDimension                  int
	Momentum                         float64
	Epsilon                          float64
	VirtualBatchSize                 int
	Activation                       activation.Function
	WithBias                         bool
	WeightsInit, ScaleInit, BiasInit gorgonia.InitWFn
}

type Classifier

type Classifier struct {
	// contains filtered or unexported fields
}

func NewClassifier

func NewClassifier(inputDim int, catDims []int, catIdxs []int, catEmbDim []int, opts ClassifierOpts) *Classifier

func (*Classifier) Model

func (r *Classifier) Model() *godl.Model

func (*Classifier) Train

func (r *Classifier) Train(trainX, trainY, validateX, validateY tensor.Tensor, opts godl.TrainOpts) error

type ClassifierOpts

type ClassifierOpts struct {
	BatchSize        int
	VirtualBatchSize int
	MaskFunction     activation.Function
	WithBias         bool

	SharedBlocks       int
	IndependentBlocks  int
	DecisionSteps      int
	PredictionLayerDim int
	AttentionLayerDim  int

	Gamma    float64
	Momentum float64
	Epsilon  float64

	WeightsInit, ScaleInit, BiasInit gorgonia.InitWFn
}

type FeatureTransformerModule

type FeatureTransformerModule struct {
	// contains filtered or unexported fields
}

func FeatureTransformer

func FeatureTransformer(nn *godl.Model, opts FeatureTransformerOpts) *FeatureTransformerModule

FeatureTransformer implements a feature transformer layer

func (*FeatureTransformerModule) Forward

func (m *FeatureTransformerModule) Forward(inputs ...*godl.Node) godl.Nodes

type FeatureTransformerOpts

type FeatureTransformerOpts struct {
	Shared            []*godl.LinearModule
	VirtualBatchSize  int
	IndependentBlocks int
	InputDimension    int
	OutputDimension   int
	WithBias          bool
	Momentum          float64

	WeightsInit gorgonia.InitWFn
}

FeatureTransformerOpts contains options for feature transformer layer

type GLUBlockModule

type GLUBlockModule struct {
	// contains filtered or unexported fields
}

func GLUBlock

func GLUBlock(nn *godl.Model, opts GLUBlockOpts) *GLUBlockModule

func (*GLUBlockModule) Forward

func (m *GLUBlockModule) Forward(inputs ...*godl.Node) godl.Nodes

type GLUBlockOpts

type GLUBlockOpts struct {
	InputDimension   int
	OutputDimension  int
	Shared           []*godl.LinearModule
	VirtualBatchSize int

	Size int

	WithBias    bool
	Momentum    float64
	WeightsInit gorgonia.InitWFn
}

type Regressor

type Regressor struct {
	// contains filtered or unexported fields
}

func NewRegressor

func NewRegressor(inputDim int, catDims []int, catIdxs []int, catEmbDim []int, opts RegressorOpts) *Regressor

func (*Regressor) Solve

func (r *Regressor) Solve(x tensor.Tensor, y tensor.Tensor) (tensor.Tensor, error)

FIXME: this shouldn't receive Y

func (*Regressor) Train

func (r *Regressor) Train(trainX, trainY, validateX, validateY tensor.Tensor, opts godl.TrainOpts) error

type RegressorOpts

type RegressorOpts struct {
	BatchSize        int
	VirtualBatchSize int
	MaskFunction     activation.Function
	WithBias         bool

	SharedBlocks       int
	IndependentBlocks  int
	DecisionSteps      int
	PredictionLayerDim int
	AttentionLayerDim  int

	Gamma    float64
	Momentum float64
	Epsilon  float64

	WeightsInit, ScaleInit, BiasInit gorgonia.InitWFn
}

type TabNetModule

type TabNetModule struct {
	// contains filtered or unexported fields
}

func TabNet

func TabNet(nn *godl.Model, opts TabNetOpts) *TabNetModule

func (*TabNetModule) Forward

func (m *TabNetModule) Forward(inputs ...*godl.Node) godl.Nodes

func (*TabNetModule) Name

func (m *TabNetModule) Name() string

type TabNetNoEmbeddingsModule

type TabNetNoEmbeddingsModule struct {
	// contains filtered or unexported fields
}

func TabNetNoEmbeddings

func TabNetNoEmbeddings(nn *godl.Model, opts TabNetNoEmbeddingsOpts) *TabNetNoEmbeddingsModule

TabNetNoEmbeddings implements the tab net architecture

func (*TabNetNoEmbeddingsModule) Forward

func (m *TabNetNoEmbeddingsModule) Forward(inputs ...*godl.Node) godl.Nodes

type TabNetNoEmbeddingsOpts

type TabNetNoEmbeddingsOpts struct {
	OutputSize int
	InputSize  int
	BatchSize  int

	SharedBlocks       int
	IndependentBlocks  int
	DecisionSteps      int
	PredictionLayerDim int
	AttentionLayerDim  int

	MaskFunction activation.Function

	WithBias bool

	Gamma                            float64
	Momentum                         float64
	Epsilon                          float64
	VirtualBatchSize                 int
	WeightsInit, ScaleInit, BiasInit gorgonia.InitWFn
}

TabNetNoEmbeddingsOpts contains parameters to configure the tab net algorithm

type TabNetOpts

type TabNetOpts struct {
	OutputSize int
	InputSize  int
	BatchSize  int

	SharedBlocks       int
	IndependentBlocks  int
	DecisionSteps      int
	PredictionLayerDim int
	AttentionLayerDim  int

	MaskFunction activation.Function

	WithBias bool

	Gamma                            float64
	Momentum                         float64
	Epsilon                          float64
	VirtualBatchSize                 int
	WeightsInit, ScaleInit, BiasInit gorgonia.InitWFn

	CatDims   []int
	CatIdxs   []int
	CatEmbDim []int
}

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL