pjrt

package
v0.6.2 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Feb 26, 2025 License: Apache-2.0 Imports: 24 Imported by: 2

Documentation

Overview

Package pjrt implements a Go wrapper for the PJRT_C_API.

Index

Constants

View Source
const (
	// PJRTPluginPathsEnv is the name of the environment variable that define the search paths for plugins.
	PJRTPluginPathsEnv = "PJRT_PLUGIN_LIBRARY_PATH"

	// GetPJRTApiFunctionName is the name of the function exported by PJRT plugins that returns the API.
	GetPJRTApiFunctionName = "GetPjrtApi"
)
View Source
const BufferAlignment = 64

BufferAlignment is the default alignment required for memory shared with CPU PJRT. See AlignedAlloc and FreeAlloc.

View Source
const EnvXlaDebugOptions = "XLA_DEBUG_OPTIONS"

EnvXlaDebugOptions is an environment variable that can be defined to set XLA DebugOptions proto when compiling a program.

Variables

This section is empty.

Functions

func AlignedAlloc added in v0.5.0

func AlignedAlloc(size, alignment uintptr) unsafe.Pointer

AlignedAlloc assumes that malloc/calloc already aligns to 8 bytes. And that alignment is a multiple of 8. The pointer returned must be freed with AlignedFree.

The allocation is filled with 0s.

func AlignedFree added in v0.5.0

func AlignedFree(ptr unsafe.Pointer)

AlignedFree frees an allocation created with AlignedAlloc.

func AvailablePlugins

func AvailablePlugins() (pluginsPaths map[string]string)

AvailablePlugins searches for available plugins in the standard directories and returns a map from their name to their paths.

Plugins are searched in the PJRT_PLUGIN_LIBRARY_PATH directory -- or directories, if it is a ":" separated list. If it is not set it will search in "/usr/local/lib/gomlx/pjrt" and the standard libraries directories of the system (in linux in LD_LIBRARY_PATH and /etc/ld.so.conf file, in Darwin it also searches in DYLD_LIBRARY_PATH) in that order.

If there are plugins with the same name but different versions in different directories, it respects the order of the directories given by PJRT_PLUGIN_LIBRARY_PATH or by the system.

func BufferToArray

func BufferToArray[T dtypes.Supported](buffer *Buffer) (flatValues []T, dimensions []int, err error)

BufferToArray transfers the buffer to an array defined by a slice with its flat values, and returns also its underlying dimensions.

func BufferToScalar

func BufferToScalar[T dtypes.Supported](b *Buffer) (value T, err error)

BufferToScalar is a generic function that transfer a Buffer back to host as a scalar of the given type.

func BuffersAlive added in v0.5.0

func BuffersAlive() int64

BuffersAlive returns the number of PJRT Buffers in memory and currently tracked by gopjrt.

func LoadedExecutablesAlive added in v0.5.0

func LoadedExecutablesAlive() int64

LoadedExecutablesAlive returns a count of the numbers of LoadedExecutables currently in memory and tracked by gopjrt.

func RegisterPreloadedPlugin added in v0.4.9

func RegisterPreloadedPlugin(name string, api uintptr) error

RegisterPreloadedPlugin can be used to register a PJRT plugin that has been pre-linked (dynamically or statically) with the binary -- as opposed to the usual loadPlugin using `dlopen` after the program has started.

It takes as input the name to be associated with the plugin and an unsafe pointer (uintptr) to the API table returned by the plugin's C.GetPjrtApi().

See sub-packages `cpu/static` and `cpu/dynamic` for examples of usage.

func ScalarToRaw

func ScalarToRaw[T dtypes.Supported](value T) ([]byte, dtypes.DType, []int)

ScalarToRaw generates the raw values needed by BufferFromHostConfig.FromRawData to feed a simple scalar value.

func SuppressAbseilLoggingHack

func SuppressAbseilLoggingHack(fn func())

SuppressAbseilLoggingHack prevents some irrelevant logging from PJRT plugins, by duplicating the file descriptor (fd) 2, reassigning the new fd to Go's os.Stderr, and then closing fd 2, so PJRT plugins won't be able to write anything.

Usually this is only needed during creation of the Client of the CPU plugin. So you can just wrap that part.

Now since many things rely on fd 2 being stderr, it only does that, executes fn given and reverts the change.

The issue of doing this permanently is that Go's default panic handler outputs the stack tracke the the fd 2, and this would suppress that as well.

It's an overkill, because this may prevent valid logging, in some truly exceptional situation, but it's the only solution I can think of for now. See discussion in https://github.com/abseil/abseil-cpp/discussions/1700

Since file descriptors are a global resource, this function is not reentrant, and you should make sure no two goroutines are calling this at the same time.

Types

type Buffer

type Buffer struct {
	// contains filtered or unexported fields
}

Buffer is a reference to an array storage (buffer) on device.

func ArrayToBuffer

func ArrayToBuffer[T dtypes.Supported](client *Client, flatValues []T, dimensions ...int) (b *Buffer, err error)

