checkpoints

package
v0.9.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Apr 20, 2024 License: Apache-2.0 Imports: 21 Imported by: 0

Documentation

Overview

Package checkpoints implements checkpoint management: saving and loading of checkpoints.

The main object is the Handler, that should be created by calling Build, followed by the various options setting and finally calling Config.Done. Once create, if a previous saved checkpoint exists, it will automatically load variables and parameters for your model into Context. And as the model trains, one can call Handler.Save() at any time to save a new checkpoint -- typically one will do that inside train.EveryNSteps().

Example: After creating the Context, it checks if a checkpoint directory was set (`*flagCheckpoint`) and if yes, creates a checkpoints.Handler to save checkpoints every 100 steps, keeping the last `*flagCheckpointKeep` steps.

```

…
ctx := context.NewContext(manager)
ctx.SetParam(optimizers.ParamLearningRate, *flagLearningRate)

var checkpoint *checkpoints.Handler
if *flagCheckpoint != "" {
	var err error
	checkpoint, err = checkpoints.Build(ctx).Dir(*flagCheckpoint).Keep(*flagCheckpointKeep).Done()
	Must(err)  // Panics if err != nil.
}
…
// Build training loop.
loop := train.NewLoop(trainer)
commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.
if checkpoint != nil {
	const priority = 100  // Large number here, means it runs last.
	train.EveryNSteps(loop, 100, "checkpointing", priority, checkpoint.OnStepFn)
}
…

```

TODO:

  1. Compress checkpoints.
  2. Allow to specify parts of the model to load / scope where they should be loaded to, for transfer learning.

Index

Constants

This section is empty.

Variables

View Source
var (
	// DirPermMode is the default directory creation permission (before umask) used.
	DirPermMode = os.FileMode(0770)
)

Functions

This section is empty.

Types

type Config

type Config struct {
	// contains filtered or unexported fields
}

Config for the checkpoints Handler to be created. This is created with Build() and configured with the various methods. Once finished, call Done() and it will output a checkpoints.Handler that loads (if there are any previously saved checkpoints) and saves checkpoints.

func Build

func Build(ctx *context.Context) *Config

Build a configuration for building a checkpoints.Handler. After configuring the Config object returned, call `Done` to get the configured checkpoints.Handler.

func (*Config) Dir

func (c *Config) Dir(dir string) *Config

Dir sets the directory where to save / load the checkpoints.

One must be set either Dir, DirFromBase or TempDir before building the checkpoints.Handler.

func (*Config) DirFromBase added in v0.5.0

func (c *Config) DirFromBase(dir, baseDir string) *Config

DirFromBase sets the directory where to save / load the checkpoints. If `dir` is not an absolute path, assumes it is a subdirectory of baseDir.

One must be set either Dir, DirFromBase or TempDir before building the checkpoints.Handler.

func (*Config) Done

func (c *Config) Done() (*Handler, error)

Done creates a Handler with the current configuration. It returns an error if the configuration is invalid, or if it's missing information.

func (*Config) ExcludeParams

func (c *Config) ExcludeParams() *Config

ExcludeParams configures Handler to exclude the Context parameters (values usually read/written by Context.GetParam and context.SetParam).

By default, Params are loaded and set into Context the moment Handler is created (when Done() is called), overriding values already present in the Context.

func (*Config) ExcludeVarsFromSaving added in v0.9.0

func (c *Config) ExcludeVarsFromSaving(vars ...*context.Variable) *Config

ExcludeVarsFromSaving enumerate variables to be excluded from saving. The function can be called multiple times, adding variables to be excluded from saving.

func (*Config) Immediate added in v0.9.0

func (c *Config) Immediate() *Config

Immediate forces immediate load of all variables, as opposed to dynamically load variables from checkpoint as they are being used when building the model.

func (*Config) Keep

func (c *Config) Keep(n int) *Config

Keep configures the number of checkpoint files to keep. If set to -1, it will never erase older checkpoints. The default is 1.

func (*Config) MustDone

func (c *Config) MustDone() *Handler

MustDone constructs the checkpoints.Handler. It panics if there was an error.

func (*Config) TakeMean added in v0.4.1

func (c *Config) TakeMean(n int) *Config

TakeMean loads the mean of the last `n` checkpoints. If `n <= 0`, take the mean of all available checkpoints. Notice that only trainable variables are averaged. Variables that have integer values or are not marked as trainable (e.g. the global step), are taken from the most recent checkpoint instead.

The default is 1, so only load the most recent checkpoint.

Notice the mean is taken one tensor at a time, so at any time there is only one copy of the model weights in memory, plus the tensor being merged.

func (*Config) TempDir

func (c *Config) TempDir(dir, pattern string) *Config

TempDir creates a temporary directory under dir, with the pattern name, and uses this directory to load / save checkpoints. It's a convenience wrapper to os.MkdirTemp.

If dir is the empty string, MkdirTemp uses the default directory for temporary files, as returned by os.TempDir.

The new directory's name is generated by adding a random string to the end of pattern. If `pattern` includes a "*", the random string replaces the last "*" instead (see os.MkdirTemp).

Any errors are reported on the return to the call to the method Done.

One must be set either Dir, DirFromBase or TempDir before building the checkpoints.Handler.

type Handler

type Handler struct {
	// contains filtered or unexported fields
}

Handler handles saving and loading of checkpoints for a context.Context. See example in package documentation.

It is created and configured using Build(), followed by options setting and then calling Config.Done().

Loading data into Handler happens at its creation time: it loads from the latest checkpoint. (Hyper-)Parameters are immediately loaded into the context then (if not Config.ExcludeParams) but the loaded variable values are only "consumed" (used) one at a time, as the variables are created during the graph building (e.g: when building the model).

Saving of checkpoints is explicit, by calling Handler.Save(). Usually this is done by configuring train.Loop to call it using train.EveryNSteps or train.NTimesDuringLoop. When saving all variables in Context are saved, along with any previous variables loaded by the Handler that were not used by Context and with the `Params` for all scopes (including changed values).

There can be more than one Handler attached to a Context -- they are used for loading in order they are created (so the first one created takes priority). Multiple Handler set up can be used for instance for transfer learning, where parts of the model are loaded from somewhere else.

A Handler can only be "attached" to one context.Context. If one wants to load the same checkpoint to two different contexts, another Handler object needs to be created. This is because once a variable is loaded, it is transferred to Context, and handler does not keep it.

func (*Handler) Dir

func (h *Handler) Dir() string

Dir returns the directory the Handler is configured to. It cannot be changed once the Handler was created.

It returns "" (empty) if the Handler is `nil`.

func (*Handler) HasCheckpoints added in v0.4.0

func (h *Handler) HasCheckpoints() (bool, error)

HasCheckpoints returns whether there are any checkpoints saved.

func (*Handler) ListCheckpoints

func (h *Handler) ListCheckpoints() (checkpoints []string, err error)

ListCheckpoints returns the base file name of the checkpoints in the directory in time order (older first).

func (*Handler) LoadVariable

func (h *Handler) LoadVariable(ctx *context.Context, scope, name string) (value tensor.Tensor, found bool)

LoadVariable implements context.Loader. This is called by context.Context when the variable is used for the first time. The user may want to use this function to inspect loaded values for testing.

func (*Handler) LoadedVariables added in v0.4.1

func (h *Handler) LoadedVariables() map[string]tensor.Tensor

LoadedVariables for inspection. These are the values loaded -- but not necessarily immediately available in context, since they are actually used only when a model asks for the variable.

The Handler owns the returned map, don't change it -- the behavior is undefined if you do.

func (*Handler) OnStepFn added in v0.4.0

func (h *Handler) OnStepFn(_ *train.Loop, _ []tensor.Tensor) error

OnStepFn implements `train.OnStepFn`, and make it convenient to attach to a training loop. It simply calls save.

func (*Handler) Save

func (h *Handler) Save() error

Save creates a new checkpoint and save the context variables and (optionally) Params.

All variables in the context are saved, as well as those previously loaded -- this allows one to load the variables only for a part of the model, update that part and save again with everything.

Params is (de-) serialized with package json.

func (*Handler) String

func (h *Handler) String() string

String implements Stringer.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL