Documentation
¶
Index ¶
- Variables
- func DecodeArgMinMax(op *Op) (axis int, outputDType dtypes.DType, isMin bool)
- func DecodeBroadcast(op *Op) (prefixDims []int)
- func DecodeConcatenate(op *Op) (axis int)
- func DecodeConvGeneralDilated(op *Op) (operand, filter *Op, axes ConvolveAxesConfig, strides []int, paddings [][2]int, ...)
- func DecodeConvertDType(op *Op) (dtype dtypes.DType)
- func DecodeDotGeneral(op *Op) (lhs *Op, lhsContractingAxes, lhsBatchAxes []int, rhs *Op, ...)
- func DecodeDynamicSlice(op *Op) (operand *Op, startIndices []*Op, sliceDims []int)
- func DecodeDynamicUpdateSlice(op *Op) (operand, update *Op, startIndices []*Op)
- func DecodeGather(op *Op) (indexVectorAxis int, ...)
- func DecodeGetTupleElement(op *Op) (elementIdx int)
- func DecodeReduceWindow(op *Op) (reduceType ReduceOpType, reduceComputation *XlaComputation, initialValue *Op, ...)
- func DecodeReshape(op *Op) (dimensions []int)
- func DecodeRngBitGenerator(op *Op) (state *Op, shape Shape)
- func DecodeScatter(op *Op) (indexVectorAxis int, ...)
- func DecodeSelectAndScatter(op *Op) (operand, source, defaultValue *Op, ...)
- func DecodeSlice(op *Op) (starts, limits, strides []int)
- func DecodeTranspose(op *Op) (permutations []int)
- func DecodeWhile(op *Op) (initialState *Op, condition, body *XlaComputation)
- func HasStableHLO() bool
- type ConvolveAxesConfig
- type Literal
- func NewArrayLiteral[T dtypes.Supported](flat []T, dimensions ...int) (*Literal, error)
- func NewArrayLiteralFromAny(flatAny any, dimensions ...int) (*Literal, error)
- func NewLiteralFromShape(shape Shape) (*Literal, error)
- func NewScalarLiteral[T dtypes.Supported](value T) *Literal
- func NewScalarLiteralFromAny(value any) (*Literal, error)
- func NewScalarLiteralFromFloat64(value float64, dtype dtypes.DType) (*Literal, error)
- type Op
- func Abs(x *Op) (*Op, error)
- func Add(x0, x1 *Op) (*Op, error)
- func And(x0, x1 *Op) (*Op, error)
- func ArgMinMax(x *Op, axis int, outputDType dtypes.DType, isMin bool) (*Op, error)
- func BatchNormForInference(operand, scale, offset, mean, variance *Op, epsilon float32, axis int) (*Op, error)
- func BatchNormForTraining(operand, scale, offset *Op, epsilon float32, axis int) (normalized, batchMean, batchVariance *Op, err error)
- func BatchNormGradient(operand, scale, mean, variance, gradOutput *Op, epsilon float32, axis int) (gradOperand, gradScale, gradOffset *Op, err error)
- func Broadcast(x *Op, prefixDims ...int) (*Op, error)
- func BroadcastInDim(x *Op, outputShape Shape, broadcastAxes []int) (*Op, error)
- func Call(builder *XlaBuilder, subComputation *XlaComputation, operands ...*Op) (*Op, error)
- func Ceil(x *Op) (*Op, error)
- func Clz(x *Op) (*Op, error)
- func Complex(x0, x1 *Op) (*Op, error)
- func Concatenate(axis int, operands ...*Op) (*Op, error)
- func Conj(x *Op) (*Op, error)
- func Constant(builder *XlaBuilder, x *Literal) (*Op, error)
- func ConvGeneralDilated(operand, filter *Op, axes ConvolveAxesConfig, strides []int, paddings [][2]int, ...) (*Op, error)
- func ConvertDType(x *Op, dtype dtypes.DType) (*Op, error)
- func Cos(x *Op) (*Op, error)
- func DecodeBatchNormForInference(op *Op) (operand, scale, offset, mean, variance *Op, epsilon float32, axis int)
- func DecodeBatchNormForTraining(op *Op) (operand, scale, offset *Op, epsilon float32, axis int)
- func DecodeBatchNormGrad(op *Op) (operand, scale, mean, variance, gradOutput *Op, epsilon float32, axis int)
- func DecodeFFT(op *Op) (operand *Op, fftType xla_data.FftType, fftLength []int)
- func DecodeReverse(op *Op) (x *Op, axes []int)
- func Div(x0, x1 *Op) (*Op, error)
- func Dot(x0, x1 *Op) (*Op, error)
- func DotGeneral(lhs *Op, lhsContractingAxes, lhsBatchAxes []int, rhs *Op, ...) (*Op, error)
- func DynamicSlice(operand *Op, startIndices []*Op, sliceDims []int) (*Op, error)
- func DynamicUpdateSlice(operand, update *Op, startIndices []*Op) (*Op, error)
- func Equal(x0, x1 *Op) (*Op, error)
- func EqualTotalOrder(x0, x1 *Op) (*Op, error)
- func Erf(x *Op) (*Op, error)
- func Exp(x *Op) (*Op, error)
- func Expm1(x *Op) (*Op, error)
- func FFT(operand *Op, fftType xla_data.FftType, fftLength []int) (*Op, error)
- func Floor(x *Op) (*Op, error)
- func Gather(operand, startIndices *Op, indexVectorAxis int, ...) (*Op, error)
- func GetTupleElement(input *Op, elementIdx int) (*Op, error)
- func GreaterOrEqual(x0, x1 *Op) (*Op, error)
- func GreaterOrEqualTotalOrder(x0, x1 *Op) (*Op, error)
- func GreaterThan(x0, x1 *Op) (*Op, error)
- func GreaterThanTotalOrder(x0, x1 *Op) (*Op, error)
- func Identity(input *Op) *Op
- func Imag(x *Op) (*Op, error)
- func Iota(builder *XlaBuilder, shape Shape, iotaAxis int) (*Op, error)
- func IsFinite(x *Op) (*Op, error)
- func LessOrEqual(x0, x1 *Op) (*Op, error)
- func LessOrEqualTotalOrder(x0, x1 *Op) (*Op, error)
- func LessThan(x0, x1 *Op) (*Op, error)
- func LessThanTotalOrder(x0, x1 *Op) (*Op, error)
- func Log(x *Op) (*Op, error)
- func Log1p(x *Op) (*Op, error)
- func LogicalNot(x *Op) (*Op, error)
- func Logistic(x *Op) (*Op, error)
- func Max(x0, x1 *Op) (*Op, error)
- func Min(x0, x1 *Op) (*Op, error)
- func Mul(x0, x1 *Op) (*Op, error)
- func Neg(x *Op) (*Op, error)
- func NotEqual(x0, x1 *Op) (*Op, error)
- func NotEqualTotalOrder(x0, x1 *Op) (*Op, error)
- func Or(x0, x1 *Op) (*Op, error)
- func Pad(x, fillValue *Op, axesConfig ...PadAxis) (*Op, error)
- func Parameter(builder *XlaBuilder, name string, paramIndex int, shape Shape) (*Op, error)
- func PopulationCount(x *Op) (*Op, error)
- func Pow(x0, x1 *Op) (*Op, error)
- func Real(x *Op) (*Op, error)
- func Reduce(x *Op, reduceComputation *XlaComputation, initialValue *Op, axes ...int) (*Op, error)
- func ReduceAnd(x *Op, axes ...int) (*Op, error)
- func ReduceMax(x *Op, axes ...int) (*Op, error)
- func ReduceMin(x *Op, axes ...int) (*Op, error)
- func ReduceOr(x *Op, axes ...int) (*Op, error)
- func ReduceProduct(x *Op, axes ...int) (*Op, error)
- func ReduceSum(x *Op, axes ...int) (*Op, error)
- func Rem(x0, x1 *Op) (*Op, error)
- func Reshape(x *Op, dimensions ...int) (*Op, error)
- func Reverse(x *Op, axes ...int) (*Op, error)
- func RngBitGenerator(state *Op, shape Shape) (newState, values *Op, err error)
- func Round(x *Op) (*Op, error)
- func Rsqrt(x *Op) (*Op, error)
- func ScalarOne(builder *XlaBuilder, dtype dtypes.DType) (*Op, error)
- func ScalarZero(builder *XlaBuilder, dtype dtypes.DType) (*Op, error)
- func ScatterAdd(operand, scatterIndices, updates *Op, indexVectorAxis int, ...) (*Op, error)
- func ScatterCustom(operand, scatterIndices, updates *Op, updateComputation *XlaComputation, ...) (*Op, error)
- func ScatterMax(operand, scatterIndices, updates *Op, indexVectorAxis int, ...) (*Op, error)
- func ScatterMin(operand, scatterIndices, updates *Op, indexVectorAxis int, ...) (*Op, error)
- func SelectAndScatterCustom(operand, source, defaultValue *Op, ...) (*Op, error)
- func SelectAndScatterMax(operand, source *Op, windowDimensions, windowStrides []int, paddings [][2]int) (*Op, error)
- func SelectAndScatterMin(operand, source *Op, windowDimensions, windowStrides []int, paddings [][2]int) (*Op, error)
- func SelectAndScatterSum(operand, source *Op, windowDimensions, windowStrides []int, paddings [][2]int) (*Op, error)
- func ShiftLeft(x0, x1 *Op) (*Op, error)
- func ShiftRightArithmetic(x0, x1 *Op) (*Op, error)
- func ShiftRightLogical(x0, x1 *Op) (*Op, error)
- func Sign(x *Op) (*Op, error)
- func Sin(x *Op) (*Op, error)
- func Slice(x *Op, starts, limits, strides []int) (*Op, error)
- func SplitTuple(tuple *Op) ([]*Op, error)
- func Sqrt(x *Op) (*Op, error)
- func Sub(x0, x1 *Op) (*Op, error)
- func Tanh(x *Op) (*Op, error)
- func Transpose(x *Op, permutations ...int) (*Op, error)
- func Tuple(inputs ...*Op) (*Op, error)
- func Where(condition, onTrue, onFalse *Op) (*Op, error)
- func While(initialState *Op, condition, body *XlaComputation) (*Op, error)
- func Xor(x0, x1 *Op) (*Op, error)
- type OpType
- type PadAxis
- type ReduceOpType
- type ReduceWindowConfig
- func (r *ReduceWindowConfig) Done() (*Op, error)
- func (r *ReduceWindowConfig) Max() *ReduceWindowConfig
- func (r *ReduceWindowConfig) Min() *ReduceWindowConfig
- func (r *ReduceWindowConfig) Product() *ReduceWindowConfig
- func (r *ReduceWindowConfig) Sum() *ReduceWindowConfig
- func (r *ReduceWindowConfig) UseComputation(reduceComputation *XlaComputation, initialValue *Op) *ReduceWindowConfig
- func (r *ReduceWindowConfig) WithBaseDilations(baseDilations []int) *ReduceWindowConfig
- func (r *ReduceWindowConfig) WithPadding(paddings [][2]int) *ReduceWindowConfig
- func (r *ReduceWindowConfig) WithStrides(strides []int) *ReduceWindowConfig
- func (r *ReduceWindowConfig) WithWindowDilations(windowDilations []int) *ReduceWindowConfig
- type RngAlgorithm
- type Shape
- func DecodeBroadcastInDim(op *Op) (outputShape Shape, broadcastAxes []int)
- func DecodeIota(op *Op) (shape Shape, iotaAxis int)
- func DecodeParameter(paramOp *Op) (name string, paramIndex int, shape Shape)
- func MakeShape(dtype dtypes.DType, dimensions ...int) Shape
- func MakeShapeOrError(dtype dtypes.DType, dimensions ...int) (Shape, error)
- type XlaBuilder
- func (b *XlaBuilder) Build(outputOp *Op) (*XlaComputation, error)
- func (b *XlaBuilder) CreateSubBuilder(computationName string) *XlaBuilder
- func (b *XlaBuilder) Destroy()
- func (b *XlaBuilder) GetReduceComputationAndInitialValue(reduction ReduceOpType, dtype dtypes.DType) (comp *XlaComputation, initialValue *Op, err error)
- func (b *XlaBuilder) GetSelectAndScatterComputation(reduction ReduceOpType, dtype dtypes.DType) (selectComputation, scatterComputation *XlaComputation, err error)
- func (b *XlaBuilder) IsNil() bool
- func (b *XlaBuilder) Name() string
- type XlaComputation
- func (comp *XlaComputation) Destroy()
- func (comp *XlaComputation) HasStableHLO() bool
- func (comp *XlaComputation) IsNil() bool
- func (comp *XlaComputation) Name() string
- func (comp *XlaComputation) SerializedHLO() *cbuffer.CBuffer
- func (comp *XlaComputation) SerializedStableHLO() (*cbuffer.CBuffer, error)
- func (comp *XlaComputation) TextHLO() string
- func (comp *XlaComputation) TextStableHLO() (string, error)
Constants ¶
This section is empty.
Variables ¶
var RngStateShape = MakeShape(dtypes.U64, 3)
Functions ¶
func DecodeArgMinMax ¶
DecodeArgMinMax retrieves the arguments for a ArgMinMax op.
func DecodeBroadcast ¶
DecodeBroadcast retrieves the arguments for a Broadcast op.
func DecodeConcatenate ¶
DecodeConcatenate retrieves the arguments for a Concatenate op.
func DecodeConvGeneralDilated ¶
func DecodeConvGeneralDilated(op *Op) (operand, filter *Op, axes ConvolveAxesConfig, strides []int, paddings [][2]int, inputDilation, filterDilation []int, filterGroupCount, batchGroupCount int)
DecodeConvGeneralDilated retrieves the arguments for the ConvGeneralDilated op.
func DecodeConvertDType ¶
DecodeConvertDType retrieves the arguments for a ConvertDType op.
func DecodeDotGeneral ¶
func DecodeDotGeneral(op *Op) (lhs *Op, lhsContractingAxes, lhsBatchAxes []int, rhs *Op, rhsContractingAxes, rhsBatchAxes []int)
DecodeDotGeneral retrieves the arguments for a DotGeneral op.
func DecodeDynamicSlice ¶ added in v0.2.1
DecodeDynamicSlice retrieves the arguments for the DynamicSlice op.
func DecodeDynamicUpdateSlice ¶ added in v0.2.1
DecodeDynamicUpdateSlice retrieves the arguments for the DynamicUpdateSlice op.
func DecodeGather ¶
func DecodeGather(op *Op) (indexVectorAxis int, offsetAxes, collapsedSliceAxes, startIndexMap, sliceSizes []int, indicesAreSorted bool)
DecodeGather retrieves the arguments for a Gather op.
func DecodeGetTupleElement ¶
DecodeGetTupleElement retrieves the arguments of an GetTupleElement op.
func DecodeReduceWindow ¶
func DecodeReduceWindow(op *Op) (reduceType ReduceOpType, reduceComputation *XlaComputation, initialValue *Op, windowDimensions, strides, baseDilations, windowDilations []int, paddings [][2]int)
DecodeReduceWindow retrieves the arguments for a ReduceWindow op.
func DecodeReshape ¶
DecodeReshape retrieves the arguments for a Reshape op.
func DecodeRngBitGenerator ¶
DecodeRngBitGenerator retrieves the arguments for the FFT op.
func DecodeScatter ¶
func DecodeScatter(op *Op) ( indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool)
DecodeScatter retrieves the arguments for a Scatter (ScatterCustom or ScatterAdd) op.
func DecodeSelectAndScatter ¶
func DecodeSelectAndScatter(op *Op) ( operand, source, defaultValue *Op, selectComputation, scatterComputation *XlaComputation, windowDimensions, windowStrides []int, paddings [][2]int)
DecodeSelectAndScatter retrieves the arguments for a SelectAndScatter (ScatterAndScatterCustom or ScatterAndScatterMax) op.
func DecodeSlice ¶
DecodeSlice retrieves the arguments for a Slice op.
func DecodeTranspose ¶
DecodeTranspose retrieves the arguments for a Transpose op.
func DecodeWhile ¶ added in v0.1.1
func DecodeWhile(op *Op) (initialState *Op, condition, body *XlaComputation)
DecodeWhile retrieves the arguments for the While op.
func HasStableHLO ¶ added in v0.4.3
func HasStableHLO() bool
HasStableHLO returns whether StableHLO support was included in the build -- it's very large, so by default it is not.
Types ¶
type ConvolveAxesConfig ¶
type ConvolveAxesConfig struct {
InputBatch, InputChannel int
InputSpatial []int
KernelInputChannel, KernelOutputChannel int
KernelSpatial []int
OutputBatch, OutputChannel int
OutputSpatial []int
}
ConvolveAxesConfig defines the interpretation of the input/kernel/output tensor axes. There must be the same number of spatial dimensions (axes) for each of the 3 tensors. Input and output has batch and channel axes. Kernel has inputChannel and outputChannel axes.
type Literal ¶
type Literal struct {
// contains filtered or unexported fields
}
Literal defines a constant value for the graph, and is treated as immutable.
Since it's only used to feed the graph value, it doesn't provide any way to inspecting its current values.
func NewArrayLiteral ¶
NewArrayLiteral creates a Literal initialized from the array flat data (a slice) and the dimensions of the array.
func NewArrayLiteralFromAny ¶ added in v0.2.0
NewArrayLiteralFromAny creates a slice Literal with the given dynamically typed flat values and its underlying dimensions. It uses reflection to inspect the type.
func NewLiteralFromShape ¶
NewLiteralFromShape creates a zero-initialized literal with the given shape. It cannot be used to create Literal tuples.
func NewScalarLiteral ¶
NewScalarLiteral creates a scalar Literal initialized with the given value.
func NewScalarLiteralFromAny ¶
NewScalarLiteralFromAny creates a scalar Literal with the given dynamically typed value. It uses reflection to inspect the type.
func NewScalarLiteralFromFloat64 ¶
NewScalarLiteralFromFloat64 creates a scalar Literal with the given dtype initialized from the given value as float64. This can be used to create common constants for arbitrary dtypes.
It returns an error if dtype cannot be converted.
func (*Literal) Data ¶ added in v0.4.3
Data calls accessFn with data pointing to the bytes (C++ allocated) of the literal.
The ownership of the data is maintained by the Literal, and it is guaranteed to be live (not GC'ed) until the end of the current function call, at least.
func (*Literal) Destroy ¶
func (l *Literal) Destroy()
Destroy the Literal, release resources, and the Literal is no longer valid. This is automatically called if the Literal is garbage collected.
type Op ¶
type Op struct { // Type is an OpType enum. Type OpType // Shape of the result of this Op. Shape Shape // UserPayload allows the user to add any type of meta-data. XlaBuilder simply ignores it. // Typically, extensions like github.com/gomlx/autodiff will cast UserPayload to the interfaces that matter to // them. UserPayload any // ReduceType is informative only. For some ops (ReduceMax, ScatterAdd, etc.) it informs what kind of // standard computations were used (set in ComputationArg). ReduceType ReduceOpType // OpInputs are the inputs that are generated by other ops, these are the edges on the computation graph. // Other inputs are "static", meaning they are independent of the values during the calculation. OpInputs []*Op // Index to other nodes that are used as inputs. LiteralArg *Literal // If a LiteralArg (constant) is involved in the operation. IntArg int // Used for any static integer inputs. StrArg string // Used for any static string argument. IntsArg []int // List of integer numbers. FloatArg float32 // For a float parameter. ShapeArg Shape // For Ops that require a shape parameter. ComputationArg, SecondComputationArg *XlaComputation // For Ops that require a sub-computation(s). // contains filtered or unexported fields }
Op holds information about an Op that is part of a computation being built with an XlaBuilder.
Each operation (e.g: Add, Mul) will return an Op that represents both the operation itself and the output of that operation, which can be used as input of another.
While the public fields can be introspected, they shouldn't be changed, except of UserPayload.
func Abs ¶
Abs returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func Add ¶
Add returns the element-wise sum of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.
func And ¶
And returns the element-wise logic "and" operator. The op is created on the same XlaBuilder as used for x0 and x1.
func ArgMinMax ¶
ArgMinMax calculates the "argmin" or "argmax" across an axis of the given input array x. outputDType defines the output of the argmin/argmax, it doesn't need to be the same as the input.
It's a form of reduction on the given axis, and that axis goes away. So the rank of the result is one less than the rank of x.
Examples:
ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=1, isMin=true) -> {1, 0} // (it chooses the 0 and the -3) ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=0, isMin=false) -> {0, 1, 0} // (it choose the 2, 4 and 7)
func BatchNormForInference ¶ added in v0.2.0
func BatchNormForInference(operand, scale, offset, mean, variance *Op, epsilon float32, axis int) (*Op, error)
BatchNormForInference implements Batch Norm for inference. See details in https://www.tensorflow.org/xla/operation_semantics#batchnorminference
Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
func BatchNormForTraining ¶ added in v0.2.0
func BatchNormForTraining(operand, scale, offset *Op, epsilon float32, axis int) (normalized, batchMean, batchVariance *Op, err error)
BatchNormForTraining implements Batch Norm for training. See details in https://www.tensorflow.org/xla/operation_semantics#batchnormtraining
It returns the normalized tensor, the batchMean and the batchVariance.
Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
func BatchNormGradient ¶ added in v0.2.0
func BatchNormGradient(operand, scale, mean, variance, gradOutput *Op, epsilon float32, axis int) (gradOperand, gradScale, gradOffset *Op, err error)
BatchNormGradient calculates the BatchNorm gradient. See details in https://openxla.org/xla/operation_semantics#batchnormgrad
The gradOutput is the adjoint gradient, that is, the gradient with respect to the output of the batch normalization.
It returns as a tuple with the 3 elements.
Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
func Broadcast ¶
Broadcast prefixes dimensions to an array by duplicating the data in the array. See BroadcastInDim for a broadcast in between the axes.
The new dimensions dims are inserted on the left, i.e., if prefixDims has values `{a0, ..., aN}` and the operand shape has dimensions {b0, ..., bM} then the shape of the output has dimensions {a0, ..., aN, b0, ..., bM}.
The new dimensions id into copies of the operand, i.e.
output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
func BroadcastInDim ¶
BroadcastInDim broadcasts x to an output with the given shape. broadcastAxes has an output axes value for each x axes (len(broadcastAxes) == x.Shape.Rank()). The i-th axis of x is mapped to the broadcastAxes[i]-th dimension of the output. broadcastAxes must be also increasing: this operation cannot be used to transpose axes, it will only broadcast and introduce new axes in-between.
This also requires that the i-th input axis is either 1 or is the same as the output dimension it's broadcasting into.
For example, say operand `x = (s32)[2]{1, 2}`; outputShape = `(s32)[2,2]`:
- Specifying []int{1} as broadcastAxes will generate output {{1, 2}, {1, 2}}
- On the other hand, specifying []int{0} as broadcastAxes will generate output {{1 , 1}, {2 , 2}}
func Call ¶
func Call(builder *XlaBuilder, subComputation *XlaComputation, operands ...*Op) (*Op, error)
Call will evaluate a subComputation with the given operands. The given subComputation must have been created with a sub-builder (see XlaBuilder.CreateSubBuilder) of the given builder.
func Ceil ¶
Ceil returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func Clz ¶
Clz returns element-wise the "count leading zeros" bits of input node x -- for integer values. The op is created on the same XlaBuilder as used for x.
func Complex ¶
Complex returns the complex number taking x0 as the real part and x1 as the imaginary part. The real (x0) and imaginary (x1) must have the same dtype, and they must be either `dtypes.Float32` or `dtypes.Float64`. The output will be either `dtypes.Complex64` or `dtypes.Complex128`, depending on x0 and x1 dtypes. The shapes of `real` or `imaginary` must be the same, or one must be a scalar, in which case the value is broadcast to every other value. The op is created on the same XlaBuilder as used for x0 and x1.
func Concatenate ¶
Concatenate results on the given axis.
All axes that are not being concatenated must match dimensions. It doesn't work with scalars -- use ExpandDims.
If there is only one operand, it is returned and this is a no-op.
func Conj ¶
Conj returns the conjugate of a complex number. E.g: Conj(1+3i) = 1-3i The op is created on the same XlaBuilder as used for x.
func Constant ¶
func Constant(builder *XlaBuilder, x *Literal) (*Op, error)
Constant introduces an Op
func ConvGeneralDilated ¶
func ConvGeneralDilated(operand, filter *Op, axes ConvolveAxesConfig, strides []int, paddings [][2]int, inputDilation, filterDilation []int, filterGroupCount, batchGroupCount int) (*Op, error)
ConvGeneralDilated is a generic Convolution operation offered by XLA. featureAxisAfter defines whether the features (aka. channels or depth) axis comes after the spatial dimension. Example: a 2D input can be one of the two:
- featureAxisAfter=false: input=[batch_size, features, height, width], filter=[output_features, input_features, height, width]
- featureAxisAfter=true: input=[batch_size, height, width, features], filter=[output_features, height, width, input_features]
Some details in https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution. There operand and filter are called lhs and rhs. (XLA documentation is unfortunately poor, much is guess-work). Also useful, https://arxiv.org/pdf/1603.07285v1.pdf.
func ConvertDType ¶
ConvertDType of x to dtype.
func Cos ¶
Cos returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func DecodeBatchNormForInference ¶ added in v0.2.0
func DecodeBatchNormForInference(op *Op) (operand, scale, offset, mean, variance *Op, epsilon float32, axis int)
DecodeBatchNormForInference retrieves the arguments for the BatchNormForInference op.
func DecodeBatchNormForTraining ¶ added in v0.2.0
DecodeBatchNormForTraining retrieves the arguments for the BatchNormForTraining op.
func DecodeBatchNormGrad ¶
func DecodeBatchNormGrad(op *Op) (operand, scale, mean, variance, gradOutput *Op, epsilon float32, axis int)
DecodeBatchNormGrad retrieves the arguments for the BatchNormGradient op.
func DecodeReverse ¶
DecodeReverse retrieves the arguments for the Reverse op.
func Div ¶
Div returns the element-wise division of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.
func Dot ¶
Dot returns the "dot product" operation. The exact semantics of this operation depend on the ranks of the operands:
| Input | Output | Semantics | | vector [n] dot vector [n] | scalar | vector dot product | | matrix [m x k] dot vector [k] | vector [m] matrix-vector multiplication | | matrix [m x k] dot matrix [k x n] | matrix [m x n] | matrix-matrix multiplication |
The operation performs sum of products over the second dimension of x0 (or the first if it has rank 1) and the first dimension of x1. These are the "contracted" dimensions. The contracted dimensions of x0 and x1 must be of the same size. In practice, it can be used to perform dot products between vectors, vector/matrix multiplications or matrix/matrix multiplications. The op is created on the same XlaBuilder as used for x0 and x1.
func DotGeneral ¶
func DotGeneral(lhs *Op, lhsContractingAxes, lhsBatchAxes []int, rhs *Op, rhsContractingAxes, rhsBatchAxes []int) (*Op, error)
DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications for a general vector product -- a generalized "Einsum". Each axis can be:
- Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions must match in lhs and rhs.
- Crossed (default), in which case the output is the combination (concatenation) of the dimensions.
- Contracted (contracting axes), where the output does multiply the values and reduce sum those dimensions.
It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension.
It provides the basic means of implementing Einsum.
func DynamicSlice ¶ added in v0.2.1
DynamicSlice extracts a sub-array from the input array at dynamic start_indices. The size of the slice in each axis is passed in sliceDims, which specify the slice intervals for each axis: [start, start + size). The shape of startIndices must be rank == 1, with dimension size equal to the rank of operand.
See description in https://openxla.org/xla/operation_semantics#dynamicslice
func DynamicUpdateSlice ¶ added in v0.2.1
DynamicUpdateSlice generates a result which is the value of the input array operand, with a slice update overwritten at startIndices. The shape of update determines the shape of the sub-array of the result which is updated. The shape of startIndices must be rank == 1, with dimension size equal to the rank of operand.
See description in https://openxla.org/xla/operation_semantics#dynamicupdateslice
func Equal ¶
Equal performs element-wise equality check, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.
func EqualTotalOrder ¶
EqualTotalOrder returns the element-wise operation.
Standard broadcasting rules apply (see documentation).
The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.
func Erf ¶ added in v0.4.0
Erf returns the "error function", defined as erf(x) = 2/Pi * \int_{0}^{x}{e^{-t^2}dt}. The op is created on the same XlaBuilder as used for x.
func Exp ¶
Exp returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func Expm1 ¶
Expm1 returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func FFT ¶
FFT calls the XLA FFT operation, which implements {Forward, Inverse} x {Complex, Real} versions.
See documentation in https://www.tensorflow.org/xla/operation_semantics. Underlying, CPU FFT is backed by Eigen's TensorFFT and GPU FFT uses cuFFT.
func Floor ¶
Floor returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func Gather ¶
func Gather(operand, startIndices *Op, indexVectorAxis int, offsetAxes, collapsedSliceAxes, startIndexMap, sliceSizes []int, indicesAreSorted bool) (*Op, error)
Gather is a powerful but cumbersome Gather operation offered by XLA. Full details in https://www.tensorflow.org/xla/operation_semantics#gather. (Warning: it's poorly described, with many undefined terms)
Arguments:
- startIndices: are the indices we want to gather. There will be one axis with which enumerates the indices in the operand array, typically the last one. All other axes are "batch dimensions" and they will have equivalent axes in the output.
- indexVectorAxis: typically the last axis of startIndices, so startIndices.Shape.Rank()-1. Usually, one has the dimension of the indexVectorAxis equal to the full rank of the operand. That is: startIndices.Shape.Dimensions[indexVectorAxis] = operand.Shape.Rank() Lets call "one index vector" a value of startIndices formed by a slice across indexVectorAxis.
- startIndexMap: for each "index vector" from startIndices, this maps each element of the vector goes to which axes of the operand. Typically, this is [0, 1, 2, ..., operand.Shape.Rank()-1], that is, each "index vector" fully defines an element on the operand. If one is gathering slices of the operand (as opposed to individual values), one can skip some of those axes from startIndexMap, and the index for those axis is considered 0, and set sliceSizes to take the slice one wants (typically the full slice).
- sliceSizes: the "index vector" described above points to the data in the operand to be gathered. Then sliceSizes indicates how much data to gather. One value per axis of the operand must be set. For gathering individual values, set these all to 1.
- collapsedSliceAxes: the slice gathered for each "index vector" (with sizes sliceSizes), often has dimension one for most (or all, in case of gathering individual items) axes. collapsedSliceAxes allows one to collapse those axes, so they don't show up in the output. Usually, collapse all axes that are size one. These are axes within the rank of operand (from 0 to operand.Shape.Rank()-1).
- offsetAxes: for those gathered slices not collapsed (with collapsedSliceAxes), this maps them to a position in the output array. Typically, these will be consecutive numbers starting with indexVectorAxis. So, the output will have the same prefix shape (the "batch dimensions") as the startIndices array, and the suffix shape will be the gathered slices mapped to these `offsetAxes`. There must be one value per axis not collapsed with collapsedSliceAxes -- the value itself is an axis in the output shape.
func GetTupleElement ¶
GetTupleElement extracts one element from a Tuple.
func GreaterOrEqual ¶
GreaterOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.
func GreaterOrEqualTotalOrder ¶
GreaterOrEqualTotalOrder returns the element-wise operation.
Standard broadcasting rules apply (see documentation).
The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.
func GreaterThan ¶
GreaterThan performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.
func GreaterThanTotalOrder ¶
GreaterThanTotalOrder returns the element-wise operation.
Standard broadcasting rules apply (see documentation).
The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.
func Identity ¶
Identity returns an Op whose output is the same as its input.
It's a no-op that is not registered with the C++ XlaBuilder, it's simply serves as a place-holder for some arbitrary meta-data the user may want to include in the UserPayload field.
func Imag ¶
Imag returns the imaginary part of a complex number. It returns 0 if the x is a float number. The op is created on the same XlaBuilder as used for x.
func Iota ¶
func Iota(builder *XlaBuilder, shape Shape, iotaAxis int) (*Op, error)
Iota creates a constant of the given shape with increasing numbers (starting from 0) on the given axis. So Iota([2,2], 1) returns [[0 1][0 1]], while Iota([2,2], 0) returns [[0 0][1 1]].
func IsFinite ¶ added in v0.4.2
IsFinite tests whether each element of operand is finite, i.e., is not positive or negative infinity, and is not NaN. It returns an array of boolean values with the same shape as the input, where each element is true if and only if the corresponding input element is finite. The op is created on the same XlaBuilder as used for x.
func LessOrEqual ¶
LessOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.
func LessOrEqualTotalOrder ¶
LessOrEqualTotalOrder returns the element-wise operation.
Standard broadcasting rules apply (see documentation).
The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.
func LessThan ¶
LessThan performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.
func LessThanTotalOrder ¶
LessThanTotalOrder returns the element-wise operation.
Standard broadcasting rules apply (see documentation).
The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.
func Log ¶
Log returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func Log1p ¶
Log1p returns the expression log(x+1). The op is created on the same XlaBuilder as used for x.
func LogicalNot ¶
LogicalNot returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func Logistic ¶
Logistic returns the element-wise expression 1/(1+exp(-x)). Also known as the Sigmoid function. The op is created on the same XlaBuilder as used for x.
func Max ¶
Max returns the element-wise highest value among the two. The op is created on the same XlaBuilder as used for x0 and x1.
func Min ¶
Min returns the element-wise smallest value among the two. The op is created on the same XlaBuilder as used for x0 and x1.
func Mul ¶
Mul returns the element-wise multiplication of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.
func Neg ¶
Neg returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func NotEqual ¶
NotEqual performs element-wise inequality check, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.
func NotEqualTotalOrder ¶
NotEqualTotalOrder returns the element-wise operation.
Standard broadcasting rules apply (see documentation).
The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.
func Or ¶
Or returns the element-wise logic "and" operator. The op is created on the same XlaBuilder as used for x0 and x1.
func Pad ¶
Pad injects padding on the start, end or interior (in between each element) of the given operand. There must be at most `operand.Rank()` axesConfig values. Missing PadAxis are assumed to be zeros, that is, no padding for those axes.
func Parameter ¶
Parameter creates a "retrieves a parameter value" op in builder.
The name is cosmetic, but should be unique among the parameters.
The paramIndex must be carefully set to match the parameters fed to the computation during execution and after it is compiled (see package pjrt for that).
The shape of the parameter must be given -- and match the value given during execution.
func PopulationCount ¶ added in v0.4.2
PopulationCount computes the number of bits set in each element of operand. The op is created on the same XlaBuilder as used for x.
func Pow ¶
Pow returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x0 and x1.
func Real ¶
Real return the real part of a complex number. It returns x if the x is a float number. The op is created on the same XlaBuilder as used for x.
func Reduce ¶
Reduce the selected axes of the input array, using the given custom reduceComputation. The initialValue should be a scalar value that the reduction starts with.
If no axes are given, it reduces the full array.
Consider instead using one of the standard ReduceSum, ReduceProduct, ReduceMax and ReduceMin. They use cached values for both the corresponding reduceComputation and initialValue, per dtype of x.
func ReduceAnd ¶ added in v0.3.2
ReduceAnd is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the logical-and of the reduced axes. It only works for booleans.
If no axes are given, it reduces the full array.
func ReduceMax ¶
ReduceMax is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the max value.
If no axes are given, it reduces the full array.
func ReduceMin ¶
ReduceMin is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the min value.
If no axes are given, it reduces the full array.
func ReduceOr ¶ added in v0.3.2
ReduceOr is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the logical-or of the reduced axes. It only works for booleans.
If no axes are given, it reduces the full array.
func ReduceProduct ¶
ReduceProduct is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the product of the reduced axes.
If no axes are given, it reduces the full array.
func ReduceSum ¶
ReduceSum is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the sum of the reduced axes.
If no axes are given, it reduces the full array.
func Rem ¶
Rem returns the remainder operation, also known as modulo (or Mod for short). Notice despite the name XLA implements Mod not IEEE754 Remainder operation. The op is created on the same XlaBuilder as used for x0 and x1.
func Reshape ¶
Reshape reshapes x to the new dimensions. Total size cannot change, it's just a "reinterpretation" of the same flat data.
The dtype remains the same, see ConvertDType to actually convert the values.
func Reverse ¶
Reverse returns x with the values for the given dimensions reversed, that is, the value indexed at `i` will be swapped with the value at indexed `(dimension_size - 1 - i)`. The shape remains the same.
func RngBitGenerator ¶
RngBitGenerator generates the given shape filled with random bits. It takes as input the current random number generator (RNG) state, see RngState or RngStateFromSeed. The algorithm is hard-coded to use Philox algorithm for now.
The state should be `[3]uint64` for Philox, see https://openxla.org/xla/operation_semantics#rngbitgenerator.
It returns the new state of the RNG and the generated values (with random bits) with the given shape.
func Round ¶
Round returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func Rsqrt ¶
Rsqrt returns the element-wise reciprocal of square root operation 1/sqrt(x). The op is created on the same XlaBuilder as used for x.
func ScalarOne ¶
func ScalarOne(builder *XlaBuilder, dtype dtypes.DType) (*Op, error)
ScalarOne returns a one (1) constant for the given dtype. It caches the constant, so it doesn't get defined multiple times.
func ScalarZero ¶
func ScalarZero(builder *XlaBuilder, dtype dtypes.DType) (*Op, error)
ScalarZero returns a zero constant for the given dtype. It caches the constant, so it doesn't get defined multiple times.
func ScatterAdd ¶
func ScatterAdd(operand, scatterIndices, updates *Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (*Op, error)
ScatterAdd values from updates pointed by scatterIndices to operand. Details in ScatterCustom, which is used with the updateComputation set to Sum.
func ScatterCustom ¶
func ScatterCustom(operand, scatterIndices, updates *Op, updateComputation *XlaComputation, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (*Op, error)
ScatterCustom is a powerful but cumbersome Scatter operation offered by XLA. Full details in https://www.tensorflow.org/xla/operation_semantics#scatter.
It takes a custom updateComputation used when scattering values. See ScatterAdd for a version that adds the values when scattering.
func ScatterMax ¶
func ScatterMax(operand, scatterIndices, updates *Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (*Op, error)
ScatterMax scatter values from updates pointed by scatterIndices to operand, by taking the Max. Details in ScatterCustom, which is used with the updateComputation set to Max.
func ScatterMin ¶
func ScatterMin(operand, scatterIndices, updates *Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (*Op, error)
ScatterMin scatter values from updates pointed by scatterIndices to operand, by taking the Min. Details in ScatterCustom, which is used with the updateComputation set to Min.
func SelectAndScatterCustom ¶
func SelectAndScatterCustom(operand, source, defaultValue *Op, selectComputation, scatterComputation *XlaComputation, windowDimensions, windowStrides []int, paddings [][2]int) (*Op, error)
SelectAndScatterCustom runs windows (similar to ReduceWindow) over the operand, selects values (selectComputation) to updates the output (like Scatter) using the scatterComputation with values from source. The output is initialized with defaultValue. See details in https://openxla.org/xla/operation_semantics#selectandscatter
func SelectAndScatterMax ¶
func SelectAndScatterMax(operand, source *Op, windowDimensions, windowStrides []int, paddings [][2]int) (*Op, error)
SelectAndScatterMax calls SelectAndScatterCustom with a zero defaultValue, sum for updateComputation and an appropriate selectComputation to implement a SelectAndScatter that updates the max value in the windows. Details in SelectAndScatterCustom.
func SelectAndScatterMin ¶
func SelectAndScatterMin(operand, source *Op, windowDimensions, windowStrides []int, paddings [][2]int) (*Op, error)
SelectAndScatterMin calls SelectAndScatterCustom with a zero defaultValue, sum for updateComputation and an appropriate selectComputation to implement a SelectAndScatter that updates the max value in the windows. Details in SelectAndScatterCustom.
func SelectAndScatterSum ¶
func SelectAndScatterSum(operand, source *Op, windowDimensions, windowStrides []int, paddings [][2]int) (*Op, error)
SelectAndScatterSum calls SelectAndScatterCustom with a zero defaultValue, sum for updateComputation and a selectComputation that always selects. to implement a SelectAndScatter that updates the max value in the windows. Details in SelectAndScatterCustom.
func ShiftLeft ¶ added in v0.5.1
ShiftLeft n bits, preserving the sign. So ShiftLeft(-1, 1) = -2. The op is created on the same XlaBuilder as used for x0 and x1.
func ShiftRightArithmetic ¶ added in v0.5.1
ShiftRightArithmetic shifts right by n bits, preserving the sign bit. So ShiftRight(-2, 1) = -1. The op is created on the same XlaBuilder as used for x0 and x1.
func ShiftRightLogical ¶ added in v0.5.1
ShiftRightLogical shifts right by n bits, destroying the sign bit. The op is created on the same XlaBuilder as used for x0 and x1.
func Sign ¶
Sign returns element-wise +1, +/-0 or -1 depending on the sign of x. It returns NaN if the input is NaN. The op is created on the same XlaBuilder as used for x.
func Sin ¶
Sin returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func Slice ¶
Slice extracts a sub-array from the input array. The sub-array is of the same rank as the input and contains the values inside a bounding box within the input array where the dimensions and indices of the bounding box are given as arguments to the slice operation.
The strides set the input stride of the slice in each axis and must be >= 1. It is optional, and if missing it is assumed to be 1 for every dimension.
Examples:
Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={4}, strides=nil) -> {2, 3} Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={5}, strides={2}) -> {2, 4}
func SplitTuple ¶
SplitTuple is a convenience wrapper around GetTupleElement, it will return an array with all the nodes.
func Sqrt ¶
Sqrt returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func Sub ¶
Sub returns the element-wise subtraction of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.
func Tanh ¶
Tanh returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x.
func Transpose ¶
Transpose axes of x. There should be one value in permutations for each axis in x. The output will have: output.Shape.Dimension[ii] = x.Shape.Dimension[permutations[i]].
func Tuple ¶
Tuple organizes multiple nodes in one tuple-node.
This is particularly useful to get multiple outputs to a computation.
func Where ¶
Where takes element-wise values from onTrue or onFalse depending on the value of condition (expected to be boolean).
func While ¶ added in v0.1.1
func While(initialState *Op, condition, body *XlaComputation) (*Op, error)
While executes a loop in the computation.
It takes as input:
- initialState: usually a tuple, that includes all variables used by condition and body.
- condition: a sub-computation (see XlaBuilder.CreateSubBuilder) takes the current state as input and outputs a bool (dtypes.PRED) whether the loop should keep iterating.
- body: a sub-computation (see XlaBuilder.CreateSubBuilder) takes the current state as input and outputs an updated state.
See details in https://openxla.org/xla/operation_semantics#while
func Xor ¶
Xor returns the element-wise logic "and" operator. The op is created on the same XlaBuilder as used for x0 and x1.
func (*Op) Builder ¶ added in v0.2.0
func (op *Op) Builder() *XlaBuilder
Builder returns the XlaBuilder associated with this Op.
type OpType ¶
type OpType int32
OpType enumerates the various operation types supported by XLA.
const ( InvalidOp OpType = iota ParameterOp IotaOp ConstantOp IdentityOp ConvertDTypeOp WhereOp TupleOp GetTupleElementOp ReshapeOp BroadcastOp BroadcastInDimOp TransposeOp CallOp ReduceOp ReduceWindowOp ConcatenateOp SliceOp ArgMinMaxOp PadOp GatherOp ScatterOp SelectAndScatterOp ConvGeneralDilatedOp ReverseOp DotGeneralOp FftOp BatchNormTrainingOp BatchNormInferenceOp BatchNormGradOp RngBitGeneratorOp WhileOp AbsOp NegOp ExpOp Expm1Op FloorOp CeilOp RoundOp LogOp Log1pOp LogicalNotOp LogisticOp SignOp ClzOp CosOp SinOp TanhOp SqrtOp RsqrtOp ImagOp RealOp ConjOp AddOp MulOp SubOp DivOp RemOp AndOp OrOp XorOp DotOp MinOp MaxOp PowOp ComplexOp EqualOp NotEqualOp GreaterOrEqualOp GreaterThanOp LessOrEqualOp LessThanOp EqualTotalOrderOp NotEqualTotalOrderOp GreaterOrEqualTotalOrderOp GreaterThanTotalOrderOp LessOrEqualTotalOrderOp LessThanTotalOrderOp DynamicSliceOp DynamicUpdateSliceOp ErfOp IsFiniteOp PopulationCountOp ShiftLeftOp ShiftRightArithmeticOp ShiftRightLogicalOp )
type PadAxis ¶
type PadAxis struct {
Start, End, Interior int
}
PadAxis defines the amount of padding preceding one axis (Start), at the end of axis (End) or in between the inputs (Interior). This is used as a parameter for the Pad operation.
type ReduceOpType ¶
type ReduceOpType int
ReduceOpType select among the basic types of reduction supported, see XlaBuilder.ReduceComputation.
const ( // UndefinedReduceType is an undefined value. UndefinedReduceType ReduceOpType = iota // ReduceSumType reduces by summing all elements being reduced. ReduceSumType // ReduceProductType reduces by multiplying all elements being reduced. ReduceProductType // ReduceMaxType reduces by taking the maximum value. ReduceMaxType // ReduceMinType reduces by taking the minimum value. ReduceMinType // ReduceAndType reduces by taking the logical-and value. ReduceAndType // ReduceOrType reduces by taking the logical-or value. ReduceOrType )
func (ReduceOpType) String ¶
func (i ReduceOpType) String() string
type ReduceWindowConfig ¶
type ReduceWindowConfig struct {
// contains filtered or unexported fields
}
func ReduceWindow ¶
func ReduceWindow(x *Op, windowDimensions []int) *ReduceWindowConfig
ReduceWindow applies a reduction function to all elements in each window of x, producing an N multidimensional array as output. The output array has the same number of elements as the number of valid positions of the window.
A pooling layer (typical in image processing) can be expressed as a ReduceWindow.
x is the array to reduce, it cannot be a scalar. And windowDimensions is the size of the windows on which to reduce: they must be set for each axis of x.
There are other options, so this uses the "builder pattern": it returns a ReduceWindowConfig object that can be further configured. When finished call the ReduceWindowConfig.Done to trigger its execution.
More details and examples can be takes from OpenXLA site: https://openxla.org/xla/operation_semantics#reducewindow
func (*ReduceWindowConfig) Done ¶
func (r *ReduceWindowConfig) Done() (*Op, error)
Done executes the ReduceWindow and returns the corresponding Op, or an error.
func (*ReduceWindowConfig) Max ¶
func (r *ReduceWindowConfig) Max() *ReduceWindowConfig
Max configures the reduction type.
There is no defaults for the type of reduction: one has to either configure Max, Min, Sum, Product or some arbitrary computation with UseComputation.
func (*ReduceWindowConfig) Min ¶
func (r *ReduceWindowConfig) Min() *ReduceWindowConfig
Min configures the reduction type.
There is no defaults for the type of reduction: one has to either configure Max, Min, Sum, Product or some arbitrary computation with UseComputation.
func (*ReduceWindowConfig) Product ¶
func (r *ReduceWindowConfig) Product() *ReduceWindowConfig
Product configures the reduction type.
There is no defaults for the type of reduction: one has to either configure Max, Min, Sum, Product or some arbitrary computation with UseComputation.
func (*ReduceWindowConfig) Sum ¶
func (r *ReduceWindowConfig) Sum() *ReduceWindowConfig
Sum configures the reduction type.
There is no defaults for the type of reduction: one has to either configure Max, Min, Sum, Product or some arbitrary computation with UseComputation.
func (*ReduceWindowConfig) UseComputation ¶
func (r *ReduceWindowConfig) UseComputation(reduceComputation *XlaComputation, initialValue *Op) *ReduceWindowConfig
UseComputation configures a custom reduction function and initial value.
reduceComputation must take two scalars of x dtype as input, and return a scalar as the output. The initialValue must be a scalar of the same dtype (typically a Constant, but it can be the result of another operation).
There is no defaults for the type of reduction: one has to either configure Max, Min, Sum, Product or some arbitrary computation with UseComputation.
func (*ReduceWindowConfig) WithBaseDilations ¶
func (r *ReduceWindowConfig) WithBaseDilations(baseDilations []int) *ReduceWindowConfig
WithBaseDilations provides the base dilation for each axis of x. Either nil or one value per axis of x must be given.
The default is 0 for every axis.
func (*ReduceWindowConfig) WithPadding ¶
func (r *ReduceWindowConfig) WithPadding(paddings [][2]int) *ReduceWindowConfig
WithPadding provides the amount of padding on the start and end of each axis of x. Either nil or one value per axis of x must be given.
The default is (0, 0) for every axis.
func (*ReduceWindowConfig) WithStrides ¶
func (r *ReduceWindowConfig) WithStrides(strides []int) *ReduceWindowConfig
WithStrides provides the stride size for each axis of x. Either nil or one value per axis of x must be given.
The default is same value as windowDimensions.
func (*ReduceWindowConfig) WithWindowDilations ¶
func (r *ReduceWindowConfig) WithWindowDilations(windowDilations []int) *ReduceWindowConfig
WithWindowDilations provides the window dilation for each axis of x. Either nil or one value per axis of x must be given.
The default is 1 for every axis.
type RngAlgorithm ¶ added in v0.4.6
type RngAlgorithm int
RngAlgorithm is an enum of the types of algorithms supported by XLA. We use Philox for now.
const ( RngAlgorithmDefault RngAlgorithm = 0 RngAlgorithmThreeFry RngAlgorithm = 1 RngAlgorithmPhilox RngAlgorithm = 2 )
type Shape ¶
type Shape struct { DType dtypes.DType Dimensions []int TupleShapes []Shape // Shapes of the tuple, if this is a tuple. }
Shape is a minimalistic shape representation of a tensor. It is used to describe the output of an Op, or as an input for operations that change the Shape of another Op, or part of a Literal value.
It is defined as a DType (the underlying data type, e.g.: Float32, Int64, etc.) and the dimensions on each axis of the tensor. If len(Dimensions) is 0, it represents a scalar.
Alternatively, in XLA, a value can represent a "tuple" of sub-values. In this case Shape.TupleShapes is defined with the shapes of its sub-values -- it is a recursive structure. In this case DType is set to InvalidDType, and the shape doesn't have a value of itself.
func DecodeBroadcastInDim ¶
DecodeBroadcastInDim retrieves the arguments for a BroadcastInDim op.
func DecodeIota ¶
DecodeIota retrieves the arguments of an Iota op.
func DecodeParameter ¶
DecodeParameter extracts the arguments to the Parameter call that created the op.
func MakeShape ¶
MakeShape filled with the values given.
The dimensions must be >= 1, and it doesn't work for tuple shapes.
func MakeShapeOrError ¶ added in v0.3.0
MakeShapeOrError is the same as MakeShape, but it returns an error instead if the dimensions are <= 0.
func (Shape) Clone ¶ added in v0.1.1
Clone makes a deep copy (including dimensions and tuples) of the given shape.
func (Shape) IsScalar ¶
IsScalar returns whether the Shape is a scalar, i.e. its len(Shape.Dimensions) == 0.
func (Shape) Memory ¶ added in v0.2.0
Memory returns the memory used to store an array of the given shape, the same as the size in bytes. Careful, so far all types in Go and on device seem to use the same sizes, but future type this is not guaranteed.
func (Shape) Rank ¶
Rank of a shape is the number of axes. A shortcut to len(Shape.Dimensions). Scalar values have rank 0.
func (Shape) Size ¶
Size returns the total size of the shape. E.g.: a Shape of dimensions [3, 5] has size 15. A scalar has size 1.
type XlaBuilder ¶
type XlaBuilder struct {
// contains filtered or unexported fields
}
XlaBuilder is used to create "computations" (XlaComputation), that are like "StableHLO" functions.
In turn XlaComputation can be exported to a serialized `HloModuleProto` (a binary blob) and used by a PJRT plugin (see github.com/gomlx/gopjrt/pjrt package) to compile and execute on accelerators.
Once created (New), one can issue "operations" ("ops" for short), like "Add", "Mul", etc, which are recorded. When the computation definition is finalized, call "XlaBuilder.Build" to get the XlaComputation representing the function built. The XlaComputation can then be used with PJRT (see XlaComputation.SerializedHLO), or pretty (+/-, relatively speaking) print (text, HTML, graphviz, etc). See XlaComputation documentation.
Once done (usually, just after StableHLO is called) deallocate the underlying C++ resources by calling Destroy.
Some observations:
- The XlaBuilder is used by all ops creating functions (like "Add", "Mul", etc.). But since the input of most ops, are other created ops, and they hold a link to the XlaBuilder, there is no need to explicitly pass the XlaBuilder to every op function.
func New ¶
func New(name string) *XlaBuilder
New create a new XlaBuilder with the given name, that can be used to create a new StableHLO program. See details on how to use it on XlaBuilder.
func (*XlaBuilder) Build ¶
func (b *XlaBuilder) Build(outputOp *Op) (*XlaComputation, error)
Build builds the computation (*XlaComputation) with the requested operations (the outputOp and all its dependencies) or returns a non-ok status.
Note that all ops that have been enqueued will be moved to the computation being returned and will no longer be valid.
func (*XlaBuilder) CreateSubBuilder ¶
func (b *XlaBuilder) CreateSubBuilder(computationName string) *XlaBuilder
CreateSubBuilder returns a new XlaBuilder whose resultant Computation is used only by this XlaBuilder.
Some operations, like Call and Reduce, take as input a sub-computation (the reduction function), that can be created with a sub-builder.
It takes as input the computationName that is going to be built with it.
func (*XlaBuilder) Destroy ¶
func (b *XlaBuilder) Destroy()
Destroy and free the underlying C++ object. It can be called more than once -- once finalized the first time, it becomes a no-op.
It is called at garbage-collection automatically.
func (*XlaBuilder) GetReduceComputationAndInitialValue ¶
func (b *XlaBuilder) GetReduceComputationAndInitialValue(reduction ReduceOpType, dtype dtypes.DType) (comp *XlaComputation, initialValue *Op, err error)
GetReduceComputationAndInitialValue builds or returns a cached computation that implements a reduction function with one of the standard ReduceOpType: sum, multiply, max or min.
func (*XlaBuilder) GetSelectAndScatterComputation ¶
func (b *XlaBuilder) GetSelectAndScatterComputation(reduction ReduceOpType, dtype dtypes.DType) (selectComputation, scatterComputation *XlaComputation, err error)
GetSelectAndScatterComputation builds or returns a cached computation that implements a select and scatter functions with one of the standard ReduceOpType: sum, multiply, max or min. This is used for SelectAndScatter family of operations.
func (*XlaBuilder) IsNil ¶
func (b *XlaBuilder) IsNil() bool
IsNil returns true if either b is nil or the contained C++ XlaBuilder. Usually true for destroyed XlaBuilder objects.
func (*XlaBuilder) Name ¶
func (b *XlaBuilder) Name() string
Name returns the name after it was canonicalized by the XlaBuilder library -- so it may be different from the one given.
type XlaComputation ¶
type XlaComputation struct {
// contains filtered or unexported fields
}
XlaComputation represents a computation created with XlaBuilder.
It can be used as is by pjrt.Client.Compile or serialized (to be saved) with XlaComputation.SerializedHLO.
It is also used as a "subroutine" for other XlaBuilder ops, like Reduce, which takes the computation to use for reduction.
To print the contents of the HloModuleProto, the github.com/openxla/xla repository offers a small utility called `run_hlo_module`. Follow the XLA build instructions and build the target `//xla/tools:run_hlo_module`. E.g.: if you saved your serialized HLO to a file name "my_hlo.ph", you can print it out as:
$ run_hlo_module --platform=cpu --xla_dump_hlo_as_text my_hlo.pb
The `run_hlo_module` tool can also be used to run the program, export go HTML, graphviz, etc.
func (*XlaComputation) Destroy ¶
func (comp *XlaComputation) Destroy()
Destroy immediately the underlying (C/C++) XlaComputation. This is called automatically at garbage-collection.
func (*XlaComputation) HasStableHLO ¶ added in v0.4.3
func (comp *XlaComputation) HasStableHLO() bool
HasStableHLO returns whether StableHLO support was included in the build -- it's very large, so by default it is not.
func (*XlaComputation) IsNil ¶
func (comp *XlaComputation) IsNil() bool
IsNil returns whether the computation or the underlying C/C++ object are nil. It's true after it is destroyed.
func (*XlaComputation) Name ¶
func (comp *XlaComputation) Name() string
Name returns the name assigned to the computation (given at the builder construction).
func (*XlaComputation) SerializedHLO ¶
func (comp *XlaComputation) SerializedHLO() *cbuffer.CBuffer
SerializedHLO generates the StableHLO program as a <serialized HLOModule proto> (something that PJRT can consume) for the given computation.
The returned CBuffer needs to be freed (CBuffer.Destroy) after being used (presumably by PJRT, or saved to a file).
See XlaComputation documentation on how to pretty-print the computation as text HLO.
func (*XlaComputation) SerializedStableHLO ¶ added in v0.4.3
func (comp *XlaComputation) SerializedStableHLO() (*cbuffer.CBuffer, error)
SerializedStableHLO exports the computation as a StableHLO as an `mlir:ModuleOp`.
It does that by converting the `HLOModule` proto to an `mlir:ModuleOp`.
This functionality is not included by default -- linking StableHLO will include LLVM and make the XlaBuilder library literally 10 times larger. If not included, it will return an error.
func (*XlaComputation) TextHLO ¶
func (comp *XlaComputation) TextHLO() string
TextHLO generates the HLO program as a <serialized HLOModule proto> and returns its text representation. It can be used for testing and debugging.
Alternatively, see XlaComputation documentation on how to pretty-print the computation as text HLO.
func (*XlaComputation) TextStableHLO ¶ added in v0.4.3
func (comp *XlaComputation) TextStableHLO() (string, error)
TextStableHLO generates the StableHLO program.