tpu

package
v2.12.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 8, 2023 License: BSD-3-Clause Imports: 11 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

View Source
var (
	CompilationResultProto_ErrorCode_name = map[int32]string{
		0: "UNKNOWN",
		1: "OUT_OF_MEMORY",
	}
	CompilationResultProto_ErrorCode_value = map[string]int32{
		"UNKNOWN":       0,
		"OUT_OF_MEMORY": 1,
	}
)

Enum value maps for CompilationResultProto_ErrorCode.

View Source
var (
	TPUCompileMetadataProto_Arg_Kind_name = map[int32]string{
		0: "INVALID",
		1: "PARAMETER",
		2: "VARIABLE",
		3: "GUARANTEED_CONSTANT",
	}
	TPUCompileMetadataProto_Arg_Kind_value = map[string]int32{
		"INVALID":             0,
		"PARAMETER":           1,
		"VARIABLE":            2,
		"GUARANTEED_CONSTANT": 3,
	}
)

Enum value maps for TPUCompileMetadataProto_Arg_Kind.

View Source
var (
	TPUCompileMetadataProto_Arg_EnableXlaSharding_name = map[int32]string{
		0: "DISALLOWED",
		1: "TENTATIVE",
		2: "ALLOWED",
	}
	TPUCompileMetadataProto_Arg_EnableXlaSharding_value = map[string]int32{
		"DISALLOWED": 0,
		"TENTATIVE":  1,
		"ALLOWED":    2,
	}
)

Enum value maps for TPUCompileMetadataProto_Arg_EnableXlaSharding.

View Source
var (
	TPUCompileOptions_Precision_name = map[int32]string{
		0: "DEFAULT",
		1: "BFLOAT16",
		2: "FLOAT32",
		3: "TENSOR_FLOAT32",
	}
	TPUCompileOptions_Precision_value = map[string]int32{
		"DEFAULT":        0,
		"BFLOAT16":       1,
		"FLOAT32":        2,
		"TENSOR_FLOAT32": 3,
	}
)

Enum value maps for TPUCompileOptions_Precision.

View Source
var (
	GradientAccumulationStatus_Status_name = map[int32]string{
		0: "UNSPECIFIED",
		1: "ENABLED",
		2: "DISABLED",
	}
	GradientAccumulationStatus_Status_value = map[string]int32{
		"UNSPECIFIED": 0,
		"ENABLED":     1,
		"DISABLED":    2,
	}
)

Enum value maps for GradientAccumulationStatus_Status.

View Source
var (
	LowDimensionalPackingStatus_Status_name = map[int32]string{
		0: "UNSPECIFIED",
		1: "ENABLED",
		2: "DISABLED",
	}
	LowDimensionalPackingStatus_Status_value = map[string]int32{
		"UNSPECIFIED": 0,
		"ENABLED":     1,
		"DISABLED":    2,
	}
)

Enum value maps for LowDimensionalPackingStatus_Status.

View Source
var (
	HotIdReplicationConfiguration_Status_name = map[int32]string{
		0: "UNSPECIFIED",
		1: "ENABLED",
		2: "DISABLED",
		3: "MIGRATION_ONLY",
	}
	HotIdReplicationConfiguration_Status_value = map[string]int32{
		"UNSPECIFIED":    0,
		"ENABLED":        1,
		"DISABLED":       2,
		"MIGRATION_ONLY": 3,
	}
)

Enum value maps for HotIdReplicationConfiguration_Status.

View Source
var (
	TPUHardwareFeature_EmbeddingFeature_name = map[int32]string{
		0: "UNSUPPORTED",
		1: "V1",
		2: "V2",
	}
	TPUHardwareFeature_EmbeddingFeature_value = map[string]int32{
		"UNSUPPORTED": 0,
		"V1":          1,
		"V2":          2,
	}
)

Enum value maps for TPUHardwareFeature_EmbeddingFeature.

View Source
var (
	TPUEmbeddingConfiguration_Mode_name = map[int32]string{
		0: "UNSPECIFIED",
		1: "INFERENCE",
		2: "TRAINING",
		3: "BACKWARD_PASS_ONLY",
	}
	TPUEmbeddingConfiguration_Mode_value = map[string]int32{
		"UNSPECIFIED":        0,
		"INFERENCE":          1,
		"TRAINING":           2,
		"BACKWARD_PASS_ONLY": 3,
	}
)

Enum value maps for TPUEmbeddingConfiguration_Mode.

View Source
var (
	TPUEmbeddingConfiguration_ShardingStrategy_name = map[int32]string{
		0: "DIV_DEFAULT",
		1: "MOD",
	}
	TPUEmbeddingConfiguration_ShardingStrategy_value = map[string]int32{
		"DIV_DEFAULT": 0,
		"MOD":         1,
	}
)

Enum value maps for TPUEmbeddingConfiguration_ShardingStrategy.

View Source
var File_tensorflow_core_protobuf_tpu_compilation_result_proto protoreflect.FileDescriptor
View Source
var File_tensorflow_core_protobuf_tpu_compile_metadata_proto protoreflect.FileDescriptor
View Source
var File_tensorflow_core_protobuf_tpu_dynamic_padding_proto protoreflect.FileDescriptor
View Source
var File_tensorflow_core_protobuf_tpu_optimization_parameters_proto protoreflect.FileDescriptor
View Source
var File_tensorflow_core_protobuf_tpu_topology_proto protoreflect.FileDescriptor
View Source
var File_tensorflow_core_protobuf_tpu_tpu_embedding_configuration_proto protoreflect.FileDescriptor

Functions

This section is empty.

Types

type AdadeltaParameters

type AdadeltaParameters struct {
	Rho     float32 `protobuf:"fixed32,1,opt,name=rho,proto3" json:"rho,omitempty"`
	Epsilon float32 `protobuf:"fixed32,2,opt,name=epsilon,proto3" json:"epsilon,omitempty"`
	// contains filtered or unexported fields
}

https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adadelta https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L933

func (*AdadeltaParameters) Descriptor deprecated

func (*AdadeltaParameters) Descriptor() ([]byte, []int)

Deprecated: Use AdadeltaParameters.ProtoReflect.Descriptor instead.

func (*AdadeltaParameters) GetEpsilon

func (x *AdadeltaParameters) GetEpsilon() float32

func (*AdadeltaParameters) GetRho

func (x *AdadeltaParameters) GetRho() float32

func (*AdadeltaParameters) ProtoMessage

func (*AdadeltaParameters) ProtoMessage()

func (*AdadeltaParameters) ProtoReflect

func (x *AdadeltaParameters) ProtoReflect() protoreflect.Message

func (*AdadeltaParameters) Reset

func (x *AdadeltaParameters) Reset()

func (*AdadeltaParameters) String

func (x *AdadeltaParameters) String() string

type AdagradMomentumParameters

type AdagradMomentumParameters struct {

	// Moving average parameter for the momentum accumulator.
	Momentum float32 `protobuf:"fixed32,1,opt,name=momentum,proto3" json:"momentum,omitempty"`
	// Whether to use the Nesterov variant of momentum.
	UseNesterov bool `protobuf:"varint,2,opt,name=use_nesterov,json=useNesterov,proto3" json:"use_nesterov,omitempty"`
	// Exponent for the gradient^2 accumulator.
	Exponent float32 `protobuf:"fixed32,3,opt,name=exponent,proto3" json:"exponent,omitempty"`
	// Moving average parameter for the gradient^2 accumulator.
	Beta2 float32 `protobuf:"fixed32,4,opt,name=beta2,proto3" json:"beta2,omitempty"`
	// Offset added to the Adagrad accumulator.
	Epsilon float32 `protobuf:"fixed32,5,opt,name=epsilon,proto3" json:"epsilon,omitempty"`
	// contains filtered or unexported fields
}

This optimizer combines the Adagrad and Momentum update rules. accum(new) = beta2 == 1.0 ?

accum(old) + grad^2 :
beta2 * accum(old) + (1 - beta2) * grad^2

accum_with_exponent = (accum(new) + epsilon)^(-1.0 / exponent) mom_accum(new) = momentum * mom_accum(old) + accum_with_exponent update = use_nesterov ?

momentum * mom_accum(new) + accum_with_exponent :
mom_accum(new)

var(new) = var(old) - lr * grad * update Algorithm described in https://arxiv.org/abs/2002.11803.

func (*AdagradMomentumParameters) Descriptor deprecated

func (*AdagradMomentumParameters) Descriptor() ([]byte, []int)

Deprecated: Use AdagradMomentumParameters.ProtoReflect.Descriptor instead.

func (*AdagradMomentumParameters) GetBeta2

func (x *AdagradMomentumParameters) GetBeta2() float32

func (*AdagradMomentumParameters) GetEpsilon

func (x *AdagradMomentumParameters) GetEpsilon() float32

func (*AdagradMomentumParameters) GetExponent

func (x *AdagradMomentumParameters) GetExponent() float32

func (*AdagradMomentumParameters) GetMomentum

func (x *AdagradMomentumParameters) GetMomentum() float32

func (*AdagradMomentumParameters) GetUseNesterov

func (x *AdagradMomentumParameters) GetUseNesterov() bool

func (*AdagradMomentumParameters) ProtoMessage

func (*AdagradMomentumParameters) ProtoMessage()

func (*AdagradMomentumParameters) ProtoReflect

func (*AdagradMomentumParameters) Reset

func (x *AdagradMomentumParameters) Reset()

func (*AdagradMomentumParameters) String

func (x *AdagradMomentumParameters) String() string

type AdagradParameters

type AdagradParameters struct {
	// contains filtered or unexported fields
}

https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adagrad https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L1634

func (*AdagradParameters) Descriptor deprecated

func (*AdagradParameters) Descriptor() ([]byte, []int)

Deprecated: Use AdagradParameters.ProtoReflect.Descriptor instead.

func (*AdagradParameters) ProtoMessage

func (*AdagradParameters) ProtoMessage()

func (*AdagradParameters) ProtoReflect

func (x *AdagradParameters) ProtoReflect() protoreflect.Message

func (*AdagradParameters) Reset

func (x *AdagradParameters) Reset()

func (*AdagradParameters) String

func (x *AdagradParameters) String() string

type AdamParameters

type AdamParameters struct {
	Beta1            float32 `protobuf:"fixed32,3,opt,name=beta1,proto3" json:"beta1,omitempty"`
	Beta2            float32 `protobuf:"fixed32,4,opt,name=beta2,proto3" json:"beta2,omitempty"`
	Epsilon          float32 `protobuf:"fixed32,5,opt,name=epsilon,proto3" json:"epsilon,omitempty"`
	UseNonLazyAdam   bool    `protobuf:"varint,8,opt,name=use_non_lazy_adam,json=useNonLazyAdam,proto3" json:"use_non_lazy_adam,omitempty"`
	UseSumInsideSqrt bool    `protobuf:"varint,10,opt,name=use_sum_inside_sqrt,json=useSumInsideSqrt,proto3" json:"use_sum_inside_sqrt,omitempty"`
	// contains filtered or unexported fields
}

The Adam optimizer does not implement hyper-parameter update due to hardware limitations; use the dynamic learning rate feature instead, setting the learning rate to: user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) Here, t is the current timestep.

https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L32

Note that the code by default implements the lazy version of Adam (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer) unless the use_non_lazy_adam parameter is set, in which case it implements the normal version of Adam that updates all parameters in the embedding table, even for entries that are not used in the current minibatch (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If use_non_lazy_adam is enabled, gradient accumulation is also required to be enabled in order to get correct results; a warning will be printed otherwise (which may change to an error in the future). If use_sum_inside_sqrt is set, the Adam variable update formula will be changed from m / (sqrt(v) + epsilon) to m / sqrt(v + epsilon**2); this option improves the performance of TPU training and is not expected to harm model quality.

func (*AdamParameters) Descriptor deprecated

func (*AdamParameters) Descriptor() ([]byte, []int)

Deprecated: Use AdamParameters.ProtoReflect.Descriptor instead.

func (*AdamParameters) GetBeta1

func (x *AdamParameters) GetBeta1() float32

func (*AdamParameters) GetBeta2

func (x *AdamParameters) GetBeta2() float32

func (*AdamParameters) GetEpsilon

func (x *AdamParameters) GetEpsilon() float32

func (*AdamParameters) GetUseNonLazyAdam

func (x *AdamParameters) GetUseNonLazyAdam() bool

func (*AdamParameters) GetUseSumInsideSqrt

func (x *AdamParameters) GetUseSumInsideSqrt() bool

func (*AdamParameters) ProtoMessage

func (*AdamParameters) ProtoMessage()

func (*AdamParameters) ProtoReflect

func (x *AdamParameters) ProtoReflect() protoreflect.Message

func (*AdamParameters) Reset

func (x *AdamParameters) Reset()

func (*AdamParameters) String

func (x *AdamParameters) String() string

type AssignParameters

type AssignParameters struct {
	// contains filtered or unexported fields
}

Optimizer that just sets the variable to the value of the gradient. To be correct, this requires either gradient accumulation (to sum the values of a computed expression across the samples) or to deduplicate IDs within a single host (to assign the value from an arbitrary sample).

func (*AssignParameters) Descriptor deprecated

func (*AssignParameters) Descriptor() ([]byte, []int)

Deprecated: Use AssignParameters.ProtoReflect.Descriptor instead.

func (*AssignParameters) ProtoMessage

func (*AssignParameters) ProtoMessage()

func (*AssignParameters) ProtoReflect

func (x *AssignParameters) ProtoReflect() protoreflect.Message

func (*AssignParameters) Reset

func (x *AssignParameters) Reset()

func (*AssignParameters) String

func (x *AssignParameters) String() string

type BoundedAdagradParameters

type BoundedAdagradParameters struct {

	// Whether to use the updated or the old value of the accumulator when
	// computing the effective learning rate. When update_accumulator_first is set
	// to True, the updated value of the accumulator is used.
	UpdateAccumulatorFirst bool `` /* 130-byte string literal not displayed */
	// The max_var_update value to use. Set value to 0 (default) to disable using
	// max_var_update to clip the gradient.
	MaxVarUpdate float32 `protobuf:"fixed32,2,opt,name=max_var_update,json=maxVarUpdate,proto3" json:"max_var_update,omitempty"`
	// The maximum value of the accumulator. Set max_accumulator to 0 (default)
	// to disable using max_accumulator to clip the accumulator.
	MaxAccumulator float32 `protobuf:"fixed32,3,opt,name=max_accumulator,json=maxAccumulator,proto3" json:"max_accumulator,omitempty"`
	// contains filtered or unexported fields
}

Algorithm in http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.

func (*BoundedAdagradParameters) Descriptor deprecated

func (*BoundedAdagradParameters) Descriptor() ([]byte, []int)

Deprecated: Use BoundedAdagradParameters.ProtoReflect.Descriptor instead.

func (*BoundedAdagradParameters) GetMaxAccumulator

func (x *BoundedAdagradParameters) GetMaxAccumulator() float32

func (*BoundedAdagradParameters) GetMaxVarUpdate

func (x *BoundedAdagradParameters) GetMaxVarUpdate() float32

func (*BoundedAdagradParameters) GetUpdateAccumulatorFirst

func (x *BoundedAdagradParameters) GetUpdateAccumulatorFirst() bool

func (*BoundedAdagradParameters) ProtoMessage

func (*BoundedAdagradParameters) ProtoMessage()

func (*BoundedAdagradParameters) ProtoReflect

func (x *BoundedAdagradParameters) ProtoReflect() protoreflect.Message

func (*BoundedAdagradParameters) Reset

func (x *BoundedAdagradParameters) Reset()

func (*BoundedAdagradParameters) String

func (x *BoundedAdagradParameters) String() string

type CenteredRmsPropParameters

type CenteredRmsPropParameters struct {
	Rho      float32 `protobuf:"fixed32,1,opt,name=rho,proto3" json:"rho,omitempty"`
	Momentum float32 `protobuf:"fixed32,2,opt,name=momentum,proto3" json:"momentum,omitempty"`
	Epsilon  float32 `protobuf:"fixed32,3,opt,name=epsilon,proto3" json:"epsilon,omitempty"`
	// contains filtered or unexported fields
}

https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L4358

func (*CenteredRmsPropParameters) Descriptor deprecated

func (*CenteredRmsPropParameters) Descriptor() ([]byte, []int)

Deprecated: Use CenteredRmsPropParameters.ProtoReflect.Descriptor instead.

func (*CenteredRmsPropParameters) GetEpsilon

func (x *CenteredRmsPropParameters) GetEpsilon() float32

func (*CenteredRmsPropParameters) GetMomentum

func (x *CenteredRmsPropParameters) GetMomentum() float32

func (*CenteredRmsPropParameters) GetRho

func (x *CenteredRmsPropParameters) GetRho() float32

func (*CenteredRmsPropParameters) ProtoMessage

func (*CenteredRmsPropParameters) ProtoMessage()

func (*CenteredRmsPropParameters) ProtoReflect

func (*CenteredRmsPropParameters) Reset

func (x *CenteredRmsPropParameters) Reset()

func (*CenteredRmsPropParameters) String

func (x *CenteredRmsPropParameters) String() string

type ClippingLimits

type ClippingLimits struct {
	Lower *wrapperspb.FloatValue `protobuf:"bytes,1,opt,name=lower,proto3" json:"lower,omitempty"` // -inf if not set
	Upper *wrapperspb.FloatValue `protobuf:"bytes,2,opt,name=upper,proto3" json:"upper,omitempty"` // +inf if not set
	// contains filtered or unexported fields
}

func (*ClippingLimits) Descriptor deprecated

func (*ClippingLimits) Descriptor() ([]byte, []int)

Deprecated: Use ClippingLimits.ProtoReflect.Descriptor instead.

func (*ClippingLimits) GetLower

func (x *ClippingLimits) GetLower() *wrapperspb.FloatValue

func (*ClippingLimits) GetUpper

func (x *ClippingLimits) GetUpper() *wrapperspb.FloatValue

func (*ClippingLimits) ProtoMessage

func (*ClippingLimits) ProtoMessage()

func (*ClippingLimits) ProtoReflect

func (x *ClippingLimits) ProtoReflect() protoreflect.Message

func (*ClippingLimits) Reset

func (x *ClippingLimits) Reset()

func (*ClippingLimits) String

func (x *ClippingLimits) String() string

type CompilationResultProto

type CompilationResultProto struct {

	// The error message, if any, returned during compilation.
	StatusCode         protobuf.Code `protobuf:"varint,1,opt,name=status_code,json=statusCode,proto3,enum=tensorflow.error.Code" json:"status_code,omitempty"`
	StatusErrorMessage string        `protobuf:"bytes,2,opt,name=status_error_message,json=statusErrorMessage,proto3" json:"status_error_message,omitempty"`
	// HLO proto.
	HloProtos []*service.HloProto              `protobuf:"bytes,3,rep,name=hlo_protos,json=hloProtos,proto3" json:"hlo_protos,omitempty"`
	ErrorCode CompilationResultProto_ErrorCode `` /* 142-byte string literal not displayed */
	// contains filtered or unexported fields
}

Describes the result of a TPU compilation. This is also used as TPU compilation result status payload. URI: "type.googleapis.com/tensorflow.tpu.CompilationResultProto"

func (*CompilationResultProto) Descriptor deprecated

func (*CompilationResultProto) Descriptor() ([]byte, []int)

Deprecated: Use CompilationResultProto.ProtoReflect.Descriptor instead.

func (*CompilationResultProto) GetErrorCode

func (*CompilationResultProto) GetHloProtos

func (x *CompilationResultProto) GetHloProtos() []*service.HloProto

func (*CompilationResultProto) GetStatusCode

func (x *CompilationResultProto) GetStatusCode() protobuf.Code

func (*CompilationResultProto) GetStatusErrorMessage

func (x *CompilationResultProto) GetStatusErrorMessage() string

func (*CompilationResultProto) ProtoMessage

func (*CompilationResultProto) ProtoMessage()

func (*CompilationResultProto) ProtoReflect

func (x *CompilationResultProto) ProtoReflect() protoreflect.Message

func (*CompilationResultProto) Reset

func (x *CompilationResultProto) Reset()

func (*CompilationResultProto) String

func (x *CompilationResultProto) String() string

type CompilationResultProto_ErrorCode

type CompilationResultProto_ErrorCode int32
const (
	CompilationResultProto_UNKNOWN       CompilationResultProto_ErrorCode = 0
	CompilationResultProto_OUT_OF_MEMORY CompilationResultProto_ErrorCode = 1
)

func (CompilationResultProto_ErrorCode) Descriptor

func (CompilationResultProto_ErrorCode) Enum

func (CompilationResultProto_ErrorCode) EnumDescriptor deprecated

func (CompilationResultProto_ErrorCode) EnumDescriptor() ([]byte, []int)

Deprecated: Use CompilationResultProto_ErrorCode.Descriptor instead.

func (CompilationResultProto_ErrorCode) Number

func (CompilationResultProto_ErrorCode) String

func (CompilationResultProto_ErrorCode) Type

type DynamicLearningRate

type DynamicLearningRate struct {

	// For tables where learning rates are dynamically computed and communicated
	// to the TPU embedding program, a tag must be specified for the learning
	// rate.
	//
	// The tag must be a non-negative  integer. The total number of unique tags
	// must be less than or equal to the number of tables in the TPU embedding
	// configuration (a table does not specify any tag if it uses a constant
	// learning rate, and specifies exactly one tag if it uses dynamic learning
	// rates).
	//
	// All tags in the range [0, number_of_unique_tags) must be present in the TPU
	// embedding configuration, i.e. a tag cannot be skipped if a different tag
	// numerically greater than it is used in the configuration.
	//
	// If multiple tables specify the same tag, they *MUST* have
	// the same dynamic learning rate, for example, their dynamic learning rate
	// could be computed by the same TensorFlow sub-graph. The partitioning of the
	// embedding layer would be more optimal if the number_of_unique_tags is as
	// *LOW* as possible, i.e., if many tables share the same tag.
	//
	// The learning_rate input of the SendTPUEmbeddingGradients op is used to
	// communicate dynamic learning rates to the TPU embedding program.
	// The learning_rate input is a list of scalars where the size of the list is
	// equal to the number of unique tags. The learning rate associated with a
	// particular tag is specified by populating its corresponding index in the
	// list of learning_rate scalars.
	Tag int32 `protobuf:"varint,1,opt,name=tag,proto3" json:"tag,omitempty"`
	// contains filtered or unexported fields
}

Dynamic learning rate specification in the TPUEmbeddingConfiguration. The actual learning rates are provided as a scalar input list to the SendTPUEmbeddingGradients Op indexed by their tag specified through the following proto.

func (*DynamicLearningRate) Descriptor deprecated

func (*DynamicLearningRate) Descriptor() ([]byte, []int)

Deprecated: Use DynamicLearningRate.ProtoReflect.Descriptor instead.

func (*DynamicLearningRate) GetTag

func (x *DynamicLearningRate) GetTag() int32

func (*DynamicLearningRate) ProtoMessage

func (*DynamicLearningRate) ProtoMessage()

func (*DynamicLearningRate) ProtoReflect

func (x *DynamicLearningRate) ProtoReflect() protoreflect.Message

func (*DynamicLearningRate) Reset

func (x *DynamicLearningRate) Reset()

func (*DynamicLearningRate) String

func (x *DynamicLearningRate) String() string

type FrequencyEstimatorParameters

type FrequencyEstimatorParameters struct {

	// Learning rate between (0, 1) that is used to update the array D.
	Tau float32 `protobuf:"fixed32,1,opt,name=tau,proto3" json:"tau,omitempty"`
	// Maximum value of delta: difference between the current global step and the
	// last global step at which the row was sampled.
	MaxDelta float32 `protobuf:"fixed32,2,opt,name=max_delta,json=maxDelta,proto3" json:"max_delta,omitempty"`
	// Threshold used to determine whether the current update is an outlier.
	OutlierThreshold float32 `protobuf:"fixed32,3,opt,name=outlier_threshold,json=outlierThreshold,proto3" json:"outlier_threshold,omitempty"`
	// The weight exponent used to transform the estimated delta into weights.
	// The transformation function is: (delta / max_delta) ^ (weight_exponent)
	WeightExponent float32 `protobuf:"fixed32,4,opt,name=weight_exponent,json=weightExponent,proto3" json:"weight_exponent,omitempty"`
	// contains filtered or unexported fields
}

Estimator for the frequency of updates to a lookup table. It maintains an array (tf.Variable) D, where each element records the average number of global steps between two consecutive batches that hit the corresponding bucket. Once an item with bucket id i is sampled, D[i] is updated by:

D[i] <- D[i] * (1 - tau) + delta[i] * tau,

where tau is a learning rate between 0 and 1 (exclusive), and

delta[i] = current global step - last step i is sampled.

The estimated frequency (sampling rate in a batch) is thus 1 / D[i].

Elements in D are initialized with a large value max_delta. delta[i] will also be capped by this value.

The exact sequence of operations used in the optimizer is shown below. last_hit_step[i] is a tf.Variable that holds the last global step at which i was sampled.

delta = global_step - last_hit_step[i]
clipped_delta = min(delta, params.max_delta)
is_outlier = (delta >= params.outlier_threshold * D[i])
D[i] <- is_outlier ? clipped_delta
                   : D[i] * (1 - params.tau) + clipped_delta * params.tau
last_hit_step[i] <- global_step

func (*FrequencyEstimatorParameters) Descriptor deprecated

func (*FrequencyEstimatorParameters) Descriptor() ([]byte, []int)

Deprecated: Use FrequencyEstimatorParameters.ProtoReflect.Descriptor instead.

func (*FrequencyEstimatorParameters) GetMaxDelta

func (x *FrequencyEstimatorParameters) GetMaxDelta() float32

func (*FrequencyEstimatorParameters) GetOutlierThreshold

func (x *FrequencyEstimatorParameters) GetOutlierThreshold() float32

func (*FrequencyEstimatorParameters) GetTau

func (*FrequencyEstimatorParameters) GetWeightExponent

func (x *FrequencyEstimatorParameters) GetWeightExponent() float32

func (*FrequencyEstimatorParameters) ProtoMessage

func (*FrequencyEstimatorParameters) ProtoMessage()

func (*FrequencyEstimatorParameters) ProtoReflect

func (*FrequencyEstimatorParameters) Reset

func (x *FrequencyEstimatorParameters) Reset()

func (*FrequencyEstimatorParameters) String

type FtrlParameters

type FtrlParameters struct {
	L1                 float32 `protobuf:"fixed32,1,opt,name=l1,proto3" json:"l1,omitempty"`
	L2                 float32 `protobuf:"fixed32,2,opt,name=l2,proto3" json:"l2,omitempty"`
	LrPower            float32 `protobuf:"fixed32,3,opt,name=lr_power,json=lrPower,proto3" json:"lr_power,omitempty"`
	Beta               float32 `protobuf:"fixed32,7,opt,name=beta,proto3" json:"beta,omitempty"`
	MultiplyLinearByLr bool    `protobuf:"varint,6,opt,name=multiply_linear_by_lr,json=multiplyLinearByLr,proto3" json:"multiply_linear_by_lr,omitempty"`
	// Previously, allow_zero_accumulator parameter changed some internal formulas
	// to allow zero and near-zero accumulator values at the cost of some
	// performance. The current implementation ignores this parameter; zero or
	// near-zero accumulator values are now always supported.
	//
	// Deprecated: Marked as deprecated in tensorflow/core/protobuf/tpu/optimization_parameters.proto.
	AllowZeroAccumulator bool `protobuf:"varint,8,opt,name=allow_zero_accumulator,json=allowZeroAccumulator,proto3" json:"allow_zero_accumulator,omitempty"`
	// contains filtered or unexported fields
}

https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L2646

The hyperparameters for FTRL are the same as for the Keras implementation, with some additions. The "beta" parameter matches the behavior described in the second link above; "beta" / (2 * learning rate) should be added to "l2" to get equivalent behavior in the other TensorFlow implementations of this optimizer. When the multiply_linear_by_lr field is set to true, a modified formula is used for FTRL that treats the "linear" accumulator as being pre-multiplied by the learning rate (i.e., the accumulator named "linear" actually stores "linear * learning_rate"). Other than checkpoint compatibility, this is mathematically equivalent for a static learning rate; for a dynamic learning rate, it is nearly the same as long as the learning rate does not change quickly. The benefit of setting multiply_linear_by_lr to true is that the modified formula handles zero and near-zero learning rates without producing NaNs, improving flexibility for learning rate ramp-up.

func (*FtrlParameters) Descriptor deprecated

func (*FtrlParameters) Descriptor() ([]byte, []int)

Deprecated: Use FtrlParameters.ProtoReflect.Descriptor instead.

func (*FtrlParameters) GetAllowZeroAccumulator deprecated

func (x *FtrlParameters) GetAllowZeroAccumulator() bool

Deprecated: Marked as deprecated in tensorflow/core/protobuf/tpu/optimization_parameters.proto.

func (*FtrlParameters) GetBeta

func (x *FtrlParameters) GetBeta() float32

func (*FtrlParameters) GetL1

func (x *FtrlParameters) GetL1() float32

func (*FtrlParameters) GetL2

func (x *FtrlParameters) GetL2() float32

func (*FtrlParameters) GetLrPower

func (x *FtrlParameters) GetLrPower() float32

func (*FtrlParameters) GetMultiplyLinearByLr

func (x *FtrlParameters) GetMultiplyLinearByLr() bool

func (*FtrlParameters) ProtoMessage

func (*FtrlParameters) ProtoMessage()

func (*FtrlParameters) ProtoReflect

func (x *FtrlParameters) ProtoReflect() protoreflect.Message

func (*FtrlParameters) Reset

func (x *FtrlParameters) Reset()

func (*FtrlParameters) String

func (x *FtrlParameters) String() string

type GradientAccumulationStatus

type GradientAccumulationStatus struct {
	// contains filtered or unexported fields
}

Status of using gradient accumulation (doing two passes over the input gradients: one to accumulate them into a temporary array and another to apply them using the actual optimization algorithm). The extra message is to wrap the enum for scoping.

func (*GradientAccumulationStatus) Descriptor deprecated

func (*GradientAccumulationStatus) Descriptor() ([]byte, []int)

Deprecated: Use GradientAccumulationStatus.ProtoReflect.Descriptor instead.

func (*GradientAccumulationStatus) ProtoMessage

func (*GradientAccumulationStatus) ProtoMessage()

func (*GradientAccumulationStatus) ProtoReflect

func (*GradientAccumulationStatus) Reset

func (x *GradientAccumulationStatus) Reset()

func (*GradientAccumulationStatus) String

func (x *GradientAccumulationStatus) String() string

type GradientAccumulationStatus_Status

type GradientAccumulationStatus_Status int32

if UNSPECIFIED (default), gradient accumulation is ENABLED.

const (
	GradientAccumulationStatus_UNSPECIFIED GradientAccumulationStatus_Status = 0
	GradientAccumulationStatus_ENABLED     GradientAccumulationStatus_Status = 1
	GradientAccumulationStatus_DISABLED    GradientAccumulationStatus_Status = 2
)

func (GradientAccumulationStatus_Status) Descriptor

func (GradientAccumulationStatus_Status) Enum

func (GradientAccumulationStatus_Status) EnumDescriptor deprecated

func (GradientAccumulationStatus_Status) EnumDescriptor() ([]byte, []int)

Deprecated: Use GradientAccumulationStatus_Status.Descriptor instead.

func (GradientAccumulationStatus_Status) Number

func (GradientAccumulationStatus_Status) String

func (GradientAccumulationStatus_Status) Type

type HotIdReplicationConfiguration

type HotIdReplicationConfiguration struct {
	Status HotIdReplicationConfiguration_Status `protobuf:"varint,1,opt,name=status,proto3,enum=tensorflow.tpu.HotIdReplicationConfiguration_Status" json:"status,omitempty"`
	// contains filtered or unexported fields
}

Configuration proto for hot ID optimization. This is an experimental feature that is currently disabled (by default).

func (*HotIdReplicationConfiguration) Descriptor deprecated

func (*HotIdReplicationConfiguration) Descriptor() ([]byte, []int)

Deprecated: Use HotIdReplicationConfiguration.ProtoReflect.Descriptor instead.

func (*HotIdReplicationConfiguration) GetStatus

func (*HotIdReplicationConfiguration) ProtoMessage

func (*HotIdReplicationConfiguration) ProtoMessage()

func (*HotIdReplicationConfiguration) ProtoReflect

func (*HotIdReplicationConfiguration) Reset

func (x *HotIdReplicationConfiguration) Reset()

func (*HotIdReplicationConfiguration) String

type HotIdReplicationConfiguration_Status

type HotIdReplicationConfiguration_Status int32

Whether to enable or disable hot ID optimization. If set to UNSPECIFIED (default), hot ID optimization is DISABLED. If set to ENABLED, hot ID replication is turned ON. If set to MIGRATION_ONLY, hot ID migration is turned ON.

const (
	HotIdReplicationConfiguration_UNSPECIFIED    HotIdReplicationConfiguration_Status = 0
	HotIdReplicationConfiguration_ENABLED        HotIdReplicationConfiguration_Status = 1
	HotIdReplicationConfiguration_DISABLED       HotIdReplicationConfiguration_Status = 2
	HotIdReplicationConfiguration_MIGRATION_ONLY HotIdReplicationConfiguration_Status = 3
)

func (HotIdReplicationConfiguration_Status) Descriptor

func (HotIdReplicationConfiguration_Status) Enum

func (HotIdReplicationConfiguration_Status) EnumDescriptor deprecated

func (HotIdReplicationConfiguration_Status) EnumDescriptor() ([]byte, []int)

Deprecated: Use HotIdReplicationConfiguration_Status.Descriptor instead.

func (HotIdReplicationConfiguration_Status) Number

func (HotIdReplicationConfiguration_Status) String

func (HotIdReplicationConfiguration_Status) Type

type LearningRate

type LearningRate struct {

	// Types that are assignable to LearningRate:
	//
	//	*LearningRate_Constant
	//	*LearningRate_Dynamic
	LearningRate isLearningRate_LearningRate `protobuf_oneof:"learning_rate"`
	// contains filtered or unexported fields
}

Source of learning rate to use.

func (*LearningRate) Descriptor deprecated

func (*LearningRate) Descriptor() ([]byte, []int)

Deprecated: Use LearningRate.ProtoReflect.Descriptor instead.

func (*LearningRate) GetConstant

func (x *LearningRate) GetConstant() float32

func (*LearningRate) GetDynamic

func (x *LearningRate) GetDynamic() *DynamicLearningRate

func (*LearningRate) GetLearningRate

func (m *LearningRate) GetLearningRate() isLearningRate_LearningRate

func (*LearningRate) ProtoMessage

func (*LearningRate) ProtoMessage()

func (*LearningRate) ProtoReflect

func (x *LearningRate) ProtoReflect() protoreflect.Message

func (*LearningRate) Reset

func (x *LearningRate) Reset()

func (*LearningRate) String

func (x *LearningRate) String() string

type LearningRate_Constant

type LearningRate_Constant struct {
	Constant float32 `protobuf:"fixed32,1,opt,name=constant,proto3,oneof"`
}

type LearningRate_Dynamic

type LearningRate_Dynamic struct {
	Dynamic *DynamicLearningRate `protobuf:"bytes,2,opt,name=dynamic,proto3,oneof"`
}

type LowDimensionalPackingStatus

type LowDimensionalPackingStatus struct {
	// contains filtered or unexported fields
}

There is one important limitation for this HBM packing though. When only a subset of rows in an 8-float chunk are accessed on a particular step, the adjoining rows in the same chunk are updated with zero gradients on the backward pass even if they are not touched. This is an artifact of the packing implementation. This operation is NOT functionally correct for optimizers where zero gradients change the embeddings/slot-variable values, e.g., momentum-based optimizers. Hence, this HBM packing cannot be enabled for embedding tables with such optimizers. The TPU software automatically recognizes that a zero gradient can modify state and turns off the low dimensional embedding packing in that scenario.

However, for optimizers where a zero gradient is a NoOp, such as SGD, Adagrad, and FTRL, this packing optimization can be used. However, there are some important considerations:

  • Clipping limits: The initial values for such embeddings should fall within the clipping limits specified in the optimization parameters. Otherwise, a zero gradient will cause the embeddings to be clipped. This changes state and hence, is not a NoOp.
  • FTRL: The embedding vector is computed directly from the values of the accumulator and linear slot variables. Hence, the initial embedding values should match that computed from the initial values of the accumulator and linear slot variables. Note that in nearly all cases, the linear value is initialized to zero; this corresponds to an embedding value of zero.

Performance: The TPU has to perform additional work when low dimensional packing is enabled. In certain situations when the vocabulary size is small, it may not make sense to turn on this packing since the total memory usage due to padding is extremely low. Hence, the TPU software automatically turns off the packing optimization in such scenarios.

func (*LowDimensionalPackingStatus) Descriptor deprecated

func (*LowDimensionalPackingStatus) Descriptor() ([]byte, []int)

Deprecated: Use LowDimensionalPackingStatus.ProtoReflect.Descriptor instead.

func (*LowDimensionalPackingStatus) ProtoMessage

func (*LowDimensionalPackingStatus) ProtoMessage()

func (*LowDimensionalPackingStatus) ProtoReflect

func (*LowDimensionalPackingStatus) Reset

func (x *LowDimensionalPackingStatus) Reset()

func (*LowDimensionalPackingStatus) String

func (x *LowDimensionalPackingStatus) String() string

type LowDimensionalPackingStatus_Status

type LowDimensionalPackingStatus_Status int32

if UNSPECIFIED (default), the low dimension packing status is DISABLED. This can change in future.

if ENABLED, the low dimension packing is enabled only if the following three additional conditions are true:

  • The optimizer treats the zero gradient as a NoOp.
  • The embedding dimension is 1, 2, or 4.
  • The vocabulary size is large enough to avoid performance issues.

if DISABLED, the low dimension packing is always disabled.

const (
	LowDimensionalPackingStatus_UNSPECIFIED LowDimensionalPackingStatus_Status = 0
	LowDimensionalPackingStatus_ENABLED     LowDimensionalPackingStatus_Status = 1
	LowDimensionalPackingStatus_DISABLED    LowDimensionalPackingStatus_Status = 2
)

func (LowDimensionalPackingStatus_Status) Descriptor

func (LowDimensionalPackingStatus_Status) Enum

func (LowDimensionalPackingStatus_Status) EnumDescriptor deprecated

func (LowDimensionalPackingStatus_Status) EnumDescriptor() ([]byte, []int)

Deprecated: Use LowDimensionalPackingStatus_Status.Descriptor instead.

func (LowDimensionalPackingStatus_Status) Number

func (LowDimensionalPackingStatus_Status) String

func (LowDimensionalPackingStatus_Status) Type

type MdlAdagradLightParameters

type MdlAdagradLightParameters struct {
	L2                    float32 `protobuf:"fixed32,1,opt,name=l2,proto3" json:"l2,omitempty"`
	LrPower               float32 `protobuf:"fixed32,2,opt,name=lr_power,json=lrPower,proto3" json:"lr_power,omitempty"`
	MinServableMdlBenefit float32 `` /* 130-byte string literal not displayed */
	MdlMixInMargin        float32 `protobuf:"fixed32,4,opt,name=mdl_mix_in_margin,json=mdlMixInMargin,proto3" json:"mdl_mix_in_margin,omitempty"`
	MdlBenefitRampupCoeff float32 `` /* 130-byte string literal not displayed */
	MdlMinWeight          float32 `protobuf:"fixed32,6,opt,name=mdl_min_weight,json=mdlMinWeight,proto3" json:"mdl_min_weight,omitempty"`
	BenefitRevisitScale   float32 `protobuf:"fixed32,7,opt,name=benefit_revisit_scale,json=benefitRevisitScale,proto3" json:"benefit_revisit_scale,omitempty"`
	MaxEventBenefit       float32 `protobuf:"fixed32,8,opt,name=max_event_benefit,json=maxEventBenefit,proto3" json:"max_event_benefit,omitempty"`
	MaxTotalBenefit       float32 `protobuf:"fixed32,9,opt,name=max_total_benefit,json=maxTotalBenefit,proto3" json:"max_total_benefit,omitempty"`
	MdlHardLimit          float32 `protobuf:"fixed32,10,opt,name=mdl_hard_limit,json=mdlHardLimit,proto3" json:"mdl_hard_limit,omitempty"`
	HardLimitMinBenefit   bool    `protobuf:"varint,11,opt,name=hard_limit_min_benefit,json=hardLimitMinBenefit,proto3" json:"hard_limit_min_benefit,omitempty"`
	MdlRegularize         bool    `protobuf:"varint,12,opt,name=mdl_regularize,json=mdlRegularize,proto3" json:"mdl_regularize,omitempty"`
	// contains filtered or unexported fields
}

Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf

func (*MdlAdagradLightParameters) Descriptor deprecated

func (*MdlAdagradLightParameters) Descriptor() ([]byte, []int)

Deprecated: Use MdlAdagradLightParameters.ProtoReflect.Descriptor instead.

func (*MdlAdagradLightParameters) GetBenefitRevisitScale

func (x *MdlAdagradLightParameters) GetBenefitRevisitScale() float32

func (*MdlAdagradLightParameters) GetHardLimitMinBenefit

func (x *MdlAdagradLightParameters) GetHardLimitMinBenefit() bool

func (*MdlAdagradLightParameters) GetL2

func (*MdlAdagradLightParameters) GetLrPower

func (x *MdlAdagradLightParameters) GetLrPower() float32

func (*MdlAdagradLightParameters) GetMaxEventBenefit

func (x *MdlAdagradLightParameters) GetMaxEventBenefit() float32

func (*MdlAdagradLightParameters) GetMaxTotalBenefit

func (x *MdlAdagradLightParameters) GetMaxTotalBenefit() float32

func (*MdlAdagradLightParameters) GetMdlBenefitRampupCoeff

func (x *MdlAdagradLightParameters) GetMdlBenefitRampupCoeff() float32

func (*MdlAdagradLightParameters) GetMdlHardLimit

func (x *MdlAdagradLightParameters) GetMdlHardLimit() float32

func (*MdlAdagradLightParameters) GetMdlMinWeight

func (x *MdlAdagradLightParameters) GetMdlMinWeight() float32

func (*MdlAdagradLightParameters) GetMdlMixInMargin

func (x *MdlAdagradLightParameters) GetMdlMixInMargin() float32

func (*MdlAdagradLightParameters) GetMdlRegularize

func (x *MdlAdagradLightParameters) GetMdlRegularize() bool

func (*MdlAdagradLightParameters) GetMinServableMdlBenefit

func (x *MdlAdagradLightParameters) GetMinServableMdlBenefit() float32

func (*MdlAdagradLightParameters) ProtoMessage

func (*MdlAdagradLightParameters) ProtoMessage()

func (*MdlAdagradLightParameters) ProtoReflect

func (*MdlAdagradLightParameters) Reset

func (x *MdlAdagradLightParameters) Reset()

func (*MdlAdagradLightParameters) String

func (x *MdlAdagradLightParameters) String() string

type MomentumParameters

type MomentumParameters struct {
	Momentum    float32 `protobuf:"fixed32,1,opt,name=momentum,proto3" json:"momentum,omitempty"`
	UseNesterov bool    `protobuf:"varint,2,opt,name=use_nesterov,json=useNesterov,proto3" json:"use_nesterov,omitempty"`
	// contains filtered or unexported fields
}

https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L3068

func (*MomentumParameters) Descriptor deprecated

func (*MomentumParameters) Descriptor() ([]byte, []int)

Deprecated: Use MomentumParameters.ProtoReflect.Descriptor instead.

func (*MomentumParameters) GetMomentum

func (x *MomentumParameters) GetMomentum() float32

func (*MomentumParameters) GetUseNesterov

func (x *MomentumParameters) GetUseNesterov() bool

func (*MomentumParameters) ProtoMessage

func (*MomentumParameters) ProtoMessage()

func (*MomentumParameters) ProtoReflect

func (x *MomentumParameters) ProtoReflect() protoreflect.Message

func (*MomentumParameters) Reset

func (x *MomentumParameters) Reset()

func (*MomentumParameters) String

func (x *MomentumParameters) String() string

type OnlineYogiParameters

type OnlineYogiParameters struct {

	// The L1 regularization parameter (used analogously to the one in FTRL).
	L1 float32 `protobuf:"fixed32,1,opt,name=l1,proto3" json:"l1,omitempty"`
	// The L2 regularization parameter (used analogously to the one in FTRL).
	L2 float32 `protobuf:"fixed32,2,opt,name=l2,proto3" json:"l2,omitempty"`
	// \beta_2 from Algorithm 2 in the paper.
	Beta2 float32 `protobuf:"fixed32,3,opt,name=beta2,proto3" json:"beta2,omitempty"`
	// contains filtered or unexported fields
}

The online Yogi optimizer does not implement hyper-parameter update; use the dynamic learning rate feature instead, setting the learning rate to: user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) Here, t is the current timestep.