ArrayToBuffer transfer a slice to a Buffer on the default device. The underlying array is provided with its flat values as a slice, and the underlying dimensions.

It is a shortcut to Client.BufferFromHost call with default parameters. If you need more control where the value will be used you'll have to use Client.BufferFromHost instead.

func ScalarToBuffer

func ScalarToBuffer[T dtypes.Supported](client *Client, value T) (b *Buffer, err error)

ScalarToBuffer transfers the scalar value to a Buffer on the default device.

It is a shortcut to Client.BufferFromHost call with default parameters. If you need more control where the value will be used you'll have to use Client.BufferFromHost instead.

func ScalarToBufferOnDeviceNum added in v0.2.0

func ScalarToBufferOnDeviceNum[T dtypes.Supported](client *Client, deviceNum int, value T) (b *Buffer, err error)

ScalarToBufferOnDeviceNum transfers the scalar value to a Buffer on the given device.

It is a shortcut to Client.BufferFromHost call with default parameters. If you need more control where the value will be used you'll have to use Client.BufferFromHost instead.

func (*Buffer) Client added in v0.2.0

func (b *Buffer) Client() *Client

Client returns the client that created this Buffer.

func (*Buffer) DType

func (b *Buffer) DType() (dtype dtypes.DType, err error)

DType of the Buffer (PJRT_Buffer_ElementType).

func (*Buffer) Data added in v0.5.0

func (b *Buffer) Data() (flat any, err error)

Data returns the flat slice pointing to the underlying storage data for the buffer.

This is an undocumented feature of PJRT and likely only works for CPU platforms. The flat slice returned is only valid while the buffer is alive.

func (*Buffer) Destroy

func (b *Buffer) Destroy() error

Destroy the Buffer, release resources, and Buffer is no longer valid. This is automatically called if Buffer is garbage collected.

func (*Buffer) Device added in v0.2.0

func (b *Buffer) Device() (device *Device, err error)

Device returns the device the buffer is stored.

func (*Buffer) Dimensions

func (b *Buffer) Dimensions() (dims []int, err error)

Dimensions of the Buffer. Returned slice is owned by the buffer, to avoid creating a copy. Don't change it.

func (*Buffer) IsShared added in v0.5.0

func (b *Buffer) IsShared() bool

IsShared returns whether this buffer was created with Client.NewSharedBuffer. These buffers cannot be donated in execution.

func (*Buffer) Size

func (b *Buffer) Size() (int, error)

Size returns the size in bytes if required for the buffer to be transferred with ToHost.

func (*Buffer) ToFlatDataAndDimensions added in v0.2.0

func (b *Buffer) ToFlatDataAndDimensions() (flat any, dimensions []int, err error)

ToFlatDataAndDimensions transfers the buffer to a flat slice and returns also its underlying dimensions.

Similar to the generic BufferToArray[T], but this returns an anonymous typed (`any`) flat slice instead of using generics.

func (*Buffer) ToHost

func (b *Buffer) ToHost(dst []byte) error

ToHost transfers the contents of buffer stored on device to the host. The space in dst has to hold enough space (see Buffer.Size) to hold the required data, or an error is returned.

This always request a major-to-minor layout, the assumption of the layout in host memory -- TPUs are known to reorganize the layout.

func (*Buffer) UnsafePointer added in v0.5.0

func (b *Buffer) UnsafePointer() (unsafe.Pointer, error)

UnsafePointer returns platform-dependent address for the given buffer that is often but not guaranteed to be the physical/device address. Consider using the more convenient DirectAccess.

Probably, this should only be used by CPU plugins.

To be on the safe side, only use this if Client.HasSharedBuffers is true. It uses the undocumented PJRT_Buffer_UnsafePointer.

type BufferFromHostConfig

type BufferFromHostConfig struct {
	// contains filtered or unexported fields
}

BufferFromHostConfig is used to configure the transfer from a buffer from host memory to on-device memory, it is created with Client.BufferFromHost.

The data to transfer from host can be set up with one of the following methods:

- FromRawData: it takes as inputs the bytes and shape (dtype and dimensions). - FromFlatDataWithDimensions: it takes as inputs a flat slice and shape (dtype and dimensions).

The device defaults to 0, but it can be configured with BufferFromHostConfig.ToDevice or BufferFromHostConfig.ToDeviceNum.

At the end call BufferFromHostConfig.Done to actually initiate the transfer.

TODO: Implement async transfers, arbitrary memory layout, etc.

func (*BufferFromHostConfig) Done

func (b *BufferFromHostConfig) Done() (*Buffer, error)

Done will use the configuration to start the transfer from host to device. It's synchronous: it awaits the transfer to finish and then returns.

func (*BufferFromHostConfig) FromFlatDataWithDimensions added in v0.2.0

func (b *BufferFromHostConfig) FromFlatDataWithDimensions(flat any, dimensions []int) *BufferFromHostConfig

FromFlatDataWithDimensions configures the data to come from a flat slice of the desired data type, and the underlying dimensions. The flat slice size must match the product of the dimension. If no dimensions are given, it is assumed to be a scalar, and flat should have length 1.

func (*BufferFromHostConfig) FromRawData

func (b *BufferFromHostConfig) FromRawData(data []byte, dtype dtypes.DType, dimensions []int) *BufferFromHostConfig

FromRawData configures the data from host to copy: a pointer to bytes that must be kept alive (and constant) during the call. The parameters dtype and dimensions provide the shape of the array.

func (*BufferFromHostConfig) ToDevice

func (b *BufferFromHostConfig) ToDevice(device *Device) *BufferFromHostConfig

ToDevice configures which device to copy the host data to.

If left un-configured, it will pick the first device returned by Client.AddressableDevices.

You can also provide a device by their index in Client.AddressableDevices.

func (*BufferFromHostConfig) ToDeviceNum added in v0.2.0

func (b *BufferFromHostConfig) ToDeviceNum(deviceNum int) *BufferFromHostConfig

ToDeviceNum configures which device to copy the host data to, given a deviceNum pointing to the device in the list returned by Client.AddressableDevices.

If left un-configured, it will pick the first device returned by Client.AddressableDevices.

You can also provide a device by their index in Client.AddressableDevices.

type Client

type Client struct {
	// contains filtered or unexported fields
}

Client manages the resources of one device: its buffers, compilation and execution of HLO code.

func (*Client) AddressableDevices

func (c *Client) AddressableDevices() []*Device

AddressableDevices returns a list of devices addressable to the client. Addressable devices are those that the client can issue commands to. All devices are addressable in a single-process environment (Client.ProcessIndex() == 0).

The returned slice and the Devices are owned by the Client, don't change it.

func (*Client) BufferFromHost

func (c *Client) BufferFromHost() *BufferFromHostConfig

BufferFromHost creates an on-device buffer with the contents copied (optionally reused, if device is CPU) from the given host buffer.

It returns a BufferFromHostConfig that must be furthered configured -- at least the host data to transfer must be given. Call BufferFromHostConfig.Done to trigger the transfer.

func (*Client) Compile

func (c *Client) Compile() *CompileConfig

Compile turn a StableHLO program into a "LoadedExecutable" that is the executable runner.

There are different formats of input, and many different compilation options [1], so this returns a CompilationConfig that must be furthered configured. At the very least the program must be given: see CompileConfig.WithComputation or CompileConfig.WithHLO. Then the call to CompileConfig.Done triggers the compilation into a "LoadedExecutable".

[1] The original compilation options is defined as the proto CompileOptionsProto: https://github.com/openxla/xla/blob/main/xla/pjrt/compile_options.proto . But the proto itself is not documented, instead see documentation in the C++ xla::CompileOptions class defined in: https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_executable.h .

func (*Client) CreateViewOfDeviceBuffer added in v0.5.0

func (c *Client) CreateViewOfDeviceBuffer(rawData unsafe.Pointer, dtype dtypes.DType, dimensions []int, device ...*Device) (*Buffer, error)

CreateViewOfDeviceBuffer creates a PJRT Buffer that is backed by storage on the same device given by the caller as flatData and shape. Consider using the simpler API NewSharedBuffer.

Different PJRT may have different requirements on alignment, but for the CPU PJRT the library provide AlignedAlloc and AlignedFree, that can be used to allocate the aligned storage space.

Example of how a typical usage where the same buffer is reused as input in loop:

dtype := dtypes.Float32
dimensions := []int{batchSize, sequenceLength, 384}
rawData := pjrt.AlignedAlloc(dtype.SizeForDimensions(dimensions...), pjrt.BufferAlignment)
defer pjrt.AlignedFree(rawData)
buf := client.CreateViewOfDeviceBuffer(rawData, dtype, dimensions)
flat := unsafe.Slice((*float32)(storage), batchSize*sequenceLength*384)
for _, batch := range batches {
	// ... set flat values
	// ... use buf as input when executing a PJRT program
}

If device is not given (at most one can be given), the first device available for the client is used.

The naming comes from PJRT and is unfortunate, since it's the name from PJRT's perspective (PJRT view of a users device buffer). Probably, it should have been better named by "ShareDeviceBuffer" or something similar.

This may not be implemented for all hardware (or all PJRT plugins).

This can be useful to avoid the copy of values, by mutating directly in the memory shared with PJRT, to be used as input to a computation.

See: dtypes.SizeForDimensions() to calculate the size for an arbitrary shape; AlignedAlloc, AlignedFree and BufferAlignment (a constant with the required alignment size) to allocate and free aligned storage.

func (*Client) Destroy

func (c *Client) Destroy() error

Destroy the client, release resources, and Client is no longer valid. This is automatically called if Client is garbage collected.

func (*Client) Devices

func (c *Client) Devices() ([]*Device, error)

Devices returns a list of all devices visible to the runtime, including addressable // and non-addressable devices.

func (*Client) NewSharedBuffer added in v0.5.0

func (c *Client) NewSharedBuffer(dtype dtypes.DType, dimensions []int, device ...*Device) (buffer *Buffer, flat any, err error)

NewSharedBuffer returns a buffer that can be used for execution and share the underlying memory space with the host/local, which can be read and mutated directly.

Shared buffers cannot be donated to executions.

The buffer should not be mutated while it is used by an execution.

When the buffer is finalized, the shared memory is also de-allocated.

It returns a handle to the buffer and a slice of the corresponding data type pointing to the shared data.

func (*Client) NumForDevice added in v0.2.0

func (c *Client) NumForDevice(device *Device) int

NumForDevice returns the "deviceNum" for the given device. The value deviceNum is an index to Client.AddressableDevices, and can be used in several other methods.

It returns -1 if device not found in Client.AddressableDevices.

func (*Client) Platform

func (c *Client) Platform() string

Platform returns the name of the client platform.

func (*Client) PlatformVersion

func (c *Client) PlatformVersion() string

PlatformVersion returns the version of the client platform.

func (*Client) Plugin added in v0.4.3

func (c *Client) Plugin() *Plugin

Plugin returns the Plugin from which the Client was created.

func (*Client) ProcessIndex

func (c *Client) ProcessIndex() int

ProcessIndex returns the process index of the client platform. Always 0 in single-process settings.

func (*Client) String

func (c *Client) String() string

String implements fmt.Stringer.

type CompileConfig

type CompileConfig struct {
	// contains filtered or unexported fields
}

CompileConfig is created with Client.Compile, and is a "builder pattern" to configure a compilation call.

At a minimum one has to set the program to compile (use CompileConfig.WithHLO or CompileConfig.WithComputation). Optionally, many other options can be set.

Once finished call CompileConfig.Done to trigger the compilation and get back a LoadedExecutable or an error.

TODO: expose all (or more) configuration options with "WithX" methods.

func (*CompileConfig) Done

func (cc *CompileConfig) Done() (*LoadedExecutable, error)

Done triggers the compilation of the program. If the compilation succeeds a LoadedExecutable is returned, otherwise an error is returned.

func (*CompileConfig) WithComputation

func (cc *CompileConfig) WithComputation(computation XlaComputation) *CompileConfig

WithComputation configures the program to the xlabuilder.XlaComputation -- see xlabuilder package. Behind the scenes it is an HLO program (HloModule proto), but this handles the details. If plugin.UseStableHLO is set to true, it takes the StableHLO (in MLIR format) program instead.

Either WithHLO, WithStableHLO or WithComputation must be set, before Done can be called to trigger the computation, but not both. It panics if more than one version of the program is called.

It returns itself (CompileConfig) to allow cascading configuration calls.

func (*CompileConfig) WithHLO

func (cc *CompileConfig) WithHLO(serialized []byte) *CompileConfig

WithHLO configures the program to the serialized HLO (HloModule proto). The serialized proto blob can allocated in Go or in C/C++, and must be kept alive (and unchanged) until the call to Done is returned.

Either WithHLO, WithStableHLO or WithComputation must be set, before Done can be called to trigger the computation, but not both. It panics if more than one version of the program is called.

It returns itself (CompileConfig) to allow cascading configuration calls.

func (*CompileConfig) WithStableHLO added in v0.4.3

func (cc *CompileConfig) WithStableHLO(serialized []byte) *CompileConfig

WithStableHLO configures the program with a StableHLO program, encoded as a serialized `mlir.ModuleOp` object. The serialized proto blob can allocated in Go or in C/C++, and must be kept alive (and unchanged) until the call to Done is returned.

Either WithHLO, WithStableHLO or WithComputation must be set, before Done can be called to trigger the computation, but not both. It panics if more than one version of the program is called.

type Device

type Device struct {
	// contains filtered or unexported fields
}

Device is a lightweight reference to a Device managed by a Client -- it doesn't own the underlying object.

(Explanation by Gemini) The meaning of Device in PJRT/XLA is a bit nuanced, it refers to an individual unit of processing capable of executing XLA computations.

Here's how it breaks down in different scenarios:

- Single-GPU System: In a computer with a single GPU, the device typically corresponds to that entire GPU. - Multi-GPU System: In a system with multiple GPUs, each individual GPU is considered a separate device. You would typically create multiple PjrtClient instances, each associated with a different GPU device. - TPU Pods/Slices: On Google Cloud TPUs, a device can represent either a whole TPU chip (with multiple cores) or a slice of a TPU chip (a subset of cores). - CPU: In the context of the CPU plugin, the device usually refers to the entire CPU or a specific NUMA node (a group of CPU cores with faster access to a particular region of memory).

Device Selection: When creating a PjrtClient, you can either let PjRT choose a default device or explicitly specify which device to use. The PjrtClient_Devices function can help you list the available devices.

Device-Specific Operations: Some PJRT operations (like querying device attributes or transferring data to/from the device) are device-specific and operate on individual PjrtDevice objects (obtained from the PjrtClient_Devices list).

func (*Device) GetDescription

func (d *Device) GetDescription() (*DeviceDescription, error)

GetDescription get a DeviceDescription object associated with this device.

func (*Device) IsAddressable

func (d *Device) IsAddressable() (bool, error)

IsAddressable returns whether the device is addressable by this client.

func (*Device) LocalHardwareId

func (d *Device) LocalHardwareId() int

LocalHardwareId returns an opaque hardware ID, e.g., the CUDA device number. In general, not guaranteed to be dense, and -1 if undefined.

type DeviceDescription

type DeviceDescription struct {
	// contains filtered or unexported fields
}

DeviceDescription may be associated with an actual device (via PJRT_Device_GetDescription), but they can also be used to describe a device that isn't currently available to the plugin. This is useful for compiling executables without hardware available, which can then be serialized and written somewhere durable, and then loaded and run on actual hardware later.

func (*DeviceDescription) DebugString

func (dDesc *DeviceDescription) DebugString() string

DebugString suitable for logging when errors occur. Should be verbose enough to describe the current device unambiguously.

func (*DeviceDescription) ProcessIndex

func (dDesc *DeviceDescription) ProcessIndex() int

ProcessIndex returns the index of the process that this device belongs to, i.e. is addressable from. This is not always identical to PJRT_Client_ProcessIndex in a multi-process setting, where each client can see devices from all processes, but only a subset of them are addressable and have the same process_index as the client.

type Event

type Event struct {
	// contains filtered or unexported fields
}

Event is a reference that a future event (when something is done), and it is created by asynchronous calls.

While it is exported here if someone needs it to implement some extension, usually users of the Go's pjrt package don't need to use it directly: the various methods of the API handles the events.

func (*Event) Await

func (e *Event) Await() error

Await blocks the calling thread until `event` is ready, then returns the error, if any.

func (*Event) AwaitAndFree

func (e *Event) AwaitAndFree() error

AwaitAndFree blocks the calling thread until `event` is ready, destroy the even and then returns the error, if any.

An error destroying the even is simply reported in the logs, but not returned.

func (*Event) Destroy

func (e *Event) Destroy() error

Destroy the Event, release resources, and Event is no longer valid. This is automatically called if Event is garbage collected.

type Executable

type Executable struct {
	// contains filtered or unexported fields
}

Executable is a reference that describes a compiled program -- it cannot be executed, only introspected.

This is usually not directly used: the LoadedExecutable when create automatically extracts the Executable and related information.

func (*Executable) Destroy

func (e *Executable) Destroy() error

Destroy the Executable, release resources, and Executable is no longer valid. This is automatically called if Executable is garbage collected.

func (*Executable) GetMemoryStats added in v0.6.1

func (e *Executable) GetMemoryStats() (onDevice, onHost ExecutableMemoryUsageStats, err error)

GetMemoryStats returns the sizes (in bytes) for the compiled code, inputs, outputs, aliases and temporary memory used both in host and on device.

This can be used to estimate memory requirements for the program.

func (*Executable) Name

func (e *Executable) Name() (string, error)

Name returns the name of the executable.

func (*Executable) NumOutputs

func (e *Executable) NumOutputs() (int, error)

NumOutputs returns the number of outputs for the given executable.

type ExecutableMemoryUsageStats added in v0.6.1

type ExecutableMemoryUsageStats struct {
	GeneratedCode, Inputs, Outputs, Aliases, Temporary int64
}

ExecutableMemoryUsageStats reports the static memory usage for a compiled program, in bytes. The on-device memory needed to run an executable is at least: GeneratedCode + Inputs + Outputs - Aliases + Temporary. See ExecutableMemoryUsageStats.Requirements.

Aliases is how much memory of the input is reused as output (?).

The documentation is sparse in XLA, here are the links:

func (ExecutableMemoryUsageStats) Requirements added in v0.6.1

func (m ExecutableMemoryUsageStats) Requirements() int64

Requirements returns an estimate of memory requirements for the executable.

type ExecutionConfig

type ExecutionConfig struct {
	// contains filtered or unexported fields
}

ExecutionConfig holds the configuration for executing a LoadedExecutable. It is created with LoadedExecutable.Execute.

After configuring it, call Done to actually trigger the execution.

TODO: add support for multi-device execution, with some inputs shared across devices, and some per-device specific.

func (*ExecutionConfig) Donate

func (c *ExecutionConfig) Donate(inputsIndices ...int) *ExecutionConfig

Donate marks the inputs (referred to its indices) to be donated.

This can be called more than once for different inputsIndices.

Donated inputs become invalid after the execution. Often donated arguments are also the output of a computation and are updated in place. See discussion in https://jax.readthedocs.io/en/latest/faq.html#buffer-donation

func (*ExecutionConfig) DonateAll

func (c *ExecutionConfig) DonateAll() *ExecutionConfig

DonateAll marks all inputs to be "donated".

Donated inputs become invalid after the execution. Often donated arguments are also the output of a computation and are updated in place. See discussion in https://jax.readthedocs.io/en/latest/faq.html#buffer-donation

func (*ExecutionConfig) DonateNone added in v0.2.1

func (c *ExecutionConfig) DonateNone() *ExecutionConfig

DonateNone makes all inputs to be marked as non-donatable. This is the default.

Donated inputs become invalid after the execution. Often donated arguments are also the output of a computation and are updated in place. See discussion in https://jax.readthedocs.io/en/latest/faq.html#buffer-donation

func (*ExecutionConfig) Done

func (c *ExecutionConfig) Done() ([]*Buffer, error)

func (*ExecutionConfig) OnDevices added in v0.2.0

func (c *ExecutionConfig) OnDevices(devices ...*Device) *ExecutionConfig

OnDevices selects which devices to execute. Usually only 1, but more than one can be configured.

The default is to use the first addressable device. See also OnDevicesByNum.

func (*ExecutionConfig) OnDevicesByNum added in v0.2.0

func (c *ExecutionConfig) OnDevicesByNum(devicesNum ...int) *ExecutionConfig

OnDevicesByNum selects which devices to execute. The devicesNum point to the device in the list returned by Client.AddressableDevices. Usually only 1, but more than one can be configured.

The default is to use the first addressable device. See also OnDevices.

func (*ExecutionConfig) SetDonate added in v0.2.1

func (c *ExecutionConfig) SetDonate(donate []bool) *ExecutionConfig

SetDonate set the donate status of all inputs in one call. The default is no input is donated.

Donated inputs become invalid after the execution. Often donated arguments are also the output of a computation and are updated in place. See discussion in https://jax.readthedocs.io/en/latest/faq.html#buffer-donation

type LoadedExecutable

type LoadedExecutable struct {

	// Name of the executable.
	Name string

	// NumOutputs of the executable.
	NumOutputs int

	// OnDeviceMemoryUsageStats, OnHostMemoryUsageStats can be used to estimate the required memory usage for the executable on device (and on host).
	OnDeviceMemoryUsageStats, OnHostMemoryUsageStats ExecutableMemoryUsageStats
	// contains filtered or unexported fields
}

LoadedExecutable is a reference to a compiled program ready to be executed.

All public attributes are read-only.

func (*LoadedExecutable) Destroy

func (e *LoadedExecutable) Destroy() error

Destroy the LoadedExecutable, release resources, and LoadedExecutable is no longer valid. This is automatically called if LoadedExecutable is garbage collected.

func (*LoadedExecutable) Execute

func (e *LoadedExecutable) Execute(inputs ...*Buffer) *ExecutionConfig

Execute the compiled computation. It returns an ExecutionConfig for further configuration. Call ExecutionConfig.Done and the computation is executed.

It provides good defaults, so in the common case nothing else is needed:

- Using the first addressable device. - All input buffers marked as not-donated (see discussion in https://jax.readthedocs.io/en/latest/faq.html#buffer-donation) input buffers.

See ExecutionConfig for more details and options.

Example:

outputBuffers, err := loadedExec.Execute(inputBuffer).Done()

type NamedValuesMap

type NamedValuesMap map[string]any

NamedValuesMap map names to any of the supported named values types defined by PJRT_NamedValue_Type.

type PJRT_Buffer_MemoryLayout_Type

type PJRT_Buffer_MemoryLayout_Type int

PJRT_Buffer_MemoryLayout_Type is mapping of the corresponded C enum defined in pjrt_c_api.h.

const (
	// PJRT_Buffer_MemoryLayout_Type_Tiled is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Buffer_MemoryLayout_Type_Tiled PJRT_Buffer_MemoryLayout_Type = 0

	// PJRT_Buffer_MemoryLayout_Type_Strides is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Buffer_MemoryLayout_Type_Strides PJRT_Buffer_MemoryLayout_Type = 1
)

type PJRT_Error_Code

type PJRT_Error_Code int

PJRT_Error_Code is mapping of the corresponded C enum defined in pjrt_c_api.h.

const (
	// PJRT_Error_Code_CANCELLED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_CANCELLED PJRT_Error_Code = 1

	// PJRT_Error_Code_UNKNOWN is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_UNKNOWN PJRT_Error_Code = 2

	// PJRT_Error_Code_INVALID_ARGUMENT is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_INVALID_ARGUMENT PJRT_Error_Code = 3

	// PJRT_Error_Code_DEADLINE_EXCEEDED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_DEADLINE_EXCEEDED PJRT_Error_Code = 4

	// PJRT_Error_Code_NOT_FOUND is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_NOT_FOUND PJRT_Error_Code = 5

	// PJRT_Error_Code_ALREADY_EXISTS is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_ALREADY_EXISTS PJRT_Error_Code = 6

	// PJRT_Error_Code_PERMISSION_DENIED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_PERMISSION_DENIED PJRT_Error_Code = 7

	// PJRT_Error_Code_RESOURCE_EXHAUSTED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_RESOURCE_EXHAUSTED PJRT_Error_Code = 8

	// PJRT_Error_Code_FAILED_PRECONDITION is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_FAILED_PRECONDITION PJRT_Error_Code = 9

	// PJRT_Error_Code_ABORTED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_ABORTED PJRT_Error_Code = 10

	// PJRT_Error_Code_OUT_OF_RANGE is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_OUT_OF_RANGE PJRT_Error_Code = 11

	// PJRT_Error_Code_UNIMPLEMENTED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_UNIMPLEMENTED PJRT_Error_Code = 12

	// PJRT_Error_Code_INTERNAL is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_INTERNAL PJRT_Error_Code = 13

	// PJRT_Error_Code_UNAVAILABLE is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_UNAVAILABLE PJRT_Error_Code = 14

	// PJRT_Error_Code_DATA_LOSS is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_DATA_LOSS PJRT_Error_Code = 15

	// PJRT_Error_Code_UNAUTHENTICATED is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Error_Code_UNAUTHENTICATED PJRT_Error_Code = 16
)

type PJRT_Extension_Type

type PJRT_Extension_Type int

PJRT_Extension_Type is mapping of the corresponded C enum defined in pjrt_c_api.h.

const (
	// PJRT_Extension_Type_Gpu_Custom_Call is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Extension_Type_Gpu_Custom_Call PJRT_Extension_Type = 0

	// PJRT_Extension_Type_Profiler is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Extension_Type_Profiler PJRT_Extension_Type = 1

	// PJRT_Extension_Type_Custom_Partitioner is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Extension_Type_Custom_Partitioner PJRT_Extension_Type = 2

	// PJRT_Extension_Type_Stream is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Extension_Type_Stream PJRT_Extension_Type = 3

	// PJRT_Extension_Type_Layouts is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Extension_Type_Layouts PJRT_Extension_Type = 4

	// PJRT_Extension_Type_FFI is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Extension_Type_FFI PJRT_Extension_Type = 5

	// PJRT_Extension_Type_MemoryDescriptions is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Extension_Type_MemoryDescriptions PJRT_Extension_Type = 6

	// PJRT_Extension_Type_Triton is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_Extension_Type_Triton PJRT_Extension_Type = 7
)

type PJRT_HostBufferSemantics

type PJRT_HostBufferSemantics int

PJRT_HostBufferSemantics is mapping of the corresponded C enum defined in pjrt_c_api.h.

const (
	// PJRT_HostBufferSemantics_kImmutableOnlyDuringCall is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	// The runtime may not hold references to `data` after the call to
	// `PJRT_Client_BufferFromHostBuffer` completes. The caller promises that
	// `data` is immutable and will not be freed only for the duration of the
	// PJRT_Client_BufferFromHostBuffer call.
	PJRT_HostBufferSemantics_kImmutableOnlyDuringCall PJRT_HostBufferSemantics = 0

	// PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	// The runtime may hold onto `data` after the call to
	// `PJRT_Client_BufferFromHostBuffer`
	// returns while the runtime completes a transfer to the device. The caller
	// promises not to mutate or free `data` until the transfer completes, at
	// which point `done_with_host_buffer` will be triggered.
	PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes PJRT_HostBufferSemantics = 1

	// PJRT_HostBufferSemantics_kImmutableZeroCopy is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	// The PjRtBuffer may alias `data` internally and the runtime may use the
	// `data` contents as long as the buffer is alive. The runtime promises not
	// to mutate contents of the buffer (i.e. it will not use it for aliased
	// output buffers). The caller promises to keep `data` alive and not to mutate
	// its contents as long as the buffer is alive; to notify the caller that the
	// buffer may be freed, the runtime will call `done_with_host_buffer` when the
	// PjRtBuffer is freed.
	PJRT_HostBufferSemantics_kImmutableZeroCopy PJRT_HostBufferSemantics = 2

	// PJRT_HostBufferSemantics_kMutableZeroCopy is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	// The PjRtBuffer may alias `data` internally and the runtime may use the
	// `data` contents as long as the buffer is alive. The runtime is allowed
	// to mutate contents of the buffer (i.e. use it for aliased output
	// buffers). The caller promises to keep `data` alive and not to mutate its
	// contents as long as the buffer is alive (otherwise it could be a data
	// race with the runtime); to notify the caller that the buffer may be
	// freed, the runtime will call `on_done_with_host_buffer` when the
	// PjRtBuffer is freed. On non-CPU platforms this acts identically to
	// kImmutableUntilTransferCompletes.
	PJRT_HostBufferSemantics_kMutableZeroCopy PJRT_HostBufferSemantics = 3
)

type PJRT_NamedValue_Type

type PJRT_NamedValue_Type int

PJRT_NamedValue_Type is mapping of the corresponded C enum defined in pjrt_c_api.h.

const (
	// PJRT_NamedValue_kString is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_NamedValue_kString PJRT_NamedValue_Type = 0

	// PJRT_NamedValue_kInt64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_NamedValue_kInt64 PJRT_NamedValue_Type = 1

	// PJRT_NamedValue_kInt64List is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_NamedValue_kInt64List PJRT_NamedValue_Type = 2

	// PJRT_NamedValue_kFloat is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_NamedValue_kFloat PJRT_NamedValue_Type = 3

	// PJRT_NamedValue_kBool is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h.
	PJRT_NamedValue_kBool PJRT_NamedValue_Type = 4
)

type Plugin

type Plugin struct {

	// UseStableHLO configures the plugin clients to convert XlaBuilder programs from "HLO" to "StableHLO"
	// before compilation. The "StableHLO" (encoded as MLIR) is the more recent
	// "intermediary representation" program language.
	//
	// Setting to true incurs in a conversion step from "HLO" to "StableHLO" during compilation. But some
	// PJRT will only support "StableHLO" (namely Apple Metal PJRT).
	//
	// Default is true, but it can be changed by setting the environment variable "GOPJRT_NO_STABLE_HLO=1"
	//
	// Most people don't need to worry about this, it should be an implementation detail.
	UseStableHLO bool
	// contains filtered or unexported fields
}

Plugin represents a loaded PJRT plugin that can be used to compile and execute StableHLO code.

Loaded plugins are singletons per platform and cached (GetPlugin will return a pointer to the same plugin if called with the same platform or its aliases).

Plugins are searched in the PJRT_PLUGIN_LIBRARY_PATH directory -- or directories, if it is a ":" separated list.

Design document: https://docs.google.com/document/d/1Qdptisz1tUPGn1qFAVgCV2omnfjN01zoQPwKLdlizas/edit

func GetPlugin

func GetPlugin(name string) (*Plugin, error)

GetPlugin returns the plugin with the given name -- typically it reflect the platform, e.g: "cpu" or "gpu". But one can also give the full path to the `.so` file with the plugin.

Loaded plugins are singletons and cached (GetPlugin will return a pointer to the same plugin if called with the same name or its aliases).

Plugins are searched in the PJRT_PLUGIN_LIBRARY_PATH directory -- or directories, if it is a ":" separated list. If it is not set it will search in `/usr/local/lib/gomlx` and the standard libraries directories of the system (in linux in LD_LIBRARY_CONFIG and /etc/ld.so.conf file).

func (*Plugin) Attributes

func (p *Plugin) Attributes() NamedValuesMap

Attributes returns a NamedValueMap with the attributes returned by the plugin at the time of its initialization.

func (*Plugin) Name

func (p *Plugin) Name() string

Name returns the name of the plugin, usually it reflects its platform (cpu, gpu, tpu, etc.).

func (*Plugin) NewClient

func (p *Plugin) NewClient(options NamedValuesMap) (*Client, error)

NewClient creates a new Client object to manage available devices. The options (it can be left nil) are plugin specific, and should (but often aren't) documented by the plugins.

func (*Plugin) Path

func (p *Plugin) Path() string

Path returns the path from where the plugin was loaded.

func (*Plugin) String

func (p *Plugin) String() string

String implements fmt.Stringer. It returns the platform and version of the plugin.

func (*Plugin) Version

func (p *Plugin) Version() (major, minor int)

Version returns the version reported by the loaded plugin.

type XlaComputation

type XlaComputation interface {
	// SerializedHLO exports the computation as a serialized `HloModuleProto`.
	SerializedHLO() *cbuffer.CBuffer

	// HasStableHLO returns whether StableHLO is supported.
	HasStableHLO() bool

	// SerializedStableHLO exports the computation as a StableHLO as an `mlir:ModuleOp`.
	SerializedStableHLO() (*cbuffer.CBuffer, error)
}

XlaComputation is an interface that matches xlabuilder.XlaComputation method needed by PJRT.

Created here to avoid creating a hard dependency to the xlabuilder package.

The returned buffer ownership is transferred to the caller -- the caller must free the buffer later.

Directories

Path Synopsis
cpu
dynamic
Package dynamic will link (preload) a dynamically loaded library `libpjrt_c_api_cpu_dynamic`, that is used if the user requests a "cpu" plugin.
Package dynamic will link (preload) a dynamically loaded library `libpjrt_c_api_cpu_dynamic`, that is used if the user requests a "cpu" plugin.
static
Package static statically links a CPU PJRT plugin, and registers with the name "cpu".
Package static statically links a CPU PJRT plugin, and registers with the name "cpu".
internal
cpudynamictest
Package cpudynamictest is just a hack around Go's limitation to use CGO in tests and to avoid cyclic dependency.
Package cpudynamictest is just a hack around Go's limitation to use CGO in tests and to avoid cyclic dependency.
cpustatictest
Package cpustatictest is just a hack around Go's limitation to use CGO in tests and to avoid cyclic dependency.
Package cpustatictest is just a hack around Go's limitation to use CGO in tests and to avoid cyclic dependency.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL