Documentation
¶
Overview ¶
Package bert provides an implementation of BERT model (Bidirectional Encoder Representations from Transformers).
Reference: "Attention Is All You Need" by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin (2017) (http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)
Index ¶
- Constants
- func ConvertHuggingFacePreTrained(modelPath string) error
- func Dump(value interface{}, pretty bool) ([]byte, error)
- type Answer
- type AnswerSlice
- type Body
- type ClassConfidencePair
- type Classifier
- type ClassifierConfig
- type ClassifyResponse
- type Config
- type Discriminator
- type DiscriminatorConfig
- type Embeddings
- type EmbeddingsConfig
- type EncodeResponse
- type Encoder
- type EncoderConfig
- type EncoderLayer
- type LabelerOptionsType
- type Model
- func (m *Model) Discriminate(encoded []ag.Node) []int
- func (m *Model) Encode(tokens []string) []ag.Node
- func (m *Model) Pool(transformed []ag.Node) ag.Node
- func (m *Model) PredictMasked(transformed []ag.Node, masked []int) map[int]ag.Node
- func (m *Model) PredictSeqRelationship(pooled ag.Node) ag.Node
- func (m *Model) SequenceClassification(transformed []ag.Node) ag.Node
- func (m *Model) TokenClassification(transformed []ag.Node) []ag.Node
- type Pooler
- type PoolerConfig
- type Predictor
- type PredictorConfig
- type QABody
- type QuestionAnsweringResponse
- type Response
- type Server
- func (s *Server) Answer(ctx context.Context, req *grpcapi.AnswerRequest) (*grpcapi.AnswerReply, error)
- func (s *Server) Classify(_ context.Context, req *grpcapi.ClassifyRequest) (*grpcapi.ClassifyReply, error)
- func (s *Server) ClassifyHandler(w http.ResponseWriter, req *http.Request)
- func (s *Server) Discriminate(ctx context.Context, req *grpcapi.DiscriminateRequest) (*grpcapi.DiscriminateReply, error)
- func (s *Server) DiscriminateHandler(w http.ResponseWriter, req *http.Request)
- func (s *Server) Encode(_ context.Context, req *grpcapi.EncodeRequest) (*grpcapi.EncodeReply, error)
- func (s *Server) LabelerHandler(w http.ResponseWriter, req *http.Request)
- func (s *Server) Predict(ctx context.Context, req *grpcapi.PredictRequest) (*grpcapi.PredictReply, error)
- func (s *Server) PredictHandler(w http.ResponseWriter, req *http.Request)
- func (s *Server) QaHandler(w http.ResponseWriter, req *http.Request)
- func (s *Server) SentenceEncoderHandler(w http.ResponseWriter, req *http.Request)
- func (s *Server) StartDefaultServer(address, grpcAddress, tlsCert, tlsKey string, tlsDisable bool)
- type SpanClassifier
- type SpanClassifierConfig
- type Token
- type TokenClassifierBody
- type Trainer
- type TrainingConfig
Constants ¶
const ( // DefaultConfigurationFile is the default BERT JSON configuration filename. DefaultConfigurationFile = "config.json" // DefaultVocabularyFile is the default BERT model's vocabulary filename. DefaultVocabularyFile = "vocab.txt" // DefaultModelFile is the default BERT spaGO model filename. DefaultModelFile = "spago_model.bin" // DefaultEmbeddingsStorage is the default directory name for BERT model's embedding storage. DefaultEmbeddingsStorage = "embeddings_storage" )
const DefaultFakeLabel = "FAKE"
DefaultFakeLabel is the default value for the fake label used for BERT "discriminate" server requests.
const DefaultPredictedLabel = "PREDICTED"
DefaultPredictedLabel is the default value for the predicted label used for BERT "predict" server requests.
const DefaultRealLabel = "REAL"
DefaultRealLabel is the default value for the real label used for BERT "discriminate" server requests.
Variables ¶
This section is empty.
Functions ¶
func ConvertHuggingFacePreTrained ¶
ConvertHuggingFacePreTrained converts a HuggingFace pre-trained BERT transformer model to a corresponding spaGO model.
Types ¶
type Answer ¶
type Answer struct { Text string `json:"text"` Start int `json:"start"` End int `json:"end"` Confidence mat.Float `json:"confidence"` }
Answer represent a single JSON-serializable BERT question-answering answer, used as part of a server's response.
type AnswerSlice ¶
type AnswerSlice []Answer
AnswerSlice is a slice of Answer elements, which implements the sort.Interface.
func (AnswerSlice) Less ¶
func (p AnswerSlice) Less(i, j int) bool
Less returns true if the Answer.Confidence of the element at position i is lower than the one of the element at position j.
func (AnswerSlice) Sort ¶
func (p AnswerSlice) Sort()
Sort sorts the AnswerSlice's elements by Answer.Confidence.
func (AnswerSlice) Swap ¶
func (p AnswerSlice) Swap(i, j int)
Swap swaps the elements at positions i and j.
type ClassConfidencePair ¶
type ClassConfidencePair struct { Class string `json:"class"` Confidence mat.Float `json:"confidence"` }
ClassConfidencePair associates a Confidence to a symbolic Class.
type Classifier ¶
type Classifier struct { Config ClassifierConfig *linear.Model }
Classifier implements a BERT Classifier.
func NewTokenClassifier ¶
func NewTokenClassifier(config ClassifierConfig) *Classifier
NewTokenClassifier returns a new BERT Classifier model.
type ClassifierConfig ¶
ClassifierConfig provides configuration settings for a BERT Classifier.
type ClassifyResponse ¶
type ClassifyResponse struct { Class string `json:"class"` Confidence mat.Float `json:"confidence"` Distribution []ClassConfidencePair `json:"distribution"` // Took is the number of milliseconds it took the server to execute the request. Took int64 `json:"took"` }
ClassifyResponse is a JSON-serializable server response for BERT "classify" requests.
type Config ¶
type Config struct { HiddenAct string `json:"hidden_act"` HiddenSize int `json:"hidden_size"` IntermediateSize int `json:"intermediate_size"` MaxPositionEmbeddings int `json:"max_position_embeddings"` NumAttentionHeads int `json:"num_attention_heads"` NumHiddenLayers int `json:"num_hidden_layers"` TypeVocabSize int `json:"type_vocab_size"` VocabSize int `json:"vocab_size"` ID2Label map[string]string `json:"id2label"` ReadOnly bool `json:"read_only"` }
Config provides configuration settings for a BERT Model.
func LoadConfig ¶
LoadConfig loads a BERT model Config from file.
type Discriminator ¶
Discriminator is a BERT Discriminator model.
func NewDiscriminator ¶
func NewDiscriminator(config DiscriminatorConfig) *Discriminator
NewDiscriminator returns a new BERT Discriminator model.
func (*Discriminator) Discriminate ¶ added in v0.2.0
func (m *Discriminator) Discriminate(encoded []ag.Node) []int
Discriminate returns 0 or 1 for each encoded element, where 1 means that the word is out of context.
type DiscriminatorConfig ¶
type DiscriminatorConfig struct { InputSize int HiddenSize int HiddenActivation ag.OpName OutputActivation ag.OpName }
DiscriminatorConfig provides configuration settings for a BERT Discriminator.
type Embeddings ¶
type Embeddings struct { nn.BaseModel EmbeddingsConfig Words *embeddings.Model Position []nn.Param `spago:"type:weights"` // TODO: stop auto-wrapping TokenType []nn.Param `spago:"type:weights"` Norm *layernorm.Model Projector *linear.Model UnknownEmbedding ag.Node `spago:"scope:processor"` }
Embeddings is a BERT Embeddings model.
func NewEmbeddings ¶
func NewEmbeddings(config EmbeddingsConfig) *Embeddings
NewEmbeddings returns a new BERT Embeddings model.
func (*Embeddings) Encode ¶ added in v0.2.0
func (m *Embeddings) Encode(words []string) []ag.Node
Encode transforms a string sequence into an encoded representation.
func (*Embeddings) InitProcessor ¶ added in v0.2.0
func (m *Embeddings) InitProcessor()
InitProcessor initializes the unknown embeddings.
type EmbeddingsConfig ¶
type EmbeddingsConfig struct { Size int OutputSize int MaxPositions int TokenTypes int WordsMapFilename string WordsMapReadOnly bool DeletePreEmbeddings bool }
EmbeddingsConfig provides configuration settings for BERT Embeddings.
type EncodeResponse ¶
type EncodeResponse struct { Data []mat.Float `json:"data"` // Took is the number of milliseconds it took the server to execute the request. Took int64 `json:"took"` }
EncodeResponse is a JSON-serializable server response for BERT "encode" requests.
type Encoder ¶
type Encoder struct { EncoderConfig *stack.Model }
Encoder is a BERT Encoder model.
func NewAlbertEncoder ¶
func NewAlbertEncoder(config EncoderConfig) *Encoder
NewAlbertEncoder returns a new variant of the BERT encoder model. In this variant the stack of N identical BERT encoder layers share the same parameters.
func NewBertEncoder ¶
func NewBertEncoder(config EncoderConfig) *Encoder
NewBertEncoder returns a new BERT encoder model composed of a stack of N identical BERT encoder layers.
type EncoderConfig ¶
type EncoderConfig struct { Size int NumOfAttentionHeads int IntermediateSize int IntermediateActivation ag.OpName NumOfLayers int }
EncoderConfig provides configuration parameters for BERT Encoder. TODO: include and use the dropout hyper-parameter
type EncoderLayer ¶
type EncoderLayer struct { nn.BaseModel MultiHeadAttention *multiheadattention.Model NormAttention *layernorm.Model FFN *stack.Model NormFFN *layernorm.Model Index int // layer index (useful for debugging) }
EncoderLayer is a BERT Encoder Layer model.
type LabelerOptionsType ¶
type LabelerOptionsType struct { MergeEntities bool `json:"mergeEntities"` // default false FilterNotEntities bool `json:"filterNotEntities"` // default false }
LabelerOptionsType is a JSON-serializable set of options for BERT "tag" (labeler) requests.
type Model ¶
type Model struct { nn.BaseModel Config Config Vocabulary *vocabulary.Vocabulary Embeddings *Embeddings Encoder *Encoder Predictor *Predictor Discriminator *Discriminator // used by "ELECTRA" training method Pooler *Pooler SeqRelationship *linear.Model SpanClassifier *SpanClassifier Classifier *Classifier }
Model implements a BERT model.
func NewDefaultBERT ¶
NewDefaultBERT returns a new model based on the original BERT architecture.
func (*Model) Discriminate ¶ added in v0.2.0
Discriminate returns 0 or 1 for each encoded element, where 1 means that the word is out of context.
func (*Model) Encode ¶ added in v0.2.0
Encode transforms a string sequence into an encoded representation.
func (*Model) Pool ¶ added in v0.2.0
Pool "pools" the model by simply taking the hidden state corresponding to the `[CLS]` token.
func (*Model) PredictMasked ¶ added in v0.2.0
PredictMasked performs a masked prediction task. It returns the predictions for indices associated to the masked nodes.
func (*Model) PredictSeqRelationship ¶ added in v0.2.0
PredictSeqRelationship predicts if the second sentence in the pair is the subsequent sentence in the original document.
func (*Model) SequenceClassification ¶ added in v0.2.0
SequenceClassification performs a single sentence-level classification, using the pooled CLS token.
type Pooler ¶
Pooler is a BERT Pooler model.
func NewPooler ¶
func NewPooler(config PoolerConfig) *Pooler
NewPooler returns a new BERT Pooler model.
type PoolerConfig ¶
PoolerConfig provides configuration settings for a BERT Pooler.
type Predictor ¶
Predictor is a BERT Predictor model.
func NewPredictor ¶
func NewPredictor(config PredictorConfig) *Predictor
NewPredictor returns a new BERT Predictor model.
type PredictorConfig ¶
type PredictorConfig struct { InputSize int HiddenSize int OutputSize int HiddenActivation ag.OpName OutputActivation ag.OpName }
PredictorConfig provides configuration settings for a BERT Predictor.
type QABody ¶
QABody is the JSON-serializable expected request body for BERT question-answering server requests.
type QuestionAnsweringResponse ¶
type QuestionAnsweringResponse struct { Answers AnswerSlice `json:"answers"` // Took is the number of milliseconds it took the server to execute the request. Took int64 `json:"took"` }
QuestionAnsweringResponse is the JSON-serializable structure for BERT question-answering server response.
type Response ¶
type Response struct { Tokens []Token `json:"tokens"` // Took is the number of milliseconds it took the server to execute the request. Took int64 `json:"took"` }
Response is the JSON-serializable server response for various BERT-related requests.
type Server ¶
type Server struct { TimeoutSeconds int MaxRequestBytes int // UnimplementedBERTServer must be embedded to have forward compatible implementations for gRPC. grpcapi.UnimplementedBERTServer // contains filtered or unexported fields }
Server contains everything needed to run a BERT server.
func (*Server) Answer ¶
func (s *Server) Answer(ctx context.Context, req *grpcapi.AnswerRequest) (*grpcapi.AnswerReply, error)
Answer handles a question-answering request over gRPC. TODO(evanmcclure@gmail.com) Reuse the gRPC message type for HTTP requests.
func (*Server) Classify ¶
func (s *Server) Classify(_ context.Context, req *grpcapi.ClassifyRequest) (*grpcapi.ClassifyReply, error)
Classify handles a classification request over gRPC. TODO(evanmcclure@gmail.com) Reuse the gRPC message type for HTTP requests.
func (*Server) ClassifyHandler ¶
func (s *Server) ClassifyHandler(w http.ResponseWriter, req *http.Request)
ClassifyHandler handles a classify request over HTTP.
func (*Server) Discriminate ¶
func (s *Server) Discriminate(ctx context.Context, req *grpcapi.DiscriminateRequest) (*grpcapi.DiscriminateReply, error)
Discriminate handles a discriminate request over gRPC. TODO(evanmcclure@gmail.com) Reuse the gRPC message type for HTTP requests.
func (*Server) DiscriminateHandler ¶
func (s *Server) DiscriminateHandler(w http.ResponseWriter, req *http.Request)
DiscriminateHandler handles a discriminate request over HTTP.
func (*Server) Encode ¶
func (s *Server) Encode(_ context.Context, req *grpcapi.EncodeRequest) (*grpcapi.EncodeReply, error)
Encode handles an encoding request over gRPC. TODO(evanmcclure@gmail.com) Reuse the gRPC message type for HTTP requests.
func (*Server) LabelerHandler ¶
func (s *Server) LabelerHandler(w http.ResponseWriter, req *http.Request)
LabelerHandler handles a labeling request over HTTP.
func (*Server) Predict ¶
func (s *Server) Predict(ctx context.Context, req *grpcapi.PredictRequest) (*grpcapi.PredictReply, error)
Predict handles a predict request over gRPC. TODO(evanmcclure@gmail.com) Reuse the gRPC message type for HTTP requests.
func (*Server) PredictHandler ¶
func (s *Server) PredictHandler(w http.ResponseWriter, req *http.Request)
PredictHandler handles a predict request over HTTP.
func (*Server) QaHandler ¶
func (s *Server) QaHandler(w http.ResponseWriter, req *http.Request)
QaHandler is the HTTP server handler function for BERT question-answering requests.
func (*Server) SentenceEncoderHandler ¶
func (s *Server) SentenceEncoderHandler(w http.ResponseWriter, req *http.Request)
SentenceEncoderHandler handles a sentence encoding request over HTTP.
func (*Server) StartDefaultServer ¶
StartDefaultServer is used to start a basic BERT HTTP server. If you want more control of the HTTP server you can run your own HTTP router using the public handler functions
type SpanClassifier ¶
SpanClassifier implements span classification for extractive question-answering tasks like SQuAD. It uses a linear layers to compute "span start logits" and "span end logits".
func NewSpanClassifier ¶
func NewSpanClassifier(config SpanClassifierConfig) *SpanClassifier
NewSpanClassifier returns a new BERT SpanClassifier model.
type SpanClassifierConfig ¶
type SpanClassifierConfig struct {
InputSize int
}
SpanClassifierConfig provides configuration settings for a BERT SpanClassifier.
type Token ¶
type Token struct { Text string `json:"text"` Start int `json:"start"` End int `json:"end"` Label string `json:"label"` }
Token is a JSON-serializable labeled text token.
type TokenClassifierBody ¶
type TokenClassifierBody struct { Options LabelerOptionsType `json:"options"` Text string `json:"text"` }
TokenClassifierBody provides JSON-serializable parameters for BERT "tag" (labeler) requests.
type Trainer ¶
type Trainer struct { TrainingConfig // contains filtered or unexported fields }
Trainer implements the training process for a BERT Model.
func NewTrainer ¶
func NewTrainer(model *Model, config TrainingConfig) *Trainer
NewTrainer returns a new BERT Trainer.
type TrainingConfig ¶
type TrainingConfig struct { Seed uint64 BatchSize int GradientClipping mat.Float UpdateMethod gd.MethodConfig CorpusPath string ModelPath string }
TrainingConfig provides configuration settings for a BERT Trainer.