Documentation
¶
Overview ¶
Package pjrt implements a Go wrapper for the PJRT_C_API.
Index ¶
- Constants
- func AlignedAlloc(size, alignment uintptr) unsafe.Pointer
- func AlignedFree(ptr unsafe.Pointer)
- func AvailablePlugins() (pluginsPaths map[string]string)
- func BufferToArray[T dtypes.Supported](buffer *Buffer) (flatValues []T, dimensions []int, err error)
- func BufferToScalar[T dtypes.Supported](b *Buffer) (value T, err error)
- func BuffersAlive() int64
- func LoadedExecutablesAlive() int64
- func RegisterPreloadedPlugin(name string, api uintptr) error
- func ScalarToRaw[T dtypes.Supported](value T) ([]byte, dtypes.DType, []int)
- func SuppressAbseilLoggingHack(fn func())
- type Buffer
- func ArrayToBuffer[T dtypes.Supported](client *Client, flatValues []T, dimensions ...int) (b *Buffer, err error)
- func ScalarToBuffer[T dtypes.Supported](client *Client, value T) (b *Buffer, err error)
- func ScalarToBufferOnDeviceNum[T dtypes.Supported](client *Client, deviceNum int, value T) (b *Buffer, err error)
- func (b *Buffer) Client() *Client
- func (b *Buffer) DType() (dtype dtypes.DType, err error)
- func (b *Buffer) Data() (flat any, err error)
- func (b *Buffer) Destroy() error
- func (b *Buffer) Device() (device *Device, err error)
- func (b *Buffer) Dimensions() (dims []int, err error)
- func (b *Buffer) IsShared() bool
- func (b *Buffer) Size() (int, error)
- func (b *Buffer) ToFlatDataAndDimensions() (flat any, dimensions []int, err error)
- func (b *Buffer) ToHost(dst []byte) error
- func (b *Buffer) UnsafePointer() (unsafe.Pointer, error)
- type BufferFromHostConfig
- func (b *BufferFromHostConfig) Done() (*Buffer, error)
- func (b *BufferFromHostConfig) FromFlatDataWithDimensions(flat any, dimensions []int) *BufferFromHostConfig
- func (b *BufferFromHostConfig) FromRawData(data []byte, dtype dtypes.DType, dimensions []int) *BufferFromHostConfig
- func (b *BufferFromHostConfig) ToDevice(device *Device) *BufferFromHostConfig
- func (b *BufferFromHostConfig) ToDeviceNum(deviceNum int) *BufferFromHostConfig
- type Client
- func (c *Client) AddressableDevices() []*Device
- func (c *Client) BufferFromHost() *BufferFromHostConfig
- func (c *Client) Compile() *CompileConfig
- func (c *Client) CreateViewOfDeviceBuffer(rawData unsafe.Pointer, dtype dtypes.DType, dimensions []int, ...) (*Buffer, error)
- func (c *Client) Destroy() error
- func (c *Client) Devices() ([]*Device, error)
- func (c *Client) NewSharedBuffer(dtype dtypes.DType, dimensions []int, device ...*Device) (buffer *Buffer, flat any, err error)
- func (c *Client) NumForDevice(device *Device) int
- func (c *Client) Platform() string
- func (c *Client) PlatformVersion() string
- func (c *Client) Plugin() *Plugin
- func (c *Client) ProcessIndex() int
- func (c *Client) String() string
- type CompileConfig
- type Device
- type DeviceDescription
- type Event
- type Executable
- type ExecutableMemoryUsageStats
- type ExecutionConfig
- func (c *ExecutionConfig) Donate(inputsIndices ...int) *ExecutionConfig
- func (c *ExecutionConfig) DonateAll() *ExecutionConfig
- func (c *ExecutionConfig) DonateNone() *ExecutionConfig
- func (c *ExecutionConfig) Done() ([]*Buffer, error)
- func (c *ExecutionConfig) OnDevices(devices ...*Device) *ExecutionConfig
- func (c *ExecutionConfig) OnDevicesByNum(devicesNum ...int) *ExecutionConfig
- func (c *ExecutionConfig) SetDonate(donate []bool) *ExecutionConfig
- type LoadedExecutable
- type NamedValuesMap
- type PJRT_Buffer_MemoryLayout_Type
- type PJRT_Error_Code
- type PJRT_Extension_Type
- type PJRT_HostBufferSemantics
- type PJRT_NamedValue_Type
- type Plugin
- type XlaComputation
Constants ¶
const ( // PJRTPluginPathsEnv is the name of the environment variable that define the search paths for plugins. PJRTPluginPathsEnv = "PJRT_PLUGIN_LIBRARY_PATH" // GetPJRTApiFunctionName is the name of the function exported by PJRT plugins that returns the API. GetPJRTApiFunctionName = "GetPjrtApi" )
const BufferAlignment = 64
BufferAlignment is the default alignment required for memory shared with CPU PJRT. See AlignedAlloc and FreeAlloc.
const EnvXlaDebugOptions = "XLA_DEBUG_OPTIONS"
EnvXlaDebugOptions is an environment variable that can be defined to set XLA DebugOptions proto when compiling a program.
Variables ¶
This section is empty.
Functions ¶
func AlignedAlloc ¶ added in v0.5.0
AlignedAlloc assumes that malloc/calloc already aligns to 8 bytes. And that alignment is a multiple of 8. The pointer returned must be freed with AlignedFree.
The allocation is filled with 0s.
func AlignedFree ¶ added in v0.5.0
AlignedFree frees an allocation created with AlignedAlloc.
func AvailablePlugins ¶
AvailablePlugins searches for available plugins in the standard directories and returns a map from their name to their paths.
Plugins are searched in the PJRT_PLUGIN_LIBRARY_PATH directory -- or directories, if it is a ":" separated list. If it is not set it will search in "/usr/local/lib/gomlx/pjrt" and the standard libraries directories of the system (in linux in LD_LIBRARY_PATH and /etc/ld.so.conf file, in Darwin it also searches in DYLD_LIBRARY_PATH) in that order.
If there are plugins with the same name but different versions in different directories, it respects the order of the directories given by PJRT_PLUGIN_LIBRARY_PATH or by the system.
func BufferToArray ¶
func BufferToArray[T dtypes.Supported](buffer *Buffer) (flatValues []T, dimensions []int, err error)
BufferToArray transfers the buffer to an array defined by a slice with its flat values, and returns also its underlying dimensions.
func BufferToScalar ¶
BufferToScalar is a generic function that transfer a Buffer back to host as a scalar of the given type.
func BuffersAlive ¶ added in v0.5.0
func BuffersAlive() int64
BuffersAlive returns the number of PJRT Buffers in memory and currently tracked by gopjrt.
func LoadedExecutablesAlive ¶ added in v0.5.0
func LoadedExecutablesAlive() int64
LoadedExecutablesAlive returns a count of the numbers of LoadedExecutables currently in memory and tracked by gopjrt.
func RegisterPreloadedPlugin ¶ added in v0.4.9
RegisterPreloadedPlugin can be used to register a PJRT plugin that has been pre-linked (dynamically or statically) with the binary -- as opposed to the usual loadPlugin using `dlopen` after the program has started.
It takes as input the name to be associated with the plugin and an unsafe pointer (uintptr) to the API table returned by the plugin's C.GetPjrtApi().
See sub-packages `cpu/static` and `cpu/dynamic` for examples of usage.
func ScalarToRaw ¶
ScalarToRaw generates the raw values needed by BufferFromHostConfig.FromRawData to feed a simple scalar value.
func SuppressAbseilLoggingHack ¶
func SuppressAbseilLoggingHack(fn func())
SuppressAbseilLoggingHack prevents some irrelevant logging from PJRT plugins, by duplicating the file descriptor (fd) 2, reassigning the new fd to Go's os.Stderr, and then closing fd 2, so PJRT plugins won't be able to write anything.
Usually this is only needed during creation of the Client of the CPU plugin. So you can just wrap that part.
Now since many things rely on fd 2 being stderr, it only does that, executes fn given and reverts the change.
The issue of doing this permanently is that Go's default panic handler outputs the stack tracke the the fd 2, and this would suppress that as well.
It's an overkill, because this may prevent valid logging, in some truly exceptional situation, but it's the only solution I can think of for now. See discussion in https://github.com/abseil/abseil-cpp/discussions/1700
Since file descriptors are a global resource, this function is not reentrant, and you should make sure no two goroutines are calling this at the same time.
Types ¶
type Buffer ¶
type Buffer struct {
// contains filtered or unexported fields
}
Buffer is a reference to an array storage (buffer) on device.
func ArrayToBuffer ¶
func ArrayToBuffer[T dtypes.Supported](client *Client, flatValues []T, dimensions ...int) (b *Buffer, err error)
ArrayToBuffer transfer a slice to a Buffer on the default device. The underlying array is provided with its flat values as a slice, and the underlying dimensions.
It is a shortcut to Client.BufferFromHost call with default parameters. If you need more control where the value will be used you'll have to use Client.BufferFromHost instead.
func ScalarToBuffer ¶
ScalarToBuffer transfers the scalar value to a Buffer on the default device.
It is a shortcut to Client.BufferFromHost call with default parameters. If you need more control where the value will be used you'll have to use Client.BufferFromHost instead.
func ScalarToBufferOnDeviceNum ¶ added in v0.2.0
func ScalarToBufferOnDeviceNum[T dtypes.Supported](client *Client, deviceNum int, value T) (b *Buffer, err error)
ScalarToBufferOnDeviceNum transfers the scalar value to a Buffer on the given device.
It is a shortcut to Client.BufferFromHost call with default parameters. If you need more control where the value will be used you'll have to use Client.BufferFromHost instead.
func (*Buffer) Data ¶ added in v0.5.0
Data returns the flat slice pointing to the underlying storage data for the buffer.
This is an undocumented feature of PJRT and likely only works for CPU platforms. The flat slice returned is only valid while the buffer is alive.
func (*Buffer) Destroy ¶
Destroy the Buffer, release resources, and Buffer is no longer valid. This is automatically called if Buffer is garbage collected.
func (*Buffer) Dimensions ¶
Dimensions of the Buffer. Returned slice is owned by the buffer, to avoid creating a copy. Don't change it.
func (*Buffer) IsShared ¶ added in v0.5.0
IsShared returns whether this buffer was created with Client.NewSharedBuffer. These buffers cannot be donated in execution.
func (*Buffer) Size ¶
Size returns the size in bytes if required for the buffer to be transferred with ToHost.
func (*Buffer) ToFlatDataAndDimensions ¶ added in v0.2.0
ToFlatDataAndDimensions transfers the buffer to a flat slice and returns also its underlying dimensions.
Similar to the generic BufferToArray[T], but this returns an anonymous typed (`any`) flat slice instead of using generics.
func (*Buffer) ToHost ¶
ToHost transfers the contents of buffer stored on device to the host. The space in dst has to hold enough space (see Buffer.Size) to hold the required data, or an error is returned.
This always request a major-to-minor layout, the assumption of the layout in host memory -- TPUs are known to reorganize the layout.
func (*Buffer) UnsafePointer ¶ added in v0.5.0
UnsafePointer returns platform-dependent address for the given buffer that is often but not guaranteed to be the physical/device address. Consider using the more convenient DirectAccess.
Probably, this should only be used by CPU plugins.
To be on the safe side, only use this if Client.HasSharedBuffers is true. It uses the undocumented PJRT_Buffer_UnsafePointer.
type BufferFromHostConfig ¶
type BufferFromHostConfig struct {
// contains filtered or unexported fields
}
BufferFromHostConfig is used to configure the transfer from a buffer from host memory to on-device memory, it is created with Client.BufferFromHost.
The data to transfer from host can be set up with one of the following methods:
- FromRawData: it takes as inputs the bytes and shape (dtype and dimensions). - FromFlatDataWithDimensions: it takes as inputs a flat slice and shape (dtype and dimensions).
The device defaults to 0, but it can be configured with BufferFromHostConfig.ToDevice or BufferFromHostConfig.ToDeviceNum.
At the end call BufferFromHostConfig.Done to actually initiate the transfer.
TODO: Implement async transfers, arbitrary memory layout, etc.
func (*BufferFromHostConfig) Done ¶
func (b *BufferFromHostConfig) Done() (*Buffer, error)
Done will use the configuration to start the transfer from host to device. It's synchronous: it awaits the transfer to finish and then returns.
func (*BufferFromHostConfig) FromFlatDataWithDimensions ¶ added in v0.2.0
func (b *BufferFromHostConfig) FromFlatDataWithDimensions(flat any, dimensions []int) *BufferFromHostConfig
FromFlatDataWithDimensions configures the data to come from a flat slice of the desired data type, and the underlying dimensions. The flat slice size must match the product of the dimension. If no dimensions are given, it is assumed to be a scalar, and flat should have length 1.
func (*BufferFromHostConfig) FromRawData ¶
func (b *BufferFromHostConfig) FromRawData(data []byte, dtype dtypes.DType, dimensions []int) *BufferFromHostConfig
FromRawData configures the data from host to copy: a pointer to bytes that must be kept alive (and constant) during the call. The parameters dtype and dimensions provide the shape of the array.
func (*BufferFromHostConfig) ToDevice ¶
func (b *BufferFromHostConfig) ToDevice(device *Device) *BufferFromHostConfig
ToDevice configures which device to copy the host data to.
If left un-configured, it will pick the first device returned by Client.AddressableDevices.
You can also provide a device by their index in Client.AddressableDevices.
func (*BufferFromHostConfig) ToDeviceNum ¶ added in v0.2.0
func (b *BufferFromHostConfig) ToDeviceNum(deviceNum int) *BufferFromHostConfig
ToDeviceNum configures which device to copy the host data to, given a deviceNum pointing to the device in the list returned by Client.AddressableDevices.
If left un-configured, it will pick the first device returned by Client.AddressableDevices.
You can also provide a device by their index in Client.AddressableDevices.
type Client ¶
type Client struct {
// contains filtered or unexported fields
}
Client manages the resources of one device: its buffers, compilation and execution of HLO code.
func (*Client) AddressableDevices ¶
AddressableDevices returns a list of devices addressable to the client. Addressable devices are those that the client can issue commands to. All devices are addressable in a single-process environment (Client.ProcessIndex() == 0).
The returned slice and the Devices are owned by the Client, don't change it.
func (*Client) BufferFromHost ¶
func (c *Client) BufferFromHost() *BufferFromHostConfig
BufferFromHost creates an on-device buffer with the contents copied (optionally reused, if device is CPU) from the given host buffer.
It returns a BufferFromHostConfig that must be furthered configured -- at least the host data to transfer must be given. Call BufferFromHostConfig.Done to trigger the transfer.
func (*Client) Compile ¶
func (c *Client) Compile() *CompileConfig
Compile turn a StableHLO program into a "LoadedExecutable" that is the executable runner.
There are different formats of input, and many different compilation options [1], so this returns a CompilationConfig that must be furthered configured. At the very least the program must be given: see CompileConfig.WithComputation or CompileConfig.WithHLO. Then the call to CompileConfig.Done triggers the compilation into a "LoadedExecutable".
[1] The original compilation options is defined as the proto CompileOptionsProto: https://github.com/openxla/xla/blob/main/xla/pjrt/compile_options.proto . But the proto itself is not documented, instead see documentation in the C++ xla::CompileOptions class defined in: https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_executable.h .
func (*Client) CreateViewOfDeviceBuffer ¶ added in v0.5.0
func (c *Client) CreateViewOfDeviceBuffer(rawData unsafe.Pointer, dtype dtypes.DType, dimensions []int, device ...*Device) (*Buffer, error)
CreateViewOfDeviceBuffer creates a PJRT Buffer that is backed by storage on the same device given by the caller as flatData and shape. Consider using the simpler API NewSharedBuffer.
Different PJRT may have different requirements on alignment, but for the CPU PJRT the library provide AlignedAlloc and AlignedFree, that can be used to allocate the aligned storage space.
Example of how a typical usage where the same buffer is reused as input in loop:
dtype := dtypes.Float32 dimensions := []int{batchSize, sequenceLength, 384} rawData := pjrt.AlignedAlloc(dtype.SizeForDimensions(dimensions...), pjrt.BufferAlignment) defer pjrt.AlignedFree(rawData) buf := client.CreateViewOfDeviceBuffer(rawData, dtype, dimensions) flat := unsafe.Slice((*float32)(storage), batchSize*sequenceLength*384) for _, batch := range batches { // ... set flat values // ... use buf as input when executing a PJRT program }
If device is not given (at most one can be given), the first device available for the client is used.
The naming comes from PJRT and is unfortunate, since it's the name from PJRT's perspective (PJRT view of a users device buffer). Probably, it should have been better named by "ShareDeviceBuffer" or something similar.
This may not be implemented for all hardware (or all PJRT plugins).
This can be useful to avoid the copy of values, by mutating directly in the memory shared with PJRT, to be used as input to a computation.
See: dtypes.SizeForDimensions() to calculate the size for an arbitrary shape; AlignedAlloc, AlignedFree and BufferAlignment (a constant with the required alignment size) to allocate and free aligned storage.
func (*Client) Destroy ¶
Destroy the client, release resources, and Client is no longer valid. This is automatically called if Client is garbage collected.
func (*Client) Devices ¶
Devices returns a list of all devices visible to the runtime, including addressable // and non-addressable devices.
func (*Client) NewSharedBuffer ¶ added in v0.5.0
func (c *Client) NewSharedBuffer(dtype dtypes.DType, dimensions []int, device ...*Device) (buffer *Buffer, flat any, err error)
NewSharedBuffer returns a buffer that can be used for execution and share the underlying memory space with the host/local, which can be read and mutated directly.
Shared buffers cannot be donated to executions.
The buffer should not be mutated while it is used by an execution.
When the buffer is finalized, the shared memory is also de-allocated.
It returns a handle to the buffer and a slice of the corresponding data type pointing to the shared data.
func (*Client) NumForDevice ¶ added in v0.2.0
NumForDevice returns the "deviceNum" for the given device. The value deviceNum is an index to Client.AddressableDevices, and can be used in several other methods.
It returns -1 if device not found in Client.AddressableDevices.
func (*Client) PlatformVersion ¶
PlatformVersion returns the version of the client platform.
func (*Client) Plugin ¶ added in v0.4.3
Plugin returns the Plugin from which the Client was created.
func (*Client) ProcessIndex ¶
ProcessIndex returns the process index of the client platform. Always 0 in single-process settings.
type CompileConfig ¶
type CompileConfig struct {
// contains filtered or unexported fields
}
CompileConfig is created with Client.Compile, and is a "builder pattern" to configure a compilation call.
At a minimum one has to set the program to compile (use CompileConfig.WithHLO or CompileConfig.WithComputation). Optionally, many other options can be set.
Once finished call CompileConfig.Done to trigger the compilation and get back a LoadedExecutable or an error.
TODO: expose all (or more) configuration options with "WithX" methods.
func (*CompileConfig) Done ¶
func (cc *CompileConfig) Done() (*LoadedExecutable, error)
Done triggers the compilation of the program. If the compilation succeeds a LoadedExecutable is returned, otherwise an error is returned.
func (*CompileConfig) WithComputation ¶
func (cc *CompileConfig) WithComputation(computation XlaComputation) *CompileConfig
WithComputation configures the program to the xlabuilder.XlaComputation -- see xlabuilder package. Behind the scenes it is an HLO program (HloModule proto), but this handles the details. If plugin.UseStableHLO is set to true, it takes the StableHLO (in MLIR format) program instead.
Either WithHLO, WithStableHLO or WithComputation must be set, before Done can be called to trigger the computation, but not both. It panics if more than one version of the program is called.
It returns itself (CompileConfig) to allow cascading configuration calls.
func (*CompileConfig) WithHLO ¶
func (cc *CompileConfig) WithHLO(serialized []byte) *CompileConfig
WithHLO configures the program to the serialized HLO (HloModule proto). The serialized proto blob can allocated in Go or in C/C++, and must be kept alive (and unchanged) until the call to Done is returned.
Either WithHLO, WithStableHLO or WithComputation must be set, before Done can be called to trigger the computation, but not both. It panics if more than one version of the program is called.
It returns itself (CompileConfig) to allow cascading configuration calls.
func (*CompileConfig) WithStableHLO ¶ added in v0.4.3
func (cc *CompileConfig) WithStableHLO(serialized []byte) *CompileConfig
WithStableHLO configures the program with a StableHLO program, encoded as a serialized `mlir.ModuleOp` object. The serialized proto blob can allocated in Go or in C/C++, and must be kept alive (and unchanged) until the call to Done is returned.
Either WithHLO, WithStableHLO or WithComputation must be set, before Done can be called to trigger the computation, but not both. It panics if more than one version of the program is called.
type Device ¶
type Device struct {
// contains filtered or unexported fields
}
Device is a lightweight reference to a Device managed by a Client -- it doesn't own the underlying object.
(Explanation by Gemini) The meaning of Device in PJRT/XLA is a bit nuanced, it refers to an individual unit of processing capable of executing XLA computations.
Here's how it breaks down in different scenarios:
- Single-GPU System: In a computer with a single GPU, the device typically corresponds to that entire GPU. - Multi-GPU System: In a system with multiple GPUs, each individual GPU is considered a separate device. You would typically create multiple PjrtClient instances, each associated with a different GPU device. - TPU Pods/Slices: On Google Cloud TPUs, a device can represent either a whole TPU chip (with multiple cores) or a slice of a TPU chip (a subset of cores). - CPU: In the context of the CPU plugin, the device usually refers to the entire CPU or a specific NUMA node (a group of CPU cores with faster access to a particular region of memory).
Device Selection: When creating a PjrtClient, you can either let PjRT choose a default device or explicitly specify which device to use. The PjrtClient_Devices function can help you list the available devices.
Device-Specific Operations: Some PJRT operations (like querying device attributes or transferring data to/from the device) are device-specific and operate on individual PjrtDevice objects (obtained from the PjrtClient_Devices list).
func (*Device) GetDescription ¶
func (d *Device) GetDescription() (*DeviceDescription, error)
GetDescription get a DeviceDescription object associated with this device.
func (*Device) IsAddressable ¶
IsAddressable returns whether the device is addressable by this client.
func (*Device) LocalHardwareId ¶
LocalHardwareId returns an opaque hardware ID, e.g., the CUDA device number. In general, not guaranteed to be dense, and -1 if undefined.
type DeviceDescription ¶
type DeviceDescription struct {
// contains filtered or unexported fields
}
DeviceDescription may be associated with an actual device (via PJRT_Device_GetDescription), but they can also be used to describe a device that isn't currently available to the plugin. This is useful for compiling executables without hardware available, which can then be serialized and written somewhere durable, and then loaded and run on actual hardware later.
func (*DeviceDescription) DebugString ¶
func (dDesc *DeviceDescription) DebugString() string
DebugString suitable for logging when errors occur. Should be verbose enough to describe the current device unambiguously.
func (*DeviceDescription) ProcessIndex ¶
func (dDesc *DeviceDescription) ProcessIndex() int
ProcessIndex returns the index of the process that this device belongs to, i.e. is addressable from. This is not always identical to PJRT_Client_ProcessIndex in a multi-process setting, where each client can see devices from all processes, but only a subset of them are addressable and have the same process_index as the client.
type Event ¶
type Event struct {
// contains filtered or unexported fields
}
Event is a reference that a future event (when something is done), and it is created by asynchronous calls.
While it is exported here if someone needs it to implement some extension, usually users of the Go's pjrt package don't need to use it directly: the various methods of the API handles the events.
func (*Event) Await ¶
Await blocks the calling thread until `event` is ready, then returns the error, if any.
func (*Event) AwaitAndFree ¶
AwaitAndFree blocks the calling thread until `event` is ready, destroy the even and then returns the error, if any.
An error destroying the even is simply reported in the logs, but not returned.
type Executable ¶
type Executable struct {
// contains filtered or unexported fields
}
Executable is a reference that describes a compiled program -- it cannot be executed, only introspected.
This is usually not directly used: the LoadedExecutable when create automatically extracts the Executable and related information.
func (*Executable) Destroy ¶
func (e *Executable) Destroy() error
Destroy the Executable, release resources, and Executable is no longer valid. This is automatically called if Executable is garbage collected.
func (*Executable) GetMemoryStats ¶ added in v0.6.1
func (e *Executable) GetMemoryStats() (onDevice, onHost ExecutableMemoryUsageStats, err error)
GetMemoryStats returns the sizes (in bytes) for the compiled code, inputs, outputs, aliases and temporary memory used both in host and on device.
This can be used to estimate memory requirements for the program.
func (*Executable) Name ¶
func (e *Executable) Name() (string, error)
Name returns the name of the executable.
func (*Executable) NumOutputs ¶
func (e *Executable) NumOutputs() (int, error)
NumOutputs returns the number of outputs for the given executable.
type ExecutableMemoryUsageStats ¶ added in v0.6.1
type ExecutableMemoryUsageStats struct {
GeneratedCode, Inputs, Outputs, Aliases, Temporary int64
}
ExecutableMemoryUsageStats reports the static memory usage for a compiled program, in bytes. The on-device memory needed to run an executable is at least: GeneratedCode + Inputs + Outputs - Aliases + Temporary. See ExecutableMemoryUsageStats.Requirements.
Aliases is how much memory of the input is reused as output (?).
The documentation is sparse in XLA, here are the links:
- xla::CompiledMemoryStats: https://github.com/openxla/xla/blob/2fff53249ed49930de14b235f50ed2235e69df8b/xla/pjrt/pjrt_executable.h#L284
- PJRT C API:https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api.h#L1668
func (ExecutableMemoryUsageStats) Requirements ¶ added in v0.6.1
func (m ExecutableMemoryUsageStats) Requirements() int64
Requirements returns an estimate of memory requirements for the executable.
type ExecutionConfig ¶
type ExecutionConfig struct {
// contains filtered or unexported fields
}
ExecutionConfig holds the configuration for executing a LoadedExecutable. It is created with LoadedExecutable.Execute.
After configuring it, call Done to actually trigger the execution.
TODO: add support for multi-device execution, with some inputs shared across devices, and some per-device specific.
func (*ExecutionConfig) Donate ¶
func (c *ExecutionConfig) Donate(inputsIndices ...int) *ExecutionConfig
Donate marks the inputs (referred to its indices) to be donated.
This can be called more than once for different inputsIndices.
Donated inputs become invalid after the execution. Often donated arguments are also the output of a computation and are updated in place. See discussion in https://jax.readthedocs.io/en/latest/faq.html#buffer-donation
func (*ExecutionConfig) DonateAll ¶
func (c *ExecutionConfig) DonateAll() *ExecutionConfig
DonateAll marks all inputs to be "donated".
Donated inputs become invalid after the execution. Often donated arguments are also the output of a computation and are updated in place. See discussion in https://jax.readthedocs.io/en/latest/faq.html#buffer-donation
func (*ExecutionConfig) DonateNone ¶ added in v0.2.1
func (c *ExecutionConfig) DonateNone() *ExecutionConfig
DonateNone makes all inputs to be marked as non-donatable. This is the default.
Donated inputs become invalid after the execution. Often donated arguments are also the output of a computation and are updated in place. See discussion in https://jax.readthedocs.io/en/latest/faq.html#buffer-donation
func (*ExecutionConfig) Done ¶
func (c *ExecutionConfig) Done() ([]*Buffer, error)
func (*ExecutionConfig) OnDevices ¶ added in v0.2.0
func (c *ExecutionConfig) OnDevices(devices ...*Device) *ExecutionConfig
OnDevices selects which devices to execute. Usually only 1, but more than one can be configured.
The default is to use the first addressable device. See also OnDevicesByNum.
func (*ExecutionConfig) OnDevicesByNum ¶ added in v0.2.0
func (c *ExecutionConfig) OnDevicesByNum(devicesNum ...int) *ExecutionConfig
OnDevicesByNum selects which devices to execute. The devicesNum point to the device in the list returned by Client.AddressableDevices. Usually only 1, but more than one can be configured.
The default is to use the first addressable device. See also OnDevices.
func (*ExecutionConfig) SetDonate ¶ added in v0.2.1
func (c *ExecutionConfig) SetDonate(donate []bool) *ExecutionConfig
SetDonate set the donate status of all inputs in one call. The default is no input is donated.
Donated inputs become invalid after the execution. Often donated arguments are also the output of a computation and are updated in place. See discussion in https://jax.readthedocs.io/en/latest/faq.html#buffer-donation
type LoadedExecutable ¶
type LoadedExecutable struct { // Name of the executable. Name string // NumOutputs of the executable. NumOutputs int // OnDeviceMemoryUsageStats, OnHostMemoryUsageStats can be used to estimate the required memory usage for the executable on device (and on host). OnDeviceMemoryUsageStats, OnHostMemoryUsageStats ExecutableMemoryUsageStats // contains filtered or unexported fields }
LoadedExecutable is a reference to a compiled program ready to be executed.
All public attributes are read-only.
func (*LoadedExecutable) Destroy ¶
func (e *LoadedExecutable) Destroy() error
Destroy the LoadedExecutable, release resources, and LoadedExecutable is no longer valid. This is automatically called if LoadedExecutable is garbage collected.
func (*LoadedExecutable) Execute ¶
func (e *LoadedExecutable) Execute(inputs ...*Buffer) *ExecutionConfig
Execute the compiled computation. It returns an ExecutionConfig for further configuration. Call ExecutionConfig.Done and the computation is executed.
It provides good defaults, so in the common case nothing else is needed:
- Using the first addressable device. - All input buffers marked as not-donated (see discussion in https://jax.readthedocs.io/en/latest/faq.html#buffer-donation) input buffers.
See ExecutionConfig for more details and options.
Example:
outputBuffers, err := loadedExec.Execute(inputBuffer).Done()
type NamedValuesMap ¶
NamedValuesMap map names to any of the supported named values types defined by PJRT_NamedValue_Type.
type PJRT_Buffer_MemoryLayout_Type ¶
type PJRT_Buffer_MemoryLayout_Type int
PJRT_Buffer_MemoryLayout_Type is mapping of the corresponded C enum defined in pjrt_c_api.h.
const ( // PJRT_Buffer_MemoryLayout_Type_Tiled is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Buffer_MemoryLayout_Type_Tiled PJRT_Buffer_MemoryLayout_Type = 0 // PJRT_Buffer_MemoryLayout_Type_Strides is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Buffer_MemoryLayout_Type_Strides PJRT_Buffer_MemoryLayout_Type = 1 )
type PJRT_Error_Code ¶
type PJRT_Error_Code int
PJRT_Error_Code is mapping of the corresponded C enum defined in pjrt_c_api.h.
const ( // PJRT_Error_Code_CANCELLED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_CANCELLED PJRT_Error_Code = 1 // PJRT_Error_Code_UNKNOWN is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_UNKNOWN PJRT_Error_Code = 2 // PJRT_Error_Code_INVALID_ARGUMENT is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_INVALID_ARGUMENT PJRT_Error_Code = 3 // PJRT_Error_Code_DEADLINE_EXCEEDED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_DEADLINE_EXCEEDED PJRT_Error_Code = 4 // PJRT_Error_Code_NOT_FOUND is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_NOT_FOUND PJRT_Error_Code = 5 // PJRT_Error_Code_ALREADY_EXISTS is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_ALREADY_EXISTS PJRT_Error_Code = 6 // PJRT_Error_Code_PERMISSION_DENIED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_PERMISSION_DENIED PJRT_Error_Code = 7 // PJRT_Error_Code_RESOURCE_EXHAUSTED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_RESOURCE_EXHAUSTED PJRT_Error_Code = 8 // PJRT_Error_Code_FAILED_PRECONDITION is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_FAILED_PRECONDITION PJRT_Error_Code = 9 // PJRT_Error_Code_ABORTED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_ABORTED PJRT_Error_Code = 10 // PJRT_Error_Code_OUT_OF_RANGE is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_OUT_OF_RANGE PJRT_Error_Code = 11 // PJRT_Error_Code_UNIMPLEMENTED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_UNIMPLEMENTED PJRT_Error_Code = 12 // PJRT_Error_Code_INTERNAL is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_INTERNAL PJRT_Error_Code = 13 // PJRT_Error_Code_UNAVAILABLE is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_UNAVAILABLE PJRT_Error_Code = 14 // PJRT_Error_Code_DATA_LOSS is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_DATA_LOSS PJRT_Error_Code = 15 // PJRT_Error_Code_UNAUTHENTICATED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Error_Code_UNAUTHENTICATED PJRT_Error_Code = 16 )
type PJRT_Extension_Type ¶
type PJRT_Extension_Type int
PJRT_Extension_Type is mapping of the corresponded C enum defined in pjrt_c_api.h.
const ( // PJRT_Extension_Type_Gpu_Custom_Call is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Extension_Type_Gpu_Custom_Call PJRT_Extension_Type = 0 // PJRT_Extension_Type_Profiler is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Extension_Type_Profiler PJRT_Extension_Type = 1 // PJRT_Extension_Type_Custom_Partitioner is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Extension_Type_Custom_Partitioner PJRT_Extension_Type = 2 // PJRT_Extension_Type_Stream is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Extension_Type_Stream PJRT_Extension_Type = 3 // PJRT_Extension_Type_Layouts is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Extension_Type_Layouts PJRT_Extension_Type = 4 // PJRT_Extension_Type_FFI is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Extension_Type_FFI PJRT_Extension_Type = 5 // PJRT_Extension_Type_MemoryDescriptions is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Extension_Type_MemoryDescriptions PJRT_Extension_Type = 6 // PJRT_Extension_Type_Triton is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_Extension_Type_Triton PJRT_Extension_Type = 7 )
type PJRT_HostBufferSemantics ¶
type PJRT_HostBufferSemantics int
PJRT_HostBufferSemantics is mapping of the corresponded C enum defined in pjrt_c_api.h.
const ( // PJRT_HostBufferSemantics_kImmutableOnlyDuringCall is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. // The runtime may not hold references to `data` after the call to // `PJRT_Client_BufferFromHostBuffer` completes. The caller promises that // `data` is immutable and will not be freed only for the duration of the // PJRT_Client_BufferFromHostBuffer call. PJRT_HostBufferSemantics_kImmutableOnlyDuringCall PJRT_HostBufferSemantics = 0 // PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. // The runtime may hold onto `data` after the call to // `PJRT_Client_BufferFromHostBuffer` // returns while the runtime completes a transfer to the device. The caller // promises not to mutate or free `data` until the transfer completes, at // which point `done_with_host_buffer` will be triggered. PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes PJRT_HostBufferSemantics = 1 // PJRT_HostBufferSemantics_kImmutableZeroCopy is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. // The PjRtBuffer may alias `data` internally and the runtime may use the // `data` contents as long as the buffer is alive. The runtime promises not // to mutate contents of the buffer (i.e. it will not use it for aliased // output buffers). The caller promises to keep `data` alive and not to mutate // its contents as long as the buffer is alive; to notify the caller that the // buffer may be freed, the runtime will call `done_with_host_buffer` when the // PjRtBuffer is freed. PJRT_HostBufferSemantics_kImmutableZeroCopy PJRT_HostBufferSemantics = 2 // PJRT_HostBufferSemantics_kMutableZeroCopy is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. // The PjRtBuffer may alias `data` internally and the runtime may use the // `data` contents as long as the buffer is alive. The runtime is allowed // to mutate contents of the buffer (i.e. use it for aliased output // buffers). The caller promises to keep `data` alive and not to mutate its // contents as long as the buffer is alive (otherwise it could be a data // race with the runtime); to notify the caller that the buffer may be // freed, the runtime will call `on_done_with_host_buffer` when the // PjRtBuffer is freed. On non-CPU platforms this acts identically to // kImmutableUntilTransferCompletes. PJRT_HostBufferSemantics_kMutableZeroCopy PJRT_HostBufferSemantics = 3 )
type PJRT_NamedValue_Type ¶
type PJRT_NamedValue_Type int
PJRT_NamedValue_Type is mapping of the corresponded C enum defined in pjrt_c_api.h.
const ( // PJRT_NamedValue_kString is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_NamedValue_kString PJRT_NamedValue_Type = 0 // PJRT_NamedValue_kInt64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_NamedValue_kInt64 PJRT_NamedValue_Type = 1 // PJRT_NamedValue_kInt64List is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_NamedValue_kInt64List PJRT_NamedValue_Type = 2 // PJRT_NamedValue_kFloat is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_NamedValue_kFloat PJRT_NamedValue_Type = 3 // PJRT_NamedValue_kBool is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h. PJRT_NamedValue_kBool PJRT_NamedValue_Type = 4 )
type Plugin ¶
type Plugin struct { // UseStableHLO configures the plugin clients to convert XlaBuilder programs from "HLO" to "StableHLO" // before compilation. The "StableHLO" (encoded as MLIR) is the more recent // "intermediary representation" program language. // // Setting to true incurs in a conversion step from "HLO" to "StableHLO" during compilation. But some // PJRT will only support "StableHLO" (namely Apple Metal PJRT). // // Default is true, but it can be changed by setting the environment variable "GOPJRT_NO_STABLE_HLO=1" // // Most people don't need to worry about this, it should be an implementation detail. UseStableHLO bool // contains filtered or unexported fields }
Plugin represents a loaded PJRT plugin that can be used to compile and execute StableHLO code.
Loaded plugins are singletons per platform and cached (GetPlugin will return a pointer to the same plugin if called with the same platform or its aliases).
Plugins are searched in the PJRT_PLUGIN_LIBRARY_PATH directory -- or directories, if it is a ":" separated list.
Design document: https://docs.google.com/document/d/1Qdptisz1tUPGn1qFAVgCV2omnfjN01zoQPwKLdlizas/edit
func GetPlugin ¶
GetPlugin returns the plugin with the given name -- typically it reflect the platform, e.g: "cpu" or "gpu". But one can also give the full path to the `.so` file with the plugin.
Loaded plugins are singletons and cached (GetPlugin will return a pointer to the same plugin if called with the same name or its aliases).
Plugins are searched in the PJRT_PLUGIN_LIBRARY_PATH directory -- or directories, if it is a ":" separated list. If it is not set it will search in `/usr/local/lib/gomlx` and the standard libraries directories of the system (in linux in LD_LIBRARY_CONFIG and /etc/ld.so.conf file).
func (*Plugin) Attributes ¶
func (p *Plugin) Attributes() NamedValuesMap
Attributes returns a NamedValueMap with the attributes returned by the plugin at the time of its initialization.
func (*Plugin) Name ¶
Name returns the name of the plugin, usually it reflects its platform (cpu, gpu, tpu, etc.).
func (*Plugin) NewClient ¶
func (p *Plugin) NewClient(options NamedValuesMap) (*Client, error)
NewClient creates a new Client object to manage available devices. The options (it can be left nil) are plugin specific, and should (but often aren't) documented by the plugins.
type XlaComputation ¶
type XlaComputation interface { // SerializedHLO exports the computation as a serialized `HloModuleProto`. SerializedHLO() *cbuffer.CBuffer // HasStableHLO returns whether StableHLO is supported. HasStableHLO() bool // SerializedStableHLO exports the computation as a StableHLO as an `mlir:ModuleOp`. SerializedStableHLO() (*cbuffer.CBuffer, error) }
XlaComputation is an interface that matches xlabuilder.XlaComputation method needed by PJRT.
Created here to avoid creating a hard dependency to the xlabuilder package.
The returned buffer ownership is transferred to the caller -- the caller must free the buffer later.
Source Files
¶
Directories
¶
Path | Synopsis |
---|---|
cpu
|
|
dynamic
Package dynamic will link (preload) a dynamically loaded library `libpjrt_c_api_cpu_dynamic`, that is used if the user requests a "cpu" plugin.
|
Package dynamic will link (preload) a dynamically loaded library `libpjrt_c_api_cpu_dynamic`, that is used if the user requests a "cpu" plugin. |
static
Package static statically links a CPU PJRT plugin, and registers with the name "cpu".
|
Package static statically links a CPU PJRT plugin, and registers with the name "cpu". |
internal
|
|
cpudynamictest
Package cpudynamictest is just a hack around Go's limitation to use CGO in tests and to avoid cyclic dependency.
|
Package cpudynamictest is just a hack around Go's limitation to use CGO in tests and to avoid cyclic dependency. |
cpustatictest
Package cpustatictest is just a hack around Go's limitation to use CGO in tests and to avoid cyclic dependency.
|
Package cpustatictest is just a hack around Go's limitation to use CGO in tests and to avoid cyclic dependency. |