https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf plus some extensions based on FTRL.

Note that the code by default implements the lazy version of online Yogi.

func (*OnlineYogiParameters) Descriptor deprecated

func (*OnlineYogiParameters) Descriptor() ([]byte, []int)

Deprecated: Use OnlineYogiParameters.ProtoReflect.Descriptor instead.

func (*OnlineYogiParameters) GetBeta2

func (x *OnlineYogiParameters) GetBeta2() float32

func (*OnlineYogiParameters) GetL1

func (x *OnlineYogiParameters) GetL1() float32

func (*OnlineYogiParameters) GetL2

func (x *OnlineYogiParameters) GetL2() float32

func (*OnlineYogiParameters) ProtoMessage

func (*OnlineYogiParameters) ProtoMessage()

func (*OnlineYogiParameters) ProtoReflect

func (x *OnlineYogiParameters) ProtoReflect() protoreflect.Message

func (*OnlineYogiParameters) Reset

func (x *OnlineYogiParameters) Reset()

func (*OnlineYogiParameters) String

func (x *OnlineYogiParameters) String() string

type OptimizationParameters

type OptimizationParameters struct {

	// Learning rate used for updating the embedding layer parameters.
	LearningRate *LearningRate `protobuf:"bytes,13,opt,name=learning_rate,json=learningRate,proto3" json:"learning_rate,omitempty"`
	// Limits to which to clip the weight values after the backward pass; not
	// present means no limits are applied.
	ClippingLimits *ClippingLimits `protobuf:"bytes,2,opt,name=clipping_limits,json=clippingLimits,proto3" json:"clipping_limits,omitempty"`
	// Limits to which to clip the backward pass gradient before using it for
	// updates; not present means no limits are applied.
	GradientClippingLimits *ClippingLimits `` /* 129-byte string literal not displayed */
	// Amount of weight decay to apply; see weight_decay_optimizers.py for
	// details. All optimizers except MDL Adagrad Light are supported with this
	// option. Although there is no check, users who want weight decay will also
	// want to ensure that gradient accumulation is enabled so that the decay will
	// happen once per global batch.
	WeightDecayFactor float32 `protobuf:"fixed32,16,opt,name=weight_decay_factor,json=weightDecayFactor,proto3" json:"weight_decay_factor,omitempty"`
	// If true, the weight decay factor is multiplied by the current learning rate
	// before use; this is to match the note in DecoupledWeightDecayExtension in
	// weight_decay_optimizers.py.
	MultiplyWeightDecayFactorByLearningRate bool `` /* 190-byte string literal not displayed */
	// Configuration for simulated quantization which is used to reduce
	// training/serving skew when the serving variables are quantized. The same
	// quantization operations are executed during training to minimize
	// differences with serving.
	SimulatedQuantization *SimulatedQuantization `protobuf:"bytes,27,opt,name=simulated_quantization,json=simulatedQuantization,proto3" json:"simulated_quantization,omitempty"`
	// Status of using gradient accumulation (doing two passes over the input
	// gradients: one to accumulate them into a temporary array and another to
	// apply them using the actual optimization algorithm).
	GradientAccumulationStatus GradientAccumulationStatus_Status `` /* 197-byte string literal not displayed */
	// Status of the low-dimensional embedding packing optimization. This controls
	// whether to optimize the packing of 1-dimensional, 2-dimensional, and
	// 4-dimensional embedding tables in memory.
	LowDimensionalPackingStatus LowDimensionalPackingStatus_Status `` /* 203-byte string literal not displayed */
	// Configuration proto for hot ID replication. This is an experimental
	// feature that is currently disabled (by default).
	HotIdReplicationConfiguration *HotIdReplicationConfiguration `` /* 153-byte string literal not displayed */
	// Optimization algorithm parameters; which field is selected determines which
	// algorithm to use.
	//
	// Types that are assignable to Parameters:
	//
	//	*OptimizationParameters_Adagrad
	//	*OptimizationParameters_AdagradMomentum
	//	*OptimizationParameters_BoundedAdagrad
	//	*OptimizationParameters_StochasticGradientDescent
	//	*OptimizationParameters_Ftrl
	//	*OptimizationParameters_Adam
	//	*OptimizationParameters_Momentum
	//	*OptimizationParameters_RmsProp
	//	*OptimizationParameters_CenteredRmsProp
	//	*OptimizationParameters_MdlAdagradLight
	//	*OptimizationParameters_Adadelta
	//	*OptimizationParameters_ProximalAdagrad
	//	*OptimizationParameters_OnlineYogi
	//	*OptimizationParameters_ProximalYogi
	//	*OptimizationParameters_FrequencyEstimator
	//	*OptimizationParameters_UserDefinedProgram
	//	*OptimizationParameters_Assign
	Parameters isOptimizationParameters_Parameters `protobuf_oneof:"parameters"`
	// contains filtered or unexported fields
}

func (*OptimizationParameters) Descriptor deprecated

func (*OptimizationParameters) Descriptor() ([]byte, []int)

Deprecated: Use OptimizationParameters.ProtoReflect.Descriptor instead.

func (*OptimizationParameters) GetAdadelta

func (x *OptimizationParameters) GetAdadelta() *AdadeltaParameters

func (*OptimizationParameters) GetAdagrad

func (x *OptimizationParameters) GetAdagrad() *AdagradParameters

func (*OptimizationParameters) GetAdagradMomentum

func (x *OptimizationParameters) GetAdagradMomentum() *AdagradMomentumParameters

func (*OptimizationParameters) GetAdam

func (x *OptimizationParameters) GetAdam() *AdamParameters

func (*OptimizationParameters) GetAssign

func (x *OptimizationParameters) GetAssign() *AssignParameters

func (*OptimizationParameters) GetBoundedAdagrad

func (x *OptimizationParameters) GetBoundedAdagrad() *BoundedAdagradParameters

func (*OptimizationParameters) GetCenteredRmsProp

func (x *OptimizationParameters) GetCenteredRmsProp() *CenteredRmsPropParameters

func (*OptimizationParameters) GetClippingLimits

func (x *OptimizationParameters) GetClippingLimits() *ClippingLimits

func (*OptimizationParameters) GetFrequencyEstimator

func (x *OptimizationParameters) GetFrequencyEstimator() *FrequencyEstimatorParameters

func (*OptimizationParameters) GetFtrl

func (x *OptimizationParameters) GetFtrl() *FtrlParameters

func (*OptimizationParameters) GetGradientAccumulationStatus

func (x *OptimizationParameters) GetGradientAccumulationStatus() GradientAccumulationStatus_Status

func (*OptimizationParameters) GetGradientClippingLimits

func (x *OptimizationParameters) GetGradientClippingLimits() *ClippingLimits

func (*OptimizationParameters) GetHotIdReplicationConfiguration

func (x *OptimizationParameters) GetHotIdReplicationConfiguration() *HotIdReplicationConfiguration

func (*OptimizationParameters) GetLearningRate

func (x *OptimizationParameters) GetLearningRate() *LearningRate

func (*OptimizationParameters) GetLowDimensionalPackingStatus

func (x *OptimizationParameters) GetLowDimensionalPackingStatus() LowDimensionalPackingStatus_Status

func (*OptimizationParameters) GetMdlAdagradLight

func (x *OptimizationParameters) GetMdlAdagradLight() *MdlAdagradLightParameters

func (*OptimizationParameters) GetMomentum

func (x *OptimizationParameters) GetMomentum() *MomentumParameters

func (*OptimizationParameters) GetMultiplyWeightDecayFactorByLearningRate

func (x *OptimizationParameters) GetMultiplyWeightDecayFactorByLearningRate() bool

func (*OptimizationParameters) GetOnlineYogi

func (x *OptimizationParameters) GetOnlineYogi() *OnlineYogiParameters

func (*OptimizationParameters) GetParameters

func (m *OptimizationParameters) GetParameters() isOptimizationParameters_Parameters

func (*OptimizationParameters) GetProximalAdagrad

func (x *OptimizationParameters) GetProximalAdagrad() *ProximalAdagradParameters

func (*OptimizationParameters) GetProximalYogi

func (x *OptimizationParameters) GetProximalYogi() *ProximalYogiParameters

func (*OptimizationParameters) GetRmsProp

func (x *OptimizationParameters) GetRmsProp() *RmsPropParameters

func (*OptimizationParameters) GetSimulatedQuantization

func (x *OptimizationParameters) GetSimulatedQuantization() *SimulatedQuantization

func (*OptimizationParameters) GetStochasticGradientDescent

func (x *OptimizationParameters) GetStochasticGradientDescent() *StochasticGradientDescentParameters

func (*OptimizationParameters) GetUserDefinedProgram

func (x *OptimizationParameters) GetUserDefinedProgram() *UserDefinedProgramParameters

func (*OptimizationParameters) GetWeightDecayFactor

func (x *OptimizationParameters) GetWeightDecayFactor() float32

func (*OptimizationParameters) ProtoMessage

func (*OptimizationParameters) ProtoMessage()

func (*OptimizationParameters) ProtoReflect

func (x *OptimizationParameters) ProtoReflect() protoreflect.Message

func (*OptimizationParameters) Reset

func (x *OptimizationParameters) Reset()

func (*OptimizationParameters) String

func (x *OptimizationParameters) String() string

type OptimizationParameters_Adadelta

type OptimizationParameters_Adadelta struct {
	Adadelta *AdadeltaParameters `protobuf:"bytes,12,opt,name=adadelta,proto3,oneof"`
}

type OptimizationParameters_Adagrad

type OptimizationParameters_Adagrad struct {
	Adagrad *AdagradParameters `protobuf:"bytes,3,opt,name=adagrad,proto3,oneof"`
}

type OptimizationParameters_AdagradMomentum

type OptimizationParameters_AdagradMomentum struct {
	AdagradMomentum *AdagradMomentumParameters `protobuf:"bytes,26,opt,name=adagrad_momentum,json=adagradMomentum,proto3,oneof"`
}

type OptimizationParameters_Adam

type OptimizationParameters_Adam struct {
	Adam *AdamParameters `protobuf:"bytes,6,opt,name=adam,proto3,oneof"`
}

type OptimizationParameters_Assign

type OptimizationParameters_Assign struct {
	Assign *AssignParameters `protobuf:"bytes,25,opt,name=assign,proto3,oneof"`
}

type OptimizationParameters_BoundedAdagrad

type OptimizationParameters_BoundedAdagrad struct {
	BoundedAdagrad *BoundedAdagradParameters `protobuf:"bytes,19,opt,name=bounded_adagrad,json=boundedAdagrad,proto3,oneof"`
}

type OptimizationParameters_CenteredRmsProp

type OptimizationParameters_CenteredRmsProp struct {
	CenteredRmsProp *CenteredRmsPropParameters `protobuf:"bytes,10,opt,name=centered_rms_prop,json=centeredRmsProp,proto3,oneof"`
}

type OptimizationParameters_FrequencyEstimator

type OptimizationParameters_FrequencyEstimator struct {
	FrequencyEstimator *FrequencyEstimatorParameters `protobuf:"bytes,23,opt,name=frequency_estimator,json=frequencyEstimator,proto3,oneof"`
}

type OptimizationParameters_Ftrl

type OptimizationParameters_Ftrl struct {
	Ftrl *FtrlParameters `protobuf:"bytes,5,opt,name=ftrl,proto3,oneof"`
}

type OptimizationParameters_MdlAdagradLight

type OptimizationParameters_MdlAdagradLight struct {
	MdlAdagradLight *MdlAdagradLightParameters `protobuf:"bytes,11,opt,name=mdl_adagrad_light,json=mdlAdagradLight,proto3,oneof"`
}

type OptimizationParameters_Momentum

type OptimizationParameters_Momentum struct {
	Momentum *MomentumParameters `protobuf:"bytes,8,opt,name=momentum,proto3,oneof"`
}

type OptimizationParameters_OnlineYogi

type OptimizationParameters_OnlineYogi struct {
	OnlineYogi *OnlineYogiParameters `protobuf:"bytes,20,opt,name=online_yogi,json=onlineYogi,proto3,oneof"`
}

type OptimizationParameters_ProximalAdagrad

type OptimizationParameters_ProximalAdagrad struct {
	ProximalAdagrad *ProximalAdagradParameters `protobuf:"bytes,14,opt,name=proximal_adagrad,json=proximalAdagrad,proto3,oneof"`
}

type OptimizationParameters_ProximalYogi

type OptimizationParameters_ProximalYogi struct {
	ProximalYogi *ProximalYogiParameters `protobuf:"bytes,21,opt,name=proximal_yogi,json=proximalYogi,proto3,oneof"`
}

type OptimizationParameters_RmsProp

type OptimizationParameters_RmsProp struct {
	RmsProp *RmsPropParameters `protobuf:"bytes,9,opt,name=rms_prop,json=rmsProp,proto3,oneof"`
}

type OptimizationParameters_StochasticGradientDescent

type OptimizationParameters_StochasticGradientDescent struct {
	StochasticGradientDescent *StochasticGradientDescentParameters `protobuf:"bytes,4,opt,name=stochastic_gradient_descent,json=stochasticGradientDescent,proto3,oneof"`
}

type OptimizationParameters_UserDefinedProgram

type OptimizationParameters_UserDefinedProgram struct {
	UserDefinedProgram *UserDefinedProgramParameters `protobuf:"bytes,24,opt,name=user_defined_program,json=userDefinedProgram,proto3,oneof"`
}

type PaddingMap

type PaddingMap struct {

	// Input arg index with dynamic shapes.
	ArgIndex int32 `protobuf:"varint,1,opt,name=arg_index,json=argIndex,proto3" json:"arg_index,omitempty"`
	// The dynamic shape dimension index.
	ShapeIndex int32 `protobuf:"varint,2,opt,name=shape_index,json=shapeIndex,proto3" json:"shape_index,omitempty"`
	// The arg index that dynamic dimension maps to, which represents the value
	// of the real shape.
	PaddingArgIndex int32 `protobuf:"varint,3,opt,name=padding_arg_index,json=paddingArgIndex,proto3" json:"padding_arg_index,omitempty"`
	// contains filtered or unexported fields
}

A mapping between the dynamic shape dimension of an input and the arg that represents the real shape.

func (*PaddingMap) Descriptor deprecated

func (*PaddingMap) Descriptor() ([]byte, []int)

Deprecated: Use PaddingMap.ProtoReflect.Descriptor instead.

func (*PaddingMap) GetArgIndex

func (x *PaddingMap) GetArgIndex() int32

func (*PaddingMap) GetPaddingArgIndex

func (x *PaddingMap) GetPaddingArgIndex() int32

func (*PaddingMap) GetShapeIndex

func (x *PaddingMap) GetShapeIndex() int32

func (*PaddingMap) ProtoMessage

func (*PaddingMap) ProtoMessage()

func (*PaddingMap) ProtoReflect

func (x *PaddingMap) ProtoReflect() protoreflect.Message

func (*PaddingMap) Reset

func (x *PaddingMap) Reset()

func (*PaddingMap) String

func (x *PaddingMap) String() string

type ProximalAdagradParameters

type ProximalAdagradParameters struct {
	L1 float32 `protobuf:"fixed32,1,opt,name=l1,proto3" json:"l1,omitempty"`
	L2 float32 `protobuf:"fixed32,2,opt,name=l2,proto3" json:"l2,omitempty"`
	// contains filtered or unexported fields
}

https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/ProximalAdagradOptimizer https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L1961

func (*ProximalAdagradParameters) Descriptor deprecated

func (*ProximalAdagradParameters) Descriptor() ([]byte, []int)

Deprecated: Use ProximalAdagradParameters.ProtoReflect.Descriptor instead.

func (*ProximalAdagradParameters) GetL1

func (*ProximalAdagradParameters) GetL2

func (*ProximalAdagradParameters) ProtoMessage

func (*ProximalAdagradParameters) ProtoMessage()

func (*ProximalAdagradParameters) ProtoReflect

func (*ProximalAdagradParameters) Reset

func (x *ProximalAdagradParameters) Reset()

func (*ProximalAdagradParameters) String

func (x *ProximalAdagradParameters) String() string

type ProximalYogiParameters

type ProximalYogiParameters struct {

	// The L1 regularization parameter.
	L1 float32 `protobuf:"fixed32,1,opt,name=l1,proto3" json:"l1,omitempty"`
	// The L2 regularization parameter.
	L2 float32 `protobuf:"fixed32,2,opt,name=l2,proto3" json:"l2,omitempty"`
	// The exponential decay rate for the 1st moment estimates.
	Beta1 float32 `protobuf:"fixed32,3,opt,name=beta1,proto3" json:"beta1,omitempty"`
	// The exponential decay rate for the 2nd moment estimates.
	Beta2 float32 `protobuf:"fixed32,4,opt,name=beta2,proto3" json:"beta2,omitempty"`
	// A constant trading off adaptivity and noise.
	Epsilon float32 `protobuf:"fixed32,5,opt,name=epsilon,proto3" json:"epsilon,omitempty"`
	// contains filtered or unexported fields
}

The online Yogi optimizer does not implement hyper-parameter update; use the dynamic learning rate feature instead, setting the learning rate to: user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) Here, t is the current timestep.

https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf plus some extensions based on FTRL.

Note that the code by default implements the lazy version of proximal Yogi.

func (*ProximalYogiParameters) Descriptor deprecated

func (*ProximalYogiParameters) Descriptor() ([]byte, []int)

Deprecated: Use ProximalYogiParameters.ProtoReflect.Descriptor instead.

func (*ProximalYogiParameters) GetBeta1

func (x *ProximalYogiParameters) GetBeta1() float32

func (*ProximalYogiParameters) GetBeta2

func (x *ProximalYogiParameters) GetBeta2() float32

func (*ProximalYogiParameters) GetEpsilon

func (x *ProximalYogiParameters) GetEpsilon() float32

func (*ProximalYogiParameters) GetL1

func (x *ProximalYogiParameters) GetL1() float32

func (*ProximalYogiParameters) GetL2

func (x *ProximalYogiParameters) GetL2() float32

func (*ProximalYogiParameters) ProtoMessage

func (*ProximalYogiParameters) ProtoMessage()

func (*ProximalYogiParameters) ProtoReflect

func (x *ProximalYogiParameters) ProtoReflect() protoreflect.Message

func (*ProximalYogiParameters) Reset

func (x *ProximalYogiParameters) Reset()

func (*ProximalYogiParameters) String

func (x *ProximalYogiParameters) String() string

type RmsPropParameters

type RmsPropParameters struct {
	Rho      float32 `protobuf:"fixed32,1,opt,name=rho,proto3" json:"rho,omitempty"`
	Momentum float32 `protobuf:"fixed32,2,opt,name=momentum,proto3" json:"momentum,omitempty"`
	Epsilon  float32 `protobuf:"fixed32,3,opt,name=epsilon,proto3" json:"epsilon,omitempty"`
	// contains filtered or unexported fields
}

https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L4229

func (*RmsPropParameters) Descriptor deprecated

func (*RmsPropParameters) Descriptor() ([]byte, []int)

Deprecated: Use RmsPropParameters.ProtoReflect.Descriptor instead.

func (*RmsPropParameters) GetEpsilon

func (x *RmsPropParameters) GetEpsilon() float32

func (*RmsPropParameters) GetMomentum

func (x *RmsPropParameters) GetMomentum() float32

func (*RmsPropParameters) GetRho

func (x *RmsPropParameters) GetRho() float32

func (*RmsPropParameters) ProtoMessage

func (*RmsPropParameters) ProtoMessage()

func (*RmsPropParameters) ProtoReflect

func (x *RmsPropParameters) ProtoReflect() protoreflect.Message

func (*RmsPropParameters) Reset

func (x *RmsPropParameters) Reset()

func (*RmsPropParameters) String

func (x *RmsPropParameters) String() string

type SimulatedQuantization

type SimulatedQuantization struct {

	// Whether simulated quantization is enabled.
	Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"`
	// Minimum and maximum values of the range used for quantization.
	ClippingLimits *ClippingLimits `protobuf:"bytes,2,opt,name=clipping_limits,json=clippingLimits,proto3" json:"clipping_limits,omitempty"`
	// Number of possible quantized values.
	NumBuckets int32 `protobuf:"varint,3,opt,name=num_buckets,json=numBuckets,proto3" json:"num_buckets,omitempty"`
	// contains filtered or unexported fields
}

Configuration for simulated quantization; simulated quantization is used to reduce training/serving skew when the serving variables are quantized. The same quantization operations are executed during training to minimize differences with serving.

Simulated quantization inserts the following operations on the forward pass after gathering the embedding vector from HBM. The backward pass operations are unchanged.

clipped_val = clip(input, clipping_limits) quantum = clipping_limits.range() / (num_buckets - 1) quantized_val = floor((clipped_val - clipping_limits.lower()) / quantum + .5) return quantized_val * quantum + clipping_limits.lower().

func (*SimulatedQuantization) Descriptor deprecated

func (*SimulatedQuantization) Descriptor() ([]byte, []int)

Deprecated: Use SimulatedQuantization.ProtoReflect.Descriptor instead.

func (*SimulatedQuantization) GetClippingLimits

func (x *SimulatedQuantization) GetClippingLimits() *ClippingLimits

func (*SimulatedQuantization) GetEnabled

func (x *SimulatedQuantization) GetEnabled() bool

func (*SimulatedQuantization) GetNumBuckets

func (x *SimulatedQuantization) GetNumBuckets() int32

func (*SimulatedQuantization) ProtoMessage

func (*SimulatedQuantization) ProtoMessage()

func (*SimulatedQuantization) ProtoReflect

func (x *SimulatedQuantization) ProtoReflect() protoreflect.Message

func (*SimulatedQuantization) Reset

func (x *SimulatedQuantization) Reset()

func (*SimulatedQuantization) String

func (x *SimulatedQuantization) String() string

type StateVariableSpecification

type StateVariableSpecification struct {

	// Parameter name for the state variable.
	Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
	// Usage type of this state variable.
	//
	// Types that are assignable to Usage:
	//
	//	*StateVariableSpecification_UserDefined_
	//	*StateVariableSpecification_FillWithConstant_
	Usage isStateVariableSpecification_Usage `protobuf_oneof:"usage"`
	// contains filtered or unexported fields
}

Specification of an optimization algorithm's state variables (both the main value vector and any extra accumulators, etc.). This proto is only used internally by the TPU software and is not exposed directly to the TF model.

func (*StateVariableSpecification) Descriptor deprecated

func (*StateVariableSpecification) Descriptor() ([]byte, []int)

Deprecated: Use StateVariableSpecification.ProtoReflect.Descriptor instead.

func (*StateVariableSpecification) GetFillWithConstant

func (*StateVariableSpecification) GetName

func (x *StateVariableSpecification) GetName() string

func (*StateVariableSpecification) GetUsage

func (m *StateVariableSpecification) GetUsage() isStateVariableSpecification_Usage

func (*StateVariableSpecification) GetUserDefined

func (*StateVariableSpecification) ProtoMessage

func (*StateVariableSpecification) ProtoMessage()

func (*StateVariableSpecification) ProtoReflect

func (*StateVariableSpecification) Reset

func (x *StateVariableSpecification) Reset()

func (*StateVariableSpecification) String

func (x *StateVariableSpecification) String() string

type StateVariableSpecification_FillWithConstant

type StateVariableSpecification_FillWithConstant struct {
	InitialValue float64 `protobuf:"fixed64,1,opt,name=initial_value,json=initialValue,proto3" json:"initial_value,omitempty"`
	// contains filtered or unexported fields
}

A state variable that should be filled with a constant and normally hidden from users (used for intermediate gradients being accumulated, for example).

func (*StateVariableSpecification_FillWithConstant) Descriptor deprecated

Deprecated: Use StateVariableSpecification_FillWithConstant.ProtoReflect.Descriptor instead.

func (*StateVariableSpecification_FillWithConstant) GetInitialValue

func (*StateVariableSpecification_FillWithConstant) ProtoMessage

func (*StateVariableSpecification_FillWithConstant) ProtoReflect

func (*StateVariableSpecification_FillWithConstant) Reset

func (*StateVariableSpecification_FillWithConstant) String

type StateVariableSpecification_FillWithConstant_

type StateVariableSpecification_FillWithConstant_ struct {
	FillWithConstant *StateVariableSpecification_FillWithConstant `protobuf:"bytes,3,opt,name=fill_with_constant,json=fillWithConstant,proto3,oneof"`
}

type StateVariableSpecification_UserDefined

type StateVariableSpecification_UserDefined struct {
	// contains filtered or unexported fields
}

A normal state variable that should be saved and restored in checkpoints and used as an input or output to non-debug TensorFlow ops.

func (*StateVariableSpecification_UserDefined) Descriptor deprecated

func (*StateVariableSpecification_UserDefined) Descriptor() ([]byte, []int)

Deprecated: Use StateVariableSpecification_UserDefined.ProtoReflect.Descriptor instead.

func (*StateVariableSpecification_UserDefined) ProtoMessage

func (*StateVariableSpecification_UserDefined) ProtoReflect

func (*StateVariableSpecification_UserDefined) Reset

func (*StateVariableSpecification_UserDefined) String

type StateVariableSpecification_UserDefined_

type StateVariableSpecification_UserDefined_ struct {
	UserDefined *StateVariableSpecification_UserDefined `protobuf:"bytes,2,opt,name=user_defined,json=userDefined,proto3,oneof"`
}

type StochasticGradientDescentParameters

type StochasticGradientDescentParameters struct {
	// contains filtered or unexported fields
}

https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L629

func (*StochasticGradientDescentParameters) Descriptor deprecated

func (*StochasticGradientDescentParameters) Descriptor() ([]byte, []int)

Deprecated: Use StochasticGradientDescentParameters.ProtoReflect.Descriptor instead.

func (*StochasticGradientDescentParameters) ProtoMessage

func (*StochasticGradientDescentParameters) ProtoMessage()

func (*StochasticGradientDescentParameters) ProtoReflect

func (*StochasticGradientDescentParameters) Reset

func (*StochasticGradientDescentParameters) String

type TPUCompileMetadataProto

type TPUCompileMetadataProto struct {
	Args    []*TPUCompileMetadataProto_Arg    `protobuf:"bytes,1,rep,name=args,proto3" json:"args,omitempty"`
	Retvals []*TPUCompileMetadataProto_Retval `protobuf:"bytes,2,rep,name=retvals,proto3" json:"retvals,omitempty"`
	// Number of replicas of the computation and number of cores in each replica.
	// TODO(b/140721404): it may not be necessary to state the number of cores per
	// replica here. Reconsider when replicated model-parallelism is implemented
	// in XLA.
	NumReplicas        int32                       `protobuf:"varint,3,opt,name=num_replicas,json=numReplicas,proto3" json:"num_replicas,omitempty"`
	NumCoresPerReplica int32                       `protobuf:"varint,4,opt,name=num_cores_per_replica,json=numCoresPerReplica,proto3" json:"num_cores_per_replica,omitempty"`
	DeviceAssignment   *data.DeviceAssignmentProto `protobuf:"bytes,8,opt,name=device_assignment,json=deviceAssignment,proto3" json:"device_assignment,omitempty"`
	// A fingerprint of the function library. Ensures that any functions called
	// by the computation have matching definitions.
	FunctionLibraryFingerprint uint64 `` /* 142-byte string literal not displayed */
	// Unique session identifier. Can be empty.
	SessionHandle string `protobuf:"bytes,9,opt,name=session_handle,json=sessionHandle,proto3" json:"session_handle,omitempty"`
	// Fingerprint of guaranteed_const value. The fingerprint computation inside
	// tpu_compile_op may be slow. The computation can be avoided by setting the
	// fingerprint value here.
	GuaranteedConstFingerprint string        `` /* 142-byte string literal not displayed */
	PaddingMaps                []*PaddingMap `protobuf:"bytes,11,rep,name=padding_maps,json=paddingMaps,proto3" json:"padding_maps,omitempty"`
	// The location of step markers that XLA compile will instrument.
	StepMarkerLocation xla.DebugOptions_StepMarkerLocation `` /* 160-byte string literal not displayed */
	// Minimum number of batches run through the XLA graph before XLA fusion
	// autotuner is enabled. Default value of zero disables the autotuner.
	// The XLA fusion autotuner can improve performance by executing a heuristic
	// search on the compiler parameters.
	XlaFusionAutotunerThresh int64 `` /* 139-byte string literal not displayed */
	// Enables TPU compiler to add partitioning policies for inputs/outputs to
	// the XLA computation for model parallelism.
	EnableAutomaticModelParallelism bool `` /* 160-byte string literal not displayed */
	// Whether to use XLA's SPMD or MPMD partitioner when compiler partitioning is
	// requested.
	UseSpmdForXlaPartitioning bool `` /* 144-byte string literal not displayed */
	// Whether to automatically generate XLA shardings for SPMD partitioner.
	UseAutoSpmdForXlaPartitioning bool `` /* 158-byte string literal not displayed */
	// Device mesh shape used to create the sharding search space when
	// use_auto_spmd_partitioning=true.
	AutoSpmdMeshShape []int64 `protobuf:"varint,19,rep,packed,name=auto_spmd_mesh_shape,json=autoSpmdMeshShape,proto3" json:"auto_spmd_mesh_shape,omitempty"`
	// Device mesh ids compatible with the above mesh_shape used when
	// use_auto_spmd_partitioning=true.
	AutoSpmdMeshIds []int64 `protobuf:"varint,20,rep,packed,name=auto_spmd_mesh_ids,json=autoSpmdMeshIds,proto3" json:"auto_spmd_mesh_ids,omitempty"`
	// A fingerprint generated by hashing the MLIR module content.
	MlirFingerprint uint64             `protobuf:"varint,17,opt,name=mlir_fingerprint,json=mlirFingerprint,proto3" json:"mlir_fingerprint,omitempty"`
	CompileOptions  *TPUCompileOptions `protobuf:"bytes,21,opt,name=compile_options,json=compileOptions,proto3" json:"compile_options,omitempty"`
	// contains filtered or unexported fields
}

This is an experimental proto used in the TF/XLA bridge to store metadata to a compile op (e.g. _TPUCompileMlir). TODO(lyandy): Deprecate proto once generic metadata proto is created.

func (*TPUCompileMetadataProto) Descriptor deprecated

func (*TPUCompileMetadataProto) Descriptor() ([]byte, []int)

Deprecated: Use TPUCompileMetadataProto.ProtoReflect.Descriptor instead.

func (*TPUCompileMetadataProto) GetArgs

func (*TPUCompileMetadataProto) GetAutoSpmdMeshIds

func (x *TPUCompileMetadataProto) GetAutoSpmdMeshIds() []int64

func (*TPUCompileMetadataProto) GetAutoSpmdMeshShape

func (x *TPUCompileMetadataProto) GetAutoSpmdMeshShape() []int64

func (*TPUCompileMetadataProto) GetCompileOptions

func (x *TPUCompileMetadataProto) GetCompileOptions() *TPUCompileOptions

func (*TPUCompileMetadataProto) GetDeviceAssignment

func (x *TPUCompileMetadataProto) GetDeviceAssignment() *data.DeviceAssignmentProto

func (*TPUCompileMetadataProto) GetEnableAutomaticModelParallelism

func (x *TPUCompileMetadataProto) GetEnableAutomaticModelParallelism() bool

func (*TPUCompileMetadataProto) GetFunctionLibraryFingerprint

func (x *TPUCompileMetadataProto) GetFunctionLibraryFingerprint() uint64

func (*TPUCompileMetadataProto) GetGuaranteedConstFingerprint

func (x *TPUCompileMetadataProto) GetGuaranteedConstFingerprint() string

func (*TPUCompileMetadataProto) GetMlirFingerprint

func (x *TPUCompileMetadataProto) GetMlirFingerprint() uint64

func (*TPUCompileMetadataProto) GetNumCoresPerReplica

func (x *TPUCompileMetadataProto) GetNumCoresPerReplica() int32

func (*TPUCompileMetadataProto) GetNumReplicas

func (x *TPUCompileMetadataProto) GetNumReplicas() int32

func (*TPUCompileMetadataProto) GetPaddingMaps

func (x *TPUCompileMetadataProto) GetPaddingMaps() []*PaddingMap

func (*TPUCompileMetadataProto) GetRetvals

func (*TPUCompileMetadataProto) GetSessionHandle

func (x *TPUCompileMetadataProto) GetSessionHandle() string

func (*TPUCompileMetadataProto) GetStepMarkerLocation

func (x *TPUCompileMetadataProto) GetStepMarkerLocation() xla.DebugOptions_StepMarkerLocation

func (*TPUCompileMetadataProto) GetUseAutoSpmdForXlaPartitioning

func (x *TPUCompileMetadataProto) GetUseAutoSpmdForXlaPartitioning() bool

func (*TPUCompileMetadataProto) GetUseSpmdForXlaPartitioning

func (x *TPUCompileMetadataProto) GetUseSpmdForXlaPartitioning() bool

func (*TPUCompileMetadataProto) GetXlaFusionAutotunerThresh

func (x *TPUCompileMetadataProto) GetXlaFusionAutotunerThresh() int64

func (*TPUCompileMetadataProto) ProtoMessage

func (*TPUCompileMetadataProto) ProtoMessage()

func (*TPUCompileMetadataProto) ProtoReflect

func (x *TPUCompileMetadataProto) ProtoReflect() protoreflect.Message

func (*TPUCompileMetadataProto) Reset

func (x *TPUCompileMetadataProto) Reset()

func (*TPUCompileMetadataProto) String

func (x *TPUCompileMetadataProto) String() string

type TPUCompileMetadataProto_Arg

type TPUCompileMetadataProto_Arg struct {
	Dtype framework.DataType               `protobuf:"varint,1,opt,name=dtype,proto3,enum=tensorflow.DataType" json:"dtype,omitempty"`
	Shape *framework.TensorShapeProto      `protobuf:"bytes,2,opt,name=shape,proto3" json:"shape,omitempty"`
	Kind  TPUCompileMetadataProto_Arg_Kind `protobuf:"varint,3,opt,name=kind,proto3,enum=tensorflow.tpu.TPUCompileMetadataProto_Arg_Kind" json:"kind,omitempty"`
	// The cross-core sharding of this input within each replica, e.g.,
	// assigning to one core, or replicate across all cores.
	Sharding *data.OpSharding `protobuf:"bytes,4,opt,name=sharding,proto3" json:"sharding,omitempty"`
	// Whether this argument will receive the same data across all replicas.
	IsSameDataAcrossReplicas bool `` /* 140-byte string literal not displayed */
	// Whether to allow XLA to produce separate programs to shard/unshard this
	// argument. Requires this arg to be an on-device Kind::VARIABLE, or a
	// Kind::PARAMETER. For Kind::PARAMETER, it represents the initial value of
	// a variable, and retval_index_for_sharding must be specified for the
	// corresponding updated value.
	EnableXlaSharding TPUCompileMetadataProto_Arg_EnableXlaSharding `` /* 181-byte string literal not displayed */
	// If XLA sharding is allowed on a Kind::PARAMETER, this field is used to
	// specify the corresponding updated value in the return values. Use -1 for
	// variables that are not updated.
	RetvalIndexForSharding int32 `` /* 132-byte string literal not displayed */
	// Whether this argument is placed on fast memory or not.
	FastMem bool `protobuf:"varint,7,opt,name=fast_mem,json=fastMem,proto3" json:"fast_mem,omitempty"`
	// Whether to let XLA to decide the layout during compilation, as opposed to
	// using a fixed layout determined by the shape.
	UnrestrictedLayout bool `protobuf:"varint,9,opt,name=unrestricted_layout,json=unrestrictedLayout,proto3" json:"unrestricted_layout,omitempty"`
	// Name of the node that the arg comes from.
	Name string `protobuf:"bytes,10,opt,name=name,proto3" json:"name,omitempty"`
	// Whether to use XLA collectives to broadcast this parameter to all
	// replicas, instead of using TensorFlow Send/Recv among the tasks.
	RequiresXlaBroadcast bool `protobuf:"varint,11,opt,name=requires_xla_broadcast,json=requiresXlaBroadcast,proto3" json:"requires_xla_broadcast,omitempty"`
	// contains filtered or unexported fields
}

Description of the types and shapes of the arguments to a computation.

func (*TPUCompileMetadataProto_Arg) Descriptor deprecated

func (*TPUCompileMetadataProto_Arg) Descriptor() ([]byte, []int)

Deprecated: Use TPUCompileMetadataProto_Arg.ProtoReflect.Descriptor instead.

func (*TPUCompileMetadataProto_Arg) GetDtype

func (*TPUCompileMetadataProto_Arg) GetEnableXlaSharding

func (*TPUCompileMetadataProto_Arg) GetFastMem

func (x *TPUCompileMetadataProto_Arg) GetFastMem() bool

func (*TPUCompileMetadataProto_Arg) GetIsSameDataAcrossReplicas

func (x *TPUCompileMetadataProto_Arg) GetIsSameDataAcrossReplicas() bool

func (*TPUCompileMetadataProto_Arg) GetKind

func (*TPUCompileMetadataProto_Arg) GetName

func (x *TPUCompileMetadataProto_Arg) GetName() string

func (*TPUCompileMetadataProto_Arg) GetRequiresXlaBroadcast

func (x *TPUCompileMetadataProto_Arg) GetRequiresXlaBroadcast() bool

func (*TPUCompileMetadataProto_Arg) GetRetvalIndexForSharding

func (x *TPUCompileMetadataProto_Arg) GetRetvalIndexForSharding() int32

func (*TPUCompileMetadataProto_Arg) GetShape

func (*TPUCompileMetadataProto_Arg) GetSharding

func (x *TPUCompileMetadataProto_Arg) GetSharding() *data.OpSharding

func (*TPUCompileMetadataProto_Arg) GetUnrestrictedLayout

func (x *TPUCompileMetadataProto_Arg) GetUnrestrictedLayout() bool

func (*TPUCompileMetadataProto_Arg) ProtoMessage

func (*TPUCompileMetadataProto_Arg) ProtoMessage()

func (*TPUCompileMetadataProto_Arg) ProtoReflect

func (*TPUCompileMetadataProto_Arg) Reset

func (x *TPUCompileMetadataProto_Arg) Reset()

func (*TPUCompileMetadataProto_Arg) String

func (x *TPUCompileMetadataProto_Arg) String() string

type TPUCompileMetadataProto_Arg_EnableXlaSharding

type TPUCompileMetadataProto_Arg_EnableXlaSharding int32
const (
	TPUCompileMetadataProto_Arg_DISALLOWED TPUCompileMetadataProto_Arg_EnableXlaSharding = 0
	// Sharding is allowed if host training loop exists.
	TPUCompileMetadataProto_Arg_TENTATIVE TPUCompileMetadataProto_Arg_EnableXlaSharding = 1
	TPUCompileMetadataProto_Arg_ALLOWED   TPUCompileMetadataProto_Arg_EnableXlaSharding = 2
)

func (TPUCompileMetadataProto_Arg_EnableXlaSharding) Descriptor

func (TPUCompileMetadataProto_Arg_EnableXlaSharding) Enum

func (TPUCompileMetadataProto_Arg_EnableXlaSharding) EnumDescriptor deprecated

func (TPUCompileMetadataProto_Arg_EnableXlaSharding) EnumDescriptor() ([]byte, []int)

Deprecated: Use TPUCompileMetadataProto_Arg_EnableXlaSharding.Descriptor instead.

func (TPUCompileMetadataProto_Arg_EnableXlaSharding) Number

func (TPUCompileMetadataProto_Arg_EnableXlaSharding) String

func (TPUCompileMetadataProto_Arg_EnableXlaSharding) Type

type TPUCompileMetadataProto_Arg_Kind

type TPUCompileMetadataProto_Arg_Kind int32
const (
	TPUCompileMetadataProto_Arg_INVALID   TPUCompileMetadataProto_Arg_Kind = 0
	TPUCompileMetadataProto_Arg_PARAMETER TPUCompileMetadataProto_Arg_Kind = 1
	TPUCompileMetadataProto_Arg_VARIABLE  TPUCompileMetadataProto_Arg_Kind = 2
	// These are args which have been guaranteed to be constants during the
	// session lifetime by the use of the GuaranteeConstOp (or ConstantOp).
	TPUCompileMetadataProto_Arg_GUARANTEED_CONSTANT TPUCompileMetadataProto_Arg_Kind = 3
)

func (TPUCompileMetadataProto_Arg_Kind) Descriptor

func (TPUCompileMetadataProto_Arg_Kind) Enum

func (TPUCompileMetadataProto_Arg_Kind) EnumDescriptor deprecated

func (TPUCompileMetadataProto_Arg_Kind) EnumDescriptor() ([]byte, []int)

Deprecated: Use TPUCompileMetadataProto_Arg_Kind.Descriptor instead.

func (TPUCompileMetadataProto_Arg_Kind) Number

func (TPUCompileMetadataProto_Arg_Kind) String

func (TPUCompileMetadataProto_Arg_Kind) Type

type TPUCompileMetadataProto_Retval

type TPUCompileMetadataProto_Retval struct {

	// The cross-core sharding of this return value within each replica, e.g.,
	// assigning to one core, or replicate across all cores.
	Sharding *data.OpSharding `protobuf:"bytes,1,opt,name=sharding,proto3" json:"sharding,omitempty"`
	// contains filtered or unexported fields
}

Description of the return values from a computation.

func (*TPUCompileMetadataProto_Retval) Descriptor deprecated

func (*TPUCompileMetadataProto_Retval) Descriptor() ([]byte, []int)

Deprecated: Use TPUCompileMetadataProto_Retval.ProtoReflect.Descriptor instead.

func (*TPUCompileMetadataProto_Retval) GetSharding

func (*TPUCompileMetadataProto_Retval) ProtoMessage

func (*TPUCompileMetadataProto_Retval) ProtoMessage()

func (*TPUCompileMetadataProto_Retval) ProtoReflect

func (*TPUCompileMetadataProto_Retval) Reset

func (x *TPUCompileMetadataProto_Retval) Reset()

func (*TPUCompileMetadataProto_Retval) String

type TPUCompileOptions

type TPUCompileOptions struct {
	MatrixUnitOperandPrecision TPUCompileOptions_Precision `` /* 192-byte string literal not displayed */
	// contains filtered or unexported fields
}

Stable protobuf for TPU compilation options, suitable for persistent storage. This proto needs to be backward compatible under maintenance. TODO(timshen): investigate and migrate other options from TPUCompileMetadataProto.

func (*TPUCompileOptions) Descriptor deprecated

func (*TPUCompileOptions) Descriptor() ([]byte, []int)

Deprecated: Use TPUCompileOptions.ProtoReflect.Descriptor instead.

func (*TPUCompileOptions) GetMatrixUnitOperandPrecision

func (x *TPUCompileOptions) GetMatrixUnitOperandPrecision() TPUCompileOptions_Precision

func (*TPUCompileOptions) ProtoMessage

func (*TPUCompileOptions) ProtoMessage()

func (*TPUCompileOptions) ProtoReflect

func (x *TPUCompileOptions) ProtoReflect() protoreflect.Message

func (*TPUCompileOptions) Reset

func (x *TPUCompileOptions) Reset()

func (*TPUCompileOptions) String

func (x *TPUCompileOptions) String() string

type TPUCompileOptions_Precision

type TPUCompileOptions_Precision int32
const (
	TPUCompileOptions_DEFAULT        TPUCompileOptions_Precision = 0
	TPUCompileOptions_BFLOAT16       TPUCompileOptions_Precision = 1
	TPUCompileOptions_FLOAT32        TPUCompileOptions_Precision = 2
	TPUCompileOptions_TENSOR_FLOAT32 TPUCompileOptions_Precision = 3
)

func (TPUCompileOptions_Precision) Descriptor

func (TPUCompileOptions_Precision) Enum

func (TPUCompileOptions_Precision) EnumDescriptor deprecated

func (TPUCompileOptions_Precision) EnumDescriptor() ([]byte, []int)

Deprecated: Use TPUCompileOptions_Precision.Descriptor instead.

func (TPUCompileOptions_Precision) Number

func (TPUCompileOptions_Precision) String

func (TPUCompileOptions_Precision) Type

type TPUEmbeddingConfiguration

type TPUEmbeddingConfiguration struct {
	TableDescriptor []*TPUEmbeddingConfiguration_TableDescriptor `protobuf:"bytes,1,rep,name=table_descriptor,json=tableDescriptor,proto3" json:"table_descriptor,omitempty"`
	Mode            TPUEmbeddingConfiguration_Mode               `protobuf:"varint,2,opt,name=mode,proto3,enum=tensorflow.tpu.TPUEmbeddingConfiguration_Mode" json:"mode,omitempty"`
	// Number of samples in each batch of embedding layer activations sent to
	// the TensorCore.
	BatchSizePerTensorCore int32 `` /* 134-byte string literal not displayed */
	// Number of TPU hosts used for inference/training.
	NumHosts int32 `protobuf:"varint,4,opt,name=num_hosts,json=numHosts,proto3" json:"num_hosts,omitempty"`
	// Number of TensorCore used for inference/training.
	NumTensorCores   int32                                      `protobuf:"varint,5,opt,name=num_tensor_cores,json=numTensorCores,proto3" json:"num_tensor_cores,omitempty"`
	ShardingStrategy TPUEmbeddingConfiguration_ShardingStrategy `` /* 173-byte string literal not displayed */
	// This parameter determines if the execution of the sparse core will be
	// pipelined with that of the TensorCore. This parameter only affects results
	// when mode=TRAINING. If mode=INFERENCE or BACKWARD_PASS_ONLY, this parameter
	// does not affect execution and hence, is a don't care value.
	//
	// false: The execution of the sparse core is not pipelined with that of the
	// TensorCore. The forward pass of every step on the sparse core is executed
	// only after the backward pass of the previous step is complete. And the
	// backward pass on the sparse core is executed only after the embedding
	// gradients have been computed on the TensorCore on every step. This ensures
	// that the activations on every step observe the gradient updates from the
	// previous step on both the sparse core and the TensorCore.
	//
	// true: The execution of the sparse core is pipelined with that of the
	// TensorCore. The forward pass of every step on the sparse core can be
	// executed after the forward pass of the previous step is complete without
	// waiting for the backward pass. This improves the utilization of the sparse
	// core allowing it to process step N+1 while the embedding gradients for step
	// N are computed on the TensorCore. The backward pass of every step on the
	// sparse core is executed directly after the forward pass for the next step
	// is complete. The drawback is that embedding activations for step N+1 do not
	// observe the embedding gradient updates from step N. This could affect model
	// quality if step N and N+1 involve the same set of embedding IDs. However,
	// since the embedding updates are sparse, this is generally not considered a
	// problem.
	PipelineExecutionWithTensorCore bool `` /* 161-byte string literal not displayed */
	// Directory where embedding lookup statistics are stored. These statistics
	// summarize information about the inputs to the embedding lookup
	// operation, in particular, the average number of embedding IDs per example
	// and how well the embedding IDs are load balanced across the system. The
	// lookup statistics are used during TPU initialization for embedding table
	// partitioning. Collection of lookup statistics is done at runtime by
	// profiling the embedding inputs: only 3% of input samples are profiled to
	// minimize host CPU overhead. Once a suitable number of samples are
	// profiled, the lookup statistics are saved to table-specific files in the
	// profile data directory generally at the end of a TPU training loop. The
	// filename corresponding to each table is obtained by hashing table specific
	// parameters (e.g., table name and number of features) and global
	// configuration parameters (e.g., sharding strategy and TPU worker task
	// count). The same profile data directory can be shared amongst several
	// models to reuse embedding lookup statistics.
	ProfileDataDirectory string `protobuf:"bytes,9,opt,name=profile_data_directory,json=profileDataDirectory,proto3" json:"profile_data_directory,omitempty"`
	// If the feature_descriptor field is populated, the model should NOT populate
	// TableDescriptor.num_features and batch_size_per_tensor_core. These two
	// fields will be auto-populated by the TPUEmbedding rewrite passes.
	FeatureDescriptor []*TPUEmbeddingConfiguration_FeatureDescriptor `protobuf:"bytes,10,rep,name=feature_descriptor,json=featureDescriptor,proto3" json:"feature_descriptor,omitempty"`
	SpmdSharding      *TPUEmbeddingConfiguration_SpmdSharding        `protobuf:"bytes,11,opt,name=spmd_sharding,json=spmdSharding,proto3" json:"spmd_sharding,omitempty"`
	// contains filtered or unexported fields
}

func (*TPUEmbeddingConfiguration) Descriptor deprecated

func (*TPUEmbeddingConfiguration) Descriptor() ([]byte, []int)

Deprecated: Use TPUEmbeddingConfiguration.ProtoReflect.Descriptor instead.

func (*TPUEmbeddingConfiguration) GetBatchSizePerTensorCore

func (x *TPUEmbeddingConfiguration) GetBatchSizePerTensorCore() int32

func (*TPUEmbeddingConfiguration) GetFeatureDescriptor

func (*TPUEmbeddingConfiguration) GetMode

func (*TPUEmbeddingConfiguration) GetNumHosts

func (x *TPUEmbeddingConfiguration) GetNumHosts() int32

func (*TPUEmbeddingConfiguration) GetNumTensorCores

func (x *TPUEmbeddingConfiguration) GetNumTensorCores() int32

func (*TPUEmbeddingConfiguration) GetPipelineExecutionWithTensorCore

func (x *TPUEmbeddingConfiguration) GetPipelineExecutionWithTensorCore() bool

func (*TPUEmbeddingConfiguration) GetProfileDataDirectory

func (x *TPUEmbeddingConfiguration) GetProfileDataDirectory() string

func (*TPUEmbeddingConfiguration) GetShardingStrategy

func (*TPUEmbeddingConfiguration) GetSpmdSharding

func (*TPUEmbeddingConfiguration) GetTableDescriptor

func (*TPUEmbeddingConfiguration) ProtoMessage

func (*TPUEmbeddingConfiguration) ProtoMessage()

func (*TPUEmbeddingConfiguration) ProtoReflect

func (*TPUEmbeddingConfiguration) Reset

func (x *TPUEmbeddingConfiguration) Reset()

func (*TPUEmbeddingConfiguration) String

func (x *TPUEmbeddingConfiguration) String() string

type TPUEmbeddingConfiguration_FeatureDescriptor

type TPUEmbeddingConfiguration_FeatureDescriptor struct {

	// Name of the input feature.
	Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
	// Index of the corresponding table in the TableDescriptor list.
	TableId int32 `protobuf:"varint,2,opt,name=table_id,json=tableId,proto3" json:"table_id,omitempty"`
	// Static shape of the inputs (excluding the reduction axis). Note that
	// the shape of the actual inputs provided using the infeed op must be
	// strictly smaller than input_shape. The outputs received at the TensorCore
	// will have rank = input_shape.size() + 1. The innermost axis corresponds
	// to the embedding dimension. If the input has shape [m, n, k] (excluding
	// the reduction axis) and the embedding dimension is d, the output received
	// at the TensorCore will have shape [m, n, k, d].
	InputShape []int32 `protobuf:"varint,3,rep,packed,name=input_shape,json=inputShape,proto3" json:"input_shape,omitempty"`
	// contains filtered or unexported fields
}

Description of different input features.

func (*TPUEmbeddingConfiguration_FeatureDescriptor) Descriptor deprecated

Deprecated: Use TPUEmbeddingConfiguration_FeatureDescriptor.ProtoReflect.Descriptor instead.

func (*TPUEmbeddingConfiguration_FeatureDescriptor) GetInputShape

func (*TPUEmbeddingConfiguration_FeatureDescriptor) GetName

func (*TPUEmbeddingConfiguration_FeatureDescriptor) GetTableId

func (*TPUEmbeddingConfiguration_FeatureDescriptor) ProtoMessage

func (*TPUEmbeddingConfiguration_FeatureDescriptor) ProtoReflect

func (*TPUEmbeddingConfiguration_FeatureDescriptor) Reset

func (*TPUEmbeddingConfiguration_FeatureDescriptor) String

type TPUEmbeddingConfiguration_Mode

type TPUEmbeddingConfiguration_Mode int32

Mode. Should the embedding layer program be run for inference (just forward pass), training (both forward and backward pass) or just the backward_pass.

const (
	TPUEmbeddingConfiguration_UNSPECIFIED        TPUEmbeddingConfiguration_Mode = 0
	TPUEmbeddingConfiguration_INFERENCE          TPUEmbeddingConfiguration_Mode = 1
	TPUEmbeddingConfiguration_TRAINING           TPUEmbeddingConfiguration_Mode = 2
	TPUEmbeddingConfiguration_BACKWARD_PASS_ONLY TPUEmbeddingConfiguration_Mode = 3
)

func (TPUEmbeddingConfiguration_Mode) Descriptor

func (TPUEmbeddingConfiguration_Mode) Enum

func (TPUEmbeddingConfiguration_Mode) EnumDescriptor deprecated

func (TPUEmbeddingConfiguration_Mode) EnumDescriptor() ([]byte, []int)

Deprecated: Use TPUEmbeddingConfiguration_Mode.Descriptor instead.

func (TPUEmbeddingConfiguration_Mode) Number

func (TPUEmbeddingConfiguration_Mode) String

func (TPUEmbeddingConfiguration_Mode) Type

type TPUEmbeddingConfiguration_ShardingStrategy

type TPUEmbeddingConfiguration_ShardingStrategy int32

Sharding strategy of the embedding tables among the hosts. If the sharding_strategy is "mod", each id is assigned to host "id % num_hosts". For instance, 13 ids are split across 5 hosts as: [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]. If the sharding_strategy is "div", ids are assigned to hosts in a contiguous manner. In this case, 13 ids are split across 5 hosts as: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]. In both the strategies, if the id space does not evenly divide the number of hosts, each of the first "table_descriptor.vocabulary_size % num_hosts" hosts will be assigned one more id. This partitioning strategy exactly follows that in the embedding_lookup TensorFlow function at tensorflow/python/ops/embedding_ops.py.

const (
	TPUEmbeddingConfiguration_DIV_DEFAULT TPUEmbeddingConfiguration_ShardingStrategy = 0
	TPUEmbeddingConfiguration_MOD         TPUEmbeddingConfiguration_ShardingStrategy = 1
)

func (TPUEmbeddingConfiguration_ShardingStrategy) Descriptor

func (TPUEmbeddingConfiguration_ShardingStrategy) Enum

func (TPUEmbeddingConfiguration_ShardingStrategy) EnumDescriptor deprecated

func (TPUEmbeddingConfiguration_ShardingStrategy) EnumDescriptor() ([]byte, []int)

Deprecated: Use TPUEmbeddingConfiguration_ShardingStrategy.Descriptor instead.

func (TPUEmbeddingConfiguration_ShardingStrategy) Number

func (TPUEmbeddingConfiguration_ShardingStrategy) String

func (TPUEmbeddingConfiguration_ShardingStrategy) Type

type TPUEmbeddingConfiguration_SpmdSharding

type TPUEmbeddingConfiguration_SpmdSharding struct {

	// Whether SPMD sharding is enabled.
	Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"`
	// Number of cores per replica.
	NumCoresPerReplica int32 `protobuf:"varint,2,opt,name=num_cores_per_replica,json=numCoresPerReplica,proto3" json:"num_cores_per_replica,omitempty"`
	// contains filtered or unexported fields
}

SPMD (Single Program Multiple Data) sharding configuration for TPUEmbedding. When model parallelism is used on the TensorCore, the number of cores per replica must be passed to TPUEmbedding so that the right shapes can be computed in the TF/XLA bridge.

func (*TPUEmbeddingConfiguration_SpmdSharding) Descriptor deprecated

func (*TPUEmbeddingConfiguration_SpmdSharding) Descriptor() ([]byte, []int)

Deprecated: Use TPUEmbeddingConfiguration_SpmdSharding.ProtoReflect.Descriptor instead.

func (*TPUEmbeddingConfiguration_SpmdSharding) GetEnabled

func (*TPUEmbeddingConfiguration_SpmdSharding) GetNumCoresPerReplica

func (x *TPUEmbeddingConfiguration_SpmdSharding) GetNumCoresPerReplica() int32

func (*TPUEmbeddingConfiguration_SpmdSharding) ProtoMessage

func (*TPUEmbeddingConfiguration_SpmdSharding) ProtoReflect

func (*TPUEmbeddingConfiguration_SpmdSharding) Reset

func (*TPUEmbeddingConfiguration_SpmdSharding) String

type TPUEmbeddingConfiguration_TableDescriptor

type TPUEmbeddingConfiguration_TableDescriptor struct {

	// Name of the table.
	Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
	// Size of the vocabulary (i.e., number of rows) in the table.
	VocabularySize int64 `protobuf:"varint,2,opt,name=vocabulary_size,json=vocabularySize,proto3" json:"vocabulary_size,omitempty"`
	// The embedding dimension (i.e., the width of the embedding table).
	Dimension int32 `protobuf:"varint,3,opt,name=dimension,proto3" json:"dimension,omitempty"`
	// Number of features mapped to this table.
	NumFeatures int32 `protobuf:"varint,4,opt,name=num_features,json=numFeatures,proto3" json:"num_features,omitempty"`
	// Details of the learning algorithm used to update the embedding
	// parameters.
	OptimizationParameters *OptimizationParameters `` /* 127-byte string literal not displayed */
	// contains filtered or unexported fields
}

Description of the various embedding tables.

func (*TPUEmbeddingConfiguration_TableDescriptor) Descriptor deprecated

func (*TPUEmbeddingConfiguration_TableDescriptor) Descriptor() ([]byte, []int)

Deprecated: Use TPUEmbeddingConfiguration_TableDescriptor.ProtoReflect.Descriptor instead.

func (*TPUEmbeddingConfiguration_TableDescriptor) GetDimension

func (*TPUEmbeddingConfiguration_TableDescriptor) GetName

func (*TPUEmbeddingConfiguration_TableDescriptor) GetNumFeatures

func (*TPUEmbeddingConfiguration_TableDescriptor) GetOptimizationParameters

func (x *TPUEmbeddingConfiguration_TableDescriptor) GetOptimizationParameters() *OptimizationParameters

func (*TPUEmbeddingConfiguration_TableDescriptor) GetVocabularySize

func (x *TPUEmbeddingConfiguration_TableDescriptor) GetVocabularySize() int64

func (*TPUEmbeddingConfiguration_TableDescriptor) ProtoMessage

func (*TPUEmbeddingConfiguration_TableDescriptor) ProtoReflect

func (*TPUEmbeddingConfiguration_TableDescriptor) Reset

func (*TPUEmbeddingConfiguration_TableDescriptor) String

type TPUEmbeddingError

type TPUEmbeddingError struct {
	// contains filtered or unexported fields
}

A placeholder message that is used to define a unique Status payload URL for TPU embedding errors.

func (*TPUEmbeddingError) Descriptor deprecated

func (*TPUEmbeddingError) Descriptor() ([]byte, []int)

Deprecated: Use TPUEmbeddingError.ProtoReflect.Descriptor instead.

func (*TPUEmbeddingError) ProtoMessage

func (*TPUEmbeddingError) ProtoMessage()

func (*TPUEmbeddingError) ProtoReflect

func (x *TPUEmbeddingError) ProtoReflect() protoreflect.Message

func (*TPUEmbeddingError) Reset

func (x *TPUEmbeddingError) Reset()

func (*TPUEmbeddingError) String

func (x *TPUEmbeddingError) String() string

type TPUHardwareFeature

type TPUHardwareFeature struct {
	EmbeddingFeature TPUHardwareFeature_EmbeddingFeature `` /* 166-byte string literal not displayed */
	// contains filtered or unexported fields
}

Describes features of a tpu.

func (*TPUHardwareFeature) Descriptor deprecated

func (*TPUHardwareFeature) Descriptor() ([]byte, []int)

Deprecated: Use TPUHardwareFeature.ProtoReflect.Descriptor instead.

func (*TPUHardwareFeature) GetEmbeddingFeature

func (*TPUHardwareFeature) ProtoMessage

func (*TPUHardwareFeature) ProtoMessage()

func (*TPUHardwareFeature) ProtoReflect

func (x *TPUHardwareFeature) ProtoReflect() protoreflect.Message

func (*TPUHardwareFeature) Reset

func (x *TPUHardwareFeature) Reset()

func (*TPUHardwareFeature) String

func (x *TPUHardwareFeature) String() string

type TPUHardwareFeature_EmbeddingFeature

type TPUHardwareFeature_EmbeddingFeature int32

Embedding feature of a tpu.

const (
	// No embedding lookup accelerator available on the tpu.
	TPUHardwareFeature_UNSUPPORTED TPUHardwareFeature_EmbeddingFeature = 0
	// Embedding lookup accelerator V1. The embedding lookup operation can only
	// be placed at the beginning of computation. Only one instance of embedding
	// lookup layer is allowed.
	TPUHardwareFeature_V1 TPUHardwareFeature_EmbeddingFeature = 1
	// Embedding lookup accelerator V2. The embedding lookup operation can be
	// placed anywhere of the computation. Multiple instances of embedding
	// lookup layer is allowed.
	TPUHardwareFeature_V2 TPUHardwareFeature_EmbeddingFeature = 2
)

func (TPUHardwareFeature_EmbeddingFeature) Descriptor

func (TPUHardwareFeature_EmbeddingFeature) Enum

func (TPUHardwareFeature_EmbeddingFeature) EnumDescriptor deprecated

func (TPUHardwareFeature_EmbeddingFeature) EnumDescriptor() ([]byte, []int)

Deprecated: Use TPUHardwareFeature_EmbeddingFeature.Descriptor instead.

func (TPUHardwareFeature_EmbeddingFeature) Number

func (TPUHardwareFeature_EmbeddingFeature) String

func (TPUHardwareFeature_EmbeddingFeature) Type

type TopologyProto

type TopologyProto struct {

	// The dimensions of the TPU topology, in cores. Typically, this is a 4D
	// topology [x, y, z, core], where the major dimensions correspond to TPU
	// chips, and the minor dimension describes the number of cores on a multicore
	// chip.
	MeshShape []int32 `protobuf:"varint,1,rep,packed,name=mesh_shape,json=meshShape,proto3" json:"mesh_shape,omitempty"`
	// Number of TensorFlow tasks in the cluster.
	NumTasks int32 `protobuf:"varint,2,opt,name=num_tasks,json=numTasks,proto3" json:"num_tasks,omitempty"`
	// Number of TPU devices per task.
	NumTpuDevicesPerTask int32 `` /* 128-byte string literal not displayed */
	// A flattened rank 3 int32 array with shape
	// [num_tasks, num_tpu_devices_per_task, len(mesh_shape)].
	// `tasks` is the number of tasks in the TPU cluster, `devices` is the number
	// of TPU devices per task, and the minor dimension corresponds to a position
	// in the TPU mesh topology. Each entry [task, device, axis] gives the
	// `axis`-th coordinate in the topology of a task/device pair.
	DeviceCoordinates []int32 `protobuf:"varint,4,rep,packed,name=device_coordinates,json=deviceCoordinates,proto3" json:"device_coordinates,omitempty"`
	// TPU supported features.
	TpuHardwareFeature *TPUHardwareFeature `protobuf:"bytes,5,opt,name=tpu_hardware_feature,json=tpuHardwareFeature,proto3" json:"tpu_hardware_feature,omitempty"`
	// contains filtered or unexported fields
}

Describes the geometry of a TPU mesh.

func (*TopologyProto) Descriptor deprecated

func (*TopologyProto) Descriptor() ([]byte, []int)

Deprecated: Use TopologyProto.ProtoReflect.Descriptor instead.

func (*TopologyProto) GetDeviceCoordinates

func (x *TopologyProto) GetDeviceCoordinates() []int32

func (*TopologyProto) GetMeshShape

func (x *TopologyProto) GetMeshShape() []int32

func (*TopologyProto) GetNumTasks

func (x *TopologyProto) GetNumTasks() int32

func (*TopologyProto) GetNumTpuDevicesPerTask

func (x *TopologyProto) GetNumTpuDevicesPerTask() int32

func (*TopologyProto) GetTpuHardwareFeature

func (x *TopologyProto) GetTpuHardwareFeature() *TPUHardwareFeature

func (*TopologyProto) ProtoMessage

func (*TopologyProto) ProtoMessage()

func (*TopologyProto) ProtoReflect

func (x *TopologyProto) ProtoReflect() protoreflect.Message

func (*TopologyProto) Reset

func (x *TopologyProto) Reset()

func (*TopologyProto) String

func (x *TopologyProto) String() string

type UserDefinedProgramParameters

type UserDefinedProgramParameters struct {
	Program *service.HloModuleProto `protobuf:"bytes,1,opt,name=program,proto3" json:"program,omitempty"`
	// contains filtered or unexported fields
}

A user-defined optimizer. The contained HLO program must take the following arguments in the following order:

  1. gradients
  2. table weights
  3. slot variables
  4. an optional scalar input that is passed in via the dynamic learning rate mechanism.

It must return/end in a tuple op that contains the following values in the following order: 1. new table values 2. new slot variable value

The program must have shape (1,1) with dtype float32 throughout and only use HLO that operate elementwise (e.g., no reduce, no variables, no control flow and no broadcasting outside of the single scalar input). The HLO program should be written as if it were a dense update. It will be called on each row that needs an update and will applied elementwise.

func (*UserDefinedProgramParameters) Descriptor deprecated

func (*UserDefinedProgramParameters) Descriptor() ([]byte, []int)

Deprecated: Use UserDefinedProgramParameters.ProtoReflect.Descriptor instead.

func (*UserDefinedProgramParameters) GetProgram

func (*UserDefinedProgramParameters) ProtoMessage

func (*UserDefinedProgramParameters) ProtoMessage()

func (*UserDefinedProgramParameters) ProtoReflect

func (*UserDefinedProgramParameters) Reset

func (x *UserDefinedProgramParameters) Reset()

func (*UserDefinedProgramParameters) String

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL