barthead

package
v0.2.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 30, 2020 License: BSD-2-Clause Imports: 11 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Classification

type Classification struct {
	Config ClassificationConfig
	*stack.Model
}

Classification is a model for BART head for sentence-level classification tasks.

func NewClassification

func NewClassification(config ClassificationConfig) *Classification

NewClassification returns a new Classification.

type ClassificationConfig

type ClassificationConfig struct {
	InputSize     int
	HiddenSize    int
	OutputSize    int
	PoolerDropout float64
}

ClassificationConfig provides configuration settings for a BART head for sentence-level Classification model.

type SequenceClassification

type SequenceClassification struct {
	nn.BaseModel
	BART           *bart.Model
	Classification *Classification
}

SequenceClassification is a model for sentence-level classification tasks which embeds a BART pre-trained model.

func LoadModelForSequenceClassification

func LoadModelForSequenceClassification(modelPath string) (*SequenceClassification, error)

LoadModelForSequenceClassification loads a SequenceClassification model from file.

func NewSequenceClassification

func NewSequenceClassification(config bartconfig.Config, embeddingsPath string) *SequenceClassification

NewSequenceClassification returns a new SequenceClassification.

func (*SequenceClassification) Classify

func (m *SequenceClassification) Classify(in interface{}) ag.Node

Classify performs the classification using the last transformed state.

func (*SequenceClassification) Close

func (m *SequenceClassification) Close()

Close closes the BART model's embeddings DB.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL