Documentation ¶
Index ¶
- Variables
- type AdadeltaParameters
- func (*AdadeltaParameters) Descriptor() ([]byte, []int)deprecated
- func (x *AdadeltaParameters) GetEpsilon() float32
- func (x *AdadeltaParameters) GetRho() float32
- func (*AdadeltaParameters) ProtoMessage()
- func (x *AdadeltaParameters) ProtoReflect() protoreflect.Message
- func (x *AdadeltaParameters) Reset()
- func (x *AdadeltaParameters) String() string
- type AdagradMomentumParameters
- func (*AdagradMomentumParameters) Descriptor() ([]byte, []int)deprecated
- func (x *AdagradMomentumParameters) GetBeta2() float32
- func (x *AdagradMomentumParameters) GetEpsilon() float32
- func (x *AdagradMomentumParameters) GetExponent() float32
- func (x *AdagradMomentumParameters) GetMomentum() float32
- func (x *AdagradMomentumParameters) GetUseNesterov() bool
- func (*AdagradMomentumParameters) ProtoMessage()
- func (x *AdagradMomentumParameters) ProtoReflect() protoreflect.Message
- func (x *AdagradMomentumParameters) Reset()
- func (x *AdagradMomentumParameters) String() string
- type AdagradParameters
- type AdamParameters
- func (*AdamParameters) Descriptor() ([]byte, []int)deprecated
- func (x *AdamParameters) GetBeta1() float32
- func (x *AdamParameters) GetBeta2() float32
- func (x *AdamParameters) GetEpsilon() float32
- func (x *AdamParameters) GetUseNonLazyAdam() bool
- func (x *AdamParameters) GetUseSumInsideSqrt() bool
- func (*AdamParameters) ProtoMessage()
- func (x *AdamParameters) ProtoReflect() protoreflect.Message
- func (x *AdamParameters) Reset()
- func (x *AdamParameters) String() string
- type AssignParameters
- type BoundedAdagradParameters
- func (*BoundedAdagradParameters) Descriptor() ([]byte, []int)deprecated
- func (x *BoundedAdagradParameters) GetMaxAccumulator() float32
- func (x *BoundedAdagradParameters) GetMaxVarUpdate() float32
- func (x *BoundedAdagradParameters) GetUpdateAccumulatorFirst() bool
- func (*BoundedAdagradParameters) ProtoMessage()
- func (x *BoundedAdagradParameters) ProtoReflect() protoreflect.Message
- func (x *BoundedAdagradParameters) Reset()
- func (x *BoundedAdagradParameters) String() string
- type CenteredRmsPropParameters
- func (*CenteredRmsPropParameters) Descriptor() ([]byte, []int)deprecated
- func (x *CenteredRmsPropParameters) GetEpsilon() float32
- func (x *CenteredRmsPropParameters) GetMomentum() float32
- func (x *CenteredRmsPropParameters) GetRho() float32
- func (*CenteredRmsPropParameters) ProtoMessage()
- func (x *CenteredRmsPropParameters) ProtoReflect() protoreflect.Message
- func (x *CenteredRmsPropParameters) Reset()
- func (x *CenteredRmsPropParameters) String() string
- type ClippingLimits
- func (*ClippingLimits) Descriptor() ([]byte, []int)deprecated
- func (x *ClippingLimits) GetLower() *wrapperspb.FloatValue
- func (x *ClippingLimits) GetUpper() *wrapperspb.FloatValue
- func (*ClippingLimits) ProtoMessage()
- func (x *ClippingLimits) ProtoReflect() protoreflect.Message
- func (x *ClippingLimits) Reset()
- func (x *ClippingLimits) String() string
- type CompilationResultProto
- func (*CompilationResultProto) Descriptor() ([]byte, []int)deprecated
- func (x *CompilationResultProto) GetErrorCode() CompilationResultProto_ErrorCode
- func (x *CompilationResultProto) GetHloProtos() []*service.HloProto
- func (x *CompilationResultProto) GetStatusCode() protobuf.Code
- func (x *CompilationResultProto) GetStatusErrorMessage() string
- func (*CompilationResultProto) ProtoMessage()
- func (x *CompilationResultProto) ProtoReflect() protoreflect.Message
- func (x *CompilationResultProto) Reset()
- func (x *CompilationResultProto) String() string
- type CompilationResultProto_ErrorCode
- func (CompilationResultProto_ErrorCode) Descriptor() protoreflect.EnumDescriptor
- func (x CompilationResultProto_ErrorCode) Enum() *CompilationResultProto_ErrorCode
- func (CompilationResultProto_ErrorCode) EnumDescriptor() ([]byte, []int)deprecated
- func (x CompilationResultProto_ErrorCode) Number() protoreflect.EnumNumber
- func (x CompilationResultProto_ErrorCode) String() string
- func (CompilationResultProto_ErrorCode) Type() protoreflect.EnumType
- type DynamicLearningRate
- func (*DynamicLearningRate) Descriptor() ([]byte, []int)deprecated
- func (x *DynamicLearningRate) GetTag() int32
- func (*DynamicLearningRate) ProtoMessage()
- func (x *DynamicLearningRate) ProtoReflect() protoreflect.Message
- func (x *DynamicLearningRate) Reset()
- func (x *DynamicLearningRate) String() string
- type FrequencyEstimatorParameters
- func (*FrequencyEstimatorParameters) Descriptor() ([]byte, []int)deprecated
- func (x *FrequencyEstimatorParameters) GetMaxDelta() float32
- func (x *FrequencyEstimatorParameters) GetOutlierThreshold() float32
- func (x *FrequencyEstimatorParameters) GetTau() float32
- func (x *FrequencyEstimatorParameters) GetWeightExponent() float32
- func (*FrequencyEstimatorParameters) ProtoMessage()
- func (x *FrequencyEstimatorParameters) ProtoReflect() protoreflect.Message
- func (x *FrequencyEstimatorParameters) Reset()
- func (x *FrequencyEstimatorParameters) String() string
- type FtrlParameters
- func (*FtrlParameters) Descriptor() ([]byte, []int)deprecated
- func (x *FtrlParameters) GetAllowZeroAccumulator() booldeprecated
- func (x *FtrlParameters) GetBeta() float32
- func (x *FtrlParameters) GetL1() float32
- func (x *FtrlParameters) GetL2() float32
- func (x *FtrlParameters) GetLrPower() float32
- func (x *FtrlParameters) GetMultiplyLinearByLr() bool
- func (*FtrlParameters) ProtoMessage()
- func (x *FtrlParameters) ProtoReflect() protoreflect.Message
- func (x *FtrlParameters) Reset()
- func (x *FtrlParameters) String() string
- type GradientAccumulationStatus
- type GradientAccumulationStatus_Status
- func (GradientAccumulationStatus_Status) Descriptor() protoreflect.EnumDescriptor
- func (x GradientAccumulationStatus_Status) Enum() *GradientAccumulationStatus_Status
- func (GradientAccumulationStatus_Status) EnumDescriptor() ([]byte, []int)deprecated
- func (x GradientAccumulationStatus_Status) Number() protoreflect.EnumNumber
- func (x GradientAccumulationStatus_Status) String() string
- func (GradientAccumulationStatus_Status) Type() protoreflect.EnumType
- type HotIdReplicationConfiguration
- func (*HotIdReplicationConfiguration) Descriptor() ([]byte, []int)deprecated
- func (x *HotIdReplicationConfiguration) GetStatus() HotIdReplicationConfiguration_Status
- func (*HotIdReplicationConfiguration) ProtoMessage()
- func (x *HotIdReplicationConfiguration) ProtoReflect() protoreflect.Message
- func (x *HotIdReplicationConfiguration) Reset()
- func (x *HotIdReplicationConfiguration) String() string
- type HotIdReplicationConfiguration_Status
- func (HotIdReplicationConfiguration_Status) Descriptor() protoreflect.EnumDescriptor
- func (x HotIdReplicationConfiguration_Status) Enum() *HotIdReplicationConfiguration_Status
- func (HotIdReplicationConfiguration_Status) EnumDescriptor() ([]byte, []int)deprecated
- func (x HotIdReplicationConfiguration_Status) Number() protoreflect.EnumNumber
- func (x HotIdReplicationConfiguration_Status) String() string
- func (HotIdReplicationConfiguration_Status) Type() protoreflect.EnumType
- type LearningRate
- func (*LearningRate) Descriptor() ([]byte, []int)deprecated
- func (x *LearningRate) GetConstant() float32
- func (x *LearningRate) GetDynamic() *DynamicLearningRate
- func (m *LearningRate) GetLearningRate() isLearningRate_LearningRate
- func (*LearningRate) ProtoMessage()
- func (x *LearningRate) ProtoReflect() protoreflect.Message
- func (x *LearningRate) Reset()
- func (x *LearningRate) String() string
- type LearningRate_Constant
- type LearningRate_Dynamic
- type LowDimensionalPackingStatus
- type LowDimensionalPackingStatus_Status
- func (LowDimensionalPackingStatus_Status) Descriptor() protoreflect.EnumDescriptor
- func (x LowDimensionalPackingStatus_Status) Enum() *LowDimensionalPackingStatus_Status
- func (LowDimensionalPackingStatus_Status) EnumDescriptor() ([]byte, []int)deprecated
- func (x LowDimensionalPackingStatus_Status) Number() protoreflect.EnumNumber
- func (x LowDimensionalPackingStatus_Status) String() string
- func (LowDimensionalPackingStatus_Status) Type() protoreflect.EnumType
- type MdlAdagradLightParameters
- func (*MdlAdagradLightParameters) Descriptor() ([]byte, []int)deprecated
- func (x *MdlAdagradLightParameters) GetBenefitRevisitScale() float32
- func (x *MdlAdagradLightParameters) GetHardLimitMinBenefit() bool
- func (x *MdlAdagradLightParameters) GetL2() float32
- func (x *MdlAdagradLightParameters) GetLrPower() float32
- func (x *MdlAdagradLightParameters) GetMaxEventBenefit() float32
- func (x *MdlAdagradLightParameters) GetMaxTotalBenefit() float32
- func (x *MdlAdagradLightParameters) GetMdlBenefitRampupCoeff() float32
- func (x *MdlAdagradLightParameters) GetMdlHardLimit() float32
- func (x *MdlAdagradLightParameters) GetMdlMinWeight() float32
- func (x *MdlAdagradLightParameters) GetMdlMixInMargin() float32
- func (x *MdlAdagradLightParameters) GetMdlRegularize() bool
- func (x *MdlAdagradLightParameters) GetMinServableMdlBenefit() float32
- func (*MdlAdagradLightParameters) ProtoMessage()
- func (x *MdlAdagradLightParameters) ProtoReflect() protoreflect.Message
- func (x *MdlAdagradLightParameters) Reset()
- func (x *MdlAdagradLightParameters) String() string
- type MomentumParameters
- func (*MomentumParameters) Descriptor() ([]byte, []int)deprecated
- func (x *MomentumParameters) GetMomentum() float32
- func (x *MomentumParameters) GetUseNesterov() bool
- func (*MomentumParameters) ProtoMessage()
- func (x *MomentumParameters) ProtoReflect() protoreflect.Message
- func (x *MomentumParameters) Reset()
- func (x *MomentumParameters) String() string
- type OnlineYogiParameters
- func (*OnlineYogiParameters) Descriptor() ([]byte, []int)deprecated
- func (x *OnlineYogiParameters) GetBeta2() float32
- func (x *OnlineYogiParameters) GetL1() float32
- func (x *OnlineYogiParameters) GetL2() float32
- func (*OnlineYogiParameters) ProtoMessage()
- func (x *OnlineYogiParameters) ProtoReflect() protoreflect.Message
- func (x *OnlineYogiParameters) Reset()
- func (x *OnlineYogiParameters) String() string
- type OptimizationParameters
- func (*OptimizationParameters) Descriptor() ([]byte, []int)deprecated
- func (x *OptimizationParameters) GetAdadelta() *AdadeltaParameters
- func (x *OptimizationParameters) GetAdagrad() *AdagradParameters
- func (x *OptimizationParameters) GetAdagradMomentum() *AdagradMomentumParameters
- func (x *OptimizationParameters) GetAdam() *AdamParameters
- func (x *OptimizationParameters) GetAssign() *AssignParameters
- func (x *OptimizationParameters) GetBoundedAdagrad() *BoundedAdagradParameters
- func (x *OptimizationParameters) GetCenteredRmsProp() *CenteredRmsPropParameters
- func (x *OptimizationParameters) GetClippingLimits() *ClippingLimits
- func (x *OptimizationParameters) GetFrequencyEstimator() *FrequencyEstimatorParameters
- func (x *OptimizationParameters) GetFtrl() *FtrlParameters
- func (x *OptimizationParameters) GetGradientAccumulationStatus() GradientAccumulationStatus_Status
- func (x *OptimizationParameters) GetGradientClippingLimits() *ClippingLimits
- func (x *OptimizationParameters) GetHotIdReplicationConfiguration() *HotIdReplicationConfiguration
- func (x *OptimizationParameters) GetLearningRate() *LearningRate
- func (x *OptimizationParameters) GetLowDimensionalPackingStatus() LowDimensionalPackingStatus_Status
- func (x *OptimizationParameters) GetMdlAdagradLight() *MdlAdagradLightParameters
- func (x *OptimizationParameters) GetMomentum() *MomentumParameters
- func (x *OptimizationParameters) GetMultiplyWeightDecayFactorByLearningRate() bool
- func (x *OptimizationParameters) GetOnlineYogi() *OnlineYogiParameters
- func (m *OptimizationParameters) GetParameters() isOptimizationParameters_Parameters
- func (x *OptimizationParameters) GetProximalAdagrad() *ProximalAdagradParameters
- func (x *OptimizationParameters) GetProximalYogi() *ProximalYogiParameters
- func (x *OptimizationParameters) GetRmsProp() *RmsPropParameters
- func (x *OptimizationParameters) GetSimulatedQuantization() *SimulatedQuantization
- func (x *OptimizationParameters) GetStochasticGradientDescent() *StochasticGradientDescentParameters
- func (x *OptimizationParameters) GetUserDefinedProgram() *UserDefinedProgramParameters
- func (x *OptimizationParameters) GetWeightDecayFactor() float32
- func (*OptimizationParameters) ProtoMessage()
- func (x *OptimizationParameters) ProtoReflect() protoreflect.Message
- func (x *OptimizationParameters) Reset()
- func (x *OptimizationParameters) String() string
- type OptimizationParameters_Adadelta
- type OptimizationParameters_Adagrad
- type OptimizationParameters_AdagradMomentum
- type OptimizationParameters_Adam
- type OptimizationParameters_Assign
- type OptimizationParameters_BoundedAdagrad
- type OptimizationParameters_CenteredRmsProp
- type OptimizationParameters_FrequencyEstimator
- type OptimizationParameters_Ftrl
- type OptimizationParameters_MdlAdagradLight
- type OptimizationParameters_Momentum
- type OptimizationParameters_OnlineYogi
- type OptimizationParameters_ProximalAdagrad
- type OptimizationParameters_ProximalYogi
- type OptimizationParameters_RmsProp
- type OptimizationParameters_StochasticGradientDescent
- type OptimizationParameters_UserDefinedProgram
- type PaddingMap
- func (*PaddingMap) Descriptor() ([]byte, []int)deprecated
- func (x *PaddingMap) GetArgIndex() int32
- func (x *PaddingMap) GetPaddingArgIndex() int32
- func (x *PaddingMap) GetShapeIndex() int32
- func (*PaddingMap) ProtoMessage()
- func (x *PaddingMap) ProtoReflect() protoreflect.Message
- func (x *PaddingMap) Reset()
- func (x *PaddingMap) String() string
- type ProximalAdagradParameters
- func (*ProximalAdagradParameters) Descriptor() ([]byte, []int)deprecated
- func (x *ProximalAdagradParameters) GetL1() float32
- func (x *ProximalAdagradParameters) GetL2() float32
- func (*ProximalAdagradParameters) ProtoMessage()
- func (x *ProximalAdagradParameters) ProtoReflect() protoreflect.Message
- func (x *ProximalAdagradParameters) Reset()
- func (x *ProximalAdagradParameters) String() string
- type ProximalYogiParameters
- func (*ProximalYogiParameters) Descriptor() ([]byte, []int)deprecated
- func (x *ProximalYogiParameters) GetBeta1() float32
- func (x *ProximalYogiParameters) GetBeta2() float32
- func (x *ProximalYogiParameters) GetEpsilon() float32
- func (x *ProximalYogiParameters) GetL1() float32
- func (x *ProximalYogiParameters) GetL2() float32
- func (*ProximalYogiParameters) ProtoMessage()
- func (x *ProximalYogiParameters) ProtoReflect() protoreflect.Message
- func (x *ProximalYogiParameters) Reset()
- func (x *ProximalYogiParameters) String() string
- type RmsPropParameters
- func (*RmsPropParameters) Descriptor() ([]byte, []int)deprecated
- func (x *RmsPropParameters) GetEpsilon() float32
- func (x *RmsPropParameters) GetMomentum() float32
- func (x *RmsPropParameters) GetRho() float32
- func (*RmsPropParameters) ProtoMessage()
- func (x *RmsPropParameters) ProtoReflect() protoreflect.Message
- func (x *RmsPropParameters) Reset()
- func (x *RmsPropParameters) String() string
- type SimulatedQuantization
- func (*SimulatedQuantization) Descriptor() ([]byte, []int)deprecated
- func (x *SimulatedQuantization) GetClippingLimits() *ClippingLimits
- func (x *SimulatedQuantization) GetEnabled() bool
- func (x *SimulatedQuantization) GetNumBuckets() int32
- func (*SimulatedQuantization) ProtoMessage()
- func (x *SimulatedQuantization) ProtoReflect() protoreflect.Message
- func (x *SimulatedQuantization) Reset()
- func (x *SimulatedQuantization) String() string
- type StateVariableSpecification
- func (*StateVariableSpecification) Descriptor() ([]byte, []int)deprecated
- func (x *StateVariableSpecification) GetFillWithConstant() *StateVariableSpecification_FillWithConstant
- func (x *StateVariableSpecification) GetName() string
- func (m *StateVariableSpecification) GetUsage() isStateVariableSpecification_Usage
- func (x *StateVariableSpecification) GetUserDefined() *StateVariableSpecification_UserDefined
- func (*StateVariableSpecification) ProtoMessage()
- func (x *StateVariableSpecification) ProtoReflect() protoreflect.Message
- func (x *StateVariableSpecification) Reset()
- func (x *StateVariableSpecification) String() string
- type StateVariableSpecification_FillWithConstant
- func (*StateVariableSpecification_FillWithConstant) Descriptor() ([]byte, []int)deprecated
- func (x *StateVariableSpecification_FillWithConstant) GetInitialValue() float64
- func (*StateVariableSpecification_FillWithConstant) ProtoMessage()
- func (x *StateVariableSpecification_FillWithConstant) ProtoReflect() protoreflect.Message
- func (x *StateVariableSpecification_FillWithConstant) Reset()
- func (x *StateVariableSpecification_FillWithConstant) String() string
- type StateVariableSpecification_FillWithConstant_
- type StateVariableSpecification_UserDefined
- func (*StateVariableSpecification_UserDefined) Descriptor() ([]byte, []int)deprecated
- func (*StateVariableSpecification_UserDefined) ProtoMessage()
- func (x *StateVariableSpecification_UserDefined) ProtoReflect() protoreflect.Message
- func (x *StateVariableSpecification_UserDefined) Reset()
- func (x *StateVariableSpecification_UserDefined) String() string
- type StateVariableSpecification_UserDefined_
- type StochasticGradientDescentParameters
- func (*StochasticGradientDescentParameters) Descriptor() ([]byte, []int)deprecated
- func (*StochasticGradientDescentParameters) ProtoMessage()
- func (x *StochasticGradientDescentParameters) ProtoReflect() protoreflect.Message
- func (x *StochasticGradientDescentParameters) Reset()
- func (x *StochasticGradientDescentParameters) String() string
- type TPUCompileMetadataProto
- func (*TPUCompileMetadataProto) Descriptor() ([]byte, []int)deprecated
- func (x *TPUCompileMetadataProto) GetArgs() []*TPUCompileMetadataProto_Arg
- func (x *TPUCompileMetadataProto) GetAutoSpmdMeshIds() []int64
- func (x *TPUCompileMetadataProto) GetAutoSpmdMeshShape() []int64
- func (x *TPUCompileMetadataProto) GetCompileOptions() *TPUCompileOptions
- func (x *TPUCompileMetadataProto) GetDeviceAssignment() *data.DeviceAssignmentProto
- func (x *TPUCompileMetadataProto) GetEnableAutomaticModelParallelism() bool
- func (x *TPUCompileMetadataProto) GetFunctionLibraryFingerprint() uint64
- func (x *TPUCompileMetadataProto) GetGuaranteedConstFingerprint() string
- func (x *TPUCompileMetadataProto) GetMlirFingerprint() uint64
- func (x *TPUCompileMetadataProto) GetNumCoresPerReplica() int32
- func (x *TPUCompileMetadataProto) GetNumReplicas() int32
- func (x *TPUCompileMetadataProto) GetPaddingMaps() []*PaddingMap
- func (x *TPUCompileMetadataProto) GetRetvals() []*TPUCompileMetadataProto_Retval
- func (x *TPUCompileMetadataProto) GetSessionHandle() string
- func (x *TPUCompileMetadataProto) GetStepMarkerLocation() xla.DebugOptions_StepMarkerLocation
- func (x *TPUCompileMetadataProto) GetUseAutoSpmdForXlaPartitioning() bool
- func (x *TPUCompileMetadataProto) GetUseSpmdForXlaPartitioning() bool
- func (x *TPUCompileMetadataProto) GetXlaFusionAutotunerThresh() int64
- func (*TPUCompileMetadataProto) ProtoMessage()
- func (x *TPUCompileMetadataProto) ProtoReflect() protoreflect.Message
- func (x *TPUCompileMetadataProto) Reset()
- func (x *TPUCompileMetadataProto) String() string
- type TPUCompileMetadataProto_Arg
- func (*TPUCompileMetadataProto_Arg) Descriptor() ([]byte, []int)deprecated
- func (x *TPUCompileMetadataProto_Arg) GetDtype() framework.DataType
- func (x *TPUCompileMetadataProto_Arg) GetEnableXlaSharding() TPUCompileMetadataProto_Arg_EnableXlaSharding
- func (x *TPUCompileMetadataProto_Arg) GetFastMem() bool
- func (x *TPUCompileMetadataProto_Arg) GetIsSameDataAcrossReplicas() bool
- func (x *TPUCompileMetadataProto_Arg) GetKind() TPUCompileMetadataProto_Arg_Kind
- func (x *TPUCompileMetadataProto_Arg) GetName() string
- func (x *TPUCompileMetadataProto_Arg) GetRequiresXlaBroadcast() bool
- func (x *TPUCompileMetadataProto_Arg) GetRetvalIndexForSharding() int32
- func (x *TPUCompileMetadataProto_Arg) GetShape() *framework.TensorShapeProto
- func (x *TPUCompileMetadataProto_Arg) GetSharding() *data.OpSharding
- func (x *TPUCompileMetadataProto_Arg) GetUnrestrictedLayout() bool
- func (*TPUCompileMetadataProto_Arg) ProtoMessage()
- func (x *TPUCompileMetadataProto_Arg) ProtoReflect() protoreflect.Message
- func (x *TPUCompileMetadataProto_Arg) Reset()
- func (x *TPUCompileMetadataProto_Arg) String() string
- type TPUCompileMetadataProto_Arg_EnableXlaSharding
- func (TPUCompileMetadataProto_Arg_EnableXlaSharding) Descriptor() protoreflect.EnumDescriptor
- func (x TPUCompileMetadataProto_Arg_EnableXlaSharding) Enum() *TPUCompileMetadataProto_Arg_EnableXlaSharding
- func (TPUCompileMetadataProto_Arg_EnableXlaSharding) EnumDescriptor() ([]byte, []int)deprecated
- func (x TPUCompileMetadataProto_Arg_EnableXlaSharding) Number() protoreflect.EnumNumber
- func (x TPUCompileMetadataProto_Arg_EnableXlaSharding) String() string
- func (TPUCompileMetadataProto_Arg_EnableXlaSharding) Type() protoreflect.EnumType
- type TPUCompileMetadataProto_Arg_Kind
- func (TPUCompileMetadataProto_Arg_Kind) Descriptor() protoreflect.EnumDescriptor
- func (x TPUCompileMetadataProto_Arg_Kind) Enum() *TPUCompileMetadataProto_Arg_Kind
- func (TPUCompileMetadataProto_Arg_Kind) EnumDescriptor() ([]byte, []int)deprecated
- func (x TPUCompileMetadataProto_Arg_Kind) Number() protoreflect.EnumNumber
- func (x TPUCompileMetadataProto_Arg_Kind) String() string
- func (TPUCompileMetadataProto_Arg_Kind) Type() protoreflect.EnumType
- type TPUCompileMetadataProto_Retval
- func (*TPUCompileMetadataProto_Retval) Descriptor() ([]byte, []int)deprecated
- func (x *TPUCompileMetadataProto_Retval) GetSharding() *data.OpSharding
- func (*TPUCompileMetadataProto_Retval) ProtoMessage()
- func (x *TPUCompileMetadataProto_Retval) ProtoReflect() protoreflect.Message
- func (x *TPUCompileMetadataProto_Retval) Reset()
- func (x *TPUCompileMetadataProto_Retval) String() string
- type TPUCompileOptions
- func (*TPUCompileOptions) Descriptor() ([]byte, []int)deprecated
- func (x *TPUCompileOptions) GetMatrixUnitOperandPrecision() TPUCompileOptions_Precision
- func (*TPUCompileOptions) ProtoMessage()
- func (x *TPUCompileOptions) ProtoReflect() protoreflect.Message
- func (x *TPUCompileOptions) Reset()
- func (x *TPUCompileOptions) String() string
- type TPUCompileOptions_Precision
- func (TPUCompileOptions_Precision) Descriptor() protoreflect.EnumDescriptor
- func (x TPUCompileOptions_Precision) Enum() *TPUCompileOptions_Precision
- func (TPUCompileOptions_Precision) EnumDescriptor() ([]byte, []int)deprecated
- func (x TPUCompileOptions_Precision) Number() protoreflect.EnumNumber
- func (x TPUCompileOptions_Precision) String() string
- func (TPUCompileOptions_Precision) Type() protoreflect.EnumType
- type TPUEmbeddingConfiguration
- func (*TPUEmbeddingConfiguration) Descriptor() ([]byte, []int)deprecated
- func (x *TPUEmbeddingConfiguration) GetBatchSizePerTensorCore() int32
- func (x *TPUEmbeddingConfiguration) GetFeatureDescriptor() []*TPUEmbeddingConfiguration_FeatureDescriptor
- func (x *TPUEmbeddingConfiguration) GetMode() TPUEmbeddingConfiguration_Mode
- func (x *TPUEmbeddingConfiguration) GetNumHosts() int32
- func (x *TPUEmbeddingConfiguration) GetNumTensorCores() int32
- func (x *TPUEmbeddingConfiguration) GetPipelineExecutionWithTensorCore() bool
- func (x *TPUEmbeddingConfiguration) GetProfileDataDirectory() string
- func (x *TPUEmbeddingConfiguration) GetShardingStrategy() TPUEmbeddingConfiguration_ShardingStrategy
- func (x *TPUEmbeddingConfiguration) GetSpmdSharding() *TPUEmbeddingConfiguration_SpmdSharding
- func (x *TPUEmbeddingConfiguration) GetTableDescriptor() []*TPUEmbeddingConfiguration_TableDescriptor
- func (*TPUEmbeddingConfiguration) ProtoMessage()
- func (x *TPUEmbeddingConfiguration) ProtoReflect() protoreflect.Message
- func (x *TPUEmbeddingConfiguration) Reset()
- func (x *TPUEmbeddingConfiguration) String() string
- type TPUEmbeddingConfiguration_FeatureDescriptor
- func (*TPUEmbeddingConfiguration_FeatureDescriptor) Descriptor() ([]byte, []int)deprecated
- func (x *TPUEmbeddingConfiguration_FeatureDescriptor) GetInputShape() []int32
- func (x *TPUEmbeddingConfiguration_FeatureDescriptor) GetName() string
- func (x *TPUEmbeddingConfiguration_FeatureDescriptor) GetTableId() int32
- func (*TPUEmbeddingConfiguration_FeatureDescriptor) ProtoMessage()
- func (x *TPUEmbeddingConfiguration_FeatureDescriptor) ProtoReflect() protoreflect.Message
- func (x *TPUEmbeddingConfiguration_FeatureDescriptor) Reset()
- func (x *TPUEmbeddingConfiguration_FeatureDescriptor) String() string
- type TPUEmbeddingConfiguration_Mode
- func (TPUEmbeddingConfiguration_Mode) Descriptor() protoreflect.EnumDescriptor
- func (x TPUEmbeddingConfiguration_Mode) Enum() *TPUEmbeddingConfiguration_Mode
- func (TPUEmbeddingConfiguration_Mode) EnumDescriptor() ([]byte, []int)deprecated
- func (x TPUEmbeddingConfiguration_Mode) Number() protoreflect.EnumNumber
- func (x TPUEmbeddingConfiguration_Mode) String() string
- func (TPUEmbeddingConfiguration_Mode) Type() protoreflect.EnumType
- type TPUEmbeddingConfiguration_ShardingStrategy
- func (TPUEmbeddingConfiguration_ShardingStrategy) Descriptor() protoreflect.EnumDescriptor
- func (x TPUEmbeddingConfiguration_ShardingStrategy) Enum() *TPUEmbeddingConfiguration_ShardingStrategy
- func (TPUEmbeddingConfiguration_ShardingStrategy) EnumDescriptor() ([]byte, []int)deprecated
- func (x TPUEmbeddingConfiguration_ShardingStrategy) Number() protoreflect.EnumNumber
- func (x TPUEmbeddingConfiguration_ShardingStrategy) String() string
- func (TPUEmbeddingConfiguration_ShardingStrategy) Type() protoreflect.EnumType
- type TPUEmbeddingConfiguration_SpmdSharding
- func (*TPUEmbeddingConfiguration_SpmdSharding) Descriptor() ([]byte, []int)deprecated
- func (x *TPUEmbeddingConfiguration_SpmdSharding) GetEnabled() bool
- func (x *TPUEmbeddingConfiguration_SpmdSharding) GetNumCoresPerReplica() int32
- func (*TPUEmbeddingConfiguration_SpmdSharding) ProtoMessage()
- func (x *TPUEmbeddingConfiguration_SpmdSharding) ProtoReflect() protoreflect.Message
- func (x *TPUEmbeddingConfiguration_SpmdSharding) Reset()
- func (x *TPUEmbeddingConfiguration_SpmdSharding) String() string
- type TPUEmbeddingConfiguration_TableDescriptor
- func (*TPUEmbeddingConfiguration_TableDescriptor) Descriptor() ([]byte, []int)deprecated
- func (x *TPUEmbeddingConfiguration_TableDescriptor) GetDimension() int32
- func (x *TPUEmbeddingConfiguration_TableDescriptor) GetName() string
- func (x *TPUEmbeddingConfiguration_TableDescriptor) GetNumFeatures() int32
- func (x *TPUEmbeddingConfiguration_TableDescriptor) GetOptimizationParameters() *OptimizationParameters
- func (x *TPUEmbeddingConfiguration_TableDescriptor) GetVocabularySize() int64
- func (*TPUEmbeddingConfiguration_TableDescriptor) ProtoMessage()
- func (x *TPUEmbeddingConfiguration_TableDescriptor) ProtoReflect() protoreflect.Message
- func (x *TPUEmbeddingConfiguration_TableDescriptor) Reset()
- func (x *TPUEmbeddingConfiguration_TableDescriptor) String() string
- type TPUEmbeddingError
- type TPUHardwareFeature
- func (*TPUHardwareFeature) Descriptor() ([]byte, []int)deprecated
- func (x *TPUHardwareFeature) GetEmbeddingFeature() TPUHardwareFeature_EmbeddingFeature
- func (*TPUHardwareFeature) ProtoMessage()
- func (x *TPUHardwareFeature) ProtoReflect() protoreflect.Message
- func (x *TPUHardwareFeature) Reset()
- func (x *TPUHardwareFeature) String() string
- type TPUHardwareFeature_EmbeddingFeature
- func (TPUHardwareFeature_EmbeddingFeature) Descriptor() protoreflect.EnumDescriptor
- func (x TPUHardwareFeature_EmbeddingFeature) Enum() *TPUHardwareFeature_EmbeddingFeature
- func (TPUHardwareFeature_EmbeddingFeature) EnumDescriptor() ([]byte, []int)deprecated
- func (x TPUHardwareFeature_EmbeddingFeature) Number() protoreflect.EnumNumber
- func (x TPUHardwareFeature_EmbeddingFeature) String() string
- func (TPUHardwareFeature_EmbeddingFeature) Type() protoreflect.EnumType
- type TopologyProto
- func (*TopologyProto) Descriptor() ([]byte, []int)deprecated
- func (x *TopologyProto) GetDeviceCoordinates() []int32
- func (x *TopologyProto) GetMeshShape() []int32
- func (x *TopologyProto) GetNumTasks() int32
- func (x *TopologyProto) GetNumTpuDevicesPerTask() int32
- func (x *TopologyProto) GetTpuHardwareFeature() *TPUHardwareFeature
- func (*TopologyProto) ProtoMessage()
- func (x *TopologyProto) ProtoReflect() protoreflect.Message
- func (x *TopologyProto) Reset()
- func (x *TopologyProto) String() string
- type UserDefinedProgramParameters
- func (*UserDefinedProgramParameters) Descriptor() ([]byte, []int)deprecated
- func (x *UserDefinedProgramParameters) GetProgram() *service.HloModuleProto
- func (*UserDefinedProgramParameters) ProtoMessage()
- func (x *UserDefinedProgramParameters) ProtoReflect() protoreflect.Message
- func (x *UserDefinedProgramParameters) Reset()
- func (x *UserDefinedProgramParameters) String() string
Constants ¶
This section is empty.
Variables ¶
var ( CompilationResultProto_ErrorCode_name = map[int32]string{ 0: "UNKNOWN", 1: "OUT_OF_MEMORY", } CompilationResultProto_ErrorCode_value = map[string]int32{ "UNKNOWN": 0, "OUT_OF_MEMORY": 1, } )
Enum value maps for CompilationResultProto_ErrorCode.
var ( TPUCompileMetadataProto_Arg_Kind_name = map[int32]string{ 0: "INVALID", 1: "PARAMETER", 2: "VARIABLE", 3: "GUARANTEED_CONSTANT", } TPUCompileMetadataProto_Arg_Kind_value = map[string]int32{ "INVALID": 0, "PARAMETER": 1, "VARIABLE": 2, "GUARANTEED_CONSTANT": 3, } )
Enum value maps for TPUCompileMetadataProto_Arg_Kind.
var ( TPUCompileMetadataProto_Arg_EnableXlaSharding_name = map[int32]string{ 0: "DISALLOWED", 1: "TENTATIVE", 2: "ALLOWED", } TPUCompileMetadataProto_Arg_EnableXlaSharding_value = map[string]int32{ "DISALLOWED": 0, "TENTATIVE": 1, "ALLOWED": 2, } )
Enum value maps for TPUCompileMetadataProto_Arg_EnableXlaSharding.
var ( TPUCompileOptions_Precision_name = map[int32]string{ 0: "DEFAULT", 1: "BFLOAT16", 2: "FLOAT32", 3: "TENSOR_FLOAT32", } TPUCompileOptions_Precision_value = map[string]int32{ "DEFAULT": 0, "BFLOAT16": 1, "FLOAT32": 2, "TENSOR_FLOAT32": 3, } )
Enum value maps for TPUCompileOptions_Precision.
var ( GradientAccumulationStatus_Status_name = map[int32]string{ 0: "UNSPECIFIED", 1: "ENABLED", 2: "DISABLED", } GradientAccumulationStatus_Status_value = map[string]int32{ "UNSPECIFIED": 0, "ENABLED": 1, "DISABLED": 2, } )
Enum value maps for GradientAccumulationStatus_Status.
var ( LowDimensionalPackingStatus_Status_name = map[int32]string{ 0: "UNSPECIFIED", 1: "ENABLED", 2: "DISABLED", } LowDimensionalPackingStatus_Status_value = map[string]int32{ "UNSPECIFIED": 0, "ENABLED": 1, "DISABLED": 2, } )
Enum value maps for LowDimensionalPackingStatus_Status.
var ( HotIdReplicationConfiguration_Status_name = map[int32]string{ 0: "UNSPECIFIED", 1: "ENABLED", 2: "DISABLED", } HotIdReplicationConfiguration_Status_value = map[string]int32{ "UNSPECIFIED": 0, "ENABLED": 1, "DISABLED": 2, } )
Enum value maps for HotIdReplicationConfiguration_Status.
var ( TPUHardwareFeature_EmbeddingFeature_name = map[int32]string{ 0: "UNSUPPORTED", 1: "V1", 2: "V2", } TPUHardwareFeature_EmbeddingFeature_value = map[string]int32{ "UNSUPPORTED": 0, "V1": 1, "V2": 2, } )
Enum value maps for TPUHardwareFeature_EmbeddingFeature.
var ( TPUEmbeddingConfiguration_Mode_name = map[int32]string{ 0: "UNSPECIFIED", 1: "INFERENCE", 2: "TRAINING", 3: "BACKWARD_PASS_ONLY", } TPUEmbeddingConfiguration_Mode_value = map[string]int32{ "UNSPECIFIED": 0, "INFERENCE": 1, "TRAINING": 2, "BACKWARD_PASS_ONLY": 3, } )
Enum value maps for TPUEmbeddingConfiguration_Mode.
var ( TPUEmbeddingConfiguration_ShardingStrategy_name = map[int32]string{ 0: "DIV_DEFAULT", 1: "MOD", } TPUEmbeddingConfiguration_ShardingStrategy_value = map[string]int32{ "DIV_DEFAULT": 0, "MOD": 1, } )
Enum value maps for TPUEmbeddingConfiguration_ShardingStrategy.
var File_tensorflow_core_protobuf_tpu_compilation_result_proto protoreflect.FileDescriptor
var File_tensorflow_core_protobuf_tpu_compile_metadata_proto protoreflect.FileDescriptor
var File_tensorflow_core_protobuf_tpu_dynamic_padding_proto protoreflect.FileDescriptor
var File_tensorflow_core_protobuf_tpu_optimization_parameters_proto protoreflect.FileDescriptor
var File_tensorflow_core_protobuf_tpu_topology_proto protoreflect.FileDescriptor
var File_tensorflow_core_protobuf_tpu_tpu_embedding_configuration_proto protoreflect.FileDescriptor
Functions ¶
This section is empty.
Types ¶
type AdadeltaParameters ¶
type AdadeltaParameters struct { Rho float32 `protobuf:"fixed32,1,opt,name=rho,proto3" json:"rho,omitempty"` Epsilon float32 `protobuf:"fixed32,2,opt,name=epsilon,proto3" json:"epsilon,omitempty"` // contains filtered or unexported fields }
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adadelta https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L933
func (*AdadeltaParameters) Descriptor
deprecated
func (*AdadeltaParameters) Descriptor() ([]byte, []int)
Deprecated: Use AdadeltaParameters.ProtoReflect.Descriptor instead.
func (*AdadeltaParameters) GetEpsilon ¶
func (x *AdadeltaParameters) GetEpsilon() float32
func (*AdadeltaParameters) GetRho ¶
func (x *AdadeltaParameters) GetRho() float32
func (*AdadeltaParameters) ProtoMessage ¶
func (*AdadeltaParameters) ProtoMessage()
func (*AdadeltaParameters) ProtoReflect ¶
func (x *AdadeltaParameters) ProtoReflect() protoreflect.Message
func (*AdadeltaParameters) Reset ¶
func (x *AdadeltaParameters) Reset()
func (*AdadeltaParameters) String ¶
func (x *AdadeltaParameters) String() string
type AdagradMomentumParameters ¶
type AdagradMomentumParameters struct { // Moving average parameter for the momentum accumulator. Momentum float32 `protobuf:"fixed32,1,opt,name=momentum,proto3" json:"momentum,omitempty"` // Whether to use the Nesterov variant of momentum. UseNesterov bool `protobuf:"varint,2,opt,name=use_nesterov,json=useNesterov,proto3" json:"use_nesterov,omitempty"` // Exponent for the gradient^2 accumulator. Exponent float32 `protobuf:"fixed32,3,opt,name=exponent,proto3" json:"exponent,omitempty"` // Moving average parameter for the gradient^2 accumulator. Beta2 float32 `protobuf:"fixed32,4,opt,name=beta2,proto3" json:"beta2,omitempty"` // Offset added to the Adagrad accumulator. Epsilon float32 `protobuf:"fixed32,5,opt,name=epsilon,proto3" json:"epsilon,omitempty"` // contains filtered or unexported fields }
This optimizer combines the Adagrad and Momentum update rules. accum(new) = beta2 == 1.0 ?
accum(old) + grad^2 : beta2 * accum(old) + (1 - beta2) * grad^2
accum_with_exponent = (accum(new) + epsilon)^(-1.0 / exponent) mom_accum(new) = momentum * mom_accum(old) + accum_with_exponent update = use_nesterov ?
momentum * mom_accum(new) + accum_with_exponent : mom_accum(new)
var(new) = var(old) - lr * grad * update Algorithm described in https://arxiv.org/abs/2002.11803.
func (*AdagradMomentumParameters) Descriptor
deprecated
func (*AdagradMomentumParameters) Descriptor() ([]byte, []int)
Deprecated: Use AdagradMomentumParameters.ProtoReflect.Descriptor instead.
func (*AdagradMomentumParameters) GetBeta2 ¶
func (x *AdagradMomentumParameters) GetBeta2() float32
func (*AdagradMomentumParameters) GetEpsilon ¶
func (x *AdagradMomentumParameters) GetEpsilon() float32
func (*AdagradMomentumParameters) GetExponent ¶
func (x *AdagradMomentumParameters) GetExponent() float32
func (*AdagradMomentumParameters) GetMomentum ¶
func (x *AdagradMomentumParameters) GetMomentum() float32
func (*AdagradMomentumParameters) GetUseNesterov ¶
func (x *AdagradMomentumParameters) GetUseNesterov() bool
func (*AdagradMomentumParameters) ProtoMessage ¶
func (*AdagradMomentumParameters) ProtoMessage()
func (*AdagradMomentumParameters) ProtoReflect ¶
func (x *AdagradMomentumParameters) ProtoReflect() protoreflect.Message
func (*AdagradMomentumParameters) Reset ¶
func (x *AdagradMomentumParameters) Reset()
func (*AdagradMomentumParameters) String ¶
func (x *AdagradMomentumParameters) String() string
type AdagradParameters ¶
type AdagradParameters struct {
// contains filtered or unexported fields
}
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adagrad https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L1634
func (*AdagradParameters) Descriptor
deprecated
func (*AdagradParameters) Descriptor() ([]byte, []int)
Deprecated: Use AdagradParameters.ProtoReflect.Descriptor instead.
func (*AdagradParameters) ProtoMessage ¶
func (*AdagradParameters) ProtoMessage()
func (*AdagradParameters) ProtoReflect ¶
func (x *AdagradParameters) ProtoReflect() protoreflect.Message
func (*AdagradParameters) Reset ¶
func (x *AdagradParameters) Reset()
func (*AdagradParameters) String ¶
func (x *AdagradParameters) String() string
type AdamParameters ¶
type AdamParameters struct { Beta1 float32 `protobuf:"fixed32,3,opt,name=beta1,proto3" json:"beta1,omitempty"` Beta2 float32 `protobuf:"fixed32,4,opt,name=beta2,proto3" json:"beta2,omitempty"` Epsilon float32 `protobuf:"fixed32,5,opt,name=epsilon,proto3" json:"epsilon,omitempty"` UseNonLazyAdam bool `protobuf:"varint,8,opt,name=use_non_lazy_adam,json=useNonLazyAdam,proto3" json:"use_non_lazy_adam,omitempty"` UseSumInsideSqrt bool `protobuf:"varint,10,opt,name=use_sum_inside_sqrt,json=useSumInsideSqrt,proto3" json:"use_sum_inside_sqrt,omitempty"` // contains filtered or unexported fields }
The Adam optimizer does not implement hyper-parameter update due to hardware limitations; use the dynamic learning rate feature instead, setting the learning rate to: user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) Here, t is the current timestep.
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L32
Note that the code by default implements the lazy version of Adam (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer) unless the use_non_lazy_adam parameter is set, in which case it implements the normal version of Adam that updates all parameters in the embedding table, even for entries that are not used in the current minibatch (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If use_non_lazy_adam is enabled, gradient accumulation is also required to be enabled in order to get correct results; a warning will be printed otherwise (which may change to an error in the future). If use_sum_inside_sqrt is set, the Adam variable update formula will be changed from m / (sqrt(v) + epsilon) to m / sqrt(v + epsilon**2); this option improves the performance of TPU training and is not expected to harm model quality.
func (*AdamParameters) Descriptor
deprecated
func (*AdamParameters) Descriptor() ([]byte, []int)
Deprecated: Use AdamParameters.ProtoReflect.Descriptor instead.
func (*AdamParameters) GetBeta1 ¶
func (x *AdamParameters) GetBeta1() float32
func (*AdamParameters) GetBeta2 ¶
func (x *AdamParameters) GetBeta2() float32
func (*AdamParameters) GetEpsilon ¶
func (x *AdamParameters) GetEpsilon() float32
func (*AdamParameters) GetUseNonLazyAdam ¶
func (x *AdamParameters) GetUseNonLazyAdam() bool
func (*AdamParameters) GetUseSumInsideSqrt ¶
func (x *AdamParameters) GetUseSumInsideSqrt() bool
func (*AdamParameters) ProtoMessage ¶
func (*AdamParameters) ProtoMessage()
func (*AdamParameters) ProtoReflect ¶
func (x *AdamParameters) ProtoReflect() protoreflect.Message
func (*AdamParameters) Reset ¶
func (x *AdamParameters) Reset()
func (*AdamParameters) String ¶
func (x *AdamParameters) String() string
type AssignParameters ¶
type AssignParameters struct {
// contains filtered or unexported fields
}
Optimizer that just sets the variable to the value of the gradient. To be correct, this requires either gradient accumulation (to sum the values of a computed expression across the samples) or to deduplicate IDs within a single host (to assign the value from an arbitrary sample).
func (*AssignParameters) Descriptor
deprecated
func (*AssignParameters) Descriptor() ([]byte, []int)
Deprecated: Use AssignParameters.ProtoReflect.Descriptor instead.
func (*AssignParameters) ProtoMessage ¶
func (*AssignParameters) ProtoMessage()
func (*AssignParameters) ProtoReflect ¶
func (x *AssignParameters) ProtoReflect() protoreflect.Message
func (*AssignParameters) Reset ¶
func (x *AssignParameters) Reset()
func (*AssignParameters) String ¶
func (x *AssignParameters) String() string
type BoundedAdagradParameters ¶
type BoundedAdagradParameters struct { // Whether to use the updated or the old value of the accumulator when // computing the effective learning rate. When update_accumulator_first is set // to True, the updated value of the accumulator is used. UpdateAccumulatorFirst bool `` /* 130-byte string literal not displayed */ // The max_var_update value to use. Set value to 0 (default) to disable using // max_var_update to clip the gradient. MaxVarUpdate float32 `protobuf:"fixed32,2,opt,name=max_var_update,json=maxVarUpdate,proto3" json:"max_var_update,omitempty"` // The maximum value of the accumulator. Set max_accumulator to 0 (default) // to disable using max_accumulator to clip the accumulator. MaxAccumulator float32 `protobuf:"fixed32,3,opt,name=max_accumulator,json=maxAccumulator,proto3" json:"max_accumulator,omitempty"` // contains filtered or unexported fields }
Algorithm in http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
func (*BoundedAdagradParameters) Descriptor
deprecated
func (*BoundedAdagradParameters) Descriptor() ([]byte, []int)
Deprecated: Use BoundedAdagradParameters.ProtoReflect.Descriptor instead.
func (*BoundedAdagradParameters) GetMaxAccumulator ¶
func (x *BoundedAdagradParameters) GetMaxAccumulator() float32
func (*BoundedAdagradParameters) GetMaxVarUpdate ¶
func (x *BoundedAdagradParameters) GetMaxVarUpdate() float32
func (*BoundedAdagradParameters) GetUpdateAccumulatorFirst ¶
func (x *BoundedAdagradParameters) GetUpdateAccumulatorFirst() bool
func (*BoundedAdagradParameters) ProtoMessage ¶
func (*BoundedAdagradParameters) ProtoMessage()
func (*BoundedAdagradParameters) ProtoReflect ¶
func (x *BoundedAdagradParameters) ProtoReflect() protoreflect.Message
func (*BoundedAdagradParameters) Reset ¶
func (x *BoundedAdagradParameters) Reset()
func (*BoundedAdagradParameters) String ¶
func (x *BoundedAdagradParameters) String() string
type CenteredRmsPropParameters ¶
type CenteredRmsPropParameters struct { Rho float32 `protobuf:"fixed32,1,opt,name=rho,proto3" json:"rho,omitempty"` Momentum float32 `protobuf:"fixed32,2,opt,name=momentum,proto3" json:"momentum,omitempty"` Epsilon float32 `protobuf:"fixed32,3,opt,name=epsilon,proto3" json:"epsilon,omitempty"` // contains filtered or unexported fields }
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L4358
func (*CenteredRmsPropParameters) Descriptor
deprecated
func (*CenteredRmsPropParameters) Descriptor() ([]byte, []int)
Deprecated: Use CenteredRmsPropParameters.ProtoReflect.Descriptor instead.
func (*CenteredRmsPropParameters) GetEpsilon ¶
func (x *CenteredRmsPropParameters) GetEpsilon() float32
func (*CenteredRmsPropParameters) GetMomentum ¶
func (x *CenteredRmsPropParameters) GetMomentum() float32
func (*CenteredRmsPropParameters) GetRho ¶
func (x *CenteredRmsPropParameters) GetRho() float32
func (*CenteredRmsPropParameters) ProtoMessage ¶
func (*CenteredRmsPropParameters) ProtoMessage()
func (*CenteredRmsPropParameters) ProtoReflect ¶
func (x *CenteredRmsPropParameters) ProtoReflect() protoreflect.Message
func (*CenteredRmsPropParameters) Reset ¶
func (x *CenteredRmsPropParameters) Reset()
func (*CenteredRmsPropParameters) String ¶
func (x *CenteredRmsPropParameters) String() string
type ClippingLimits ¶
type ClippingLimits struct { Lower *wrapperspb.FloatValue `protobuf:"bytes,1,opt,name=lower,proto3" json:"lower,omitempty"` // -inf if not set Upper *wrapperspb.FloatValue `protobuf:"bytes,2,opt,name=upper,proto3" json:"upper,omitempty"` // +inf if not set // contains filtered or unexported fields }
func (*ClippingLimits) Descriptor
deprecated
func (*ClippingLimits) Descriptor() ([]byte, []int)
Deprecated: Use ClippingLimits.ProtoReflect.Descriptor instead.
func (*ClippingLimits) GetLower ¶
func (x *ClippingLimits) GetLower() *wrapperspb.FloatValue
func (*ClippingLimits) GetUpper ¶
func (x *ClippingLimits) GetUpper() *wrapperspb.FloatValue
func (*ClippingLimits) ProtoMessage ¶
func (*ClippingLimits) ProtoMessage()
func (*ClippingLimits) ProtoReflect ¶
func (x *ClippingLimits) ProtoReflect() protoreflect.Message
func (*ClippingLimits) Reset ¶
func (x *ClippingLimits) Reset()
func (*ClippingLimits) String ¶
func (x *ClippingLimits) String() string
type CompilationResultProto ¶
type CompilationResultProto struct { // The error message, if any, returned during compilation. StatusCode protobuf.Code `protobuf:"varint,1,opt,name=status_code,json=statusCode,proto3,enum=tensorflow.error.Code" json:"status_code,omitempty"` StatusErrorMessage string `protobuf:"bytes,2,opt,name=status_error_message,json=statusErrorMessage,proto3" json:"status_error_message,omitempty"` // HLO proto. HloProtos []*service.HloProto `protobuf:"bytes,3,rep,name=hlo_protos,json=hloProtos,proto3" json:"hlo_protos,omitempty"` ErrorCode CompilationResultProto_ErrorCode `` /* 142-byte string literal not displayed */ // contains filtered or unexported fields }
Describes the result of a TPU compilation. This is also used as TPU compilation result status payload. URI: "type.googleapis.com/tensorflow.tpu.CompilationResultProto"
func (*CompilationResultProto) Descriptor
deprecated
func (*CompilationResultProto) Descriptor() ([]byte, []int)
Deprecated: Use CompilationResultProto.ProtoReflect.Descriptor instead.
func (*CompilationResultProto) GetErrorCode ¶
func (x *CompilationResultProto) GetErrorCode() CompilationResultProto_ErrorCode
func (*CompilationResultProto) GetHloProtos ¶
func (x *CompilationResultProto) GetHloProtos() []*service.HloProto
func (*CompilationResultProto) GetStatusCode ¶
func (x *CompilationResultProto) GetStatusCode() protobuf.Code
func (*CompilationResultProto) GetStatusErrorMessage ¶
func (x *CompilationResultProto) GetStatusErrorMessage() string
func (*CompilationResultProto) ProtoMessage ¶
func (*CompilationResultProto) ProtoMessage()
func (*CompilationResultProto) ProtoReflect ¶
func (x *CompilationResultProto) ProtoReflect() protoreflect.Message
func (*CompilationResultProto) Reset ¶
func (x *CompilationResultProto) Reset()
func (*CompilationResultProto) String ¶
func (x *CompilationResultProto) String() string
type CompilationResultProto_ErrorCode ¶
type CompilationResultProto_ErrorCode int32
const ( CompilationResultProto_UNKNOWN CompilationResultProto_ErrorCode = 0 CompilationResultProto_OUT_OF_MEMORY CompilationResultProto_ErrorCode = 1 )
func (CompilationResultProto_ErrorCode) Descriptor ¶
func (CompilationResultProto_ErrorCode) Descriptor() protoreflect.EnumDescriptor
func (CompilationResultProto_ErrorCode) Enum ¶
func (x CompilationResultProto_ErrorCode) Enum() *CompilationResultProto_ErrorCode
func (CompilationResultProto_ErrorCode) EnumDescriptor
deprecated
func (CompilationResultProto_ErrorCode) EnumDescriptor() ([]byte, []int)
Deprecated: Use CompilationResultProto_ErrorCode.Descriptor instead.
func (CompilationResultProto_ErrorCode) Number ¶
func (x CompilationResultProto_ErrorCode) Number() protoreflect.EnumNumber
func (CompilationResultProto_ErrorCode) String ¶
func (x CompilationResultProto_ErrorCode) String() string
func (CompilationResultProto_ErrorCode) Type ¶
func (CompilationResultProto_ErrorCode) Type() protoreflect.EnumType
type DynamicLearningRate ¶
type DynamicLearningRate struct { // For tables where learning rates are dynamically computed and communicated // to the TPU embedding program, a tag must be specified for the learning // rate. // // The tag must be a non-negative integer. The total number of unique tags // must be less than or equal to the number of tables in the TPU embedding // configuration (a table does not specify any tag if it uses a constant // learning rate, and specifies exactly one tag if it uses dynamic learning // rates). // // All tags in the range [0, number_of_unique_tags) must be present in the TPU // embedding configuration, i.e. a tag cannot be skipped if a different tag // numerically greater than it is used in the configuration. // // If multiple tables specify the same tag, they *MUST* have // the same dynamic learning rate, for example, their dynamic learning rate // could be computed by the same TensorFlow sub-graph. The partitioning of the // embedding layer would be more optimal if the number_of_unique_tags is as // *LOW* as possible, i.e., if many tables share the same tag. // // The learning_rate input of the SendTPUEmbeddingGradients op is used to // communicate dynamic learning rates to the TPU embedding program. // The learning_rate input is a list of scalars where the size of the list is // equal to the number of unique tags. The learning rate associated with a // particular tag is specified by populating its corresponding index in the // list of learning_rate scalars. Tag int32 `protobuf:"varint,1,opt,name=tag,proto3" json:"tag,omitempty"` // contains filtered or unexported fields }
Dynamic learning rate specification in the TPUEmbeddingConfiguration. The actual learning rates are provided as a scalar input list to the SendTPUEmbeddingGradients Op indexed by their tag specified through the following proto.
func (*DynamicLearningRate) Descriptor
deprecated
func (*DynamicLearningRate) Descriptor() ([]byte, []int)
Deprecated: Use DynamicLearningRate.ProtoReflect.Descriptor instead.
func (*DynamicLearningRate) GetTag ¶
func (x *DynamicLearningRate) GetTag() int32
func (*DynamicLearningRate) ProtoMessage ¶
func (*DynamicLearningRate) ProtoMessage()
func (*DynamicLearningRate) ProtoReflect ¶
func (x *DynamicLearningRate) ProtoReflect() protoreflect.Message
func (*DynamicLearningRate) Reset ¶
func (x *DynamicLearningRate) Reset()
func (*DynamicLearningRate) String ¶
func (x *DynamicLearningRate) String() string
type FrequencyEstimatorParameters ¶
type FrequencyEstimatorParameters struct { // Learning rate between (0, 1) that is used to update the array D. Tau float32 `protobuf:"fixed32,1,opt,name=tau,proto3" json:"tau,omitempty"` // Maximum value of delta: difference between the current global step and the // last global step at which the row was sampled. MaxDelta float32 `protobuf:"fixed32,2,opt,name=max_delta,json=maxDelta,proto3" json:"max_delta,omitempty"` // Threshold used to determine whether the current update is an outlier. OutlierThreshold float32 `protobuf:"fixed32,3,opt,name=outlier_threshold,json=outlierThreshold,proto3" json:"outlier_threshold,omitempty"` // The weight exponent used to transform the estimated delta into weights. // The transformation function is: (delta / max_delta) ^ (weight_exponent) WeightExponent float32 `protobuf:"fixed32,4,opt,name=weight_exponent,json=weightExponent,proto3" json:"weight_exponent,omitempty"` // contains filtered or unexported fields }
Estimator for the frequency of updates to a lookup table. It maintains an array (tf.Variable) D, where each element records the average number of global steps between two consecutive batches that hit the corresponding bucket. Once an item with bucket id i is sampled, D[i] is updated by:
D[i] <- D[i] * (1 - tau) + delta[i] * tau,
where tau is a learning rate between 0 and 1 (exclusive), and
delta[i] = current global step - last step i is sampled.
The estimated frequency (sampling rate in a batch) is thus 1 / D[i].
Elements in D are initialized with a large value max_delta. delta[i] will also be capped by this value.
The exact sequence of operations used in the optimizer is shown below. last_hit_step[i] is a tf.Variable that holds the last global step at which i was sampled.
delta = global_step - last_hit_step[i] clipped_delta = min(delta, params.max_delta) is_outlier = (delta >= params.outlier_threshold * D[i]) D[i] <- is_outlier ? clipped_delta : D[i] * (1 - params.tau) + clipped_delta * params.tau last_hit_step[i] <- global_step
func (*FrequencyEstimatorParameters) Descriptor
deprecated
func (*FrequencyEstimatorParameters) Descriptor() ([]byte, []int)
Deprecated: Use FrequencyEstimatorParameters.ProtoReflect.Descriptor instead.
func (*FrequencyEstimatorParameters) GetMaxDelta ¶
func (x *FrequencyEstimatorParameters) GetMaxDelta() float32
func (*FrequencyEstimatorParameters) GetOutlierThreshold ¶
func (x *FrequencyEstimatorParameters) GetOutlierThreshold() float32
func (*FrequencyEstimatorParameters) GetTau ¶
func (x *FrequencyEstimatorParameters) GetTau() float32
func (*FrequencyEstimatorParameters) GetWeightExponent ¶
func (x *FrequencyEstimatorParameters) GetWeightExponent() float32
func (*FrequencyEstimatorParameters) ProtoMessage ¶
func (*FrequencyEstimatorParameters) ProtoMessage()
func (*FrequencyEstimatorParameters) ProtoReflect ¶
func (x *FrequencyEstimatorParameters) ProtoReflect() protoreflect.Message
func (*FrequencyEstimatorParameters) Reset ¶
func (x *FrequencyEstimatorParameters) Reset()
func (*FrequencyEstimatorParameters) String ¶
func (x *FrequencyEstimatorParameters) String() string
type FtrlParameters ¶
type FtrlParameters struct { L1 float32 `protobuf:"fixed32,1,opt,name=l1,proto3" json:"l1,omitempty"` L2 float32 `protobuf:"fixed32,2,opt,name=l2,proto3" json:"l2,omitempty"` LrPower float32 `protobuf:"fixed32,3,opt,name=lr_power,json=lrPower,proto3" json:"lr_power,omitempty"` Beta float32 `protobuf:"fixed32,7,opt,name=beta,proto3" json:"beta,omitempty"` MultiplyLinearByLr bool `protobuf:"varint,6,opt,name=multiply_linear_by_lr,json=multiplyLinearByLr,proto3" json:"multiply_linear_by_lr,omitempty"` // Previously, allow_zero_accumulator parameter changed some internal formulas // to allow zero and near-zero accumulator values at the cost of some // performance. The current implementation ignores this parameter; zero or // near-zero accumulator values are now always supported. // // Deprecated: Marked as deprecated in tensorflow/core/protobuf/tpu/optimization_parameters.proto. AllowZeroAccumulator bool `protobuf:"varint,8,opt,name=allow_zero_accumulator,json=allowZeroAccumulator,proto3" json:"allow_zero_accumulator,omitempty"` // contains filtered or unexported fields }
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L2646
The hyperparameters for FTRL are the same as for the Keras implementation, with some additions. The "beta" parameter matches the behavior described in the second link above; "beta" / (2 * learning rate) should be added to "l2" to get equivalent behavior in the other TensorFlow implementations of this optimizer. When the multiply_linear_by_lr field is set to true, a modified formula is used for FTRL that treats the "linear" accumulator as being pre-multiplied by the learning rate (i.e., the accumulator named "linear" actually stores "linear * learning_rate"). Other than checkpoint compatibility, this is mathematically equivalent for a static learning rate; for a dynamic learning rate, it is nearly the same as long as the learning rate does not change quickly. The benefit of setting multiply_linear_by_lr to true is that the modified formula handles zero and near-zero learning rates without producing NaNs, improving flexibility for learning rate ramp-up.
func (*FtrlParameters) Descriptor
deprecated
func (*FtrlParameters) Descriptor() ([]byte, []int)
Deprecated: Use FtrlParameters.ProtoReflect.Descriptor instead.
func (*FtrlParameters) GetAllowZeroAccumulator
deprecated
func (x *FtrlParameters) GetAllowZeroAccumulator() bool
Deprecated: Marked as deprecated in tensorflow/core/protobuf/tpu/optimization_parameters.proto.
func (*FtrlParameters) GetBeta ¶
func (x *FtrlParameters) GetBeta() float32
func (*FtrlParameters) GetL1 ¶
func (x *FtrlParameters) GetL1() float32
func (*FtrlParameters) GetL2 ¶
func (x *FtrlParameters) GetL2() float32
func (*FtrlParameters) GetLrPower ¶
func (x *FtrlParameters) GetLrPower() float32
func (*FtrlParameters) GetMultiplyLinearByLr ¶
func (x *FtrlParameters) GetMultiplyLinearByLr() bool
func (*FtrlParameters) ProtoMessage ¶
func (*FtrlParameters) ProtoMessage()
func (*FtrlParameters) ProtoReflect ¶
func (x *FtrlParameters) ProtoReflect() protoreflect.Message
func (*FtrlParameters) Reset ¶
func (x *FtrlParameters) Reset()
func (*FtrlParameters) String ¶
func (x *FtrlParameters) String() string
type GradientAccumulationStatus ¶
type GradientAccumulationStatus struct {
// contains filtered or unexported fields
}
Status of using gradient accumulation (doing two passes over the input gradients: one to accumulate them into a temporary array and another to apply them using the actual optimization algorithm). The extra message is to wrap the enum for scoping.
func (*GradientAccumulationStatus) Descriptor
deprecated
func (*GradientAccumulationStatus) Descriptor() ([]byte, []int)
Deprecated: Use GradientAccumulationStatus.ProtoReflect.Descriptor instead.
func (*GradientAccumulationStatus) ProtoMessage ¶
func (*GradientAccumulationStatus) ProtoMessage()
func (*GradientAccumulationStatus) ProtoReflect ¶
func (x *GradientAccumulationStatus) ProtoReflect() protoreflect.Message
func (*GradientAccumulationStatus) Reset ¶
func (x *GradientAccumulationStatus) Reset()
func (*GradientAccumulationStatus) String ¶
func (x *GradientAccumulationStatus) String() string
type GradientAccumulationStatus_Status ¶
type GradientAccumulationStatus_Status int32
if UNSPECIFIED (default), gradient accumulation is ENABLED.
const ( GradientAccumulationStatus_UNSPECIFIED GradientAccumulationStatus_Status = 0 GradientAccumulationStatus_ENABLED GradientAccumulationStatus_Status = 1 GradientAccumulationStatus_DISABLED GradientAccumulationStatus_Status = 2 )
func (GradientAccumulationStatus_Status) Descriptor ¶
func (GradientAccumulationStatus_Status) Descriptor() protoreflect.EnumDescriptor
func (GradientAccumulationStatus_Status) Enum ¶
func (x GradientAccumulationStatus_Status) Enum() *GradientAccumulationStatus_Status
func (GradientAccumulationStatus_Status) EnumDescriptor
deprecated
func (GradientAccumulationStatus_Status) EnumDescriptor() ([]byte, []int)
Deprecated: Use GradientAccumulationStatus_Status.Descriptor instead.
func (GradientAccumulationStatus_Status) Number ¶
func (x GradientAccumulationStatus_Status) Number() protoreflect.EnumNumber
func (GradientAccumulationStatus_Status) String ¶
func (x GradientAccumulationStatus_Status) String() string
func (GradientAccumulationStatus_Status) Type ¶
func (GradientAccumulationStatus_Status) Type() protoreflect.EnumType
type HotIdReplicationConfiguration ¶
type HotIdReplicationConfiguration struct { Status HotIdReplicationConfiguration_Status `protobuf:"varint,1,opt,name=status,proto3,enum=tensorflow.tpu.HotIdReplicationConfiguration_Status" json:"status,omitempty"` // contains filtered or unexported fields }
Configuration proto for hot ID optimization. This is an experimental feature that is currently disabled (by default).
func (*HotIdReplicationConfiguration) Descriptor
deprecated
func (*HotIdReplicationConfiguration) Descriptor() ([]byte, []int)
Deprecated: Use HotIdReplicationConfiguration.ProtoReflect.Descriptor instead.
func (*HotIdReplicationConfiguration) GetStatus ¶
func (x *HotIdReplicationConfiguration) GetStatus() HotIdReplicationConfiguration_Status
func (*HotIdReplicationConfiguration) ProtoMessage ¶
func (*HotIdReplicationConfiguration) ProtoMessage()
func (*HotIdReplicationConfiguration) ProtoReflect ¶
func (x *HotIdReplicationConfiguration) ProtoReflect() protoreflect.Message
func (*HotIdReplicationConfiguration) Reset ¶
func (x *HotIdReplicationConfiguration) Reset()
func (*HotIdReplicationConfiguration) String ¶
func (x *HotIdReplicationConfiguration) String() string
type HotIdReplicationConfiguration_Status ¶
type HotIdReplicationConfiguration_Status int32
Whether to enable or disable hot ID optimization. If UNSPECIFIED (default), hot ID optimization is DISABLED.
const ( HotIdReplicationConfiguration_UNSPECIFIED HotIdReplicationConfiguration_Status = 0 HotIdReplicationConfiguration_ENABLED HotIdReplicationConfiguration_Status = 1 HotIdReplicationConfiguration_DISABLED HotIdReplicationConfiguration_Status = 2 )
func (HotIdReplicationConfiguration_Status) Descriptor ¶
func (HotIdReplicationConfiguration_Status) Descriptor() protoreflect.EnumDescriptor
func (HotIdReplicationConfiguration_Status) Enum ¶
func (x HotIdReplicationConfiguration_Status) Enum() *HotIdReplicationConfiguration_Status
func (HotIdReplicationConfiguration_Status) EnumDescriptor
deprecated
func (HotIdReplicationConfiguration_Status) EnumDescriptor() ([]byte, []int)
Deprecated: Use HotIdReplicationConfiguration_Status.Descriptor instead.
func (HotIdReplicationConfiguration_Status) Number ¶
func (x HotIdReplicationConfiguration_Status) Number() protoreflect.EnumNumber
func (HotIdReplicationConfiguration_Status) String ¶
func (x HotIdReplicationConfiguration_Status) String() string
func (HotIdReplicationConfiguration_Status) Type ¶
func (HotIdReplicationConfiguration_Status) Type() protoreflect.EnumType
type LearningRate ¶
type LearningRate struct { // Types that are assignable to LearningRate: // // *LearningRate_Constant // *LearningRate_Dynamic LearningRate isLearningRate_LearningRate `protobuf_oneof:"learning_rate"` // contains filtered or unexported fields }
Source of learning rate to use.
func (*LearningRate) Descriptor
deprecated
func (*LearningRate) Descriptor() ([]byte, []int)
Deprecated: Use LearningRate.ProtoReflect.Descriptor instead.
func (*LearningRate) GetConstant ¶
func (x *LearningRate) GetConstant() float32
func (*LearningRate) GetDynamic ¶
func (x *LearningRate) GetDynamic() *DynamicLearningRate
func (*LearningRate) GetLearningRate ¶
func (m *LearningRate) GetLearningRate() isLearningRate_LearningRate
func (*LearningRate) ProtoMessage ¶
func (*LearningRate) ProtoMessage()
func (*LearningRate) ProtoReflect ¶
func (x *LearningRate) ProtoReflect() protoreflect.Message
func (*LearningRate) Reset ¶
func (x *LearningRate) Reset()
func (*LearningRate) String ¶
func (x *LearningRate) String() string
type LearningRate_Constant ¶
type LearningRate_Constant struct {
Constant float32 `protobuf:"fixed32,1,opt,name=constant,proto3,oneof"`
}
type LearningRate_Dynamic ¶
type LearningRate_Dynamic struct {
Dynamic *DynamicLearningRate `protobuf:"bytes,2,opt,name=dynamic,proto3,oneof"`
}
type LowDimensionalPackingStatus ¶
type LowDimensionalPackingStatus struct {
// contains filtered or unexported fields
}
There is one important limitation for this HBM packing though. When only a subset of rows in an 8-float chunk are accessed on a particular step, the adjoining rows in the same chunk are updated with zero gradients on the backward pass even if they are not touched. This is an artifact of the packing implementation. This operation is NOT functionally correct for optimizers where zero gradients change the embeddings/slot-variable values, e.g., momentum-based optimizers. Hence, this HBM packing cannot be enabled for embedding tables with such optimizers. The TPU software automatically recognizes that a zero gradient can modify state and turns off the low dimensional embedding packing in that scenario.
However, for optimizers where a zero gradient is a NoOp, such as SGD, Adagrad, and FTRL, this packing optimization can be used. However, there are some important considerations:
- Clipping limits: The initial values for such embeddings should fall within the clipping limits specified in the optimization parameters. Otherwise, a zero gradient will cause the embeddings to be clipped. This changes state and hence, is not a NoOp.
- FTRL: The embedding vector is computed directly from the values of the accumulator and linear slot variables. Hence, the initial embedding values should match that computed from the initial values of the accumulator and linear slot variables. Note that in nearly all cases, the linear value is initialized to zero; this corresponds to an embedding value of zero.
Performance: The TPU has to perform additional work when low dimensional packing is enabled. In certain situations when the vocabulary size is small, it may not make sense to turn on this packing since the total memory usage due to padding is extremely low. Hence, the TPU software automatically turns off the packing optimization in such scenarios.
func (*LowDimensionalPackingStatus) Descriptor
deprecated
func (*LowDimensionalPackingStatus) Descriptor() ([]byte, []int)
Deprecated: Use LowDimensionalPackingStatus.ProtoReflect.Descriptor instead.
func (*LowDimensionalPackingStatus) ProtoMessage ¶
func (*LowDimensionalPackingStatus) ProtoMessage()
func (*LowDimensionalPackingStatus) ProtoReflect ¶
func (x *LowDimensionalPackingStatus) ProtoReflect() protoreflect.Message
func (*LowDimensionalPackingStatus) Reset ¶
func (x *LowDimensionalPackingStatus) Reset()
func (*LowDimensionalPackingStatus) String ¶
func (x *LowDimensionalPackingStatus) String() string
type LowDimensionalPackingStatus_Status ¶
type LowDimensionalPackingStatus_Status int32
if UNSPECIFIED (default), the low dimension packing status is DISABLED. This can change in future.
if ENABLED, the low dimension packing is enabled only if the following three additional conditions are true:
- The optimizer treats the zero gradient as a NoOp.
- The embedding dimension is 1, 2, or 4.
- The vocabulary size is large enough to avoid performance issues.
if DISABLED, the low dimension packing is always disabled.
const ( LowDimensionalPackingStatus_UNSPECIFIED LowDimensionalPackingStatus_Status = 0 LowDimensionalPackingStatus_ENABLED LowDimensionalPackingStatus_Status = 1 LowDimensionalPackingStatus_DISABLED LowDimensionalPackingStatus_Status = 2 )
func (LowDimensionalPackingStatus_Status) Descriptor ¶
func (LowDimensionalPackingStatus_Status) Descriptor() protoreflect.EnumDescriptor
func (LowDimensionalPackingStatus_Status) Enum ¶
func (x LowDimensionalPackingStatus_Status) Enum() *LowDimensionalPackingStatus_Status
func (LowDimensionalPackingStatus_Status) EnumDescriptor
deprecated
func (LowDimensionalPackingStatus_Status) EnumDescriptor() ([]byte, []int)
Deprecated: Use LowDimensionalPackingStatus_Status.Descriptor instead.
func (LowDimensionalPackingStatus_Status) Number ¶
func (x LowDimensionalPackingStatus_Status) Number() protoreflect.EnumNumber
func (LowDimensionalPackingStatus_Status) String ¶
func (x LowDimensionalPackingStatus_Status) String() string
func (LowDimensionalPackingStatus_Status) Type ¶
func (LowDimensionalPackingStatus_Status) Type() protoreflect.EnumType
type MdlAdagradLightParameters ¶
type MdlAdagradLightParameters struct { L2 float32 `protobuf:"fixed32,1,opt,name=l2,proto3" json:"l2,omitempty"` LrPower float32 `protobuf:"fixed32,2,opt,name=lr_power,json=lrPower,proto3" json:"lr_power,omitempty"` MinServableMdlBenefit float32 `` /* 130-byte string literal not displayed */ MdlMixInMargin float32 `protobuf:"fixed32,4,opt,name=mdl_mix_in_margin,json=mdlMixInMargin,proto3" json:"mdl_mix_in_margin,omitempty"` MdlBenefitRampupCoeff float32 `` /* 130-byte string literal not displayed */ MdlMinWeight float32 `protobuf:"fixed32,6,opt,name=mdl_min_weight,json=mdlMinWeight,proto3" json:"mdl_min_weight,omitempty"` BenefitRevisitScale float32 `protobuf:"fixed32,7,opt,name=benefit_revisit_scale,json=benefitRevisitScale,proto3" json:"benefit_revisit_scale,omitempty"` MaxEventBenefit float32 `protobuf:"fixed32,8,opt,name=max_event_benefit,json=maxEventBenefit,proto3" json:"max_event_benefit,omitempty"` MaxTotalBenefit float32 `protobuf:"fixed32,9,opt,name=max_total_benefit,json=maxTotalBenefit,proto3" json:"max_total_benefit,omitempty"` MdlHardLimit float32 `protobuf:"fixed32,10,opt,name=mdl_hard_limit,json=mdlHardLimit,proto3" json:"mdl_hard_limit,omitempty"` HardLimitMinBenefit bool `protobuf:"varint,11,opt,name=hard_limit_min_benefit,json=hardLimitMinBenefit,proto3" json:"hard_limit_min_benefit,omitempty"` MdlRegularize bool `protobuf:"varint,12,opt,name=mdl_regularize,json=mdlRegularize,proto3" json:"mdl_regularize,omitempty"` // contains filtered or unexported fields }
Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf
func (*MdlAdagradLightParameters) Descriptor
deprecated
func (*MdlAdagradLightParameters) Descriptor() ([]byte, []int)
Deprecated: Use MdlAdagradLightParameters.ProtoReflect.Descriptor instead.
func (*MdlAdagradLightParameters) GetBenefitRevisitScale ¶
func (x *MdlAdagradLightParameters) GetBenefitRevisitScale() float32
func (*MdlAdagradLightParameters) GetHardLimitMinBenefit ¶
func (x *MdlAdagradLightParameters) GetHardLimitMinBenefit() bool
func (*MdlAdagradLightParameters) GetL2 ¶
func (x *MdlAdagradLightParameters) GetL2() float32
func (*MdlAdagradLightParameters) GetLrPower ¶
func (x *MdlAdagradLightParameters) GetLrPower() float32
func (*MdlAdagradLightParameters) GetMaxEventBenefit ¶
func (x *MdlAdagradLightParameters) GetMaxEventBenefit() float32
func (*MdlAdagradLightParameters) GetMaxTotalBenefit ¶
func (x *MdlAdagradLightParameters) GetMaxTotalBenefit() float32
func (*MdlAdagradLightParameters) GetMdlBenefitRampupCoeff ¶
func (x *MdlAdagradLightParameters) GetMdlBenefitRampupCoeff() float32
func (*MdlAdagradLightParameters) GetMdlHardLimit ¶
func (x *MdlAdagradLightParameters) GetMdlHardLimit() float32
func (*MdlAdagradLightParameters) GetMdlMinWeight ¶
func (x *MdlAdagradLightParameters) GetMdlMinWeight() float32
func (*MdlAdagradLightParameters) GetMdlMixInMargin ¶
func (x *MdlAdagradLightParameters) GetMdlMixInMargin() float32
func (*MdlAdagradLightParameters) GetMdlRegularize ¶
func (x *MdlAdagradLightParameters) GetMdlRegularize() bool
func (*MdlAdagradLightParameters) GetMinServableMdlBenefit ¶
func (x *MdlAdagradLightParameters) GetMinServableMdlBenefit() float32
func (*MdlAdagradLightParameters) ProtoMessage ¶
func (*MdlAdagradLightParameters) ProtoMessage()
func (*MdlAdagradLightParameters) ProtoReflect ¶
func (x *MdlAdagradLightParameters) ProtoReflect() protoreflect.Message
func (*MdlAdagradLightParameters) Reset ¶
func (x *MdlAdagradLightParameters) Reset()
func (*MdlAdagradLightParameters) String ¶
func (x *MdlAdagradLightParameters) String() string
type MomentumParameters ¶
type MomentumParameters struct { Momentum float32 `protobuf:"fixed32,1,opt,name=momentum,proto3" json:"momentum,omitempty"` UseNesterov bool `protobuf:"varint,2,opt,name=use_nesterov,json=useNesterov,proto3" json:"use_nesterov,omitempty"` // contains filtered or unexported fields }
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L3068
func (*MomentumParameters) Descriptor
deprecated
func (*MomentumParameters) Descriptor() ([]byte, []int)
Deprecated: Use MomentumParameters.ProtoReflect.Descriptor instead.
func (*MomentumParameters) GetMomentum ¶
func (x *MomentumParameters) GetMomentum() float32
func (*MomentumParameters) GetUseNesterov ¶
func (x *MomentumParameters) GetUseNesterov() bool
func (*MomentumParameters) ProtoMessage ¶
func (*MomentumParameters) ProtoMessage()
func (*MomentumParameters) ProtoReflect ¶
func (x *MomentumParameters) ProtoReflect() protoreflect.Message
func (*MomentumParameters) Reset ¶
func (x *MomentumParameters) Reset()
func (*MomentumParameters) String ¶
func (x *MomentumParameters) String() string
type OnlineYogiParameters ¶
type OnlineYogiParameters struct { // The L1 regularization parameter (used analogously to the one in FTRL). L1 float32 `protobuf:"fixed32,1,opt,name=l1,proto3" json:"l1,omitempty"` // The L2 regularization parameter (used analogously to the one in FTRL). L2 float32 `protobuf:"fixed32,2,opt,name=l2,proto3" json:"l2,omitempty"` // \beta_2 from Algorithm 2 in the paper. Beta2 float32 `protobuf:"fixed32,3,opt,name=beta2,proto3" json:"beta2,omitempty"` // contains filtered or unexported fields }
The online Yogi optimizer does not implement hyper-parameter update; use the dynamic learning rate feature instead, setting the learning rate to: user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) Here, t is the current timestep.
https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf plus some extensions based on FTRL.
Note that the code by default implements the lazy version of online Yogi.
func (*OnlineYogiParameters) Descriptor
deprecated
func (*OnlineYogiParameters) Descriptor() ([]byte, []int)
Deprecated: Use OnlineYogiParameters.ProtoReflect.Descriptor instead.
func (*OnlineYogiParameters) GetBeta2 ¶
func (x *OnlineYogiParameters) GetBeta2() float32
func (*OnlineYogiParameters) GetL1 ¶
func (x *OnlineYogiParameters) GetL1() float32
func (*OnlineYogiParameters) GetL2 ¶
func (x *OnlineYogiParameters) GetL2() float32
func (*OnlineYogiParameters) ProtoMessage ¶
func (*OnlineYogiParameters) ProtoMessage()
func (*OnlineYogiParameters) ProtoReflect ¶
func (x *OnlineYogiParameters) ProtoReflect() protoreflect.Message
func (*OnlineYogiParameters) Reset ¶
func (x *OnlineYogiParameters) Reset()
func (*OnlineYogiParameters) String ¶
func (x *OnlineYogiParameters) String() string
type OptimizationParameters ¶
type OptimizationParameters struct { // Learning rate used for updating the embedding layer parameters. LearningRate *LearningRate `protobuf:"bytes,13,opt,name=learning_rate,json=learningRate,proto3" json:"learning_rate,omitempty"` // Limits to which to clip the weight values after the backward pass; not // present means no limits are applied. ClippingLimits *ClippingLimits `protobuf:"bytes,2,opt,name=clipping_limits,json=clippingLimits,proto3" json:"clipping_limits,omitempty"` // Limits to which to clip the backward pass gradient before using it for // updates; not present means no limits are applied. GradientClippingLimits *ClippingLimits `` /* 129-byte string literal not displayed */ // Amount of weight decay to apply; see weight_decay_optimizers.py for // details. All optimizers except MDL Adagrad Light are supported with this // option. Although there is no check, users who want weight decay will also // want to ensure that gradient accumulation is enabled so that the decay will // happen once per global batch. WeightDecayFactor float32 `protobuf:"fixed32,16,opt,name=weight_decay_factor,json=weightDecayFactor,proto3" json:"weight_decay_factor,omitempty"` // If true, the weight decay factor is multiplied by the current learning rate // before use; this is to match the note in DecoupledWeightDecayExtension in // weight_decay_optimizers.py. MultiplyWeightDecayFactorByLearningRate bool `` /* 190-byte string literal not displayed */ // Configuration for simulated quantization which is used to reduce // training/serving skew when the serving variables are quantized. The same // quantization operations are executed during training to minimize // differences with serving. SimulatedQuantization *SimulatedQuantization `protobuf:"bytes,27,opt,name=simulated_quantization,json=simulatedQuantization,proto3" json:"simulated_quantization,omitempty"` // Status of using gradient accumulation (doing two passes over the input // gradients: one to accumulate them into a temporary array and another to // apply them using the actual optimization algorithm). GradientAccumulationStatus GradientAccumulationStatus_Status `` /* 197-byte string literal not displayed */ // Status of the low-dimensional embedding packing optimization. This controls // whether to optimize the packing of 1-dimensional, 2-dimensional, and // 4-dimensional embedding tables in memory. LowDimensionalPackingStatus LowDimensionalPackingStatus_Status `` /* 203-byte string literal not displayed */ // Configuration proto for hot ID replication. This is an experimental // feature that is currently disabled (by default). HotIdReplicationConfiguration *HotIdReplicationConfiguration `` /* 153-byte string literal not displayed */ // Optimization algorithm parameters; which field is selected determines which // algorithm to use. // // Types that are assignable to Parameters: // // *OptimizationParameters_Adagrad // *OptimizationParameters_AdagradMomentum // *OptimizationParameters_BoundedAdagrad // *OptimizationParameters_StochasticGradientDescent // *OptimizationParameters_Ftrl // *OptimizationParameters_Adam // *OptimizationParameters_Momentum // *OptimizationParameters_RmsProp // *OptimizationParameters_CenteredRmsProp // *OptimizationParameters_MdlAdagradLight // *OptimizationParameters_Adadelta // *OptimizationParameters_ProximalAdagrad // *OptimizationParameters_OnlineYogi // *OptimizationParameters_ProximalYogi // *OptimizationParameters_FrequencyEstimator // *OptimizationParameters_UserDefinedProgram // *OptimizationParameters_Assign Parameters isOptimizationParameters_Parameters `protobuf_oneof:"parameters"` // contains filtered or unexported fields }
func (*OptimizationParameters) Descriptor
deprecated
func (*OptimizationParameters) Descriptor() ([]byte, []int)
Deprecated: Use OptimizationParameters.ProtoReflect.Descriptor instead.
func (*OptimizationParameters) GetAdadelta ¶
func (x *OptimizationParameters) GetAdadelta() *AdadeltaParameters
func (*OptimizationParameters) GetAdagrad ¶
func (x *OptimizationParameters) GetAdagrad() *AdagradParameters
func (*OptimizationParameters) GetAdagradMomentum ¶
func (x *OptimizationParameters) GetAdagradMomentum() *AdagradMomentumParameters
func (*OptimizationParameters) GetAdam ¶
func (x *OptimizationParameters) GetAdam() *AdamParameters
func (*OptimizationParameters) GetAssign ¶
func (x *OptimizationParameters) GetAssign() *AssignParameters
func (*OptimizationParameters) GetBoundedAdagrad ¶
func (x *OptimizationParameters) GetBoundedAdagrad() *BoundedAdagradParameters
func (*OptimizationParameters) GetCenteredRmsProp ¶
func (x *OptimizationParameters) GetCenteredRmsProp() *CenteredRmsPropParameters
func (*OptimizationParameters) GetClippingLimits ¶
func (x *OptimizationParameters) GetClippingLimits() *ClippingLimits
func (*OptimizationParameters) GetFrequencyEstimator ¶
func (x *OptimizationParameters) GetFrequencyEstimator() *FrequencyEstimatorParameters
func (*OptimizationParameters) GetFtrl ¶
func (x *OptimizationParameters) GetFtrl() *FtrlParameters
func (*OptimizationParameters) GetGradientAccumulationStatus ¶
func (x *OptimizationParameters) GetGradientAccumulationStatus() GradientAccumulationStatus_Status
func (*OptimizationParameters) GetGradientClippingLimits ¶
func (x *OptimizationParameters) GetGradientClippingLimits() *ClippingLimits
func (*OptimizationParameters) GetHotIdReplicationConfiguration ¶
func (x *OptimizationParameters) GetHotIdReplicationConfiguration() *HotIdReplicationConfiguration
func (*OptimizationParameters) GetLearningRate ¶
func (x *OptimizationParameters) GetLearningRate() *LearningRate
func (*OptimizationParameters) GetLowDimensionalPackingStatus ¶
func (x *OptimizationParameters) GetLowDimensionalPackingStatus() LowDimensionalPackingStatus_Status
func (*OptimizationParameters) GetMdlAdagradLight ¶
func (x *OptimizationParameters) GetMdlAdagradLight() *MdlAdagradLightParameters
func (*OptimizationParameters) GetMomentum ¶
func (x *OptimizationParameters) GetMomentum() *MomentumParameters
func (*OptimizationParameters) GetMultiplyWeightDecayFactorByLearningRate ¶
func (x *OptimizationParameters) GetMultiplyWeightDecayFactorByLearningRate() bool
func (*OptimizationParameters) GetOnlineYogi ¶
func (x *OptimizationParameters) GetOnlineYogi() *OnlineYogiParameters
func (*OptimizationParameters) GetParameters ¶
func (m *OptimizationParameters) GetParameters() isOptimizationParameters_Parameters
func (*OptimizationParameters) GetProximalAdagrad ¶
func (x *OptimizationParameters) GetProximalAdagrad() *ProximalAdagradParameters
func (*OptimizationParameters) GetProximalYogi ¶
func (x *OptimizationParameters) GetProximalYogi() *ProximalYogiParameters
func (*OptimizationParameters) GetRmsProp ¶
func (x *OptimizationParameters) GetRmsProp() *RmsPropParameters
func (*OptimizationParameters) GetSimulatedQuantization ¶
func (x *OptimizationParameters) GetSimulatedQuantization() *SimulatedQuantization
func (*OptimizationParameters) GetStochasticGradientDescent ¶
func (x *OptimizationParameters) GetStochasticGradientDescent() *StochasticGradientDescentParameters
func (*OptimizationParameters) GetUserDefinedProgram ¶
func (x *OptimizationParameters) GetUserDefinedProgram() *UserDefinedProgramParameters
func (*OptimizationParameters) GetWeightDecayFactor ¶
func (x *OptimizationParameters) GetWeightDecayFactor() float32
func (*OptimizationParameters) ProtoMessage ¶
func (*OptimizationParameters) ProtoMessage()
func (*OptimizationParameters) ProtoReflect ¶
func (x *OptimizationParameters) ProtoReflect() protoreflect.Message
func (*OptimizationParameters) Reset ¶
func (x *OptimizationParameters) Reset()
func (*OptimizationParameters) String ¶
func (x *OptimizationParameters) String() string
type OptimizationParameters_Adadelta ¶
type OptimizationParameters_Adadelta struct {
Adadelta *AdadeltaParameters `protobuf:"bytes,12,opt,name=adadelta,proto3,oneof"`
}
type OptimizationParameters_Adagrad ¶
type OptimizationParameters_Adagrad struct {
Adagrad *AdagradParameters `protobuf:"bytes,3,opt,name=adagrad,proto3,oneof"`
}
type OptimizationParameters_AdagradMomentum ¶
type OptimizationParameters_AdagradMomentum struct {
AdagradMomentum *AdagradMomentumParameters `protobuf:"bytes,26,opt,name=adagrad_momentum,json=adagradMomentum,proto3,oneof"`
}
type OptimizationParameters_Adam ¶
type OptimizationParameters_Adam struct {
Adam *AdamParameters `protobuf:"bytes,6,opt,name=adam,proto3,oneof"`
}
type OptimizationParameters_Assign ¶
type OptimizationParameters_Assign struct {
Assign *AssignParameters `protobuf:"bytes,25,opt,name=assign,proto3,oneof"`
}
type OptimizationParameters_BoundedAdagrad ¶
type OptimizationParameters_BoundedAdagrad struct {
BoundedAdagrad *BoundedAdagradParameters `protobuf:"bytes,19,opt,name=bounded_adagrad,json=boundedAdagrad,proto3,oneof"`
}
type OptimizationParameters_CenteredRmsProp ¶
type OptimizationParameters_CenteredRmsProp struct {
CenteredRmsProp *CenteredRmsPropParameters `protobuf:"bytes,10,opt,name=centered_rms_prop,json=centeredRmsProp,proto3,oneof"`
}
type OptimizationParameters_FrequencyEstimator ¶
type OptimizationParameters_FrequencyEstimator struct {
FrequencyEstimator *FrequencyEstimatorParameters `protobuf:"bytes,23,opt,name=frequency_estimator,json=frequencyEstimator,proto3,oneof"`
}
type OptimizationParameters_Ftrl ¶
type OptimizationParameters_Ftrl struct {
Ftrl *FtrlParameters `protobuf:"bytes,5,opt,name=ftrl,proto3,oneof"`
}
type OptimizationParameters_MdlAdagradLight ¶
type OptimizationParameters_MdlAdagradLight struct {
MdlAdagradLight *MdlAdagradLightParameters `protobuf:"bytes,11,opt,name=mdl_adagrad_light,json=mdlAdagradLight,proto3,oneof"`
}
type OptimizationParameters_Momentum ¶
type OptimizationParameters_Momentum struct {
Momentum *MomentumParameters `protobuf:"bytes,8,opt,name=momentum,proto3,oneof"`
}
type OptimizationParameters_OnlineYogi ¶
type OptimizationParameters_OnlineYogi struct {
OnlineYogi *OnlineYogiParameters `protobuf:"bytes,20,opt,name=online_yogi,json=onlineYogi,proto3,oneof"`
}
type OptimizationParameters_ProximalAdagrad ¶
type OptimizationParameters_ProximalAdagrad struct {
ProximalAdagrad *ProximalAdagradParameters `protobuf:"bytes,14,opt,name=proximal_adagrad,json=proximalAdagrad,proto3,oneof"`
}
type OptimizationParameters_ProximalYogi ¶
type OptimizationParameters_ProximalYogi struct {
ProximalYogi *ProximalYogiParameters `protobuf:"bytes,21,opt,name=proximal_yogi,json=proximalYogi,proto3,oneof"`
}
type OptimizationParameters_RmsProp ¶
type OptimizationParameters_RmsProp struct {
RmsProp *RmsPropParameters `protobuf:"bytes,9,opt,name=rms_prop,json=rmsProp,proto3,oneof"`
}
type OptimizationParameters_StochasticGradientDescent ¶
type OptimizationParameters_StochasticGradientDescent struct {
StochasticGradientDescent *StochasticGradientDescentParameters `protobuf:"bytes,4,opt,name=stochastic_gradient_descent,json=stochasticGradientDescent,proto3,oneof"`
}
type OptimizationParameters_UserDefinedProgram ¶
type OptimizationParameters_UserDefinedProgram struct {
UserDefinedProgram *UserDefinedProgramParameters `protobuf:"bytes,24,opt,name=user_defined_program,json=userDefinedProgram,proto3,oneof"`
}
type PaddingMap ¶
type PaddingMap struct { // Input arg index with dynamic shapes. ArgIndex int32 `protobuf:"varint,1,opt,name=arg_index,json=argIndex,proto3" json:"arg_index,omitempty"` // The dynamic shape dimension index. ShapeIndex int32 `protobuf:"varint,2,opt,name=shape_index,json=shapeIndex,proto3" json:"shape_index,omitempty"` // The arg index that dynamic dimension maps to, which represents the value // of the real shape. PaddingArgIndex int32 `protobuf:"varint,3,opt,name=padding_arg_index,json=paddingArgIndex,proto3" json:"padding_arg_index,omitempty"` // contains filtered or unexported fields }
A mapping between the dynamic shape dimension of an input and the arg that represents the real shape.
func (*PaddingMap) Descriptor
deprecated
func (*PaddingMap) Descriptor() ([]byte, []int)
Deprecated: Use PaddingMap.ProtoReflect.Descriptor instead.
func (*PaddingMap) GetArgIndex ¶
func (x *PaddingMap) GetArgIndex() int32
func (*PaddingMap) GetPaddingArgIndex ¶
func (x *PaddingMap) GetPaddingArgIndex() int32
func (*PaddingMap) GetShapeIndex ¶
func (x *PaddingMap) GetShapeIndex() int32
func (*PaddingMap) ProtoMessage ¶
func (*PaddingMap) ProtoMessage()
func (*PaddingMap) ProtoReflect ¶
func (x *PaddingMap) ProtoReflect() protoreflect.Message
func (*PaddingMap) Reset ¶
func (x *PaddingMap) Reset()
func (*PaddingMap) String ¶
func (x *PaddingMap) String() string
type ProximalAdagradParameters ¶
type ProximalAdagradParameters struct { L1 float32 `protobuf:"fixed32,1,opt,name=l1,proto3" json:"l1,omitempty"` L2 float32 `protobuf:"fixed32,2,opt,name=l2,proto3" json:"l2,omitempty"` // contains filtered or unexported fields }
https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/ProximalAdagradOptimizer https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L1961
func (*ProximalAdagradParameters) Descriptor
deprecated
func (*ProximalAdagradParameters) Descriptor() ([]byte, []int)
Deprecated: Use ProximalAdagradParameters.ProtoReflect.Descriptor instead.
func (*ProximalAdagradParameters) GetL1 ¶
func (x *ProximalAdagradParameters) GetL1() float32
func (*ProximalAdagradParameters) GetL2 ¶
func (x *ProximalAdagradParameters) GetL2() float32
func (*ProximalAdagradParameters) ProtoMessage ¶
func (*ProximalAdagradParameters) ProtoMessage()
func (*ProximalAdagradParameters) ProtoReflect ¶
func (x *ProximalAdagradParameters) ProtoReflect() protoreflect.Message
func (*ProximalAdagradParameters) Reset ¶
func (x *ProximalAdagradParameters) Reset()
func (*ProximalAdagradParameters) String ¶
func (x *ProximalAdagradParameters) String() string
type ProximalYogiParameters ¶
type ProximalYogiParameters struct { // The L1 regularization parameter. L1 float32 `protobuf:"fixed32,1,opt,name=l1,proto3" json:"l1,omitempty"` // The L2 regularization parameter. L2 float32 `protobuf:"fixed32,2,opt,name=l2,proto3" json:"l2,omitempty"` // The exponential decay rate for the 1st moment estimates. Beta1 float32 `protobuf:"fixed32,3,opt,name=beta1,proto3" json:"beta1,omitempty"` // The exponential decay rate for the 2nd moment estimates. Beta2 float32 `protobuf:"fixed32,4,opt,name=beta2,proto3" json:"beta2,omitempty"` // A constant trading off adaptivity and noise. Epsilon float32 `protobuf:"fixed32,5,opt,name=epsilon,proto3" json:"epsilon,omitempty"` // contains filtered or unexported fields }
The online Yogi optimizer does not implement hyper-parameter update; use the dynamic learning rate feature instead, setting the learning rate to: user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) Here, t is the current timestep.
https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf plus some extensions based on FTRL.
Note that the code by default implements the lazy version of proximal Yogi.
func (*ProximalYogiParameters) Descriptor
deprecated
func (*ProximalYogiParameters) Descriptor() ([]byte, []int)
Deprecated: Use ProximalYogiParameters.ProtoReflect.Descriptor instead.
func (*ProximalYogiParameters) GetBeta1 ¶
func (x *ProximalYogiParameters) GetBeta1() float32
func (*ProximalYogiParameters) GetBeta2 ¶
func (x *ProximalYogiParameters) GetBeta2() float32
func (*ProximalYogiParameters) GetEpsilon ¶
func (x *ProximalYogiParameters) GetEpsilon() float32
func (*ProximalYogiParameters) GetL1 ¶
func (x *ProximalYogiParameters) GetL1() float32
func (*ProximalYogiParameters) GetL2 ¶
func (x *ProximalYogiParameters) GetL2() float32
func (*ProximalYogiParameters) ProtoMessage ¶
func (*ProximalYogiParameters) ProtoMessage()
func (*ProximalYogiParameters) ProtoReflect ¶
func (x *ProximalYogiParameters) ProtoReflect() protoreflect.Message
func (*ProximalYogiParameters) Reset ¶
func (x *ProximalYogiParameters) Reset()
func (*ProximalYogiParameters) String ¶
func (x *ProximalYogiParameters) String() string
type RmsPropParameters ¶
type RmsPropParameters struct { Rho float32 `protobuf:"fixed32,1,opt,name=rho,proto3" json:"rho,omitempty"` Momentum float32 `protobuf:"fixed32,2,opt,name=momentum,proto3" json:"momentum,omitempty"` Epsilon float32 `protobuf:"fixed32,3,opt,name=epsilon,proto3" json:"epsilon,omitempty"` // contains filtered or unexported fields }
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L4229
func (*RmsPropParameters) Descriptor
deprecated
func (*RmsPropParameters) Descriptor() ([]byte, []int)
Deprecated: Use RmsPropParameters.ProtoReflect.Descriptor instead.
func (*RmsPropParameters) GetEpsilon ¶
func (x *RmsPropParameters) GetEpsilon() float32
func (*RmsPropParameters) GetMomentum ¶
func (x *RmsPropParameters) GetMomentum() float32
func (*RmsPropParameters) GetRho ¶
func (x *RmsPropParameters) GetRho() float32
func (*RmsPropParameters) ProtoMessage ¶
func (*RmsPropParameters) ProtoMessage()
func (*RmsPropParameters) ProtoReflect ¶
func (x *RmsPropParameters) ProtoReflect() protoreflect.Message
func (*RmsPropParameters) Reset ¶
func (x *RmsPropParameters) Reset()
func (*RmsPropParameters) String ¶
func (x *RmsPropParameters) String() string
type SimulatedQuantization ¶
type SimulatedQuantization struct { // Whether simulated quantization is enabled. Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"` // Minimum and maximum values of the range used for quantization. ClippingLimits *ClippingLimits `protobuf:"bytes,2,opt,name=clipping_limits,json=clippingLimits,proto3" json:"clipping_limits,omitempty"` // Number of possible quantized values. NumBuckets int32 `protobuf:"varint,3,opt,name=num_buckets,json=numBuckets,proto3" json:"num_buckets,omitempty"` // contains filtered or unexported fields }
Configuration for simulated quantization; simulated quantization is used to reduce training/serving skew when the serving variables are quantized. The same quantization operations are executed during training to minimize differences with serving.
Simulated quantization inserts the following operations on the forward pass after gathering the embedding vector from HBM. The backward pass operations are unchanged.
clipped_val = clip(input, clipping_limits) quantum = clipping_limits.range() / (num_buckets - 1) quantized_val = floor((clipped_val - clipping_limits.lower()) / quantum + .5) return quantized_val * quantum + clipping_limits.lower().
func (*SimulatedQuantization) Descriptor
deprecated
func (*SimulatedQuantization) Descriptor() ([]byte, []int)
Deprecated: Use SimulatedQuantization.ProtoReflect.Descriptor instead.
func (*SimulatedQuantization) GetClippingLimits ¶
func (x *SimulatedQuantization) GetClippingLimits() *ClippingLimits
func (*SimulatedQuantization) GetEnabled ¶
func (x *SimulatedQuantization) GetEnabled() bool
func (*SimulatedQuantization) GetNumBuckets ¶
func (x *SimulatedQuantization) GetNumBuckets() int32
func (*SimulatedQuantization) ProtoMessage ¶
func (*SimulatedQuantization) ProtoMessage()
func (*SimulatedQuantization) ProtoReflect ¶
func (x *SimulatedQuantization) ProtoReflect() protoreflect.Message
func (*SimulatedQuantization) Reset ¶
func (x *SimulatedQuantization) Reset()
func (*SimulatedQuantization) String ¶
func (x *SimulatedQuantization) String() string
type StateVariableSpecification ¶
type StateVariableSpecification struct { // Parameter name for the state variable. Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` // Usage type of this state variable. // // Types that are assignable to Usage: // // *StateVariableSpecification_UserDefined_ // *StateVariableSpecification_FillWithConstant_ Usage isStateVariableSpecification_Usage `protobuf_oneof:"usage"` // contains filtered or unexported fields }
Specification of an optimization algorithm's state variables (both the main value vector and any extra accumulators, etc.). This proto is only used internally by the TPU software and is not exposed directly to the TF model.
func (*StateVariableSpecification) Descriptor
deprecated
func (*StateVariableSpecification) Descriptor() ([]byte, []int)
Deprecated: Use StateVariableSpecification.ProtoReflect.Descriptor instead.
func (*StateVariableSpecification) GetFillWithConstant ¶
func (x *StateVariableSpecification) GetFillWithConstant() *StateVariableSpecification_FillWithConstant
func (*StateVariableSpecification) GetName ¶
func (x *StateVariableSpecification) GetName() string
func (*StateVariableSpecification) GetUsage ¶
func (m *StateVariableSpecification) GetUsage() isStateVariableSpecification_Usage
func (*StateVariableSpecification) GetUserDefined ¶
func (x *StateVariableSpecification) GetUserDefined() *StateVariableSpecification_UserDefined
func (*StateVariableSpecification) ProtoMessage ¶
func (*StateVariableSpecification) ProtoMessage()
func (*StateVariableSpecification) ProtoReflect ¶
func (x *StateVariableSpecification) ProtoReflect() protoreflect.Message
func (*StateVariableSpecification) Reset ¶
func (x *StateVariableSpecification) Reset()
func (*StateVariableSpecification) String ¶
func (x *StateVariableSpecification) String() string
type StateVariableSpecification_FillWithConstant ¶
type StateVariableSpecification_FillWithConstant struct { InitialValue float64 `protobuf:"fixed64,1,opt,name=initial_value,json=initialValue,proto3" json:"initial_value,omitempty"` // contains filtered or unexported fields }
A state variable that should be filled with a constant and normally hidden from users (used for intermediate gradients being accumulated, for example).
func (*StateVariableSpecification_FillWithConstant) Descriptor
deprecated
func (*StateVariableSpecification_FillWithConstant) Descriptor() ([]byte, []int)
Deprecated: Use StateVariableSpecification_FillWithConstant.ProtoReflect.Descriptor instead.
func (*StateVariableSpecification_FillWithConstant) GetInitialValue ¶
func (x *StateVariableSpecification_FillWithConstant) GetInitialValue() float64
func (*StateVariableSpecification_FillWithConstant) ProtoMessage ¶
func (*StateVariableSpecification_FillWithConstant) ProtoMessage()
func (*StateVariableSpecification_FillWithConstant) ProtoReflect ¶
func (x *StateVariableSpecification_FillWithConstant) ProtoReflect() protoreflect.Message
func (*StateVariableSpecification_FillWithConstant) Reset ¶
func (x *StateVariableSpecification_FillWithConstant) Reset()
func (*StateVariableSpecification_FillWithConstant) String ¶
func (x *StateVariableSpecification_FillWithConstant) String() string
type StateVariableSpecification_FillWithConstant_ ¶
type StateVariableSpecification_FillWithConstant_ struct {
FillWithConstant *StateVariableSpecification_FillWithConstant `protobuf:"bytes,3,opt,name=fill_with_constant,json=fillWithConstant,proto3,oneof"`
}
type StateVariableSpecification_UserDefined ¶
type StateVariableSpecification_UserDefined struct {
// contains filtered or unexported fields
}
A normal state variable that should be saved and restored in checkpoints and used as an input or output to non-debug TensorFlow ops.
func (*StateVariableSpecification_UserDefined) Descriptor
deprecated
func (*StateVariableSpecification_UserDefined) Descriptor() ([]byte, []int)
Deprecated: Use StateVariableSpecification_UserDefined.ProtoReflect.Descriptor instead.
func (*StateVariableSpecification_UserDefined) ProtoMessage ¶
func (*StateVariableSpecification_UserDefined) ProtoMessage()
func (*StateVariableSpecification_UserDefined) ProtoReflect ¶
func (x *StateVariableSpecification_UserDefined) ProtoReflect() protoreflect.Message
func (*StateVariableSpecification_UserDefined) Reset ¶
func (x *StateVariableSpecification_UserDefined) Reset()
func (*StateVariableSpecification_UserDefined) String ¶
func (x *StateVariableSpecification_UserDefined) String() string
type StateVariableSpecification_UserDefined_ ¶
type StateVariableSpecification_UserDefined_ struct {
UserDefined *StateVariableSpecification_UserDefined `protobuf:"bytes,2,opt,name=user_defined,json=userDefined,proto3,oneof"`
}
type StochasticGradientDescentParameters ¶
type StochasticGradientDescentParameters struct {
// contains filtered or unexported fields
}
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L629
func (*StochasticGradientDescentParameters) Descriptor
deprecated
func (*StochasticGradientDescentParameters) Descriptor() ([]byte, []int)
Deprecated: Use StochasticGradientDescentParameters.ProtoReflect.Descriptor instead.
func (*StochasticGradientDescentParameters) ProtoMessage ¶
func (*StochasticGradientDescentParameters) ProtoMessage()
func (*StochasticGradientDescentParameters) ProtoReflect ¶
func (x *StochasticGradientDescentParameters) ProtoReflect() protoreflect.Message
func (*StochasticGradientDescentParameters) Reset ¶
func (x *StochasticGradientDescentParameters) Reset()
func (*StochasticGradientDescentParameters) String ¶
func (x *StochasticGradientDescentParameters) String() string
type TPUCompileMetadataProto ¶
type TPUCompileMetadataProto struct { Args []*TPUCompileMetadataProto_Arg `protobuf:"bytes,1,rep,name=args,proto3" json:"args,omitempty"` Retvals []*TPUCompileMetadataProto_Retval `protobuf:"bytes,2,rep,name=retvals,proto3" json:"retvals,omitempty"` // Number of replicas of the computation and number of cores in each replica. // TODO(b/140721404): it may not be necessary to state the number of cores per // replica here. Reconsider when replicated model-parallelism is implemented // in XLA. NumReplicas int32 `protobuf:"varint,3,opt,name=num_replicas,json=numReplicas,proto3" json:"num_replicas,omitempty"` NumCoresPerReplica int32 `protobuf:"varint,4,opt,name=num_cores_per_replica,json=numCoresPerReplica,proto3" json:"num_cores_per_replica,omitempty"` DeviceAssignment *data.DeviceAssignmentProto `protobuf:"bytes,8,opt,name=device_assignment,json=deviceAssignment,proto3" json:"device_assignment,omitempty"` // A fingerprint of the function library. Ensures that any functions called // by the computation have matching definitions. FunctionLibraryFingerprint uint64 `` /* 142-byte string literal not displayed */ // Unique session identifier. Can be empty. SessionHandle string `protobuf:"bytes,9,opt,name=session_handle,json=sessionHandle,proto3" json:"session_handle,omitempty"` // Fingerprint of guaranteed_const value. The fingerprint computation inside // tpu_compile_op may be slow. The computation can be avoided by setting the // fingerprint value here. GuaranteedConstFingerprint string `` /* 142-byte string literal not displayed */ PaddingMaps []*PaddingMap `protobuf:"bytes,11,rep,name=padding_maps,json=paddingMaps,proto3" json:"padding_maps,omitempty"` // The location of step markers that XLA compile will instrument. StepMarkerLocation xla.DebugOptions_StepMarkerLocation `` /* 160-byte string literal not displayed */ // Minimum number of batches run through the XLA graph before XLA fusion // autotuner is enabled. Default value of zero disables the autotuner. // The XLA fusion autotuner can improve performance by executing a heuristic // search on the compiler parameters. XlaFusionAutotunerThresh int64 `` /* 139-byte string literal not displayed */ // Enables TPU compiler to add partitioning policies for inputs/outputs to // the XLA computation for model parallelism. EnableAutomaticModelParallelism bool `` /* 160-byte string literal not displayed */ // Whether to use XLA's SPMD or MPMD partitioner when compiler partitioning is // requested. UseSpmdForXlaPartitioning bool `` /* 144-byte string literal not displayed */ // Whether to automatically generate XLA shardings for SPMD partitioner. UseAutoSpmdForXlaPartitioning bool `` /* 158-byte string literal not displayed */ // Device mesh shape used to create the sharding search space when // use_auto_spmd_partitioning=true. AutoSpmdMeshShape []int64 `protobuf:"varint,19,rep,packed,name=auto_spmd_mesh_shape,json=autoSpmdMeshShape,proto3" json:"auto_spmd_mesh_shape,omitempty"` // Device mesh ids compatible with the above mesh_shape used when // use_auto_spmd_partitioning=true. AutoSpmdMeshIds []int64 `protobuf:"varint,20,rep,packed,name=auto_spmd_mesh_ids,json=autoSpmdMeshIds,proto3" json:"auto_spmd_mesh_ids,omitempty"` // A fingerprint generated by hashing the MLIR module content. MlirFingerprint uint64 `protobuf:"varint,17,opt,name=mlir_fingerprint,json=mlirFingerprint,proto3" json:"mlir_fingerprint,omitempty"` CompileOptions *TPUCompileOptions `protobuf:"bytes,21,opt,name=compile_options,json=compileOptions,proto3" json:"compile_options,omitempty"` // contains filtered or unexported fields }
This is an experimental proto used in the TF/XLA bridge to store metadata to a compile op (e.g. _TPUCompileMlir). TODO(lyandy): Deprecate proto once generic metadata proto is created.
func (*TPUCompileMetadataProto) Descriptor
deprecated
func (*TPUCompileMetadataProto) Descriptor() ([]byte, []int)
Deprecated: Use TPUCompileMetadataProto.ProtoReflect.Descriptor instead.
func (*TPUCompileMetadataProto) GetArgs ¶
func (x *TPUCompileMetadataProto) GetArgs() []*TPUCompileMetadataProto_Arg
func (*TPUCompileMetadataProto) GetAutoSpmdMeshIds ¶
func (x *TPUCompileMetadataProto) GetAutoSpmdMeshIds() []int64
func (*TPUCompileMetadataProto) GetAutoSpmdMeshShape ¶
func (x *TPUCompileMetadataProto) GetAutoSpmdMeshShape() []int64
func (*TPUCompileMetadataProto) GetCompileOptions ¶
func (x *TPUCompileMetadataProto) GetCompileOptions() *TPUCompileOptions
func (*TPUCompileMetadataProto) GetDeviceAssignment ¶
func (x *TPUCompileMetadataProto) GetDeviceAssignment() *data.DeviceAssignmentProto
func (*TPUCompileMetadataProto) GetEnableAutomaticModelParallelism ¶
func (x *TPUCompileMetadataProto) GetEnableAutomaticModelParallelism() bool
func (*TPUCompileMetadataProto) GetFunctionLibraryFingerprint ¶
func (x *TPUCompileMetadataProto) GetFunctionLibraryFingerprint() uint64
func (*TPUCompileMetadataProto) GetGuaranteedConstFingerprint ¶
func (x *TPUCompileMetadataProto) GetGuaranteedConstFingerprint() string
func (*TPUCompileMetadataProto) GetMlirFingerprint ¶
func (x *TPUCompileMetadataProto) GetMlirFingerprint() uint64
func (*TPUCompileMetadataProto) GetNumCoresPerReplica ¶
func (x *TPUCompileMetadataProto) GetNumCoresPerReplica() int32
func (*TPUCompileMetadataProto) GetNumReplicas ¶
func (x *TPUCompileMetadataProto) GetNumReplicas() int32
func (*TPUCompileMetadataProto) GetPaddingMaps ¶
func (x *TPUCompileMetadataProto) GetPaddingMaps() []*PaddingMap
func (*TPUCompileMetadataProto) GetRetvals ¶
func (x *TPUCompileMetadataProto) GetRetvals() []*TPUCompileMetadataProto_Retval
func (*TPUCompileMetadataProto) GetSessionHandle ¶
func (x *TPUCompileMetadataProto) GetSessionHandle() string
func (*TPUCompileMetadataProto) GetStepMarkerLocation ¶
func (x *TPUCompileMetadataProto) GetStepMarkerLocation() xla.DebugOptions_StepMarkerLocation
func (*TPUCompileMetadataProto) GetUseAutoSpmdForXlaPartitioning ¶
func (x *TPUCompileMetadataProto) GetUseAutoSpmdForXlaPartitioning() bool
func (*TPUCompileMetadataProto) GetUseSpmdForXlaPartitioning ¶
func (x *TPUCompileMetadataProto) GetUseSpmdForXlaPartitioning() bool
func (*TPUCompileMetadataProto) GetXlaFusionAutotunerThresh ¶
func (x *TPUCompileMetadataProto) GetXlaFusionAutotunerThresh() int64
func (*TPUCompileMetadataProto) ProtoMessage ¶
func (*TPUCompileMetadataProto) ProtoMessage()
func (*TPUCompileMetadataProto) ProtoReflect ¶
func (x *TPUCompileMetadataProto) ProtoReflect() protoreflect.Message
func (*TPUCompileMetadataProto) Reset ¶
func (x *TPUCompileMetadataProto) Reset()
func (*TPUCompileMetadataProto) String ¶
func (x *TPUCompileMetadataProto) String() string
type TPUCompileMetadataProto_Arg ¶
type TPUCompileMetadataProto_Arg struct { Dtype framework.DataType `protobuf:"varint,1,opt,name=dtype,proto3,enum=tensorflow.DataType" json:"dtype,omitempty"` Shape *framework.TensorShapeProto `protobuf:"bytes,2,opt,name=shape,proto3" json:"shape,omitempty"` Kind TPUCompileMetadataProto_Arg_Kind `protobuf:"varint,3,opt,name=kind,proto3,enum=tensorflow.tpu.TPUCompileMetadataProto_Arg_Kind" json:"kind,omitempty"` // The cross-core sharding of this input within each replica, e.g., // assigning to one core, or replicate across all cores. Sharding *data.OpSharding `protobuf:"bytes,4,opt,name=sharding,proto3" json:"sharding,omitempty"` // Whether this argument will receive the same data across all replicas. IsSameDataAcrossReplicas bool `` /* 140-byte string literal not displayed */ // Whether to allow XLA to produce separate programs to shard/unshard this // argument. Requires this arg to be an on-device Kind::VARIABLE, or a // Kind::PARAMETER. For Kind::PARAMETER, it represents the initial value of // a variable, and retval_index_for_sharding must be specified for the // corresponding updated value. EnableXlaSharding TPUCompileMetadataProto_Arg_EnableXlaSharding `` /* 181-byte string literal not displayed */ // If XLA sharding is allowed on a Kind::PARAMETER, this field is used to // specify the corresponding updated value in the return values. Use -1 for // variables that are not updated. RetvalIndexForSharding int32 `` /* 132-byte string literal not displayed */ // Whether this argument is placed on fast memory or not. FastMem bool `protobuf:"varint,7,opt,name=fast_mem,json=fastMem,proto3" json:"fast_mem,omitempty"` // Whether to let XLA to decide the layout during compilation, as opposed to // using a fixed layout determined by the shape. UnrestrictedLayout bool `protobuf:"varint,9,opt,name=unrestricted_layout,json=unrestrictedLayout,proto3" json:"unrestricted_layout,omitempty"` // Name of the node that the arg comes from. Name string `protobuf:"bytes,10,opt,name=name,proto3" json:"name,omitempty"` // Whether to use XLA collectives to broadcast this parameter to all // replicas, instead of using TensorFlow Send/Recv among the tasks. RequiresXlaBroadcast bool `protobuf:"varint,11,opt,name=requires_xla_broadcast,json=requiresXlaBroadcast,proto3" json:"requires_xla_broadcast,omitempty"` // contains filtered or unexported fields }
Description of the types and shapes of the arguments to a computation.
func (*TPUCompileMetadataProto_Arg) Descriptor
deprecated
func (*TPUCompileMetadataProto_Arg) Descriptor() ([]byte, []int)
Deprecated: Use TPUCompileMetadataProto_Arg.ProtoReflect.Descriptor instead.
func (*TPUCompileMetadataProto_Arg) GetDtype ¶
func (x *TPUCompileMetadataProto_Arg) GetDtype() framework.DataType
func (*TPUCompileMetadataProto_Arg) GetEnableXlaSharding ¶
func (x *TPUCompileMetadataProto_Arg) GetEnableXlaSharding() TPUCompileMetadataProto_Arg_EnableXlaSharding
func (*TPUCompileMetadataProto_Arg) GetFastMem ¶
func (x *TPUCompileMetadataProto_Arg) GetFastMem() bool
func (*TPUCompileMetadataProto_Arg) GetIsSameDataAcrossReplicas ¶
func (x *TPUCompileMetadataProto_Arg) GetIsSameDataAcrossReplicas() bool
func (*TPUCompileMetadataProto_Arg) GetKind ¶
func (x *TPUCompileMetadataProto_Arg) GetKind() TPUCompileMetadataProto_Arg_Kind
func (*TPUCompileMetadataProto_Arg) GetName ¶
func (x *TPUCompileMetadataProto_Arg) GetName() string
func (*TPUCompileMetadataProto_Arg) GetRequiresXlaBroadcast ¶
func (x *TPUCompileMetadataProto_Arg) GetRequiresXlaBroadcast() bool
func (*TPUCompileMetadataProto_Arg) GetRetvalIndexForSharding ¶
func (x *TPUCompileMetadataProto_Arg) GetRetvalIndexForSharding() int32
func (*TPUCompileMetadataProto_Arg) GetShape ¶
func (x *TPUCompileMetadataProto_Arg) GetShape() *framework.TensorShapeProto
func (*TPUCompileMetadataProto_Arg) GetSharding ¶
func (x *TPUCompileMetadataProto_Arg) GetSharding() *data.OpSharding
func (*TPUCompileMetadataProto_Arg) GetUnrestrictedLayout ¶
func (x *TPUCompileMetadataProto_Arg) GetUnrestrictedLayout() bool
func (*TPUCompileMetadataProto_Arg) ProtoMessage ¶
func (*TPUCompileMetadataProto_Arg) ProtoMessage()
func (*TPUCompileMetadataProto_Arg) ProtoReflect ¶
func (x *TPUCompileMetadataProto_Arg) ProtoReflect() protoreflect.Message
func (*TPUCompileMetadataProto_Arg) Reset ¶
func (x *TPUCompileMetadataProto_Arg) Reset()
func (*TPUCompileMetadataProto_Arg) String ¶
func (x *TPUCompileMetadataProto_Arg) String() string
type TPUCompileMetadataProto_Arg_EnableXlaSharding ¶
type TPUCompileMetadataProto_Arg_EnableXlaSharding int32
const ( TPUCompileMetadataProto_Arg_DISALLOWED TPUCompileMetadataProto_Arg_EnableXlaSharding = 0 // Sharding is allowed if host training loop exists. TPUCompileMetadataProto_Arg_TENTATIVE TPUCompileMetadataProto_Arg_EnableXlaSharding = 1 TPUCompileMetadataProto_Arg_ALLOWED TPUCompileMetadataProto_Arg_EnableXlaSharding = 2 )
func (TPUCompileMetadataProto_Arg_EnableXlaSharding) Descriptor ¶
func (TPUCompileMetadataProto_Arg_EnableXlaSharding) Descriptor() protoreflect.EnumDescriptor
func (TPUCompileMetadataProto_Arg_EnableXlaSharding) EnumDescriptor
deprecated
func (TPUCompileMetadataProto_Arg_EnableXlaSharding) EnumDescriptor() ([]byte, []int)
Deprecated: Use TPUCompileMetadataProto_Arg_EnableXlaSharding.Descriptor instead.
func (TPUCompileMetadataProto_Arg_EnableXlaSharding) Number ¶
func (x TPUCompileMetadataProto_Arg_EnableXlaSharding) Number() protoreflect.EnumNumber
func (TPUCompileMetadataProto_Arg_EnableXlaSharding) String ¶
func (x TPUCompileMetadataProto_Arg_EnableXlaSharding) String() string
type TPUCompileMetadataProto_Arg_Kind ¶
type TPUCompileMetadataProto_Arg_Kind int32
const ( TPUCompileMetadataProto_Arg_INVALID TPUCompileMetadataProto_Arg_Kind = 0 TPUCompileMetadataProto_Arg_PARAMETER TPUCompileMetadataProto_Arg_Kind = 1 TPUCompileMetadataProto_Arg_VARIABLE TPUCompileMetadataProto_Arg_Kind = 2 // These are args which have been guaranteed to be constants during the // session lifetime by the use of the GuaranteeConstOp (or ConstantOp). TPUCompileMetadataProto_Arg_GUARANTEED_CONSTANT TPUCompileMetadataProto_Arg_Kind = 3 )
func (TPUCompileMetadataProto_Arg_Kind) Descriptor ¶
func (TPUCompileMetadataProto_Arg_Kind) Descriptor() protoreflect.EnumDescriptor
func (TPUCompileMetadataProto_Arg_Kind) Enum ¶
func (x TPUCompileMetadataProto_Arg_Kind) Enum() *TPUCompileMetadataProto_Arg_Kind
func (TPUCompileMetadataProto_Arg_Kind) EnumDescriptor
deprecated
func (TPUCompileMetadataProto_Arg_Kind) EnumDescriptor() ([]byte, []int)
Deprecated: Use TPUCompileMetadataProto_Arg_Kind.Descriptor instead.
func (TPUCompileMetadataProto_Arg_Kind) Number ¶
func (x TPUCompileMetadataProto_Arg_Kind) Number() protoreflect.EnumNumber
func (TPUCompileMetadataProto_Arg_Kind) String ¶
func (x TPUCompileMetadataProto_Arg_Kind) String() string
func (TPUCompileMetadataProto_Arg_Kind) Type ¶
func (TPUCompileMetadataProto_Arg_Kind) Type() protoreflect.EnumType
type TPUCompileMetadataProto_Retval ¶
type TPUCompileMetadataProto_Retval struct { // The cross-core sharding of this return value within each replica, e.g., // assigning to one core, or replicate across all cores. Sharding *data.OpSharding `protobuf:"bytes,1,opt,name=sharding,proto3" json:"sharding,omitempty"` // contains filtered or unexported fields }
Description of the return values from a computation.
func (*TPUCompileMetadataProto_Retval) Descriptor
deprecated
func (*TPUCompileMetadataProto_Retval) Descriptor() ([]byte, []int)
Deprecated: Use TPUCompileMetadataProto_Retval.ProtoReflect.Descriptor instead.
func (*TPUCompileMetadataProto_Retval) GetSharding ¶
func (x *TPUCompileMetadataProto_Retval) GetSharding() *data.OpSharding
func (*TPUCompileMetadataProto_Retval) ProtoMessage ¶
func (*TPUCompileMetadataProto_Retval) ProtoMessage()
func (*TPUCompileMetadataProto_Retval) ProtoReflect ¶
func (x *TPUCompileMetadataProto_Retval) ProtoReflect() protoreflect.Message
func (*TPUCompileMetadataProto_Retval) Reset ¶
func (x *TPUCompileMetadataProto_Retval) Reset()
func (*TPUCompileMetadataProto_Retval) String ¶
func (x *TPUCompileMetadataProto_Retval) String() string
type TPUCompileOptions ¶
type TPUCompileOptions struct { MatrixUnitOperandPrecision TPUCompileOptions_Precision `` /* 192-byte string literal not displayed */ // contains filtered or unexported fields }
Stable protobuf for TPU compilation options, suitable for persistent storage. This proto needs to be backward compatible under maintenance. TODO(timshen): investigate and migrate other options from TPUCompileMetadataProto.
func (*TPUCompileOptions) Descriptor
deprecated
func (*TPUCompileOptions) Descriptor() ([]byte, []int)
Deprecated: Use TPUCompileOptions.ProtoReflect.Descriptor instead.
func (*TPUCompileOptions) GetMatrixUnitOperandPrecision ¶
func (x *TPUCompileOptions) GetMatrixUnitOperandPrecision() TPUCompileOptions_Precision
func (*TPUCompileOptions) ProtoMessage ¶
func (*TPUCompileOptions) ProtoMessage()
func (*TPUCompileOptions) ProtoReflect ¶
func (x *TPUCompileOptions) ProtoReflect() protoreflect.Message
func (*TPUCompileOptions) Reset ¶
func (x *TPUCompileOptions) Reset()
func (*TPUCompileOptions) String ¶
func (x *TPUCompileOptions) String() string
type TPUCompileOptions_Precision ¶
type TPUCompileOptions_Precision int32
const ( TPUCompileOptions_DEFAULT TPUCompileOptions_Precision = 0 TPUCompileOptions_BFLOAT16 TPUCompileOptions_Precision = 1 TPUCompileOptions_FLOAT32 TPUCompileOptions_Precision = 2 TPUCompileOptions_TENSOR_FLOAT32 TPUCompileOptions_Precision = 3 )
func (TPUCompileOptions_Precision) Descriptor ¶
func (TPUCompileOptions_Precision) Descriptor() protoreflect.EnumDescriptor
func (TPUCompileOptions_Precision) Enum ¶
func (x TPUCompileOptions_Precision) Enum() *TPUCompileOptions_Precision
func (TPUCompileOptions_Precision) EnumDescriptor
deprecated
func (TPUCompileOptions_Precision) EnumDescriptor() ([]byte, []int)
Deprecated: Use TPUCompileOptions_Precision.Descriptor instead.
func (TPUCompileOptions_Precision) Number ¶
func (x TPUCompileOptions_Precision) Number() protoreflect.EnumNumber
func (TPUCompileOptions_Precision) String ¶
func (x TPUCompileOptions_Precision) String() string
func (TPUCompileOptions_Precision) Type ¶
func (TPUCompileOptions_Precision) Type() protoreflect.EnumType
type TPUEmbeddingConfiguration ¶
type TPUEmbeddingConfiguration struct { TableDescriptor []*TPUEmbeddingConfiguration_TableDescriptor `protobuf:"bytes,1,rep,name=table_descriptor,json=tableDescriptor,proto3" json:"table_descriptor,omitempty"` Mode TPUEmbeddingConfiguration_Mode `protobuf:"varint,2,opt,name=mode,proto3,enum=tensorflow.tpu.TPUEmbeddingConfiguration_Mode" json:"mode,omitempty"` // Number of samples in each batch of embedding layer activations sent to // the TensorCore. BatchSizePerTensorCore int32 `` /* 134-byte string literal not displayed */ // Number of TPU hosts used for inference/training. NumHosts int32 `protobuf:"varint,4,opt,name=num_hosts,json=numHosts,proto3" json:"num_hosts,omitempty"` // Number of TensorCore used for inference/training. NumTensorCores int32 `protobuf:"varint,5,opt,name=num_tensor_cores,json=numTensorCores,proto3" json:"num_tensor_cores,omitempty"` ShardingStrategy TPUEmbeddingConfiguration_ShardingStrategy `` /* 173-byte string literal not displayed */ // This parameter determines if the execution of the sparse core will be // pipelined with that of the TensorCore. This parameter only affects results // when mode=TRAINING. If mode=INFERENCE or BACKWARD_PASS_ONLY, this parameter // does not affect execution and hence, is a don't care value. // // false: The execution of the sparse core is not pipelined with that of the // TensorCore. The forward pass of every step on the sparse core is executed // only after the backward pass of the previous step is complete. And the // backward pass on the sparse core is executed only after the embedding // gradients have been computed on the TensorCore on every step. This ensures // that the activations on every step observe the gradient updates from the // previous step on both the sparse core and the TensorCore. // // true: The execution of the sparse core is pipelined with that of the // TensorCore. The forward pass of every step on the sparse core can be // executed after the forward pass of the previous step is complete without // waiting for the backward pass. This improves the utilization of the sparse // core allowing it to process step N+1 while the embedding gradients for step // N are computed on the TensorCore. The backward pass of every step on the // sparse core is executed directly after the forward pass for the next step // is complete. The drawback is that embedding activations for step N+1 do not // observe the embedding gradient updates from step N. This could affect model // quality if step N and N+1 involve the same set of embedding IDs. However, // since the embedding updates are sparse, this is generally not considered a // problem. PipelineExecutionWithTensorCore bool `` /* 161-byte string literal not displayed */ // Directory where embedding lookup statistics are stored. These statistics // summarize information about the inputs to the embedding lookup // operation, in particular, the average number of embedding IDs per example // and how well the embedding IDs are load balanced across the system. The // lookup statistics are used during TPU initialization for embedding table // partitioning. Collection of lookup statistics is done at runtime by // profiling the embedding inputs: only 3% of input samples are profiled to // minimize host CPU overhead. Once a suitable number of samples are // profiled, the lookup statistics are saved to table-specific files in the // profile data directory generally at the end of a TPU training loop. The // filename corresponding to each table is obtained by hashing table specific // parameters (e.g., table name and number of features) and global // configuration parameters (e.g., sharding strategy and TPU worker task // count). The same profile data directory can be shared amongst several // models to reuse embedding lookup statistics. ProfileDataDirectory string `protobuf:"bytes,9,opt,name=profile_data_directory,json=profileDataDirectory,proto3" json:"profile_data_directory,omitempty"` // If the feature_descriptor field is populated, the model should NOT populate // TableDescriptor.num_features and batch_size_per_tensor_core. These two // fields will be auto-populated by the TPUEmbedding rewrite passes. FeatureDescriptor []*TPUEmbeddingConfiguration_FeatureDescriptor `protobuf:"bytes,10,rep,name=feature_descriptor,json=featureDescriptor,proto3" json:"feature_descriptor,omitempty"` SpmdSharding *TPUEmbeddingConfiguration_SpmdSharding `protobuf:"bytes,11,opt,name=spmd_sharding,json=spmdSharding,proto3" json:"spmd_sharding,omitempty"` // contains filtered or unexported fields }
func (*TPUEmbeddingConfiguration) Descriptor
deprecated
func (*TPUEmbeddingConfiguration) Descriptor() ([]byte, []int)
Deprecated: Use TPUEmbeddingConfiguration.ProtoReflect.Descriptor instead.
func (*TPUEmbeddingConfiguration) GetBatchSizePerTensorCore ¶
func (x *TPUEmbeddingConfiguration) GetBatchSizePerTensorCore() int32
func (*TPUEmbeddingConfiguration) GetFeatureDescriptor ¶
func (x *TPUEmbeddingConfiguration) GetFeatureDescriptor() []*TPUEmbeddingConfiguration_FeatureDescriptor
func (*TPUEmbeddingConfiguration) GetMode ¶
func (x *TPUEmbeddingConfiguration) GetMode() TPUEmbeddingConfiguration_Mode
func (*TPUEmbeddingConfiguration) GetNumHosts ¶
func (x *TPUEmbeddingConfiguration) GetNumHosts() int32
func (*TPUEmbeddingConfiguration) GetNumTensorCores ¶
func (x *TPUEmbeddingConfiguration) GetNumTensorCores() int32
func (*TPUEmbeddingConfiguration) GetPipelineExecutionWithTensorCore ¶
func (x *TPUEmbeddingConfiguration) GetPipelineExecutionWithTensorCore() bool
func (*TPUEmbeddingConfiguration) GetProfileDataDirectory ¶
func (x *TPUEmbeddingConfiguration) GetProfileDataDirectory() string
func (*TPUEmbeddingConfiguration) GetShardingStrategy ¶
func (x *TPUEmbeddingConfiguration) GetShardingStrategy() TPUEmbeddingConfiguration_ShardingStrategy
func (*TPUEmbeddingConfiguration) GetSpmdSharding ¶
func (x *TPUEmbeddingConfiguration) GetSpmdSharding() *TPUEmbeddingConfiguration_SpmdSharding
func (*TPUEmbeddingConfiguration) GetTableDescriptor ¶
func (x *TPUEmbeddingConfiguration) GetTableDescriptor() []*TPUEmbeddingConfiguration_TableDescriptor
func (*TPUEmbeddingConfiguration) ProtoMessage ¶
func (*TPUEmbeddingConfiguration) ProtoMessage()
func (*TPUEmbeddingConfiguration) ProtoReflect ¶
func (x *TPUEmbeddingConfiguration) ProtoReflect() protoreflect.Message
func (*TPUEmbeddingConfiguration) Reset ¶
func (x *TPUEmbeddingConfiguration) Reset()
func (*TPUEmbeddingConfiguration) String ¶
func (x *TPUEmbeddingConfiguration) String() string
type TPUEmbeddingConfiguration_FeatureDescriptor ¶
type TPUEmbeddingConfiguration_FeatureDescriptor struct { // Name of the input feature. Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` // Index of the corresponding table in the TableDescriptor list. TableId int32 `protobuf:"varint,2,opt,name=table_id,json=tableId,proto3" json:"table_id,omitempty"` // Static shape of the inputs (excluding the reduction axis). Note that // the shape of the actual inputs provided using the infeed op must be // strictly smaller than input_shape. The outputs received at the TensorCore // will have rank = input_shape.size() + 1. The innermost axis corresponds // to the embedding dimension. If the input has shape [m, n, k] (excluding // the reduction axis) and the embedding dimension is d, the output received // at the TensorCore will have shape [m, n, k, d]. InputShape []int32 `protobuf:"varint,3,rep,packed,name=input_shape,json=inputShape,proto3" json:"input_shape,omitempty"` // contains filtered or unexported fields }
Description of different input features.
func (*TPUEmbeddingConfiguration_FeatureDescriptor) Descriptor
deprecated
func (*TPUEmbeddingConfiguration_FeatureDescriptor) Descriptor() ([]byte, []int)
Deprecated: Use TPUEmbeddingConfiguration_FeatureDescriptor.ProtoReflect.Descriptor instead.
func (*TPUEmbeddingConfiguration_FeatureDescriptor) GetInputShape ¶
func (x *TPUEmbeddingConfiguration_FeatureDescriptor) GetInputShape() []int32
func (*TPUEmbeddingConfiguration_FeatureDescriptor) GetName ¶
func (x *TPUEmbeddingConfiguration_FeatureDescriptor) GetName() string
func (*TPUEmbeddingConfiguration_FeatureDescriptor) GetTableId ¶
func (x *TPUEmbeddingConfiguration_FeatureDescriptor) GetTableId() int32
func (*TPUEmbeddingConfiguration_FeatureDescriptor) ProtoMessage ¶
func (*TPUEmbeddingConfiguration_FeatureDescriptor) ProtoMessage()
func (*TPUEmbeddingConfiguration_FeatureDescriptor) ProtoReflect ¶
func (x *TPUEmbeddingConfiguration_FeatureDescriptor) ProtoReflect() protoreflect.Message
func (*TPUEmbeddingConfiguration_FeatureDescriptor) Reset ¶
func (x *TPUEmbeddingConfiguration_FeatureDescriptor) Reset()
func (*TPUEmbeddingConfiguration_FeatureDescriptor) String ¶
func (x *TPUEmbeddingConfiguration_FeatureDescriptor) String() string
type TPUEmbeddingConfiguration_Mode ¶
type TPUEmbeddingConfiguration_Mode int32
Mode. Should the embedding layer program be run for inference (just forward pass), training (both forward and backward pass) or just the backward_pass.
const ( TPUEmbeddingConfiguration_UNSPECIFIED TPUEmbeddingConfiguration_Mode = 0 TPUEmbeddingConfiguration_INFERENCE TPUEmbeddingConfiguration_Mode = 1 TPUEmbeddingConfiguration_TRAINING TPUEmbeddingConfiguration_Mode = 2 TPUEmbeddingConfiguration_BACKWARD_PASS_ONLY TPUEmbeddingConfiguration_Mode = 3 )
func (TPUEmbeddingConfiguration_Mode) Descriptor ¶
func (TPUEmbeddingConfiguration_Mode) Descriptor() protoreflect.EnumDescriptor
func (TPUEmbeddingConfiguration_Mode) Enum ¶
func (x TPUEmbeddingConfiguration_Mode) Enum() *TPUEmbeddingConfiguration_Mode
func (TPUEmbeddingConfiguration_Mode) EnumDescriptor
deprecated
func (TPUEmbeddingConfiguration_Mode) EnumDescriptor() ([]byte, []int)
Deprecated: Use TPUEmbeddingConfiguration_Mode.Descriptor instead.
func (TPUEmbeddingConfiguration_Mode) Number ¶
func (x TPUEmbeddingConfiguration_Mode) Number() protoreflect.EnumNumber
func (TPUEmbeddingConfiguration_Mode) String ¶
func (x TPUEmbeddingConfiguration_Mode) String() string
func (TPUEmbeddingConfiguration_Mode) Type ¶
func (TPUEmbeddingConfiguration_Mode) Type() protoreflect.EnumType
type TPUEmbeddingConfiguration_ShardingStrategy ¶
type TPUEmbeddingConfiguration_ShardingStrategy int32
Sharding strategy of the embedding tables among the hosts. If the sharding_strategy is "mod", each id is assigned to host "id % num_hosts". For instance, 13 ids are split across 5 hosts as: [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]. If the sharding_strategy is "div", ids are assigned to hosts in a contiguous manner. In this case, 13 ids are split across 5 hosts as: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]. In both the strategies, if the id space does not evenly divide the number of hosts, each of the first "table_descriptor.vocabulary_size % num_hosts" hosts will be assigned one more id. This partitioning strategy exactly follows that in the embedding_lookup TensorFlow function at tensorflow/python/ops/embedding_ops.py.
const ( TPUEmbeddingConfiguration_DIV_DEFAULT TPUEmbeddingConfiguration_ShardingStrategy = 0 TPUEmbeddingConfiguration_MOD TPUEmbeddingConfiguration_ShardingStrategy = 1 )
func (TPUEmbeddingConfiguration_ShardingStrategy) Descriptor ¶
func (TPUEmbeddingConfiguration_ShardingStrategy) Descriptor() protoreflect.EnumDescriptor
func (TPUEmbeddingConfiguration_ShardingStrategy) EnumDescriptor
deprecated
func (TPUEmbeddingConfiguration_ShardingStrategy) EnumDescriptor() ([]byte, []int)
Deprecated: Use TPUEmbeddingConfiguration_ShardingStrategy.Descriptor instead.
func (TPUEmbeddingConfiguration_ShardingStrategy) Number ¶
func (x TPUEmbeddingConfiguration_ShardingStrategy) Number() protoreflect.EnumNumber
func (TPUEmbeddingConfiguration_ShardingStrategy) String ¶
func (x TPUEmbeddingConfiguration_ShardingStrategy) String() string
func (TPUEmbeddingConfiguration_ShardingStrategy) Type ¶
func (TPUEmbeddingConfiguration_ShardingStrategy) Type() protoreflect.EnumType
type TPUEmbeddingConfiguration_SpmdSharding ¶
type TPUEmbeddingConfiguration_SpmdSharding struct { // Whether SPMD sharding is enabled. Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"` // Number of cores per replica. NumCoresPerReplica int32 `protobuf:"varint,2,opt,name=num_cores_per_replica,json=numCoresPerReplica,proto3" json:"num_cores_per_replica,omitempty"` // contains filtered or unexported fields }
SPMD (Single Program Multiple Data) sharding configuration for TPUEmbedding. When model parallelism is used on the TensorCore, the number of cores per replica must be passed to TPUEmbedding so that the right shapes can be computed in the TF/XLA bridge.
func (*TPUEmbeddingConfiguration_SpmdSharding) Descriptor
deprecated
func (*TPUEmbeddingConfiguration_SpmdSharding) Descriptor() ([]byte, []int)
Deprecated: Use TPUEmbeddingConfiguration_SpmdSharding.ProtoReflect.Descriptor instead.
func (*TPUEmbeddingConfiguration_SpmdSharding) GetEnabled ¶
func (x *TPUEmbeddingConfiguration_SpmdSharding) GetEnabled() bool
func (*TPUEmbeddingConfiguration_SpmdSharding) GetNumCoresPerReplica ¶
func (x *TPUEmbeddingConfiguration_SpmdSharding) GetNumCoresPerReplica() int32
func (*TPUEmbeddingConfiguration_SpmdSharding) ProtoMessage ¶
func (*TPUEmbeddingConfiguration_SpmdSharding) ProtoMessage()
func (*TPUEmbeddingConfiguration_SpmdSharding) ProtoReflect ¶
func (x *TPUEmbeddingConfiguration_SpmdSharding) ProtoReflect() protoreflect.Message
func (*TPUEmbeddingConfiguration_SpmdSharding) Reset ¶
func (x *TPUEmbeddingConfiguration_SpmdSharding) Reset()
func (*TPUEmbeddingConfiguration_SpmdSharding) String ¶
func (x *TPUEmbeddingConfiguration_SpmdSharding) String() string
type TPUEmbeddingConfiguration_TableDescriptor ¶
type TPUEmbeddingConfiguration_TableDescriptor struct { // Name of the table. Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` // Size of the vocabulary (i.e., number of rows) in the table. VocabularySize int64 `protobuf:"varint,2,opt,name=vocabulary_size,json=vocabularySize,proto3" json:"vocabulary_size,omitempty"` // The embedding dimension (i.e., the width of the embedding table). Dimension int32 `protobuf:"varint,3,opt,name=dimension,proto3" json:"dimension,omitempty"` // Number of features mapped to this table. NumFeatures int32 `protobuf:"varint,4,opt,name=num_features,json=numFeatures,proto3" json:"num_features,omitempty"` // Details of the learning algorithm used to update the embedding // parameters. OptimizationParameters *OptimizationParameters `` /* 127-byte string literal not displayed */ // contains filtered or unexported fields }
Description of the various embedding tables.
func (*TPUEmbeddingConfiguration_TableDescriptor) Descriptor
deprecated
func (*TPUEmbeddingConfiguration_TableDescriptor) Descriptor() ([]byte, []int)
Deprecated: Use TPUEmbeddingConfiguration_TableDescriptor.ProtoReflect.Descriptor instead.
func (*TPUEmbeddingConfiguration_TableDescriptor) GetDimension ¶
func (x *TPUEmbeddingConfiguration_TableDescriptor) GetDimension() int32
func (*TPUEmbeddingConfiguration_TableDescriptor) GetName ¶
func (x *TPUEmbeddingConfiguration_TableDescriptor) GetName() string
func (*TPUEmbeddingConfiguration_TableDescriptor) GetNumFeatures ¶
func (x *TPUEmbeddingConfiguration_TableDescriptor) GetNumFeatures() int32
func (*TPUEmbeddingConfiguration_TableDescriptor) GetOptimizationParameters ¶
func (x *TPUEmbeddingConfiguration_TableDescriptor) GetOptimizationParameters() *OptimizationParameters
func (*TPUEmbeddingConfiguration_TableDescriptor) GetVocabularySize ¶
func (x *TPUEmbeddingConfiguration_TableDescriptor) GetVocabularySize() int64
func (*TPUEmbeddingConfiguration_TableDescriptor) ProtoMessage ¶
func (*TPUEmbeddingConfiguration_TableDescriptor) ProtoMessage()
func (*TPUEmbeddingConfiguration_TableDescriptor) ProtoReflect ¶
func (x *TPUEmbeddingConfiguration_TableDescriptor) ProtoReflect() protoreflect.Message
func (*TPUEmbeddingConfiguration_TableDescriptor) Reset ¶
func (x *TPUEmbeddingConfiguration_TableDescriptor) Reset()
func (*TPUEmbeddingConfiguration_TableDescriptor) String ¶
func (x *TPUEmbeddingConfiguration_TableDescriptor) String() string
type TPUEmbeddingError ¶
type TPUEmbeddingError struct {
// contains filtered or unexported fields
}
A placeholder message that is used to define a unique Status payload URL for TPU embedding errors.
func (*TPUEmbeddingError) Descriptor
deprecated
func (*TPUEmbeddingError) Descriptor() ([]byte, []int)
Deprecated: Use TPUEmbeddingError.ProtoReflect.Descriptor instead.
func (*TPUEmbeddingError) ProtoMessage ¶
func (*TPUEmbeddingError) ProtoMessage()
func (*TPUEmbeddingError) ProtoReflect ¶
func (x *TPUEmbeddingError) ProtoReflect() protoreflect.Message
func (*TPUEmbeddingError) Reset ¶
func (x *TPUEmbeddingError) Reset()
func (*TPUEmbeddingError) String ¶
func (x *TPUEmbeddingError) String() string
type TPUHardwareFeature ¶
type TPUHardwareFeature struct { EmbeddingFeature TPUHardwareFeature_EmbeddingFeature `` /* 166-byte string literal not displayed */ // contains filtered or unexported fields }
Describes features of a tpu.
func (*TPUHardwareFeature) Descriptor
deprecated
func (*TPUHardwareFeature) Descriptor() ([]byte, []int)
Deprecated: Use TPUHardwareFeature.ProtoReflect.Descriptor instead.
func (*TPUHardwareFeature) GetEmbeddingFeature ¶
func (x *TPUHardwareFeature) GetEmbeddingFeature() TPUHardwareFeature_EmbeddingFeature
func (*TPUHardwareFeature) ProtoMessage ¶
func (*TPUHardwareFeature) ProtoMessage()
func (*TPUHardwareFeature) ProtoReflect ¶
func (x *TPUHardwareFeature) ProtoReflect() protoreflect.Message
func (*TPUHardwareFeature) Reset ¶
func (x *TPUHardwareFeature) Reset()
func (*TPUHardwareFeature) String ¶
func (x *TPUHardwareFeature) String() string
type TPUHardwareFeature_EmbeddingFeature ¶
type TPUHardwareFeature_EmbeddingFeature int32
Embedding feature of a tpu.
const ( // No embedding lookup accelerator available on the tpu. TPUHardwareFeature_UNSUPPORTED TPUHardwareFeature_EmbeddingFeature = 0 // Embedding lookup accelerator V1. The embedding lookup operation can only // be placed at the beginning of computation. Only one instance of embedding // lookup layer is allowed. TPUHardwareFeature_V1 TPUHardwareFeature_EmbeddingFeature = 1 // Embedding lookup accelerator V2. The embedding lookup operation can be // placed anywhere of the computation. Multiple instances of embedding // lookup layer is allowed. TPUHardwareFeature_V2 TPUHardwareFeature_EmbeddingFeature = 2 )
func (TPUHardwareFeature_EmbeddingFeature) Descriptor ¶
func (TPUHardwareFeature_EmbeddingFeature) Descriptor() protoreflect.EnumDescriptor
func (TPUHardwareFeature_EmbeddingFeature) Enum ¶
func (x TPUHardwareFeature_EmbeddingFeature) Enum() *TPUHardwareFeature_EmbeddingFeature
func (TPUHardwareFeature_EmbeddingFeature) EnumDescriptor
deprecated
func (TPUHardwareFeature_EmbeddingFeature) EnumDescriptor() ([]byte, []int)
Deprecated: Use TPUHardwareFeature_EmbeddingFeature.Descriptor instead.
func (TPUHardwareFeature_EmbeddingFeature) Number ¶
func (x TPUHardwareFeature_EmbeddingFeature) Number() protoreflect.EnumNumber
func (TPUHardwareFeature_EmbeddingFeature) String ¶
func (x TPUHardwareFeature_EmbeddingFeature) String() string
func (TPUHardwareFeature_EmbeddingFeature) Type ¶
func (TPUHardwareFeature_EmbeddingFeature) Type() protoreflect.EnumType
type TopologyProto ¶
type TopologyProto struct { // The dimensions of the TPU topology, in cores. Typically, this is a 4D // topology [x, y, z, core], where the major dimensions correspond to TPU // chips, and the minor dimension describes the number of cores on a multicore // chip. MeshShape []int32 `protobuf:"varint,1,rep,packed,name=mesh_shape,json=meshShape,proto3" json:"mesh_shape,omitempty"` // Number of TensorFlow tasks in the cluster. NumTasks int32 `protobuf:"varint,2,opt,name=num_tasks,json=numTasks,proto3" json:"num_tasks,omitempty"` // Number of TPU devices per task. NumTpuDevicesPerTask int32 `` /* 128-byte string literal not displayed */ // A flattened rank 3 int32 array with shape // [num_tasks, num_tpu_devices_per_task, len(mesh_shape)]. // `tasks` is the number of tasks in the TPU cluster, `devices` is the number // of TPU devices per task, and the minor dimension corresponds to a position // in the TPU mesh topology. Each entry [task, device, axis] gives the // `axis`-th coordinate in the topology of a task/device pair. DeviceCoordinates []int32 `protobuf:"varint,4,rep,packed,name=device_coordinates,json=deviceCoordinates,proto3" json:"device_coordinates,omitempty"` // TPU supported features. TpuHardwareFeature *TPUHardwareFeature `protobuf:"bytes,5,opt,name=tpu_hardware_feature,json=tpuHardwareFeature,proto3" json:"tpu_hardware_feature,omitempty"` // contains filtered or unexported fields }
Describes the geometry of a TPU mesh.
func (*TopologyProto) Descriptor
deprecated
func (*TopologyProto) Descriptor() ([]byte, []int)
Deprecated: Use TopologyProto.ProtoReflect.Descriptor instead.
func (*TopologyProto) GetDeviceCoordinates ¶
func (x *TopologyProto) GetDeviceCoordinates() []int32
func (*TopologyProto) GetMeshShape ¶
func (x *TopologyProto) GetMeshShape() []int32
func (*TopologyProto) GetNumTasks ¶
func (x *TopologyProto) GetNumTasks() int32
func (*TopologyProto) GetNumTpuDevicesPerTask ¶
func (x *TopologyProto) GetNumTpuDevicesPerTask() int32
func (*TopologyProto) GetTpuHardwareFeature ¶
func (x *TopologyProto) GetTpuHardwareFeature() *TPUHardwareFeature
func (*TopologyProto) ProtoMessage ¶
func (*TopologyProto) ProtoMessage()
func (*TopologyProto) ProtoReflect ¶
func (x *TopologyProto) ProtoReflect() protoreflect.Message
func (*TopologyProto) Reset ¶
func (x *TopologyProto) Reset()
func (*TopologyProto) String ¶
func (x *TopologyProto) String() string
type UserDefinedProgramParameters ¶
type UserDefinedProgramParameters struct { Program *service.HloModuleProto `protobuf:"bytes,1,opt,name=program,proto3" json:"program,omitempty"` // contains filtered or unexported fields }
A user-defined optimizer. The contained HLO program must take the following arguments in the following order:
- gradients
- table weights
- slot variables
- an optional scalar input that is passed in via the dynamic learning rate mechanism.
It must return/end in a tuple op that contains the following values in the following order: 1. new table values 2. new slot variable value
The program must have shape (1,1) with dtype float32 throughout and only use HLO that operate elementwise (e.g., no reduce, no variables, no control flow and no broadcasting outside of the single scalar input). The HLO program should be written as if it were a dense update. It will be called on each row that needs an update and will applied elementwise.
func (*UserDefinedProgramParameters) Descriptor
deprecated
func (*UserDefinedProgramParameters) Descriptor() ([]byte, []int)
Deprecated: Use UserDefinedProgramParameters.ProtoReflect.Descriptor instead.
func (*UserDefinedProgramParameters) GetProgram ¶
func (x *UserDefinedProgramParameters) GetProgram() *service.HloModuleProto
func (*UserDefinedProgramParameters) ProtoMessage ¶
func (*UserDefinedProgramParameters) ProtoMessage()
func (*UserDefinedProgramParameters) ProtoReflect ¶
func (x *UserDefinedProgramParameters) ProtoReflect() protoreflect.Message
func (*UserDefinedProgramParameters) Reset ¶
func (x *UserDefinedProgramParameters) Reset()
func (*UserDefinedProgramParameters) String ¶
func (x *UserDefinedProgramParameters) String() string