pipelines

package
v0.1.4 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jul 18, 2024 License: Apache-2.0 Imports: 16 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type ClassificationOutput

type ClassificationOutput struct {
	Label string
	Score float32
}

type Entity added in v0.0.5

type Entity struct {
	Entity    string
	Score     float32
	Scores    []float32
	Index     int
	Word      string
	TokenID   uint32
	Start     uint
	End       uint
	IsSubword bool
}

type FeatureExtractionOutput

type FeatureExtractionOutput struct {
	Embeddings [][]float32
}

func (*FeatureExtractionOutput) GetOutput added in v0.0.5

func (t *FeatureExtractionOutput) GetOutput() []any

type FeatureExtractionPipeline

type FeatureExtractionPipeline struct {
	Normalization bool
	OutputName    string
	Output        ort.InputOutputInfo
	// contains filtered or unexported fields
}

FeatureExtractionPipeline A feature extraction pipeline is a go version of https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/feature_extraction.py

func NewFeatureExtractionPipeline

func NewFeatureExtractionPipeline(config PipelineConfig[*FeatureExtractionPipeline], ortOptions *ort.SessionOptions) (*FeatureExtractionPipeline, error)

NewFeatureExtractionPipeline init a feature extraction pipeline.

func (*FeatureExtractionPipeline) Destroy added in v0.1.4

func (p *FeatureExtractionPipeline) Destroy() error

Destroy frees the feature extraction pipeline resources.

func (*FeatureExtractionPipeline) Forward added in v0.1.4

func (p *FeatureExtractionPipeline) Forward(batch *PipelineBatch) error

Forward performs the forward inference of the feature extraction pipeline.

func (*FeatureExtractionPipeline) GetMetadata added in v0.1.4

GetMetadata returns metadata information about the pipeline, in particular: OutputInfo: names and dimensions of the output layer.

func (*FeatureExtractionPipeline) GetStats added in v0.1.4

func (p *FeatureExtractionPipeline) GetStats() []string

GetStats returns the runtime statistics for the pipeline.

func (*FeatureExtractionPipeline) Postprocess

Postprocess parses the first output from the network similar to the transformers implementation.

func (*FeatureExtractionPipeline) Preprocess added in v0.1.4

func (p *FeatureExtractionPipeline) Preprocess(batch *PipelineBatch, inputs []string) error

Preprocess tokenizes the input strings.

func (*FeatureExtractionPipeline) Run

Run the pipeline on a batch of strings.

func (*FeatureExtractionPipeline) RunPipeline added in v0.0.6

func (p *FeatureExtractionPipeline) RunPipeline(inputs []string) (*FeatureExtractionOutput, error)

RunPipeline is like Run, but returns the concrete feature extraction output type rather than the interface.

func (*FeatureExtractionPipeline) Validate added in v0.0.5

func (p *FeatureExtractionPipeline) Validate() error

Validate checks that the pipeline is valid.

type OutputInfo added in v0.1.4

type OutputInfo struct {
	Name       string
	Dimensions []int64
}

type Pipeline

type Pipeline interface {
	Destroy() error                            // Destroy the pipeline along with its onnx session
	GetStats() []string                        // Get the pipeline running stats
	Validate() error                           // Validate the pipeline for correctness
	GetMetadata() PipelineMetadata             // Return metadata information for the pipeline
	Run([]string) (PipelineBatchOutput, error) // Run the pipeline on an input
}

Pipeline is the interface that any pipeline must implement.

type PipelineBatch

type PipelineBatch struct {
	Input             []tokenizedInput
	InputTensors      []*ort.Tensor[int64]
	MaxSequenceLength int
	OutputTensors     []*ort.Tensor[float32]
}

PipelineBatch represents a batch of inputs that runs through the pipeline.

func NewBatch added in v0.1.4

func NewBatch() *PipelineBatch

NewBatch initializes a new batch for inference.

func (*PipelineBatch) Destroy added in v0.1.4

func (b *PipelineBatch) Destroy() error

type PipelineBatchOutput added in v0.0.5

type PipelineBatchOutput interface {
	GetOutput() []any
}

type PipelineConfig added in v0.0.9

type PipelineConfig[T Pipeline] struct {
	ModelPath    string
	Name         string
	OnnxFilename string
	Options      []PipelineOption[T]
}

PipelineConfig is a configuration for a pipeline type that can be used to create that pipeline.

type PipelineMetadata added in v0.1.4

type PipelineMetadata struct {
	OutputsInfo []OutputInfo
}

type PipelineOption added in v0.0.9

type PipelineOption[T Pipeline] func(eo T)

PipelineOption is an option for a pipeline type.

func WithHypothesisTemplate added in v0.1.4

func WithHypothesisTemplate(hypothesisTemplate string) PipelineOption[*ZeroShotClassificationPipeline]

WithHypothesisTemplate can be used to set the hypothesis template for classification.

func WithIgnoreLabels

func WithIgnoreLabels(ignoreLabels []string) PipelineOption[*TokenClassificationPipeline]

func WithLabels added in v0.1.4

WithLabels can be used to set the labels to classify the examples.

func WithMultiLabel added in v0.0.9

func WithMultiLabel() PipelineOption[*TextClassificationPipeline]

func WithMultilabel added in v0.1.4

func WithMultilabel(multilabel bool) PipelineOption[*ZeroShotClassificationPipeline]

WithMultilabel can be used to set whether the pipeline is multilabel.

func WithNormalization added in v0.1.1

func WithNormalization() PipelineOption[*FeatureExtractionPipeline]

WithNormalization applies normalization to the mean pooled output of the feature pipeline.

func WithOutputName added in v0.1.4

func WithOutputName(outputName string) PipelineOption[*FeatureExtractionPipeline]

WithOutputName if there are multiple outputs from the underlying model, which output should be returned. If not passed, the first output from the feature pipeline is returned.

func WithSigmoid added in v0.0.9

func WithSimpleAggregation

func WithSimpleAggregation() PipelineOption[*TokenClassificationPipeline]

WithSimpleAggregation sets the aggregation strategy for the token labels to simple It reproduces simple aggregation from the huggingface implementation.

func WithSingleLabel added in v0.0.9

func WithSingleLabel() PipelineOption[*TextClassificationPipeline]

func WithSoftmax added in v0.0.9

func WithoutAggregation

func WithoutAggregation() PipelineOption[*TokenClassificationPipeline]

WithoutAggregation returns the token labels.

type TextClassificationOption

type TextClassificationOption func(eo *TextClassificationPipeline)

type TextClassificationOutput

type TextClassificationOutput struct {
	ClassificationOutputs [][]ClassificationOutput
}

func (*TextClassificationOutput) GetOutput added in v0.0.5

func (t *TextClassificationOutput) GetOutput() []any

type TextClassificationPipeline

type TextClassificationPipeline struct {
	IDLabelMap              map[int]string
	AggregationFunctionName string
	ProblemType             string
	// contains filtered or unexported fields
}

func NewTextClassificationPipeline

func NewTextClassificationPipeline(config PipelineConfig[*TextClassificationPipeline], ortOptions *ort.SessionOptions) (*TextClassificationPipeline, error)

NewTextClassificationPipeline initializes a new text classification pipeline.

func (*TextClassificationPipeline) Destroy added in v0.1.4

func (p *TextClassificationPipeline) Destroy() error

Destroy frees the text classification pipeline resources.

func (*TextClassificationPipeline) Forward

func (p *TextClassificationPipeline) Forward(batch *PipelineBatch) error

func (*TextClassificationPipeline) GetMetadata added in v0.1.4

GetMetadata returns metadata information about the pipeline, in particular: OutputInfo: names and dimensions of the output layer used for text classification.

func (*TextClassificationPipeline) GetStats added in v0.1.4

func (p *TextClassificationPipeline) GetStats() []string

GetStats returns the runtime statistics for the pipeline.

func (*TextClassificationPipeline) Postprocess

func (*TextClassificationPipeline) Preprocess added in v0.1.4

func (p *TextClassificationPipeline) Preprocess(batch *PipelineBatch, inputs []string) error

Preprocess tokenizes the input strings.

func (*TextClassificationPipeline) Run

Run the pipeline on a string batch.

func (*TextClassificationPipeline) RunPipeline added in v0.0.6

func (p *TextClassificationPipeline) RunPipeline(inputs []string) (*TextClassificationOutput, error)

func (*TextClassificationPipeline) Validate added in v0.0.5

func (p *TextClassificationPipeline) Validate() error

Validate checks that the pipeline is valid.

type TextClassificationPipelineConfig

type TextClassificationPipelineConfig struct {
	IDLabelMap map[int]string `json:"id2label"`
}

type TokenClassificationOutput

type TokenClassificationOutput struct {
	Entities [][]Entity
}

func (*TokenClassificationOutput) GetOutput added in v0.0.5

func (t *TokenClassificationOutput) GetOutput() []any

type TokenClassificationPipeline

type TokenClassificationPipeline struct {
	IDLabelMap          map[int]string
	AggregationStrategy string
	IgnoreLabels        []string
	// contains filtered or unexported fields
}

TokenClassificationPipeline is a go version of huggingface tokenClassificationPipeline. https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/token_classification.py

func NewTokenClassificationPipeline

func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPipeline], ortOptions *ort.SessionOptions) (*TokenClassificationPipeline, error)

NewTokenClassificationPipeline Initializes a feature extraction pipeline.

func (*TokenClassificationPipeline) Aggregate

func (p *TokenClassificationPipeline) Aggregate(input tokenizedInput, preEntities []Entity) ([]Entity, error)

func (*TokenClassificationPipeline) Destroy added in v0.1.4

func (p *TokenClassificationPipeline) Destroy() error

Destroy frees the feature extraction pipeline resources.

func (*TokenClassificationPipeline) Forward added in v0.1.4

func (p *TokenClassificationPipeline) Forward(batch *PipelineBatch) error

Forward performs the forward inference of the pipeline.

func (*TokenClassificationPipeline) GatherPreEntities

func (p *TokenClassificationPipeline) GatherPreEntities(input tokenizedInput, output [][]float32) []Entity

GatherPreEntities from batch of logits to list of pre-aggregated outputs

func (*TokenClassificationPipeline) GetMetadata added in v0.1.4

GetMetadata returns metadata information about the pipeline, in particular: OutputInfo: names and dimensions of the output layer used for token classification.

func (*TokenClassificationPipeline) GetStats added in v0.1.4

func (p *TokenClassificationPipeline) GetStats() []string

GetStats returns the runtime statistics for the pipeline.

func (*TokenClassificationPipeline) GroupEntities

func (p *TokenClassificationPipeline) GroupEntities(entities []Entity) ([]Entity, error)

GroupEntities group together adjacent tokens with the same entity predicted.

func (*TokenClassificationPipeline) Postprocess

Postprocess function for a token classification pipeline.

func (*TokenClassificationPipeline) Preprocess added in v0.1.4

func (p *TokenClassificationPipeline) Preprocess(batch *PipelineBatch, inputs []string) error

Preprocess tokenizes the input strings.

func (*TokenClassificationPipeline) Run

Run the pipeline on a string batch.

func (*TokenClassificationPipeline) RunPipeline added in v0.0.6

RunPipeline is like Run but returns the concrete type rather than the interface.

func (*TokenClassificationPipeline) Validate added in v0.0.5

func (p *TokenClassificationPipeline) Validate() error

Validate checks that the pipeline is valid.

type TokenClassificationPipelineConfig

type TokenClassificationPipelineConfig struct {
	IDLabelMap map[int]string `json:"id2label"`
}

type ZeroShotClassificationOutput added in v0.1.4

type ZeroShotClassificationOutput struct {
	Sequence     string
	SortedValues []struct {
		Key   string
		Value float64
	}
}

type ZeroShotClassificationPipeline added in v0.1.4

type ZeroShotClassificationPipeline struct {
	IDLabelMap         map[int]string
	Sequences          []string
	Labels             []string
	HypothesisTemplate string
	Multilabel         bool
	// contains filtered or unexported fields
}

func NewZeroShotClassificationPipeline added in v0.1.4

func NewZeroShotClassificationPipeline(config PipelineConfig[*ZeroShotClassificationPipeline], ortOptions *ort.SessionOptions) (*ZeroShotClassificationPipeline, error)

NewZeroShotClassificationPipeline create new Zero Shot Classification Pipeline.

func (*ZeroShotClassificationPipeline) Destroy added in v0.1.4

func (*ZeroShotClassificationPipeline) Forward added in v0.1.4

func (*ZeroShotClassificationPipeline) GetMetadata added in v0.1.4

func (*ZeroShotClassificationPipeline) GetStats added in v0.1.4

func (p *ZeroShotClassificationPipeline) GetStats() []string

func (*ZeroShotClassificationPipeline) Postprocess added in v0.1.4

func (p *ZeroShotClassificationPipeline) Postprocess(outputTensors [][][]float32, labels []string, sequences []string) (*ZeroShotOutput, error)

func (*ZeroShotClassificationPipeline) Preprocess added in v0.1.4

func (p *ZeroShotClassificationPipeline) Preprocess(batch *PipelineBatch, inputs []string) error

func (*ZeroShotClassificationPipeline) Run added in v0.1.4

func (*ZeroShotClassificationPipeline) RunPipeline added in v0.1.4

func (p *ZeroShotClassificationPipeline) RunPipeline(inputs []string) (*ZeroShotOutput, error)

func (*ZeroShotClassificationPipeline) Validate added in v0.1.4

func (p *ZeroShotClassificationPipeline) Validate() error

type ZeroShotClassificationPipelineConfig added in v0.1.4

type ZeroShotClassificationPipelineConfig struct {
	IDLabelMap map[int]string `json:"id2label"`
}

type ZeroShotOutput added in v0.1.4

type ZeroShotOutput struct {
	ClassificationOutputs []ZeroShotClassificationOutput
}

func (*ZeroShotOutput) GetOutput added in v0.1.4

func (t *ZeroShotOutput) GetOutput() []any

GetOutput converts raw output to readable output.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL