Documentation ¶
Overview ¶
Package shapes defines Shape and DType and associated tools.
Shape represents the shape (rank, dimensions and DType) of either a Tensor or the expected shape of a node in a computation Graph. DType indicates the type of the unit element of a Tensor (or its representation as a node in a computation Graph).
Shape and DType are used both by the concrete tensor values (see tensor package) and when working on the computation graph (see graph package).
## Glossary
- Rank: number of axes (dimensions) of a Tensor.
- Axis: is the index of a dimension on a multi-dimensional Tensor. Sometimes used interchangeably with Dimension, but here we try to refer to a dimension index as "axis" (plural axes), and its size as its dimension.
- Dimension: the size of a multi-dimensions Tensor in one of its axes. See example below:
- DType: the data type of the unit element in a tensor.
- Scalar: is a shape where there are no axes (or dimensions), only a single value of the associated DType.
Example: The multi-dimensional array `[][]int32{{0, 1, 2}, {3, 4, 5}}` if converted to a Tensor would have shape `(int32)[2 3]`. We say it has rank 2 (so 2 axes), axis 0 has dimension 2, and axis 1 has dimension 3. This shape could be created with `shapes.Make(int32, 2, 3)`.
## Asserts
When coding ML models, one delicate part is keeping tabs on the shape of the nodes of the graphs -- unfortunately there is no compile-time checking of values, so validation only happens in runtime. To facilitate, and also to serve as code documentation, this package provides two variations of _assert_ funtionality. Examples:
`AssertRank` and `AssertDims` checks that the rank and dimensions of the given
object (that has a `Shape` method) match, otherwise it panics. The `-1` means the dimension is unchecked (it can be anything).
```
func modelGraph(ctx *context.Context, spec any, inputs []*Node) ([]*Node) { _ = spec // Not needed here, we know the dataset. shapes.AssertRank(inputs, 2) batchSize := inputs.Shape().Dimensions[0] logits := layers.Dense(ctx, inputs[0], /* useBias= */ true, /* outputDim= */ 1) shapes.AssertDims(logits, batchSize, -1) return []*Node{logits} }
```
If you don't want to panic, but instead return an error through the `graph.Graph`, you can use the `Node.AssertDims()` method. So it would loook like `logits.AssertDims(batchSize, -1)`.
Index ¶
- Constants
- func AssertDims(shaped HasShape, dimensions ...int)
- func AssertRank(shaped HasShape, rank int)
- func AssertScalar(shaped HasShape)
- func AssertShape(dtype DType, dimensions ...int)
- func CastAsDType(value any, dtype DType) any
- func CheckDims(shaped HasShape, dimensions ...int) error
- func CheckRank(shaped HasShape, rank int) error
- func CheckScalar(shaped HasShape) error
- func ConvertTo[T NumberNotComplex](value any) T
- func DTypeStrings() []string
- func LowestValueForDType(dtype DType) any
- func SmallestNonZeroValueForDType(dtype DType) any
- func TypeForDType(dtype DType) reflect.Type
- func UnsafeSliceForDType(dtype DType, unsafePtr unsafe.Pointer, len int) any
- type DType
- func (dtype DType) GoStr() string
- func (i DType) IsADType() bool
- func (dtype DType) IsComplex() bool
- func (dtype DType) IsFloat() bool
- func (dtype DType) IsInt() bool
- func (dtype DType) IsSupported() bool
- func (i DType) MarshalJSON() ([]byte, error)
- func (i DType) MarshalText() ([]byte, error)
- func (i DType) MarshalYAML() (interface{}, error)
- func (dtype DType) Memory() int64
- func (dtype DType) RealDType() DType
- func (i DType) String() string
- func (dtype DType) Type() reflect.Type
- func (i *DType) UnmarshalJSON(data []byte) error
- func (i *DType) UnmarshalText(text []byte) error
- func (i *DType) UnmarshalYAML(unmarshal func(interface{}) error) error
- func (DType) Values() []string
- type GoFloat
- type HasShape
- type MultiDimensionSlice
- type Number
- type NumberNotComplex
- type Shape
- func (s Shape) Assert(dtype DType, dimensions ...int)
- func (s Shape) AssertDims(dimensions ...int)
- func (s Shape) AssertRank(rank int)
- func (s Shape) AssertScalar()
- func (s Shape) Check(dtype DType, dimensions ...int) error
- func (s Shape) CheckDims(dimensions ...int) error
- func (s Shape) CheckRank(rank int) error
- func (s Shape) CheckScalar() error
- func (s Shape) Copy() (s2 Shape)
- func (s Shape) Eq(s2 Shape) bool
- func (s Shape) EqDimensions(s2 Shape) bool
- func (s Shape) GobSerialize(encoder *gob.Encoder) (err error)
- func (s Shape) IsScalar() bool
- func (s Shape) IsTuple() bool
- func (s Shape) Memory() int64
- func (s Shape) Ok() bool
- func (s Shape) Rank() int
- func (s Shape) Shape() Shape
- func (s Shape) Size() (size int)
- func (s Shape) String() string
- func (s Shape) TupleSize() int
- type Supported
Constants ¶
const ( U8 = UInt8 U32 = UInt32 U64 = UInt64 I32 = Int32 I64 = Int64 F16 = Float16 F32 = Float32 F64 = Float64 C64 = Complex64 C128 = Complex128 )
const PRED = Bool
PRED type is an alias to Bool, used in `tensorflow/compiler/xla/xla_data.proto`.
const UncheckedAxis = int(-1)
UncheckedAxis can be used in CheckDims or AssertDims functions for an axis whose dimension doesn't matter.
Variables ¶
This section is empty.
Functions ¶
func AssertDims ¶
AssertDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.
It panics if it doesn't match.
See usage example in package shapes documentation.
func AssertRank ¶
AssertRank checks that the shape has the given rank.
It panics if it doesn't match.
See usage example in package shapes documentation.
func AssertScalar ¶
func AssertScalar(shaped HasShape)
AssertScalar checks that the shape is a scalar.
It panics if it doesn't match.
See usage example in package shapes documentation.
func AssertShape ¶ added in v0.9.0
func CastAsDType ¶
CastAsDType casts a numeric value to the corresponding for the DType. If the value is a slice it will convert to a newly allocated slice of the given DType.
It doesn't work for complex numbers.
func CheckDims ¶
CheckDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.
It returns an error if the rank is different or any of the dimensions.
func CheckRank ¶
CheckRank checks that the shape has the given rank.
It returns an error if the rank is different.
func CheckScalar ¶
CheckScalar checks that the shape is a scalar.
It returns an error if shape is not a scalar.
func ConvertTo ¶ added in v0.2.1
func ConvertTo[T NumberNotComplex](value any) T
ConvertTo converts any scalar (typically returned by `tensor.Local.Value()`) of the supported dtypes to `T`. Returns 0 if value is not a scalar or not a supported number (e.g: bool). It doesn't work for if T (the output type) is a complex number. If value is a complex number, it converts by taking the real part of the number and discarding the imaginary part.
func DTypeStrings ¶ added in v0.4.0
func DTypeStrings() []string
DTypeStrings returns a slice of all String values of the enum
func LowestValueForDType ¶
LowestValueForDType converted to the corresponding Go type. For float values it will return negative infinite. There is no lowest value for complex numbers, since they are not ordered.
func SmallestNonZeroValueForDType ¶ added in v0.4.1
SmallestNonZeroValueForDType is the smallest non-zero value dtypes. Only useful for float types. The return value is converted to the corresponding Go type. There is no smallest non-zero value for complex numbers, since they are not ordered.
func TypeForDType ¶
TypeForDType returns the Go `reflect.Type` corresponding to the tensor DType.
Types ¶
type DType ¶
type DType int32
DType indicates the type of the unit element of a Tensor (or its representation in a computation graph). It enumerates the known data types. So far only Bool, Uint8 (U8), Int32 (I32), Int64 (I64), Uint64 (U64), Float32 (F32) and Float64 (F64) are supported.
The values of DType must match "tensorflow/compiler/xla/xla_data.pb.h", hence it needs to be an int32.
See example in package shapes documentation.
const ( InvalidDType DType = iota Bool // Bool, but also known as PRED in `xla_data.proto`. Int8 // S8 Int16 // S16 Int32 // S32 Int64 // S64 UInt8 // U8 UInt16 // U16 UInt32 // U32 UInt64 // U64 Float16 // F16 Float32 // F32 Float64 // F64 BFloat16 DType = 16 // BF16 Complex64 DType = 15 // C64 Complex128 DType = 18 // C128 Tuple DType = 13 OpaqueType DType = 14 Token DType = 17 )
DType constants must match `tensorflow/compiler/xla/xla_data.proto`.
func DTypeForType ¶
func DTypeGeneric ¶
func DTypeString ¶ added in v0.4.0
DTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func DTypeValues ¶ added in v0.4.0
func DTypeValues() []DType
DTypeValues returns all values of the enum
func (DType) GoStr ¶ added in v0.7.2
GoStr converts dtype to the corresponding Go type and convert that to string. Notice the names are different from the Dtype (so `Int64` dtype is simply `int` in Go).
func (DType) IsADType ¶ added in v0.4.0
IsADType returns "true" if the value is listed in the enum definition. "false" otherwise
func (DType) IsComplex ¶ added in v0.6.0
IsComplex returns whether dtype is a supported complex number type.
func (DType) IsFloat ¶
IsFloat returns whether dtype is a supported float -- float types not yet supported will return false. It returns false for complex numbers.
func (DType) IsInt ¶
IsInt returns whether dtype is a supported integer type -- float types not yet supported will return false.
func (DType) IsSupported ¶
func (DType) MarshalJSON ¶ added in v0.4.0
MarshalJSON implements the json.Marshaler interface for DType
func (DType) MarshalText ¶ added in v0.4.0
MarshalText implements the encoding.TextMarshaler interface for DType
func (DType) MarshalYAML ¶ added in v0.4.0
MarshalYAML implements a YAML Marshaler for DType
func (DType) Memory ¶ added in v0.2.0
Memory returns the number of bytes in Go used to store the given DType.
func (DType) RealDType ¶ added in v0.6.0
RealDType returns the real component of complex dtypes. For float dtypes, it returns itself.
It returns InvalidDType for other non-(complex or float) dtypes.
func (DType) Type ¶ added in v0.2.0
Type returns the Go `reflect.Type` corresponding to the tensor DType.
func (*DType) UnmarshalJSON ¶ added in v0.4.0
UnmarshalJSON implements the json.Unmarshaler interface for DType
func (*DType) UnmarshalText ¶ added in v0.4.0
UnmarshalText implements the encoding.TextUnmarshaler interface for DType
func (*DType) UnmarshalYAML ¶ added in v0.4.0
UnmarshalYAML implements a YAML Unmarshaler for DType
type HasShape ¶
type HasShape interface {
Shape() Shape
}
HasShape is an interface for objects that have an associated Shape. `tensor.Tensor` (concrete tensor) and `graph.Node` (tensor representations in a computation graph), `context.Variable` and Shape itself implement the interface.
type MultiDimensionSlice ¶
type MultiDimensionSlice interface { bool | float32 | float64 | int | int32 | int64 | uint8 | uint32 | uint64 | complex64 | complex128 | []bool | []float32 | []float64 | []int | []int32 | []int64 | []uint8 | []uint32 | []uint64 | []complex64 | []complex128 | [][]bool | [][]float32 | [][]float64 | [][]int | [][]int32 | [][]int64 | [][]uint8 | [][]uint32 | [][]uint64 | [][]complex64 | [][]complex128 | [][][]bool | [][][]float32 | [][][]float64 | [][][]int | [][][]int32 | [][][]int64 | [][][]uint8 | [][][]uint32 | [][][]uint64 | [][][]complex64 | [][][]complex128 | [][][][]bool | [][][][]float32 | [][][][]float64 | [][][][]int | [][][][]int32 | [][][][]int64 | [][][][]uint8 | [][][][]uint32 | [][][][]uint64 | [][][][]complex64 | [][][][]complex128 | [][][][][]bool | [][][][][]float32 | [][][][][]float64 | [][][][][]int | [][][][][]int32 | [][][][][]int64 | [][][][][]uint8 | [][][][][]uint32 | [][][][][]uint64 | [][][][][]complex64 | [][][][][]complex128 | [][][][][][]bool | [][][][][][]float32 | [][][][][][]float64 | [][][][][][]int | [][][][][][]int32 | [][][][][][]int64 | [][][][][][]uint8 | [][][][][][]uint32 | [][][][][][]uint64 | [][][][][][]complex64 | [][][][][][]complex128 }
MultiDimensionSlice lists the Go types a Tensor can be converted to/from. There are no recursions in generics' constraint definitions, so we enumerate up to 7 levels of slices. Feel free to add more if needed, the implementation will work with any arbitrary number.
type Number ¶
type Number interface { float32 | float64 | int | int32 | int64 | uint8 | uint32 | uint64 | complex64 | complex128 }
Number represents the Go numeric types that are supported by graph package. Used as a Generics constraint. Notice that "int" becomes int64 in the implementation. Since it needs a 1:1 mapping, it doesn't support the native (Go) int64 type. It includes complex numbers.
type NumberNotComplex ¶ added in v0.6.0
type NumberNotComplex interface { float32 | float64 | int | int32 | int64 | uint8 | uint32 | uint64 }
NumberNotComplex represents the Go numeric types that are supported by graph package except the complex numbers. Used as a Generics constraint. See Number for details.
type Shape ¶
type Shape struct { DType DType Dimensions []int TupleShapes []Shape // Shapes of the tuple, if this is a tuple. }
Shape represents the shape of either a Tensor or the expected shape of the value from a computation node.
Use Make to create a new shape. See example in package shapes documentation.
func ConcatenateDimensions ¶
ConcatenateDimensions of two shapes. The resulting rank is the sum of both ranks. They must have the same dtype. If any of them is a scalar, the resulting shape will be a copy of the other. It doesn't work for Tuples.
func GobDeserialize ¶ added in v0.2.0
GobDeserialize a Shape. Returns new Shape or an error.
func (Shape) Assert ¶ added in v0.9.0
Assert checks that the shape has the given dtype, dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.
It panics if it doesn't match.
func (Shape) AssertDims ¶
AssertDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.
It panics if it doesn't match.
See usage example in package shapes documentation.
func (Shape) AssertRank ¶
AssertRank checks that the shape has the given rank.
It panics if it doesn't match.
See usage example in package shapes documentation.
func (Shape) AssertScalar ¶
func (s Shape) AssertScalar()
AssertScalar checks that the shape is a scalar.
It panics if it doesn't match.
See usage example in package shapes documentation.
func (Shape) Check ¶ added in v0.9.0
Check that the shape has the given dtype, dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.
It returns an error if the dtype or rank is different or if any of the dimensions don't match.
func (Shape) CheckDims ¶
CheckDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.
It returns an error if the rank is different or if any of the dimensions don't match.
func (Shape) CheckRank ¶
CheckRank checks that the shape has the given rank.
It returns an error if the rank is different.
func (Shape) CheckScalar ¶
CheckScalar checks that the shape is a scalar.
It returns an error if shape is not a scalar.
func (Shape) EqDimensions ¶ added in v0.9.0
EqDimensions compares two shapes for equality of dimensions. Dtypes can be different.
func (Shape) GobSerialize ¶ added in v0.2.0
GobSerialize shape in binary format.
func (Shape) IsScalar ¶
IsScalar returns whether the shape represents a scalar, that is there are no dimensions (rank==0).
func (Shape) Memory ¶ added in v0.2.0
Memory returns the number of bytes for that would be used in Go to store the given data -- the actual memory may depend on the device implementation in some cases (e.g. bool).
func (Shape) Ok ¶
Ok returns whether this is a valid Shape. A "zero" shape, that is just instantiating it with Shape{} will be invalid.
func (Shape) Size ¶
Size returns the number of elements of DType are needed for this shape. It's the product of all dimensions.
type Supported ¶
type Supported interface { bool | float32 | float64 | int | int32 | int64 | uint8 | uint32 | uint64 | complex64 | complex128 }
Supported lists the Go types that are supported by the graph package. Used as a Generics constraint. See also Number.
Notice Go's `int` type is not portable, since it may translate to dtypes Int32 or Int64 depending on the platform.
Generated by `cmd/constraints_generator`.