Documentation ¶
Index ¶
- Constants
- func DecodeDataType(r io.Reader) (tensorflow.DataType, error)
- func DecodeInt64Array(r io.Reader) ([]int64, error)
- func DecodeString(r io.Reader) (string, error)
- func DecodeStringArray(r io.Reader) ([]string, error)
- func DecodeStringND(t reflect.Type, shape []int64, r io.Reader) (reflect.Value, error)
- func DecodeTensor(r io.Reader) (*tensorflow.Tensor, error)
- func DecodeTensorMap(r io.Reader) (map[string]*tensorflow.Tensor, error)
- func DerefType(dt tensorflow.DataType) tensorflow.DataType
- func EncodeDataType(w io.Writer, dt tensorflow.DataType) error
- func EncodeInt64Array(w io.Writer, arr []int64) error
- func EncodeString(w io.Writer, body string) error
- func EncodeStringArray(w io.Writer, arr []string) error
- func EncodeStringND(w io.Writer, val reflect.Value) error
- func EncodeTensor(w io.Writer, val *tensorflow.Tensor) error
- func EncodeTensorMap(w io.Writer, m map[string]*tensorflow.Tensor) error
- func LoadWeights(r io.Reader) (map[string]*tensorflow.Tensor, error)
- func ParseNodeOutput(path string) (string, int, error)
- func RefType(dt tensorflow.DataType) tensorflow.DataType
- type Batcher
- type Model
- func (model *Model) AddWeights(scale float64, weights map[string]*tensorflow.Tensor) error
- func (model Model) ApplyPrefix(op string) string
- func (model *Model) Close() error
- func (model *Model) ExportWeights(w io.Writer) error
- func (model *Model) Load(reader io.Reader) error
- func (model *Model) NumWeights() (int64, error)
- func (m *Model) Operation(path string) (*tensorflow.Operation, error)
- func (m *Model) Output(path string) (tensorflow.Output, error)
- func (model *Model) Save(writer io.Writer) error
- func (model *Model) SetWeights(weights map[string]*tensorflow.Tensor) error
- func (model *Model) WeightsMap() (map[string]*tensorflow.Tensor, error)
Constants ¶
const ( SaverDefName = "saver_def.pb" GraphDefName = "graph_def.pb" SavedModelName = "saved_model" ModelMetaName = "model_meta.pb" // FilePerm is the file permission all the model files use. FilePerm = 0600 )
const ( PokAssignOp = "pok/update/assign" PokAssignAddOp = "pok/update/assign_add" PokVarPrefix = "pok/update/var/" PokVarScaleOp = "pok/update/scale" )
Variables ¶
This section is empty.
Functions ¶
func DecodeDataType ¶
func DecodeDataType(r io.Reader) (tensorflow.DataType, error)
DecodeDataType decodes the data type from the reader.
func DecodeInt64Array ¶
DecodeInt64Array decodes a int64 array from the bytes. Format is: int64(num) + repeated (int64)
func DecodeString ¶
Decode string reads a string from an reader. Format is: int64(len) + body
func DecodeStringArray ¶
DecodeStringArray decodes a string array from the bytes. Format is: int64(num strings) + repeated (string)
func DecodeStringND ¶
func DecodeTensor ¶
func DecodeTensor(r io.Reader) (*tensorflow.Tensor, error)
DecodeTensor decodes a tensor from a gob.Decoder and returns it. See EncodeTensor.
func DecodeTensorMap ¶
DecodeTensorMap decodes a map[string]*tensorflow.Tensor from a gob.Decoder and returns it. See EncodeTensorMap.
func DerefType ¶
func DerefType(dt tensorflow.DataType) tensorflow.DataType
func EncodeDataType ¶
func EncodeDataType(w io.Writer, dt tensorflow.DataType) error
EncodeDataType writes the data type to the writer.
func EncodeInt64Array ¶
EncodeInt64Array encodes a int64 array. Format is: int64(num) + repeated (int64)
func EncodeString ¶
EncodeString encodes a string into the writer.
func EncodeStringArray ¶
EncodeStringArray decodes a string array from the bytes. Format is: int64(num strings) + repeated (string)
func EncodeTensor ¶
func EncodeTensor(w io.Writer, val *tensorflow.Tensor) error
EncodeTensor encodes a tensor into a gob.Encoder. See DecodeTensor.
func EncodeTensorMap ¶
EncodeTensorMap encodes a map[string]*tensorflow.Tensor into a gob.Encoder. See DecodeTensorMap.
func LoadWeights ¶
func ParseNodeOutput ¶
ParseNodeOutput returns the node name when given a "<name>:<output #>" pair.
func RefType ¶
func RefType(dt tensorflow.DataType) tensorflow.DataType
Types ¶
type Batcher ¶
type Batcher struct {
// contains filtered or unexported fields
}
Batcher takes a fixed number of tensors and concatenates them into one larger tensor. This is mostly used for creating mini-batches of tensors for SGD.
func NewTensorBatcher ¶
NewTensorBatcher returns a new Batcher with the specified params. Shape should have exactly one dimension that is unspecified (-1) and the tensors will be concatenated along that axis.
func (*Batcher) Batch ¶
func (m *Batcher) Batch(values []*tensorflow.Tensor) (*tensorflow.Tensor, error)
Batch takes in a session and values and returns a single output tensor that has all the values concatenated.
type Model ¶
type Model struct { Graph *tensorflow.Graph Session *tensorflow.Session SaverDef tensorflowpb.SaverDef Meta clientpb.ModelMeta Prefix string }
func LoadModel ¶
LoadModel loads a model from a provided .tar.gz io stream. The returned session must be closed when done using.
The .tar.gz file should contain the following files: - saver_def.pb - graph_def.pb - checkpoint - saved_model-<iteration>.{index,meta,data-*}
func (*Model) AddWeights ¶
AddWeights imports weights and then adds them to the current with a scaler.
func (Model) ApplyPrefix ¶
func (*Model) NumWeights ¶
NumWeights returns the total number of weights the model has.
func (*Model) SetWeights ¶
func (model *Model) SetWeights(weights map[string]*tensorflow.Tensor) error
SetWeights sets the weights of the model.
func (*Model) WeightsMap ¶
func (model *Model) WeightsMap() (map[string]*tensorflow.Tensor, error)