xla

package
v0.9.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Apr 20, 2024 License: Apache-2.0 Imports: 16 Imported by: 0

Documentation

Overview

Package xla wraps the XLA functionality, plus extra dependencies.

To set a default platform to use, set the environment variable GOMLX_PLATFORM to the value you want (eg. "Host", "CUDA", "TPU").

One can configure XLA C++ logging with the environment variable TF_CPP_MIN_LOG_LEVEL.

To compile this library requires (see README.md for details on how to install these):

  • gomlx_xla library installed
  • tcmalloc installed -- part of the `gperftools` library. See README.md. The distribution comes with a pre-compiled libtcmalloc.so, in the `$(INSTALL_DIR)/lib/gomlx/`. You can also copy it from there to a common libraries directory, of include that directory in your `LD_LIBRAY_PATH`.

Index

Constants

View Source
const (
	CudaDirKey            = "CUDA_DIR"
	XlaFlagsKey           = "XLA_FLAGS"
	XlaFlagGpuCudaDataDir = "--xla_gpu_cuda_data_dir"
)

This file implements the workaround to help XLA find the (infamous) `libdevice.10.bc` file needed by the NVidia CUDA drivers. Unfortunately, it is a required file and there is not a good default way of finding it.

XLA uses the environment variable XLA_FLAGS, with the flag --xla_gpu_cuda_data_dir set to the CUDA directory to be searched. If it's not set, it searches the current directory.

The strategies for GoMLX around this limitation, in case it is compiled with GPU/CUDA support, are:

1. If --xla_gpu_cuda_data_dir is not set:

  • If `CUDA_DIR` env variable is set, set that in XLA_FLAGS instead.
  • Try to find in standard (for now only Ubuntu/debian) locations for CUDA driver files, and if they find `libdevice.10.bc` there, set it in XLA_FLAGS accordingly. It starts with `./cuda_sdk_lib`, the default used by XLA (see file https://github.com/openxla/xla/blob/main/xla/debug_options_flags.cc). 2. Independent of (1) changing or not `XLA_FLAGS`, parse errors from XLA in search for `libdevice not found` and provide a much more detailed error message to the end user.
View Source
const DefaultPlatformEnv = "GOMLX_PLATFORM"
View Source
const InvalidClientId = ClientId(-1)
View Source
const LibDeviceNotFoundErrorMessage = `` /* 1069-byte string literal not displayed */
View Source
const XlaWrapperVersion = 14

XlaWrapperVersion is the version of the library. It should match the C++ one, if they are out-of-sync very odd mistakes happen.

Please bump whenever a new NodeType is created, and keep the C++ (in `c/gomlx/status.h`) and Go version numbers in sync.

Variables

View Source
var (
	PlatformPreferences = []string{"TPU", "CUDA"}
	AvoidPlatform       = "Host"
)
View Source
var (
	LiteralsCountDeallocated = int64(0)
	LiteralsCountAllocated   = int64(0)
)
View Source
var (
	OnDeviceBufferCountDeallocated = int64(0)
	OnDeviceBufferCountAllocated   = int64(0)
)

Number of OnDeviceBuffer objects allocated and freed: used for profiling and debugging.

View Source
var LibDeviceDir string
View Source
var LibDeviceFound bool

LibDeviceFound indicates whether a location for `libdevice` CUDA file was found or pre-given by the user (in XLA_FLAGS) at init time.

If one attempts to use CUDA with LibDeviceFound == false, the library will return an error.

View Source
var OpsCount = 0

Functions

func CDataToSlice

func CDataToSlice[T any](data unsafe.Pointer, count int) (result []T)

func CShapeFromShape

func CShapeFromShape(shape shapes.Shape) *C.Shape

CShapeFromShape allocates int the C-heap a new C-struct representing the shape.

func DataAndShapeFromLiteral

func DataAndShapeFromLiteral[T shapes.Number](l *Literal) (data []T, shape shapes.Shape, err error)

DataAndShapeFromLiteral return Literal's shape and data in one call. Returns error if the DType is incompatible with the requested type.

func DataFromLiteral

func DataFromLiteral[T shapes.Supported](l *Literal) ([]T, error)

DataFromLiteral returns a pointer to the raw data on the literal (without consideration of shape). The data can be mutated.

The slices themselves shouldn't be modified -- the underlying storage is not owned by Go, and tensor objects are supposed to be immutable. See discussion on data storage and exceptions to mutability on the package description if you really need to mutate its values.

func DeleteCSerializedNode

func DeleteCSerializedNode(cNode *C.SerializedNode)

DeleteCSerializedNode frees the C allocated memory within cNode. Note that cNode itself is assumed to be allocated in Go space, hence it is (and should be) automatically garbage collected.

func DirHasLibdevice added in v0.7.0

func DirHasLibdevice(dir string) bool

DirHasLibdevice checks whether the directory has "libdevice..." needed by GPU CUDA.

This is somewhat based on function `GetLibdeviceDir` defined in https://github.com/openxla/xla/blob/main/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc

func ErrorFromStatus

func ErrorFromStatus(status *C.XlaStatus) (err error)

ErrorFromStatus converts a *C.XlaStatus returned to an error or nil if there were no errors or if status == nil. It also frees the returned *C.XlaStatus.

func GarbageCollectXLAObjects

func GarbageCollectXLAObjects(verbose bool)

GarbageCollectXLAObjects interactively calls runtime.GC() until no more xla.Literal or xla.OnDeviceBuffer objects are being collected.

func GetDefaultPlatform

func GetDefaultPlatform() (string, error)

GetDefaultPlatform returns the default platform from the list of available platforms. It avoids choosing "Host", that is, it tries to pick the first ML accelerator available. The choice can be overridden by the value of the environment variable GOMLX_PLATFORM.

func GetPlatforms

func GetPlatforms() ([]string, error)

GetPlatforms lists the available platforms

func HumanBytes

func HumanBytes[T interface{ int64 | uint64 }](bytes T) string

func LiteralsCount

func LiteralsCount() int64

LiteralsCount returns the number of Literal objects still allocated. Used for profiling and debugging.

func Malloc

func Malloc[T any]() (ptr *T)

Malloc allocates a T in the C heap and initializes it to zero. It must be manually freed with C.free() by the user.

func MallocArray

func MallocArray[T any](n int) (ptr *T)

MallocArray allocates space to hold n copies of T in the C heap and initializes it to zero. It must be manually freed with C.free() by the user.

func MallocArrayAndSet

func MallocArrayAndSet[T any](n int, setFn func(i int) T) (ptr *T)

MallocArrayAndSet allocates space to hold n copies of T in the C heap, and set each element `i` with the result of `setFn(i)`. It must be manually freed with C.free() by the user.

func MemoryStats

func MemoryStats() string

MemoryStats returns memory profiling/introspection using MallocExtension::GetStats().

func MemoryUsage

func MemoryUsage() uint64

MemoryUsage returns the memory used by the application, as reported by MallocExtension::GetNumericProperty("generic.current_allocated_bytes",...). If it returns 0, something went wrong.

func MemoryUsedByFn

func MemoryUsedByFn(fn func(), msg string) int64

MemoryUsedByFn returns the extra memory usage in C/C++ heap after calling the given function. Used for testing. If msg is not "", logs information before and after.

func NoGlobalLeaks

func NoGlobalLeaks() bool

NoGlobalLeaks indicates the head-cheker (part of tcmalloc) to output its heap profile immediately -- because we are in Go, it is not called at exit, so we need to manually call this. To be used with google-pprof.

func NumberToString

func NumberToString(n int) string

NumberToString converts a number to string in C++. Trivial function used for testing only.

func OnDeviceBufferCount

func OnDeviceBufferCount() int64

OnDeviceBufferCount returns the number of OnDeviceBuffer still allocated.

func ParseXlaFlags added in v0.7.0

func ParseXlaFlags() []string

ParseXlaFlags returns the flags passed in the environment variable `XLA_FLAGS`.

func PointerOrError

func PointerOrError[T any](s C.StatusOr) (t *T, err error)

PointerOrError converts a StatusOr structure to either a pointer to T with the data or the Status converted to an error message and then freed.

func PresetXlaFlagsCudaDir added in v0.7.0

func PresetXlaFlagsCudaDir()

PresetXlaFlagsCudaDir will update `XLA_FLAGS` env variable if `xla_gpu_cuda_data_dir` is not set, in an attempt to make the file `libdevice.10.bc` available (it is required by NVidia CUDA):

  • If `CUDA_DIR` env variable is set, set that in XLA_FLAGS instead.
  • Try to find in standard (for now only Ubuntu/debian) locations for CUDA driver files, and if they find `libdevice.10.bc` there, set it in XLA_FLAGS accordingly. It starts with `./cuda_sdk_lib`, the default used by XLA (see file https://github.com/openxla/xla/blob/main/xla/debug_options_flags.cc).

func RegisterFinalizer

func RegisterFinalizer[T Finalizer](o T)

RegisterFinalizer is a trivial helper function that calls the WrapperWithDestructor.Finalize.

func ScalarFromLiteral

func ScalarFromLiteral[T shapes.Number](l *Literal) (t T, err error)

ScalarFromLiteral returns the scalar stored in a Literal. Returns an error if Literal is not a scalar or if it is the wrong DType.

func ShapeFromCShape

func ShapeFromCShape(cShape *C.Shape) (shape shapes.Shape)

ShapeFromCShape converts a shape provided in C struct (cShape) into a shapes.Shape. cShape memory is NOT freed.

func SizeOf

func SizeOf[T any]() C.size_t

SizeOf returns the size of the given type in bytes. Notice some structures may be padded, and this will include that space.

func SliceToVectorData

func SliceToVectorData[T any](slice []T) (vec *C.VectorData)

func StableHLOCurrentVersion

func StableHLOCurrentVersion() string

StableHLOCurrentVersion returns the current version for the StableHLO library.

func StrFree

func StrFree(cstr *C.char) (str string)

StrFree converts the allocated C string (char *) to a Go `string` and frees the C string immediately.

func StrVectorFree

func StrVectorFree(vp *C.VectorPointers) (strs []string)

StrVectorFree converts a C.VectorPointers that presumably contains `char *` to []string. It frees everything: the individual `char *` pointers, the array that contains it (`vp.data`) and finally `vp` itself.

func UnsafePointerOrError

func UnsafePointerOrError(s C.StatusOr) (unsafe.Pointer, error)

UnsafePointerOrError converts a StatusOr structure to either an unsafe.Pointer with the data or the Status converted to an error message and then freed.

func VectorDataToSlice

func VectorDataToSlice[T any](vec *C.VectorData) (data []T)

VectorDataToSlice and then frees the given VectorData.

Types

type AOTExecutable

type AOTExecutable struct {
	// contains filtered or unexported fields
}

AOTExecutable executes Ahead-Of-Time (AOT) compiled graphs.

func NewAOTExecutable

func NewAOTExecutable(client *Client, aotResult []byte) (*AOTExecutable, error)

NewAOTExecutable given the client and the aotResult returned by an earlier Computation.AOTCompile call. It may return an error.

func (*AOTExecutable) IsNil

func (exec *AOTExecutable) IsNil() bool

IsNil returns whether contents are invalid or have been freed already.

func (*AOTExecutable) Run

func (exec *AOTExecutable) Run(params []*OnDeviceBuffer) (*OnDeviceBuffer, error)

type AutoCPointer

type AutoCPointer[T any] struct {
	P *T
}

AutoCPointer wrapps a C pointer, and frees it when AutoCPointer is garbage collected, or when Finalize is called.

func NewAutoCPointer

func NewAutoCPointer[T any](p *T) (w *AutoCPointer[T])

NewAutoCPointer holds a C pointers, and makes sure it is freed (`C.free`) when it is garbage collected or when Finalize is called.

Note since Go 1.20 forward declared C types (not completely known) are no longer supported as generic parameters -- even though we only use the pointers to them. See https://groups.google.com/g/golang-nuts/c/h75BwBsz4YA/m/FLBIjgFBBQAJ for details.

func (*AutoCPointer[T]) Finalize

func (w *AutoCPointer[T]) Finalize()

Finalize implements Finalizer.

func (*AutoCPointer[T]) IsNil

func (w *AutoCPointer[T]) IsNil() bool

IsNil returns whether the AutoCPointer is nil or what it's pointing.

func (*AutoCPointer[T]) UnsafePtr

func (w *AutoCPointer[T]) UnsafePtr() unsafe.Pointer

UnsafePtr returns the pointer to T cast as `unsafe.Pointer`.

type Client

type Client struct {
	Id                                ClientId
	DeviceCount, DefaultDeviceOrdinal int
	// contains filtered or unexported fields
}

Client is a wrapper for the C++ `Client`, itself a wrapper to xla::Client, described as "XLA service's client object -- wraps the service with convenience and lifetime-oriented methods."

It's a required parameter to most high level XLA functionality: compilation and execution of graphs, transfer of data, etc.

func NewClient

func NewClient(platform string, numReplicas, numThreads int) (*Client, error)

NewClient constructs the Client. If platform is empty, use what is returned by GetDefaultPlatform. If numThreads is -1, use the number of available cores. Not thread-safe.

func (*Client) Client

func (c *Client) Client() *Client

Client implements the tensor.HasClient interface, by returning itself.

func (*Client) Finalize

func (c *Client) Finalize()

Finalize implements Finalizer.

func (*Client) IsNil

func (c *Client) IsNil() bool

IsNil returns whether either the Client object is nil or the contained C pointer.

func (*Client) Ok

func (c *Client) Ok() bool

Ok returns whether the client was created successful and is still valid.

type ClientId

type ClientId int

ClientId is a unique identifier to clients. Starts with 0 and increases.

type Computation

type Computation struct {
	// contains filtered or unexported fields
}

Computation is a wrapper for the C++ `XlaComputation`.

func NewComputation

func NewComputation(client *Client, name string) *Computation

NewComputation creates a new XlaComputation object and returns its wrapper.

func (*Computation) AOTCompile

func (comp *Computation) AOTCompile(paramShapes []shapes.Shape) ([]byte, error)

AOTCompile returns the Ahead-Of-Time compiled version of the graph, that can be used for execution later.

The graph needs to be compiled. And it is AOT-compiled to the same platform it was already compiled -- TODO: cross-compile.

It returns a binary serialized format that can be executed later, without linking the whole GoMLX machinery. See tutorial on instructions and an example of how to do this.

func (*Computation) AddOp

func (comp *Computation) AddOp(node *SerializedNode) (opsNum int, shape shapes.Shape, err error)

AddOp adds an Op described by node to the compilation graph. Notice after the graph is compiled it becomes frozen and no more ops can be added. Ops are created in order, and

func (*Computation) Compile

func (comp *Computation) Compile(paramShapes []shapes.Shape, output int) error

Compile compiles the Computation, so it's ready for execution. After this no more ops can be added. The output is the index to the output node.

func (*Computation) Finalize

func (comp *Computation) Finalize()

Finalize implements Finalizer.

func (*Computation) IsCompiled

func (comp *Computation) IsCompiled() bool

IsCompiled returns whether this computation has already been compiled.

func (*Computation) IsNil

func (comp *Computation) IsNil() bool

IsNil checks whether the computation is nil or it's C underlying object.

func (*Computation) Run

func (comp *Computation) Run(params []*OnDeviceBuffer) (*OnDeviceBuffer, error)

Run runs a computation. The parameter values for the graph are given in params.

func (*Computation) SerializedToC

func (comp *Computation) SerializedToC(node *SerializedNode) *C.SerializedNode

func (*Computation) ToStableHLO

func (comp *Computation) ToStableHLO() (*StableHLO, error)

ToStableHLO returning a holder of the C++ object representing the StableHLO.

type ErrorCode

type ErrorCode int

ErrorCode is used by the underlying TensorFlow/XLA libraries, in Status objects.

const (
	OK                  ErrorCode = 0
	CANCELLED           ErrorCode = 1
	UNKNOWN             ErrorCode = 2
	INVALID_ARGUMENT    ErrorCode = 3
	DEADLINE_EXCEEDED   ErrorCode = 4
	NOT_FOUND           ErrorCode = 5
	ALREADY_EXISTS      ErrorCode = 6
	PERMISSION_DENIED   ErrorCode = 7
	UNAUTHENTICATED     ErrorCode = 16
	RESOURCE_EXHAUSTED  ErrorCode = 8
	FAILED_PRECONDITION ErrorCode = 9
	ABORTED             ErrorCode = 10
	OUT_OF_RANGE        ErrorCode = 11
)

Values copied from tensorflow/core/protobuf/error_codes.proto. TODO: convert the protos definitions to Go and use that instead.

func (ErrorCode) String

func (i ErrorCode) String() string

type FftType added in v0.6.0

type FftType int

FftType should be aligned with the constants in `xla_data.proto` file in the XLA code base.

The XLA's FFT operator can work i different ways, this defines how.

const (
	// FftForward does a forward FFT: complex in, complex out.
	// FFT in the proto.
	FftForward FftType = iota

	// FftInverse does an inverse FFT: complex in, complex out.
	// IFFT in the proto.
	FftInverse

	// FftForwardReal does a forward real FFT: real in, fft_length / 2 + 1 complex out
	// RFFT in the proto.
	FftForwardReal

	// FftInverseReal does an inverse real FFT: fft_length / 2 + 1 complex in, real out
	// IRFFT in the proto.
	FftInverseReal
)

func (FftType) String added in v0.6.0

func (i FftType) String() string

type Finalizer

type Finalizer interface {
	// Finalize frees the underlying resources, presumably outside
	// Go runtime control, like C pointers.
	//
	// Finalize should be idem-potent: if called multiple times subsequent calls
	Finalize()
}

Finalizer is any object that implements Finalize, that can be called when an object is deallocated using runtime.SetFinalizer.

Finalize should be idem-potent: if called multiple times subsequent calls shouldn't affect it.

type GlobalData

type GlobalData struct {
	// contains filtered or unexported fields
}

GlobalData represents a value stored in the remote accelerator (or wherever the computation is executed). Use ToLiteral to bring it to the local CPU.

Note: GoMLX started using a different execution pathway that uses ShapedBuffer instead, and GlobalData is no longer being used right now. But I'm not sure if it will be used later when adding support for multiple accelerator devices.

func (*GlobalData) Client

func (gd *GlobalData) Client() *Client

func (*GlobalData) Finalize

func (gd *GlobalData) Finalize()

Finalize implements Finalizer.

func (*GlobalData) IsNil

func (gd *GlobalData) IsNil() bool

IsNil returns whether either the Client object is nil or the contained C pointer.

func (*GlobalData) Shape

func (gd *GlobalData) Shape() (shape shapes.Shape, err error)

Shape retrieves the shape of data stored globally.

func (*GlobalData) SplitTuple

func (gd *GlobalData) SplitTuple() (gds []*GlobalData, err error)

SplitTuple splits the tuple GlobalData into its various components. This invalidates the current global data.

func (*GlobalData) ToLiteral

func (gd *GlobalData) ToLiteral() (*Literal, error)

ToLiteral transfers from stored in accelerator server, to CPU. Returns a Literal with the transferred data.

type Literal

type Literal struct {
	// contains filtered or unexported fields
}

Literal represents a value stored in the local CPU, in C++ heap, to interact with XLA.

func FromOnDeviceBuffer

func FromOnDeviceBuffer(buffer *OnDeviceBuffer) (*Literal, error)

func NewLiteralFromShape

func NewLiteralFromShape(shape shapes.Shape) *Literal

NewLiteralFromShape returns a new Literal with the given shape and uninitialized data.

func NewLiteralTuple

func NewLiteralTuple(literals []*Literal) *Literal

NewLiteralTuple creates a tuple from the individual literals. Flat is copied.

func NewLiteralWithZeros

func NewLiteralWithZeros(shape shapes.Shape) *Literal

func (*Literal) Bytes

func (l *Literal) Bytes() []byte

Bytes returns the same memory as Data, but the raw slice of bytes, with the proper size.

func (*Literal) Data

func (l *Literal) Data() any

Data returns a slice of the data for the corresponding DType (see package types). The underlying data is not owned/managed by Go, but it is mutable.

Notice if the Literal goes out-of-scope and is finalized, the underlying data gets freed without Go knowing about , and the returned slice may become invalid. Careful to make sure Literal doesn't get out of scope --

func (*Literal) Finalize

func (l *Literal) Finalize()

Finalize implements Finalizer.

func (*Literal) IsNil

func (l *Literal) IsNil() bool

IsNil returns whether either the Client object is nil or the contained C pointer.

func (*Literal) Refresh added in v0.2.1

func (l *Literal) Refresh()

Refresh will re-read the data pointer and size information from the C++ interface. Mostly used for debugging.

func (*Literal) Shape

func (l *Literal) Shape() (shape shapes.Shape)

Shape returns the shape of the literal.

func (*Literal) SplitTuple

func (l *Literal) SplitTuple() []*Literal

SplitTuple in the Literal into multiple literals. Unlike with GlobalData, this destroys the underlying Literal.

func (*Literal) ToGlobalData

func (l *Literal) ToGlobalData(client *Client) (*GlobalData, error)

ToGlobalData transfers the literal to the global accelerator server. GlobalData is no longer used in our execution path, for now this is deprecated.

func (*Literal) ToOnDeviceBuffer

func (l *Literal) ToOnDeviceBuffer(client *Client, deviceOrdinal int) (*OnDeviceBuffer, error)

ToOnDeviceBuffer returns a OnDeviceBuffer allocated on the device (the number of the device is given by deviceOrdinal).

type NodeType

type NodeType int32

NodeType enumerate the various types of Nodes that can be converted to XLA.

const (
	InvalidNode NodeType = iota

	ConstantNode
	IotaNode
	ParameterNode
	ConvertTypeNode
	WhereNode
	TupleNode
	GetTupleElementNode
	ReshapeNode
	BroadcastNode
	BroadcastInDimNode
	ReduceSumNode
	ReduceMaxNode
	ReduceMultiplyNode
	SliceNode
	PadNode
	GatherNode
	ScatterNode
	ConcatenateNode
	ConvGeneralDilatedNode
	ReverseNode
	TransposeNode
	ReduceWindowNode
	SelectAndScatterNode
	BatchNormTrainingNode
	BatchNormInferenceNode
	BatchNormGradNode
	DotGeneralNode
	ArgMinMaxNode
	FftNode

	AbsNode
	NegNode
	ExpNode
	Expm1Node
	FloorNode
	CeilNode
	RoundNode
	LogNode
	Log1pNode
	LogicalNotNode
	LogisticNode
	SignNode
	ClzNode
	CosNode
	SinNode
	TanhNode
	SqrtNode
	RsqrtNode
	ImagNode
	RealNode
	ConjNode

	AddNode
	MulNode
	SubNode
	DivNode
	RemNode // Notice XLA implements Mod, not IEEE754 Remainder operation.
	AndNode
	OrNode
	XorNode
	DotNode
	MinNode
	MaxNode
	PowNode
	ComplexNode

	EqualNode
	NotEqualNode
	GreaterOrEqualNode
	GreaterThanNode
	LessOrEqualNode
	LessThanNode
	EqualTotalOrderNode
	NotEqualTotalOrderNode
	GreaterOrEqualTotalOrderNode
	GreaterThanTotalOrderNode
	LessOrEqualTotalOrderNode
	LessThanTotalOrderNode

	RngBitGeneratorNode
	RngNormalNode
	RngUniformNode
)

NodeType values need to be exactly the same as defined in the C++ code, in `c/gomlx/node.h` TODO: keep those in sync using some generator script.

func (NodeType) String

func (i NodeType) String() string

type OnDeviceBuffer

type OnDeviceBuffer struct {
	// contains filtered or unexported fields
}

OnDeviceBuffer represents a value stored in an execution device (CPU or accelerator).

Notice there can be multiple accelerator devices, DeviceOrdinal will inform which device this is stored in.

func (*OnDeviceBuffer) Client

func (b *OnDeviceBuffer) Client() *Client

func (*OnDeviceBuffer) DeviceOrdinal

func (b *OnDeviceBuffer) DeviceOrdinal() int

DeviceOrdinal returns the ordinal number of the device -- an id in case there are several replicas of the device (? XLA documentation is not clear about this, just guessing).

func (*OnDeviceBuffer) Finalize

func (b *OnDeviceBuffer) Finalize()

Finalize implements Finalizer.

func (*OnDeviceBuffer) IsNil

func (b *OnDeviceBuffer) IsNil() bool

IsNil returns whether OnDeviceBuffer holds no data.

func (*OnDeviceBuffer) Shape

func (b *OnDeviceBuffer) Shape() shapes.Shape

func (*OnDeviceBuffer) String

func (b *OnDeviceBuffer) String() string

func (*OnDeviceBuffer) SubTree

func (b *OnDeviceBuffer) SubTree(path []int) (*OnDeviceBuffer, error)

SubTree retrieves an element from a nested tuple (tree) OnDeviceBuffer.

type RandomAlgorithm added in v0.4.0

type RandomAlgorithm int

RandomAlgorithm should be aligned with constants in `xla_data.proto` file is the XLA code base.

Each random algorithm entails a different shape of `initialState` that needs to be fed to `RngBitGenerator`.

See details and reference of the algorithms: https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator

Unfortunately there is no documented way of figuring out what is the initial state for `RngDefault` algorithm.

const (
	// RngDefault is the back-end specific algorithm with back-end specific shape requirements.
	// There doesn't seem to be any automatic way of figuring out what it takes for `initialState`.
	RngDefault RandomAlgorithm = iota

	// RngThreeFry counter-based PRNG algorithm. The initial_state shape is U64[2] with arbitrary values.
	// [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf).
	RngThreeFry

	// RngPhilox algorithm to generate random numbers in parallel. The initial_state shape is `U64[3]` with arbitrary values.
	// [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf).
	RngPhilox
)

func (RandomAlgorithm) String added in v0.4.0

func (i RandomAlgorithm) String() string

type SerializedNode

type SerializedNode struct {
	Type       NodeType
	NodeInputs []int32  // Index to other nodes that are used as inputs.
	Literal    *Literal // If a Literal (constant) is involved in the operation.

	Int   int          // Used for any static integer inputs.
	Shape shapes.Shape // If a shape is used as a static input.
	Str   string       // Used for any static string argument.
	Ints  []int        // List of integer numbers.
	Float float32      // For a float parameter.
}

SerializedNode represents a graph.Node with its arguments. This can then be passed to the XLA C++ library. Only used by ops implementers.

type StableHLO

type StableHLO struct {
	// contains filtered or unexported fields
}

StableHLO is a wrapper for the C++ `StableHLOHolder`.

func NewStableHLO

func NewStableHLO(cPtr *C.StableHLOHolder) *StableHLO

NewStableHLO creates the wrapper.

func NewStableHLOFromSerialized

func NewStableHLOFromSerialized(serialized []byte) (*StableHLO, error)

func (*StableHLO) Finalize

func (shlo *StableHLO) Finalize()

Finalize implements Finalizer.

func (*StableHLO) IsNil

func (shlo *StableHLO) IsNil() bool

IsNil checks whether the computation is nil or it's C underlying object.

func (*StableHLO) Serialize

func (shlo *StableHLO) Serialize(filePath string) error

Serialize to bytecode that can presumably be used by PjRT and IREE, as well as embedded in one of the TensorFlow SavedModel formats.(??)

It serializes to the given file path.

func (*StableHLO) SerializeWithVersion added in v0.7.2

func (shlo *StableHLO) SerializeWithVersion(fileDescriptor uintptr, version string) error

SerializeWithVersion to bytecode that can presumably be used by PjRT and IREE, as well as embedded in one of the TensorFlow SavedModel formats.(??)

It serializes to the given file descriptor, presumable a file opened for writing.

For version, the usual is to use the value returned by StableHLOCurrentVersion.

func (*StableHLO) String

func (shlo *StableHLO) String() string

String generates a human-readable version of the StableHLO.

type Status

type Status struct {
	// contains filtered or unexported fields
}

Status is a wrapper for `xla::Status`.

func NewStatus

func NewStatus(unsafeStatus unsafe.Pointer) *Status

NewStatus creates a Status object that owns the underlying C.XlaStatus.

func (*Status) Code

func (s *Status) Code() ErrorCode

func (*Status) Error

func (s *Status) Error() error

func (*Status) ErrorMessage

func (s *Status) ErrorMessage() string

func (*Status) Finalize

func (s *Status) Finalize()

Finalize implements Finalizer.

func (*Status) IsNil

func (s *Status) IsNil() bool

IsNil returns whether either the Client object is nil or the contained C pointer.

func (*Status) Ok

func (s *Status) Ok() bool

func (*Status) UnsafeCPtr

func (s *Status) UnsafeCPtr() unsafe.Pointer

UnsafeCPtr returns the underlying C pointer converted to unsafe.Pointer.

type UnsafeCPointer

type UnsafeCPointer struct {
	P unsafe.Pointer
}

UnsafeCPointer holds a C pointers, and makes sure it is freed (`C.free`) when it is garbage collected or when Finalize is called.

Because since Go 1.20 forward declared C types (not completely known) are no longer supported as generic parameters, we need to keep those pointers as unsafe pointers, and have to be manually cast.

func NewUnsafeCPointer

func NewUnsafeCPointer(p unsafe.Pointer) *UnsafeCPointer

NewUnsafeCPointer holds a forward declared C pointer as `unsafe.Pointer`, and makes sure it is freed (`C.free`) when it is garbage collected or when WrapperWithDestructor.Finalize is called.

Note since Go 1.20 forward declared C types (not completely known) are no longer supported as type parameters -- even though we only use the pointers to them. See https://groups.google.com/g/golang-nuts/c/h75BwBsz4YA/m/FLBIjgFBBQAJ for details.

func (*UnsafeCPointer) Finalize

func (w *UnsafeCPointer) Finalize()

Finalize implements Finalizer.

func (*UnsafeCPointer) IsNil

func (w *UnsafeCPointer) IsNil() bool

IsNil returns whether the AutoCPointer is nil or what it's pointing.

func (*UnsafeCPointer) UnsafePtr

func (w *UnsafeCPointer) UnsafePtr() unsafe.Pointer

UnsafePtr returns the pointer to T cast as `unsafe.Pointer`.

type WrapperWithDestructor

type WrapperWithDestructor[T any] struct {
	Data T
	// contains filtered or unexported fields
}

WrapperWithDestructor wraps a pointer to an arbitrary type and adds a finalize method.

func NewWrapperWithDestructor

func NewWrapperWithDestructor[T any](data T, destructor func(data T)) (w *WrapperWithDestructor[T])

NewWrapperWithDestructor creates a WrapperWithDestructor to type T, using the given destructor to finalize the object.

The `destructor` will only be called once, even if `Finalize()` is called manually. The wrapper sets the destructor to nil after the first time `Finalize()` is called.

There are no synchronization mechanisms, manually calling `Finalize()` concurrently is undefined.

func (*WrapperWithDestructor[T]) Empty

func (w *WrapperWithDestructor[T]) Empty() bool

Empty returns true if either w is nil, or it's contents has already been finalized.

func (*WrapperWithDestructor[T]) Finalize

func (w *WrapperWithDestructor[T]) Finalize()

Finalize frees the pointer held. Notice this version is not concurrency safe -- but then Finalize should be called only once anyway.

Once Finalize is called, the object cannot be re-used, it will be forever marked as empty.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL