gengraph

package module
v0.0.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Sep 10, 2024 License: MIT Imports: 5 Imported by: 0

README

gengraph: High-Performance Code Generation for Computational Graphs in Go

gengraph is a small project that allows you to build simple computational graphs in Go. Unlike other graph computation frameworks, gengraph leverages a code generation step to convert your graph into pure Go code. This minimizes runtime overhead for forward and gradient propagation, making it ideal for performance-critical applications such as robotics.

I built gengraph as a further learning excersise in computational graphs, following on from my other graph computation project, github.com/JoshPattman/toygraph. I wanted to see how I could improve the performance of the graph computation by using code generation, and I'm happy with the results so far.

Note: gengraph does not, and probably never will, support GPU processing. However, I plan to add support for matrices and vectors using gonum in the future.

Usage

To generate your graphs using gengraph, you write them as Go files in the same project where you intend to use the graphs. Build tags are utilized to streamline the process.

Steps:
  1. Set up build tags:

    • Add //go:build !graph to the top of your main Go file (with func main()). This excludes it during graph generation.

      Example:

      //go:build !graph
      
      package main
      
      func main() {
          // Your main program logic
      }
      
    • Functions in your main file cannot be used for graph generation. Move any such functions to another file without build constraints.

  2. Create a graph generation file:

    • Create a file called main_generate.go with the following content:

      //go:build graph
      
      package main
      
      func main() {
          CreateGraphs()
      }
      
    • The purpose of this file is to call the CreateGraphs function, which should live in a separate file that has no build constraints.

  3. Define CreateGraphs function:

    • It's fine to put the CreateGraphs() function in main_generate.go. However, for better syntax checking in editors like VS Code, it's recommended to place CreateGraphs() in another file that isn't excluded by build constraints.
  4. Run the project:

    • Use the following command to generate the graphs and run your project:
      $ go run -tags graph . && go run .
      
Example

For a more detailed example, refer to the example subdirectory in this repository. However, here’s a quick overview of what might exist in the CreateGraphs function:

This code creates a new graph, named CosGraph, that performs the function cos(x + 3). It then saves the generated code to a file in the local Go project called graph_cosgraph.go:

// Create a new graph to build on, called CosGraph
g := G.NewGraph("CosGraph")

// Define an input variable, making sure to specify a capital letter for its first character so the field will be exported in the struct
a := G.Variable[float64](g, "Input")

// Define a constant variable, b, with a value of 3.0. The name of this variable (and any other unamed variables) will be generated.
b := G.Constant(g, 3.0)

// Add the input and the constant together
added := G.NumAdd(a, b)

// Calculate the cosine of the sum
res := G.NumCos(added)

// Alias the result to "Result" - This allows us to specify a name for the result so we can easily acsess it from the struct
G.Alias(res, "Result")

// Write the graph to the default file - graph_cosgraph.go
g.ToDefaultFile()

To use this graph in the main function, here is some example code:

// Create a new instance of our cos graph
// (the NewCosGraph function is generated and stored in graph_cosgraph.go)
g := NewCosGraph()

// Set the input value to 5.0
// If your graph has multiple inputs, you can set them all by setting their respective struct variables
g.Input = 5.0

// Run the forward pass (calculate cos(5.0 + 3.0))
g.Forward()
fmt.Println("Result=", g.Result)

// Clear the gradients
g.ClearGrads()

// gengraph computes the gradients w.r.t to output gradient, not the partial gradients.
// So we need to set the output gradient to 1.0 to see how a change in each input creates a change of 1.0 in the output.
// Side note: If you want to calculate the partials, you can set all grads to 0 except for the output of the function you want to calculate (can use an alias to get named acsess), then run the backward pass.
g.ResultGrad = 1.0

// Run the backward pass, storing gradients at each step
g.Backward()

// Every variable in gengraph has a corresponding gradient variable, suffixed with "Grad"
fmt.Println("InputGrad=", g.InputGrad)

By following these steps, you can seamlessly integrate graph generation into your Go project with minimal runtime overhead. If you're working on performance-critical applications, such as robotics, gengraph may provide the efficiency boost you're looking for!

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type AliasNode

type AliasNode[T any] struct {
	From *Buffer[T]
	To   *Buffer[T]
}

func Alias

func Alias[T any](from BufferGetter[T], name string) *AliasNode[T]

func (*AliasNode[T]) BackLines

func (v *AliasNode[T]) BackLines() []string

func (*AliasNode[T]) BufferDefs

func (n *AliasNode[T]) BufferDefs() []string

func (*AliasNode[T]) BufferInits

func (n *AliasNode[T]) BufferInits() []string

func (*AliasNode[T]) FwdLines

func (n *AliasNode[T]) FwdLines() []string

func (*AliasNode[T]) GradBufferClears

func (v *AliasNode[T]) GradBufferClears() []string

func (*AliasNode[T]) Imports

func (n *AliasNode[T]) Imports() []string

type Buffer

type Buffer[T any] struct {
	Name    string
	OnGraph *Graph
}

func (*Buffer[T]) Buf

func (v *Buffer[T]) Buf() *Buffer[T]

func (*Buffer[T]) BufferDef

func (v *Buffer[T]) BufferDef() string

func (*Buffer[T]) GradBufferDef

func (v *Buffer[T]) GradBufferDef() string

func (*Buffer[T]) GradUseString

func (v *Buffer[T]) GradUseString() string

func (*Buffer[T]) UseString

func (v *Buffer[T]) UseString() string

type BufferGetter

type BufferGetter[T any] interface {
	Buf() *Buffer[T]
}

type Graph

type Graph struct {
	Nodes []Node
	Name  string
}

func NewGraph

func NewGraph(name string) *Graph

func (*Graph) Add

func (g *Graph) Add(n Node)

func (*Graph) String

func (g *Graph) String() string

func (*Graph) ToDefaultFile

func (g *Graph) ToDefaultFile() error

func (*Graph) ToFile

func (g *Graph) ToFile(filename string) error

type Node

type Node interface {
	FwdLines() []string
	BackLines() []string
	BufferDefs() []string
	BufferInits() []string
	GradBufferClears() []string
	Imports() []string
}

type NumBinaryNode

type NumBinaryNode[T Numerical] struct {
	// contains filtered or unexported fields
}

func NumAdd

func NumAdd[T Numerical](left BufferGetter[T], right BufferGetter[T]) *NumBinaryNode[T]

func NumDiv

func NumDiv[T Numerical](left BufferGetter[T], right BufferGetter[T]) *NumBinaryNode[T]

func NumMul

func NumMul[T Numerical](left BufferGetter[T], right BufferGetter[T]) *NumBinaryNode[T]

func NumSub

func NumSub[T Numerical](left BufferGetter[T], right BufferGetter[T]) *NumBinaryNode[T]

func (*NumBinaryNode[T]) BackLines

func (v *NumBinaryNode[T]) BackLines() []string

func (*NumBinaryNode[T]) Buf

func (n *NumBinaryNode[T]) Buf() *Buffer[T]

func (*NumBinaryNode[T]) BufferDefs

func (n *NumBinaryNode[T]) BufferDefs() []string

func (*NumBinaryNode[T]) BufferInits

func (n *NumBinaryNode[T]) BufferInits() []string

func (*NumBinaryNode[T]) FwdLines

func (n *NumBinaryNode[T]) FwdLines() []string

func (*NumBinaryNode[T]) GradBufferClears

func (v *NumBinaryNode[T]) GradBufferClears() []string

func (*NumBinaryNode[T]) Imports

func (n *NumBinaryNode[T]) Imports() []string

type NumUnaryNode

type NumUnaryNode[T Numerical] struct {
	// contains filtered or unexported fields
}

func NumCos

func NumCos[T Numerical](in BufferGetter[T]) *NumUnaryNode[T]

func NumSin

func NumSin[T Numerical](in BufferGetter[T]) *NumUnaryNode[T]

func (*NumUnaryNode[T]) BackLines

func (v *NumUnaryNode[T]) BackLines() []string

func (*NumUnaryNode[T]) Buf

func (n *NumUnaryNode[T]) Buf() *Buffer[T]

func (*NumUnaryNode[T]) BufferDefs

func (n *NumUnaryNode[T]) BufferDefs() []string

func (*NumUnaryNode[T]) BufferInits

func (n *NumUnaryNode[T]) BufferInits() []string

func (*NumUnaryNode[T]) FwdLines

func (n *NumUnaryNode[T]) FwdLines() []string

func (*NumUnaryNode[T]) GradBufferClears

func (v *NumUnaryNode[T]) GradBufferClears() []string

func (*NumUnaryNode[T]) Imports

func (n *NumUnaryNode[T]) Imports() []string

type Numerical

type Numerical interface {
	float32 | float64 | int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64
}

type Value

type Value[T any] struct {
	Var *Buffer[T]
	// contains filtered or unexported fields
}

func Constant

func Constant[T any](g *Graph, val T) *Value[T]

func Variable

func Variable[T any](g *Graph, name string) *Value[T]

func (*Value[T]) BackLines

func (v *Value[T]) BackLines() []string

func (*Value[T]) Buf

func (v *Value[T]) Buf() *Buffer[T]

func (*Value[T]) BufferDefs

func (v *Value[T]) BufferDefs() []string

func (*Value[T]) BufferInits

func (v *Value[T]) BufferInits() []string

func (*Value[T]) FwdLines

func (v *Value[T]) FwdLines() []string

func (*Value[T]) GradBufferClears

func (v *Value[T]) GradBufferClears() []string

func (*Value[T]) Imports

func (v *Value[T]) Imports() []string

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL