Documentation ¶
Overview ¶
Package xla wraps the XLA functionality, plus extra dependencies.
To set a default platform to use, set the environment variable GOMLX_PLATFORM to the value you want (eg. "Host", "CUDA", "TPU").
One can configure XLA C++ logging with the environment variable TF_CPP_MIN_LOG_LEVEL.
To compile this library requires (see README.md for details on how to install these):
- gomlx_xla library installed
- tcmalloc installed -- part of the `gperftools` library. See README.md. The distribution comes with a pre-compiled libtcmalloc.so, in the `$(INSTALL_DIR)/lib/gomlx/`. You can also copy it from there to a common libraries directory, of include that directory in your `LD_LIBRAY_PATH`.
Index ¶
- Constants
- Variables
- func CDataToSlice[T any](data unsafe.Pointer, count int) (result []T)
- func CShapeFromShape(shape shapes.Shape) *C.Shape
- func DataAndShapeFromLiteral[T shapes.Number](l *Literal) (data []T, shape shapes.Shape, err error)
- func DataFromLiteral[T shapes.Supported](l *Literal) ([]T, error)
- func DeleteCSerializedNode(cNode *C.SerializedNode)
- func DirHasLibdevice(dir string) bool
- func ErrorFromStatus(status *C.XlaStatus) (err error)
- func GarbageCollectXLAObjects(verbose bool)
- func GetDefaultPlatform() (string, error)
- func GetPlatforms() ([]string, error)
- func HumanBytes[T interface{ ... }](bytes T) string
- func LiteralsCount() int64
- func Malloc[T any]() (ptr *T)
- func MallocArray[T any](n int) (ptr *T)
- func MallocArrayAndSet[T any](n int, setFn func(i int) T) (ptr *T)
- func MemoryStats() string
- func MemoryUsage() uint64
- func MemoryUsedByFn(fn func(), msg string) int64
- func NoGlobalLeaks() bool
- func NumberToString(n int) string
- func OnDeviceBufferCount() int64
- func ParseXlaFlags() []string
- func PointerOrError[T any](s C.StatusOr) (t *T, err error)
- func PresetXlaFlagsCudaDir()
- func RegisterFinalizer[T Finalizer](o T)
- func ScalarFromLiteral[T shapes.Number](l *Literal) (t T, err error)
- func ShapeFromCShape(cShape *C.Shape) (shape shapes.Shape)
- func SizeOf[T any]() C.size_t
- func SliceToVectorData[T any](slice []T) (vec *C.VectorData)
- func StableHLOCurrentVersion() string
- func StrFree(cstr *C.char) (str string)
- func StrVectorFree(vp *C.VectorPointers) (strs []string)
- func UnsafePointerOrError(s C.StatusOr) (unsafe.Pointer, error)
- func VectorDataToSlice[T any](vec *C.VectorData) (data []T)
- type AOTExecutable
- type AutoCPointer
- type Client
- type ClientId
- type Computation
- func (comp *Computation) AOTCompile(paramShapes []shapes.Shape) ([]byte, error)
- func (comp *Computation) AddOp(node *SerializedNode) (opsNum int, shape shapes.Shape, err error)
- func (comp *Computation) Compile(paramShapes []shapes.Shape, output int) error
- func (comp *Computation) Finalize()
- func (comp *Computation) IsCompiled() bool
- func (comp *Computation) IsNil() bool
- func (comp *Computation) Run(params []*OnDeviceBuffer) (*OnDeviceBuffer, error)
- func (comp *Computation) SerializedToC(node *SerializedNode) *C.SerializedNode
- func (comp *Computation) ToStableHLO() (*StableHLO, error)
- type ErrorCode
- type FftType
- type Finalizer
- type GlobalData
- type Literal
- func (l *Literal) Bytes() []byte
- func (l *Literal) Data() any
- func (l *Literal) Finalize()
- func (l *Literal) IsNil() bool
- func (l *Literal) Refresh()
- func (l *Literal) Shape() (shape shapes.Shape)
- func (l *Literal) SplitTuple() []*Literal
- func (l *Literal) ToGlobalData(client *Client) (*GlobalData, error)
- func (l *Literal) ToOnDeviceBuffer(client *Client, deviceOrdinal int) (*OnDeviceBuffer, error)
- type NodeType
- type OnDeviceBuffer
- func (b *OnDeviceBuffer) Client() *Client
- func (b *OnDeviceBuffer) DeviceOrdinal() int
- func (b *OnDeviceBuffer) Finalize()
- func (b *OnDeviceBuffer) IsNil() bool
- func (b *OnDeviceBuffer) Shape() shapes.Shape
- func (b *OnDeviceBuffer) String() string
- func (b *OnDeviceBuffer) SubTree(path []int) (*OnDeviceBuffer, error)
- type RandomAlgorithm
- type SerializedNode
- type StableHLO
- type Status
- type UnsafeCPointer
- type WrapperWithDestructor
Constants ¶
const ( CudaDirKey = "CUDA_DIR" XlaFlagsKey = "XLA_FLAGS" XlaFlagGpuCudaDataDir = "--xla_gpu_cuda_data_dir" )
This file implements the workaround to help XLA find the (infamous) `libdevice.10.bc` file needed by the NVidia CUDA drivers. Unfortunately, it is a required file and there is not a good default way of finding it.
XLA uses the environment variable XLA_FLAGS, with the flag --xla_gpu_cuda_data_dir set to the CUDA directory to be searched. If it's not set, it searches the current directory.
The strategies for GoMLX around this limitation, in case it is compiled with GPU/CUDA support, are:
1. If --xla_gpu_cuda_data_dir is not set:
- If `CUDA_DIR` env variable is set, set that in XLA_FLAGS instead.
- Try to find in standard (for now only Ubuntu/debian) locations for CUDA driver files, and if they find `libdevice.10.bc` there, set it in XLA_FLAGS accordingly. It starts with `./cuda_sdk_lib`, the default used by XLA (see file https://github.com/openxla/xla/blob/main/xla/debug_options_flags.cc). 2. Independent of (1) changing or not `XLA_FLAGS`, parse errors from XLA in search for `libdevice not found` and provide a much more detailed error message to the end user.
const DefaultPlatformEnv = "GOMLX_PLATFORM"
const InvalidClientId = ClientId(-1)
const LibDeviceNotFoundErrorMessage = `` /* 1069-byte string literal not displayed */
const XlaWrapperVersion = 14
XlaWrapperVersion is the version of the library. It should match the C++ one, if they are out-of-sync very odd mistakes happen.
Please bump whenever a new NodeType is created, and keep the C++ (in `c/gomlx/status.h`) and Go version numbers in sync.
Variables ¶
var ( PlatformPreferences = []string{"TPU", "CUDA"} AvoidPlatform = "Host" )
var ( LiteralsCountDeallocated = int64(0) LiteralsCountAllocated = int64(0) )
var ( OnDeviceBufferCountDeallocated = int64(0) OnDeviceBufferCountAllocated = int64(0) )
Number of OnDeviceBuffer objects allocated and freed: used for profiling and debugging.
var LibDeviceDir string
var LibDeviceFound bool
LibDeviceFound indicates whether a location for `libdevice` CUDA file was found or pre-given by the user (in XLA_FLAGS) at init time.
If one attempts to use CUDA with LibDeviceFound == false, the library will return an error.
var OpsCount = 0
Functions ¶
func CShapeFromShape ¶
CShapeFromShape allocates int the C-heap a new C-struct representing the shape.
func DataAndShapeFromLiteral ¶
DataAndShapeFromLiteral return Literal's shape and data in one call. Returns error if the DType is incompatible with the requested type.
func DataFromLiteral ¶
DataFromLiteral returns a pointer to the raw data on the literal (without consideration of shape). The data can be mutated.
The slices themselves shouldn't be modified -- the underlying storage is not owned by Go, and tensor objects are supposed to be immutable. See discussion on data storage and exceptions to mutability on the package description if you really need to mutate its values.
func DeleteCSerializedNode ¶
func DeleteCSerializedNode(cNode *C.SerializedNode)
DeleteCSerializedNode frees the C allocated memory within cNode. Note that cNode itself is assumed to be allocated in Go space, hence it is (and should be) automatically garbage collected.
func DirHasLibdevice ¶ added in v0.7.0
DirHasLibdevice checks whether the directory has "libdevice..." needed by GPU CUDA.
This is somewhat based on function `GetLibdeviceDir` defined in https://github.com/openxla/xla/blob/main/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
func ErrorFromStatus ¶
ErrorFromStatus converts a *C.XlaStatus returned to an error or nil if there were no errors or if status == nil. It also frees the returned *C.XlaStatus.
func GarbageCollectXLAObjects ¶
func GarbageCollectXLAObjects(verbose bool)
GarbageCollectXLAObjects interactively calls runtime.GC() until no more xla.Literal or xla.OnDeviceBuffer objects are being collected.
func GetDefaultPlatform ¶
GetDefaultPlatform returns the default platform from the list of available platforms. It avoids choosing "Host", that is, it tries to pick the first ML accelerator available. The choice can be overridden by the value of the environment variable GOMLX_PLATFORM.
func GetPlatforms ¶
GetPlatforms lists the available platforms
func HumanBytes ¶
func LiteralsCount ¶
func LiteralsCount() int64
LiteralsCount returns the number of Literal objects still allocated. Used for profiling and debugging.
func Malloc ¶
func Malloc[T any]() (ptr *T)
Malloc allocates a T in the C heap and initializes it to zero. It must be manually freed with C.free() by the user.
func MallocArray ¶
MallocArray allocates space to hold n copies of T in the C heap and initializes it to zero. It must be manually freed with C.free() by the user.
func MallocArrayAndSet ¶
MallocArrayAndSet allocates space to hold n copies of T in the C heap, and set each element `i` with the result of `setFn(i)`. It must be manually freed with C.free() by the user.
func MemoryStats ¶
func MemoryStats() string
MemoryStats returns memory profiling/introspection using MallocExtension::GetStats().
func MemoryUsage ¶
func MemoryUsage() uint64
MemoryUsage returns the memory used by the application, as reported by MallocExtension::GetNumericProperty("generic.current_allocated_bytes",...). If it returns 0, something went wrong.
func MemoryUsedByFn ¶
MemoryUsedByFn returns the extra memory usage in C/C++ heap after calling the given function. Used for testing. If msg is not "", logs information before and after.
func NoGlobalLeaks ¶
func NoGlobalLeaks() bool
NoGlobalLeaks indicates the head-cheker (part of tcmalloc) to output its heap profile immediately -- because we are in Go, it is not called at exit, so we need to manually call this. To be used with google-pprof.
func NumberToString ¶
NumberToString converts a number to string in C++. Trivial function used for testing only.
func OnDeviceBufferCount ¶
func OnDeviceBufferCount() int64
OnDeviceBufferCount returns the number of OnDeviceBuffer still allocated.
func ParseXlaFlags ¶ added in v0.7.0
func ParseXlaFlags() []string
ParseXlaFlags returns the flags passed in the environment variable `XLA_FLAGS`.
func PointerOrError ¶
PointerOrError converts a StatusOr structure to either a pointer to T with the data or the Status converted to an error message and then freed.
func PresetXlaFlagsCudaDir ¶ added in v0.7.0
func PresetXlaFlagsCudaDir()
PresetXlaFlagsCudaDir will update `XLA_FLAGS` env variable if `xla_gpu_cuda_data_dir` is not set, in an attempt to make the file `libdevice.10.bc` available (it is required by NVidia CUDA):
- If `CUDA_DIR` env variable is set, set that in XLA_FLAGS instead.
- Try to find in standard (for now only Ubuntu/debian) locations for CUDA driver files, and if they find `libdevice.10.bc` there, set it in XLA_FLAGS accordingly. It starts with `./cuda_sdk_lib`, the default used by XLA (see file https://github.com/openxla/xla/blob/main/xla/debug_options_flags.cc).
func RegisterFinalizer ¶
func RegisterFinalizer[T Finalizer](o T)
RegisterFinalizer is a trivial helper function that calls the WrapperWithDestructor.Finalize.
func ScalarFromLiteral ¶
ScalarFromLiteral returns the scalar stored in a Literal. Returns an error if Literal is not a scalar or if it is the wrong DType.
func ShapeFromCShape ¶
ShapeFromCShape converts a shape provided in C struct (cShape) into a shapes.Shape. cShape memory is NOT freed.
func SizeOf ¶
SizeOf returns the size of the given type in bytes. Notice some structures may be padded, and this will include that space.
func SliceToVectorData ¶
func SliceToVectorData[T any](slice []T) (vec *C.VectorData)
func StableHLOCurrentVersion ¶
func StableHLOCurrentVersion() string
StableHLOCurrentVersion returns the current version for the StableHLO library.
func StrFree ¶
StrFree converts the allocated C string (char *) to a Go `string` and frees the C string immediately.
func StrVectorFree ¶
func StrVectorFree(vp *C.VectorPointers) (strs []string)
StrVectorFree converts a C.VectorPointers that presumably contains `char *` to []string. It frees everything: the individual `char *` pointers, the array that contains it (`vp.data`) and finally `vp` itself.
func UnsafePointerOrError ¶
UnsafePointerOrError converts a StatusOr structure to either an unsafe.Pointer with the data or the Status converted to an error message and then freed.
func VectorDataToSlice ¶
func VectorDataToSlice[T any](vec *C.VectorData) (data []T)
VectorDataToSlice and then frees the given VectorData.
Types ¶
type AOTExecutable ¶
type AOTExecutable struct {
// contains filtered or unexported fields
}
AOTExecutable executes Ahead-Of-Time (AOT) compiled graphs.
func NewAOTExecutable ¶
func NewAOTExecutable(client *Client, aotResult []byte) (*AOTExecutable, error)
NewAOTExecutable given the client and the aotResult returned by an earlier Computation.AOTCompile call. It may return an error.
func (*AOTExecutable) IsNil ¶
func (exec *AOTExecutable) IsNil() bool
IsNil returns whether contents are invalid or have been freed already.
func (*AOTExecutable) Run ¶
func (exec *AOTExecutable) Run(params []*OnDeviceBuffer) (*OnDeviceBuffer, error)
type AutoCPointer ¶
type AutoCPointer[T any] struct { P *T }
AutoCPointer wrapps a C pointer, and frees it when AutoCPointer is garbage collected, or when Finalize is called.
func NewAutoCPointer ¶
func NewAutoCPointer[T any](p *T) (w *AutoCPointer[T])
NewAutoCPointer holds a C pointers, and makes sure it is freed (`C.free`) when it is garbage collected or when Finalize is called.
Note since Go 1.20 forward declared C types (not completely known) are no longer supported as generic parameters -- even though we only use the pointers to them. See https://groups.google.com/g/golang-nuts/c/h75BwBsz4YA/m/FLBIjgFBBQAJ for details.
func (*AutoCPointer[T]) Finalize ¶
func (w *AutoCPointer[T]) Finalize()
Finalize implements Finalizer.
func (*AutoCPointer[T]) IsNil ¶
func (w *AutoCPointer[T]) IsNil() bool
IsNil returns whether the AutoCPointer is nil or what it's pointing.
func (*AutoCPointer[T]) UnsafePtr ¶
func (w *AutoCPointer[T]) UnsafePtr() unsafe.Pointer
UnsafePtr returns the pointer to T cast as `unsafe.Pointer`.
type Client ¶
type Client struct { Id ClientId DeviceCount, DefaultDeviceOrdinal int // contains filtered or unexported fields }
Client is a wrapper for the C++ `Client`, itself a wrapper to xla::Client, described as "XLA service's client object -- wraps the service with convenience and lifetime-oriented methods."
It's a required parameter to most high level XLA functionality: compilation and execution of graphs, transfer of data, etc.
func NewClient ¶
NewClient constructs the Client. If platform is empty, use what is returned by GetDefaultPlatform. If numThreads is -1, use the number of available cores. Not thread-safe.
type ClientId ¶
type ClientId int
ClientId is a unique identifier to clients. Starts with 0 and increases.
type Computation ¶
type Computation struct {
// contains filtered or unexported fields
}
Computation is a wrapper for the C++ `XlaComputation`.
func NewComputation ¶
func NewComputation(client *Client, name string) *Computation
NewComputation creates a new XlaComputation object and returns its wrapper.
func (*Computation) AOTCompile ¶
func (comp *Computation) AOTCompile(paramShapes []shapes.Shape) ([]byte, error)
AOTCompile returns the Ahead-Of-Time compiled version of the graph, that can be used for execution later.
The graph needs to be compiled. And it is AOT-compiled to the same platform it was already compiled -- TODO: cross-compile.
It returns a binary serialized format that can be executed later, without linking the whole GoMLX machinery. See tutorial on instructions and an example of how to do this.
func (*Computation) AddOp ¶
func (comp *Computation) AddOp(node *SerializedNode) (opsNum int, shape shapes.Shape, err error)
AddOp adds an Op described by node to the compilation graph. Notice after the graph is compiled it becomes frozen and no more ops can be added. Ops are created in order, and
func (*Computation) Compile ¶
func (comp *Computation) Compile(paramShapes []shapes.Shape, output int) error
Compile compiles the Computation, so it's ready for execution. After this no more ops can be added. The output is the index to the output node.
func (*Computation) IsCompiled ¶
func (comp *Computation) IsCompiled() bool
IsCompiled returns whether this computation has already been compiled.
func (*Computation) IsNil ¶
func (comp *Computation) IsNil() bool
IsNil checks whether the computation is nil or it's C underlying object.
func (*Computation) Run ¶
func (comp *Computation) Run(params []*OnDeviceBuffer) (*OnDeviceBuffer, error)
Run runs a computation. The parameter values for the graph are given in params.
func (*Computation) SerializedToC ¶
func (comp *Computation) SerializedToC(node *SerializedNode) *C.SerializedNode
func (*Computation) ToStableHLO ¶
func (comp *Computation) ToStableHLO() (*StableHLO, error)
ToStableHLO returning a holder of the C++ object representing the StableHLO.
type ErrorCode ¶
type ErrorCode int
ErrorCode is used by the underlying TensorFlow/XLA libraries, in Status objects.
const ( OK ErrorCode = 0 CANCELLED ErrorCode = 1 UNKNOWN ErrorCode = 2 INVALID_ARGUMENT ErrorCode = 3 DEADLINE_EXCEEDED ErrorCode = 4 NOT_FOUND ErrorCode = 5 ALREADY_EXISTS ErrorCode = 6 PERMISSION_DENIED ErrorCode = 7 UNAUTHENTICATED ErrorCode = 16 RESOURCE_EXHAUSTED ErrorCode = 8 FAILED_PRECONDITION ErrorCode = 9 ABORTED ErrorCode = 10 OUT_OF_RANGE ErrorCode = 11 )
Values copied from tensorflow/core/protobuf/error_codes.proto. TODO: convert the protos definitions to Go and use that instead.
type FftType ¶ added in v0.6.0
type FftType int
FftType should be aligned with the constants in `xla_data.proto` file in the XLA code base.
The XLA's FFT operator can work i different ways, this defines how.
const ( // FftForward does a forward FFT: complex in, complex out. // FFT in the proto. FftForward FftType = iota // FftInverse does an inverse FFT: complex in, complex out. // IFFT in the proto. FftInverse // FftForwardReal does a forward real FFT: real in, fft_length / 2 + 1 complex out // RFFT in the proto. FftForwardReal // FftInverseReal does an inverse real FFT: fft_length / 2 + 1 complex in, real out // IRFFT in the proto. FftInverseReal )
type Finalizer ¶
type Finalizer interface { // Finalize frees the underlying resources, presumably outside // Go runtime control, like C pointers. // // Finalize should be idem-potent: if called multiple times subsequent calls Finalize() }
Finalizer is any object that implements Finalize, that can be called when an object is deallocated using runtime.SetFinalizer.
Finalize should be idem-potent: if called multiple times subsequent calls shouldn't affect it.
type GlobalData ¶
type GlobalData struct {
// contains filtered or unexported fields
}
GlobalData represents a value stored in the remote accelerator (or wherever the computation is executed). Use ToLiteral to bring it to the local CPU.
Note: GoMLX started using a different execution pathway that uses ShapedBuffer instead, and GlobalData is no longer being used right now. But I'm not sure if it will be used later when adding support for multiple accelerator devices.
func (*GlobalData) Client ¶
func (gd *GlobalData) Client() *Client
func (*GlobalData) IsNil ¶
func (gd *GlobalData) IsNil() bool
IsNil returns whether either the Client object is nil or the contained C pointer.
func (*GlobalData) Shape ¶
func (gd *GlobalData) Shape() (shape shapes.Shape, err error)
Shape retrieves the shape of data stored globally.
func (*GlobalData) SplitTuple ¶
func (gd *GlobalData) SplitTuple() (gds []*GlobalData, err error)
SplitTuple splits the tuple GlobalData into its various components. This invalidates the current global data.
func (*GlobalData) ToLiteral ¶
func (gd *GlobalData) ToLiteral() (*Literal, error)
ToLiteral transfers from stored in accelerator server, to CPU. Returns a Literal with the transferred data.
type Literal ¶
type Literal struct {
// contains filtered or unexported fields
}
Literal represents a value stored in the local CPU, in C++ heap, to interact with XLA.
func FromOnDeviceBuffer ¶
func FromOnDeviceBuffer(buffer *OnDeviceBuffer) (*Literal, error)
func NewLiteralFromShape ¶
NewLiteralFromShape returns a new Literal with the given shape and uninitialized data.
func NewLiteralTuple ¶
NewLiteralTuple creates a tuple from the individual literals. Flat is copied.
func NewLiteralWithZeros ¶
func (*Literal) Bytes ¶
Bytes returns the same memory as Data, but the raw slice of bytes, with the proper size.
func (*Literal) Data ¶
Data returns a slice of the data for the corresponding DType (see package types). The underlying data is not owned/managed by Go, but it is mutable.
Notice if the Literal goes out-of-scope and is finalized, the underlying data gets freed without Go knowing about , and the returned slice may become invalid. Careful to make sure Literal doesn't get out of scope --
func (*Literal) IsNil ¶
IsNil returns whether either the Client object is nil or the contained C pointer.
func (*Literal) Refresh ¶ added in v0.2.1
func (l *Literal) Refresh()
Refresh will re-read the data pointer and size information from the C++ interface. Mostly used for debugging.
func (*Literal) SplitTuple ¶
SplitTuple in the Literal into multiple literals. Unlike with GlobalData, this destroys the underlying Literal.
func (*Literal) ToGlobalData ¶
func (l *Literal) ToGlobalData(client *Client) (*GlobalData, error)
ToGlobalData transfers the literal to the global accelerator server. GlobalData is no longer used in our execution path, for now this is deprecated.
func (*Literal) ToOnDeviceBuffer ¶
func (l *Literal) ToOnDeviceBuffer(client *Client, deviceOrdinal int) (*OnDeviceBuffer, error)
ToOnDeviceBuffer returns a OnDeviceBuffer allocated on the device (the number of the device is given by deviceOrdinal).
type NodeType ¶
type NodeType int32
NodeType enumerate the various types of Nodes that can be converted to XLA.
const ( InvalidNode NodeType = iota ConstantNode IotaNode ParameterNode ConvertTypeNode WhereNode TupleNode GetTupleElementNode ReshapeNode BroadcastNode BroadcastInDimNode ReduceSumNode ReduceMaxNode ReduceMultiplyNode SliceNode PadNode GatherNode ScatterNode ConcatenateNode ConvGeneralDilatedNode ReverseNode TransposeNode ReduceWindowNode SelectAndScatterNode BatchNormTrainingNode BatchNormInferenceNode BatchNormGradNode DotGeneralNode ArgMinMaxNode FftNode AbsNode NegNode ExpNode Expm1Node FloorNode CeilNode RoundNode LogNode Log1pNode LogicalNotNode LogisticNode SignNode ClzNode CosNode SinNode TanhNode SqrtNode RsqrtNode ImagNode RealNode ConjNode AddNode MulNode SubNode DivNode RemNode // Notice XLA implements Mod, not IEEE754 Remainder operation. AndNode OrNode XorNode DotNode MinNode MaxNode PowNode ComplexNode EqualNode NotEqualNode GreaterOrEqualNode GreaterThanNode LessOrEqualNode LessThanNode EqualTotalOrderNode NotEqualTotalOrderNode GreaterOrEqualTotalOrderNode GreaterThanTotalOrderNode LessOrEqualTotalOrderNode LessThanTotalOrderNode RngBitGeneratorNode RngNormalNode RngUniformNode )
NodeType values need to be exactly the same as defined in the C++ code, in `c/gomlx/node.h` TODO: keep those in sync using some generator script.
type OnDeviceBuffer ¶
type OnDeviceBuffer struct {
// contains filtered or unexported fields
}
OnDeviceBuffer represents a value stored in an execution device (CPU or accelerator).
Notice there can be multiple accelerator devices, DeviceOrdinal will inform which device this is stored in.
func (*OnDeviceBuffer) Client ¶
func (b *OnDeviceBuffer) Client() *Client
func (*OnDeviceBuffer) DeviceOrdinal ¶
func (b *OnDeviceBuffer) DeviceOrdinal() int
DeviceOrdinal returns the ordinal number of the device -- an id in case there are several replicas of the device (? XLA documentation is not clear about this, just guessing).
func (*OnDeviceBuffer) Finalize ¶
func (b *OnDeviceBuffer) Finalize()
Finalize implements Finalizer.
func (*OnDeviceBuffer) IsNil ¶
func (b *OnDeviceBuffer) IsNil() bool
IsNil returns whether OnDeviceBuffer holds no data.
func (*OnDeviceBuffer) Shape ¶
func (b *OnDeviceBuffer) Shape() shapes.Shape
func (*OnDeviceBuffer) String ¶
func (b *OnDeviceBuffer) String() string
func (*OnDeviceBuffer) SubTree ¶
func (b *OnDeviceBuffer) SubTree(path []int) (*OnDeviceBuffer, error)
SubTree retrieves an element from a nested tuple (tree) OnDeviceBuffer.
type RandomAlgorithm ¶ added in v0.4.0
type RandomAlgorithm int
RandomAlgorithm should be aligned with constants in `xla_data.proto` file is the XLA code base.
Each random algorithm entails a different shape of `initialState` that needs to be fed to `RngBitGenerator`.
See details and reference of the algorithms: https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator
Unfortunately there is no documented way of figuring out what is the initial state for `RngDefault` algorithm.
const ( // RngDefault is the back-end specific algorithm with back-end specific shape requirements. // There doesn't seem to be any automatic way of figuring out what it takes for `initialState`. RngDefault RandomAlgorithm = iota // RngThreeFry counter-based PRNG algorithm. The initial_state shape is U64[2] with arbitrary values. // [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf). RngThreeFry // RngPhilox algorithm to generate random numbers in parallel. The initial_state shape is `U64[3]` with arbitrary values. // [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf). RngPhilox )
func (RandomAlgorithm) String ¶ added in v0.4.0
func (i RandomAlgorithm) String() string
type SerializedNode ¶
type SerializedNode struct { Type NodeType NodeInputs []int32 // Index to other nodes that are used as inputs. Literal *Literal // If a Literal (constant) is involved in the operation. Int int // Used for any static integer inputs. Shape shapes.Shape // If a shape is used as a static input. Str string // Used for any static string argument. Ints []int // List of integer numbers. Float float32 // For a float parameter. }
SerializedNode represents a graph.Node with its arguments. This can then be passed to the XLA C++ library. Only used by ops implementers.
type StableHLO ¶
type StableHLO struct {
// contains filtered or unexported fields
}
StableHLO is a wrapper for the C++ `StableHLOHolder`.
func NewStableHLO ¶
func NewStableHLO(cPtr *C.StableHLOHolder) *StableHLO
NewStableHLO creates the wrapper.
func (*StableHLO) Serialize ¶
Serialize to bytecode that can presumably be used by PjRT and IREE, as well as embedded in one of the TensorFlow SavedModel formats.(??)
It serializes to the given file path.
func (*StableHLO) SerializeWithVersion ¶ added in v0.7.2
SerializeWithVersion to bytecode that can presumably be used by PjRT and IREE, as well as embedded in one of the TensorFlow SavedModel formats.(??)
It serializes to the given file descriptor, presumable a file opened for writing.
For version, the usual is to use the value returned by StableHLOCurrentVersion.
type Status ¶
type Status struct {
// contains filtered or unexported fields
}
Status is a wrapper for `xla::Status`.
func (*Status) ErrorMessage ¶
func (*Status) IsNil ¶
IsNil returns whether either the Client object is nil or the contained C pointer.
func (*Status) UnsafeCPtr ¶
UnsafeCPtr returns the underlying C pointer converted to unsafe.Pointer.
type UnsafeCPointer ¶
UnsafeCPointer holds a C pointers, and makes sure it is freed (`C.free`) when it is garbage collected or when Finalize is called.
Because since Go 1.20 forward declared C types (not completely known) are no longer supported as generic parameters, we need to keep those pointers as unsafe pointers, and have to be manually cast.
func NewUnsafeCPointer ¶
func NewUnsafeCPointer(p unsafe.Pointer) *UnsafeCPointer
NewUnsafeCPointer holds a forward declared C pointer as `unsafe.Pointer`, and makes sure it is freed (`C.free`) when it is garbage collected or when WrapperWithDestructor.Finalize is called.
Note since Go 1.20 forward declared C types (not completely known) are no longer supported as type parameters -- even though we only use the pointers to them. See https://groups.google.com/g/golang-nuts/c/h75BwBsz4YA/m/FLBIjgFBBQAJ for details.
func (*UnsafeCPointer) Finalize ¶
func (w *UnsafeCPointer) Finalize()
Finalize implements Finalizer.
func (*UnsafeCPointer) IsNil ¶
func (w *UnsafeCPointer) IsNil() bool
IsNil returns whether the AutoCPointer is nil or what it's pointing.
func (*UnsafeCPointer) UnsafePtr ¶
func (w *UnsafeCPointer) UnsafePtr() unsafe.Pointer
UnsafePtr returns the pointer to T cast as `unsafe.Pointer`.
type WrapperWithDestructor ¶
type WrapperWithDestructor[T any] struct { Data T // contains filtered or unexported fields }
WrapperWithDestructor wraps a pointer to an arbitrary type and adds a finalize method.
func NewWrapperWithDestructor ¶
func NewWrapperWithDestructor[T any](data T, destructor func(data T)) (w *WrapperWithDestructor[T])
NewWrapperWithDestructor creates a WrapperWithDestructor to type T, using the given destructor to finalize the object.
The `destructor` will only be called once, even if `Finalize()` is called manually. The wrapper sets the destructor to nil after the first time `Finalize()` is called.
There are no synchronization mechanisms, manually calling `Finalize()` concurrently is undefined.
func (*WrapperWithDestructor[T]) Empty ¶
func (w *WrapperWithDestructor[T]) Empty() bool
Empty returns true if either w is nil, or it's contents has already been finalized.
func (*WrapperWithDestructor[T]) Finalize ¶
func (w *WrapperWithDestructor[T]) Finalize()
Finalize frees the pointer held. Notice this version is not concurrency safe -- but then Finalize should be called only once anyway.
Once Finalize is called, the object cannot be re-used, it will be forever marked as empty.