model

package
v0.26.28 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 22, 2022 License: MIT Imports: 23 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func GetTfkgPythonCode added in v0.26.28

func GetTfkgPythonCode(customDefinitions []string) string

This code is generated automatically using "go generate ./..." from model/tfkg_model.py. DO NOT EDIT manually.

func GetVanillaPythonCode added in v0.26.28

func GetVanillaPythonCode() string

Types

type CompileConfig added in v0.26.28

type CompileConfig struct {
	Loss             Loss
	Optimizer        optimizer.Optimizer
	ModelInfoSaveDir string
	BatchSize        int
	CpuInference     bool
}

type EvaluateConfig

type EvaluateConfig struct {
	BatchSize int
	PreFetch  int
	Metrics   []metric.Metric
	Callbacks []callback.Callback
	Verbose   int
}

type FitConfig

type FitConfig struct {
	Epochs     int
	BatchSize  int
	Validation bool
	PreFetch   int
	Metrics    []metric.Metric
	Callbacks  []callback.Callback
	Verbose    int
}

type Loss added in v0.26.28

type Loss string
var (
	Version = "0.2.6"

	LossBinaryCrossentropy            Loss = "binary_crossentropy"
	LossSparseCategoricalCrossentropy Loss = "sparse_categorical_crossentropy"
	LossMSE                           Loss = "mse"
)

type TfkgModel added in v0.26.28

type TfkgModel struct {
	// contains filtered or unexported fields
}

func LoadModel

func LoadModel(
	errorHandler *cberrors.ErrorsContainer,
	logger *cblog.Logger,
	dir string,
	sessionOptions ...*for_core_protos_go_proto.ConfigProto,
) (*TfkgModel, error)

func LoadVanillaModel added in v0.26.28

func LoadVanillaModel(
	errorHandler *cberrors.ErrorsContainer,
	logger *cblog.Logger,
	dir string,
	loss Loss,
	optimizer optimizer.Optimizer,
	sessionOptions ...*for_core_protos_go_proto.ConfigProto,
) (*TfkgModel, error)

func NewModel added in v0.26.28

func NewModel(
	logger *cblog.Logger,
	errorHandler *cberrors.ErrorsContainer,
	output layer.Layer,
) *TfkgModel

func NewSequentialModel

func NewSequentialModel(
	logger *cblog.Logger,
	errorHandler *cberrors.ErrorsContainer,
	input layer.Layer,
	layers ...layer.Layer,
) *TfkgModel

func (*TfkgModel) CompileAndLoad added in v0.26.28

func (m *TfkgModel) CompileAndLoad(config CompileConfig, sessionOptions ...*for_core_protos_go_proto.ConfigProto) error

func (*TfkgModel) Evaluate added in v0.26.28

func (m *TfkgModel) Evaluate(
	mode data.GeneratorMode,
	dataset data.Dataset,
	config EvaluateConfig,
)

func (*TfkgModel) Fit added in v0.26.28

func (m *TfkgModel) Fit(
	dataset data.Dataset,
	config FitConfig,
)

func (*TfkgModel) GetLayerWeights added in v0.26.28

func (m *TfkgModel) GetLayerWeights(layerName string) ([]*tf.Tensor, error)

func (*TfkgModel) GetModelWeights added in v0.26.28

func (m *TfkgModel) GetModelWeights() ([]*tf.Tensor, error)

func (*TfkgModel) Predict added in v0.26.28

func (m *TfkgModel) Predict(inputs ...*tf.Tensor) (*tf.Tensor, error)

func (*TfkgModel) Save added in v0.26.28

func (m *TfkgModel) Save(dir string) error

func (*TfkgModel) SetModelWeights added in v0.26.28

func (m *TfkgModel) SetModelWeights(weights []*tf.Tensor) error

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL