db

package
v0.38.0-rc7 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Nov 4, 2024 License: Apache-2.0 Imports: 59 Imported by: 0

Documentation

Index

Constants

View Source
const (
	// CodeUniqueViolation is the error code that Postgres uses to indicate that an attempted
	// insert/update violates a uniqueness constraint.  Obtained from:
	// https://www.postgresql.org/docs/10/errcodes-appendix.html
	CodeUniqueViolation = "23505"
	// CodeForeignKeyViolation is the error code that Postgres uses to indicate that an attempted
	// insert/update violates a foreign key constraint.  Obtained from:
	// https://www.postgresql.org/docs/10/errcodes-appendix.html
	CodeForeignKeyViolation = "23503"
	// CodeSerializationFailure is the error code that Postgres uses to indicate that a transaction
	// failed due to a serialization failure.  Obtained from:
	// https://www.postgresql.org/docs/10/errcodes-appendix.html
	CodeSerializationFailure = "40001"
)
View Source
const (
	// InfPostgresString how we store infinity in JSONB in postgres.
	InfPostgresString = "Infinity"
	// NegInfPostgresString how we store -infinity in JSONB in postgres.
	NegInfPostgresString = "-Infinity"
	// NaNPostgresString how we store NaN in JSONB in postgres.
	NaNPostgresString = "NaN"

	// MetricTypeString is the summary metric type for string or mixed types.
	MetricTypeString = "string"
	// MetricTypeNumber is the summary metric type for floats or ints.
	MetricTypeNumber = "number"
	// MetricTypeBool is the summary metric type for boolean.
	MetricTypeBool = "boolean"
	// MetricTypeDate is the summary metric type for date metrics.
	MetricTypeDate = "date"
	// MetricTypeObject is the summary metric type for object types.
	MetricTypeObject = "object"
	// MetricTypeArray is the summary metric type for array types.
	MetricTypeArray = "array"
	// MetricTypeNull is the summary metric type for array types.
	MetricTypeNull = "null"
)
View Source
const ClusterMessageMaxLength = 250

ClusterMessageMaxLength caps the length of a cluster-wide message.

Variables

View Source
var (
	// ErrNotFound is returned if nothing is found.
	ErrNotFound = errors.New("not found")

	// ErrTooManyRowsAffected is returned if too many rows are affected.
	ErrTooManyRowsAffected = errors.New("too many rows are affected")

	// ErrDuplicateRecord is returned when trying to create a row that already exists.
	ErrDuplicateRecord = errors.New("row already exists")

	// ErrInvalidInput is returned when the data passed to a function is invalid for semantic or
	// syntactic reasons.
	ErrInvalidInput = errors.New("invalid input")

	// ErrDeleteDefaultBinding is returned when trying to delete a workspace bound to its default
	// namespace.
	ErrDeleteDefaultBinding = errors.New("cannot delete the default namespace binding")
)

Functions

func ActiveLogPolicies

func ActiveLogPolicies(
	ctx context.Context, id int,
) (expconf.LogPoliciesConfig, error)

ActiveLogPolicies returns log pattern policies for an experiment ID. This should only be called on a running experiment.

func AddAllocation

func AddAllocation(ctx context.Context, a *model.Allocation) error

AddAllocation upserts the existence of an allocation. Allocation IDs may conflict in the event the master restarts and the trial run ID increment is not persisted, but it is the same allocation so this is OK.

func AddAllocationExitStatus

func AddAllocationExitStatus(ctx context.Context, a *model.Allocation) error

AddAllocationExitStatus adds the allocation exit status to the allocations table.

func AddCheckpointMetadata

func AddCheckpointMetadata(ctx context.Context, m *model.CheckpointV2, runID int) error

AddCheckpointMetadata persists metadata for a completed checkpoint to the database.

func AddExperiment

func AddExperiment(
	ctx context.Context,
	experiment *model.Experiment,
	modelDef []byte,
	activeConfig expconf.ExperimentConfig,
) (err error)

AddExperiment adds the experiment to the database and sets its ID.

func AddExperimentTx

func AddExperimentTx(
	ctx context.Context, idb bun.IDB,
	experiment *model.Experiment,
	modelDef []byte,
	activeConfig expconf.ExperimentConfig,
	upsert bool,
) (err error)

AddExperimentTx adds the experiment to the database and sets its ID.

func AddJob

func AddJob(j *model.Job) error

AddJob persists the existence of a job.

func AddJobTx

func AddJobTx(ctx context.Context, idb bun.IDB, j *model.Job) error

AddJobTx persists the existence of a job with a transaction.

func AddNonExperimentTasksContextDirectory

func AddNonExperimentTasksContextDirectory(ctx context.Context, tID model.TaskID, bytes []byte) error

AddNonExperimentTasksContextDirectory adds a context directory for a non experiment task.

func AddProjectHparams

func AddProjectHparams(ctx context.Context, tx bun.Tx, projectID int, runIDs []int32) error

AddProjectHparams adds project hyperparams from provided runs to provided project.

func AddRPWorkspaceBindings

func AddRPWorkspaceBindings(ctx context.Context, workspaceIds []int32, poolName string,
	resourcePools []config.ResourcePoolConfig,
) error

AddRPWorkspaceBindings inserts new bindings between workspaceIds and poolName.

func AddTask

func AddTask(ctx context.Context, t *model.Task) error

AddTask UPSERT's the existence of a task.

func AddTaskTx

func AddTaskTx(ctx context.Context, idb bun.IDB, t *model.Task) error

AddTaskTx UPSERT's the existence of a task in a tx.

func AddTrial

func AddTrial(ctx context.Context, trial *model.Trial, taskID model.TaskID) error

AddTrial adds the trial to the database and sets its ID.

func AllocationByID

func AllocationByID(ctx context.Context, aID model.AllocationID) (*model.Allocation, error)

AllocationByID retrieves an allocation by its ID.

func ApplyDoubleFieldFilter

func ApplyDoubleFieldFilter[T string | schema.Ident](
	q *bun.SelectQuery,
	column T,
	filter *commonv1.DoubleFieldFilter,
) (*bun.SelectQuery, error)

ApplyDoubleFieldFilter applies filtering on a bun query for double field.

func ApplyInt32FieldFilter

func ApplyInt32FieldFilter[T string | schema.Ident](
	q *bun.SelectQuery,
	column T,
	filter *commonv1.Int32FieldFilter,
) (*bun.SelectQuery, error)

ApplyInt32FieldFilter applies filtering on a bun query for int32 field.

func ApplyPolymorphicFilter

func ApplyPolymorphicFilter(
	q *bun.SelectQuery,
	column string,
	filter *commonv1.PolymorphicFilter,
) (*bun.SelectQuery, error)

ApplyPolymorphicFilter applies filtering on a bun query for a polymorphic filter.

func ApplyTimestampFieldFilter

func ApplyTimestampFieldFilter[T string | schema.Ident](
	q *bun.SelectQuery,
	column T,
	filter *commonv1.TimestampFieldFilter,
) (*bun.SelectQuery, error)

ApplyTimestampFieldFilter applies filtering on a bun query for timestamp field.

func BuildRunHParams

func BuildRunHParams(runID int, projectID int, hparams map[string]any,
	parentName string,
) ([]model.RunHparam, []model.ProjectHparam, error)

BuildRunHParams builds hyperparameters objects to add into the `run_hparams` & `project_hparams` table.

func Bun

func Bun() *bun.DB

Bun returns the singleton database connection through the bun library. bun is the database library we have decided to use for new code in the future due to its superior composability over bare SQL, and its superior flexibility over e.g. gorm. New code should not use the old bare SQL tooling.

func BunSelectMetricGroupNames

func BunSelectMetricGroupNames() *bun.SelectQuery

BunSelectMetricGroupNames sets up a bun select query for getting all the metric group and names.

func BunSelectMetricsQuery

func BunSelectMetricsQuery(mGroup model.MetricGroup, inclArchived bool) *bun.SelectQuery

BunSelectMetricsQuery sets up a bun select query for based on new metrics table simplifying some weirdness we set up for pg10 support.

func CheckIfRPUnbound

func CheckIfRPUnbound(poolName string) error

CheckIfRPUnbound checks to make sure the specified resource pools is not bound to any workspace and returns an error if it is.

func ClearClusterMessage

func ClearClusterMessage(ctx context.Context, db *bun.DB) error

ClearClusterMessage clears the active cluster message.

func CloseOpenAllocations

func CloseOpenAllocations(ctx context.Context, exclude []model.AllocationID) error

CloseOpenAllocations finds all allocations that were open when the master crashed and adds an end time.

func CompleteAllocation

func CompleteAllocation(ctx context.Context, a *model.Allocation) error

CompleteAllocation persists the end of an allocation lifetime.

func CompleteAllocationTelemetry

func CompleteAllocationTelemetry(ctx context.Context, aID model.AllocationID) ([]byte, error)

CompleteAllocationTelemetry returns the analytics of an allocation for the telemetry.

func CompleteGenericTask

func CompleteGenericTask(tID model.TaskID, endTime time.Time) error

CompleteGenericTask persists the completion of a task of type GENERIC.

func CompleteTask

func CompleteTask(ctx context.Context, tID model.TaskID, endTime time.Time) error

CompleteTask persists the completion of a task.

func DeleteAllocationSession

func DeleteAllocationSession(ctx context.Context, allocationID model.AllocationID) error

DeleteAllocationSession deletes the task session with the given AllocationID.

func DeleteDispatch

func DeleteDispatch(
	ctx context.Context,
	id string,
) (int64, error)

DeleteDispatch deletes the specified dispatch and returns the number deleted.

func DeleteDispatches

func DeleteDispatches(
	ctx context.Context,
	opts func(*bun.DeleteQuery) *bun.DeleteQuery,
) (int64, error)

DeleteDispatches deletes all dispatches for the specified query and returns the number deleted.

func DeleteNotebookSessionByTask

func DeleteNotebookSessionByTask(
	ctx context.Context,
	taskID model.TaskID,
) error

DeleteNotebookSessionByTask deletes the notebook session associated with the task.

func DoPermissionsExist

func DoPermissionsExist(ctx context.Context, curUserID model.UserID,
	permissionIDs ...rbacv1.PermissionType,
) error

DoPermissionsExist checks for the existence of a permission in any workspace.

func DoesPermissionMatch

func DoesPermissionMatch(ctx context.Context, curUserID model.UserID, workspaceID *int32,
	permissionID rbacv1.PermissionType,
) error

DoesPermissionMatch checks for the existence of a permission in a workspace.

func DoesPermissionMatchAll

func DoesPermissionMatchAll(ctx context.Context, curUserID model.UserID,
	permissionID rbacv1.PermissionType, workspaceIds ...int32,
) error

DoesPermissionMatchAll checks for the existence of a permission in all specified workspaces.

func EndAgentStats

func EndAgentStats(a *model.AgentStats) error

EndAgentStats updates the end time of an instance.

func EndAllTaskStats

func EndAllTaskStats(ctx context.Context) error

EndAllTaskStats called at master starts, in case master previously crashed.

func ExperimentBestSearcherValidation

func ExperimentBestSearcherValidation(ctx context.Context, id int) (float32, error)

ExperimentBestSearcherValidation returns the best searcher validation for an experiment.

func ExperimentByExternalIDTx

func ExperimentByExternalIDTx(ctx context.Context, idb bun.IDB, externalExperimentID string) (
	*model.Experiment, error,
)

ExperimentByExternalIDTx looks up an experiment by a given external experiment id.

func ExperimentByID

func ExperimentByID(ctx context.Context, expID int) (*model.Experiment, error)

ExperimentByID looks up an experiment by ID in a database, returning an error if none exists.

func ExperimentByTaskID

func ExperimentByTaskID(
	ctx context.Context, taskID model.TaskID,
) (*model.Experiment, error)

ExperimentByTaskID looks up an experiment by a given taskID, returning an error if none exists.

func ExperimentByTrialID

func ExperimentByTrialID(ctx context.Context, trialID int) (*model.Experiment, error)

ExperimentByTrialID looks up an experiment by a given trialID, returning an error if none exists.

func ExperimentIDsToWorkspaceIDs

func ExperimentIDsToWorkspaceIDs(ctx context.Context, experimentIDs []int32) (
	[]model.AccessScopeID, error,
)

ExperimentIDsToWorkspaceIDs returns a slice of workspaces that the given experiments belong to.

func ExperimentNumSteps

func ExperimentNumSteps(ctx context.Context, id int) (int64, error)

ExperimentNumSteps returns the total number of steps for all trials of the experiment.

func ExperimentTotalStepTime

func ExperimentTotalStepTime(ctx context.Context, id int) (float64, error)

ExperimentTotalStepTime returns the total elapsed time for all allocations of the experiment with the given ID. Any step with a NULL end_time does not contribute. Elapsed time is expressed as a floating point number of seconds.

func ExperimentsByTrialID

func ExperimentsByTrialID(ctx context.Context, trialIDs []int) ([]*model.Experiment, error)

ExperimentsByTrialID looks up an experiment by a given list of trialIDs, returning an error if none exists.

func ExperimentsTrialAndTaskIDs

func ExperimentsTrialAndTaskIDs(ctx context.Context, idb bun.IDB, expIDs []int) (
	[]int, []model.TaskID, error,
)

ExperimentsTrialAndTaskIDs returns the trial and task IDs for one or more experiments.

func GenerateNotebookSessionToken

func GenerateNotebookSessionToken(
	userID model.UserID,
	taskID model.TaskID,
) (string, error)

GenerateNotebookSessionToken generates a token for a notebook session.

func GetActiveClusterMessage

func GetActiveClusterMessage(ctx context.Context, db *bun.DB) (model.ClusterMessage, error)

GetActiveClusterMessage returns the active cluster message if one is set and active, or ErrNotFound if not.

func GetCheckpoint

func GetCheckpoint(ctx context.Context, checkpointUUID string) (*checkpointv1.Checkpoint, error)

GetCheckpoint gets checkpointv1.Checkpoint from the database by UUID. Can be moved to master/internal/checkpoints once db/postgres_model_intg_test is bunified. WARNING: Function does not account for "NaN", "Infinity", or "-Infinity" due to Bun unmarshallling.

func GetClusterMessage

func GetClusterMessage(ctx context.Context, db *bun.DB) (model.ClusterMessage, error)

GetClusterMessage returns the cluster message even if it's not yet active, or ErrNotFound if all cluster messages have expired.

func GetDefaultPoolsForWorkspace

func GetDefaultPoolsForWorkspace(ctx context.Context, workspaceID int,
) (computePool, auxPool string, err error)

GetDefaultPoolsForWorkspace returns the default compute and aux pools for a workspace.

func GetMetrics

func GetMetrics(ctx context.Context, trialID, afterBatches, limit int,
	mGroup *string,
) ([]*trialv1.MetricsReport, error)

GetMetrics returns a subset metrics of the requested type for the given trial ID.

func GetNonGlobalWorkspacesWithPermission

func GetNonGlobalWorkspacesWithPermission(ctx context.Context, curUserID model.UserID,
	permissionID rbacv1.PermissionType,
) ([]int, error)

GetNonGlobalWorkspacesWithPermission returns all workspaces the user has permissionID on. This does not check for permissions granted on scopes higher than workspace level (eg cluster).

func GetNonTerminalExperimentCount

func GetNonTerminalExperimentCount(ctx context.Context,
	experimentIDs []int32,
) (count int, err error)

GetNonTerminalExperimentCount returns the number of non terminal experiments.

func GetRunMetadata

func GetRunMetadata(ctx context.Context, runID int) (map[string]any, error)

GetRunMetadata returns the metadata of a run from the database. If the run does not have any metadata, it returns an empty map.

func GetTokenKeys

func GetTokenKeys() *model.AuthTokenKeypair

GetTokenKeys returns tokenKeys.

func GetTrialProfilerAvailableSeries

func GetTrialProfilerAvailableSeries(
	ctx context.Context, trialID int32,
) ([]*trialv1.TrialProfilerMetricLabels, error)

GetTrialProfilerAvailableSeries returns all available system profiling metric names. This method is to be deprecated in the future in place of generic metrics APIs.

func GetUnboundRPs

func GetUnboundRPs(
	ctx context.Context, resourcePools []string,
) ([]string, error)

GetUnboundRPs get unbound resource pools.

func HackAddUser

func HackAddUser(
	ctx context.Context,
	user *model.User,
) (model.UserID, error)

HackAddUser is used to prevent an import cycle in postgres_test_utils & postgres_scim (EE).

func InitAuthKeys

func InitAuthKeys() error

InitAuthKeys initializes auth token keypairs.

func InsertDispatch

func InsertDispatch(ctx context.Context, r *Dispatch) error

InsertDispatch persists the existence for a dispatch.

func InsertModel

func InsertModel(ctx context.Context, name string, description string, metadata []byte,
	labels string, notes string, userID model.UserID, workspaceID int,
) (*modelv1.Model, error)

InsertModel inserts the model into the database.

func InsertModelVersion

func InsertModelVersion(ctx context.Context, id int32, ckptID string, name string, comment string,
	metadata []byte, labels string, notes string, userID model.UserID,
) (*modelv1.ModelVersion, error)

InsertModelVersion inserts the model version into the database.

func IsPaused

func IsPaused(ctx context.Context, tID model.TaskID) (bool, error)

IsPaused returns true if given task is in paused/pausing state.

func JobByID

func JobByID(ctx context.Context, jobID model.JobID) (*model.Job, error)

JobByID retrieves a job by ID.

func KillGenericTask

func KillGenericTask(tID model.TaskID, endTime time.Time) error

KillGenericTask persists the termination of a task of type GENERIC.

func MatchSentinelError

func MatchSentinelError(err error) error

MatchSentinelError checks if the error belongs to specific families of errors and ensures that the returned error has the proper type and text.

func MetricBatches

func MetricBatches(
	experimentID int, metricName string, startTime time.Time, metricGroup model.MetricGroup,
) (
	batches []int32, endTime time.Time, err error,
)

MetricBatches returns the milestones (in batches processed) at which a specific metric was recorded.

func MustHaveAffectedRows

func MustHaveAffectedRows(result sql.Result, err error) error

MustHaveAffectedRows checks if bun has affected rows in a table or not. Returns ErrNotFound if no rows were affected and returns the provided error otherwise.

func NonExperimentTasksContextDirectory

func NonExperimentTasksContextDirectory(ctx context.Context, tID model.TaskID) ([]byte, error)

NonExperimentTasksContextDirectory returns a non experiment's context directory.

func OrderByToSQL

func OrderByToSQL(order apiv1.OrderBy) string

OrderByToSQL computes the SQL keyword corresponding to the given ordering type.

func OverwriteRPWorkspaceBindings

func OverwriteRPWorkspaceBindings(ctx context.Context,
	workspaceIds []int32, poolName string, resourcePools []config.ResourcePoolConfig,
) error

OverwriteRPWorkspaceBindings overwrites the bindings between workspaceIds and poolName.

func PaginateBun

func PaginateBun(
	query *bun.SelectQuery,
	orderColumn string,
	direction SortDirection,
	offset,
	limit int,
) *bun.SelectQuery

PaginateBun adds sorting and pagination to the provided bun query, defaulting to certain values if they are not specified. By default, we order by ascending on the id column, with no limit.

func PaginateBunUnsafe

func PaginateBunUnsafe(
	query *bun.SelectQuery,
	orderColumn string,
	direction SortDirection,
	offset,
	limit int,
) *bun.SelectQuery

PaginateBunUnsafe is a version of PaginateBun that allows an arbitrary order expression like `metrics->>'loss'`.

func ParseMapToProto

func ParseMapToProto(dest map[string]interface{}, val interface{}) error

ParseMapToProto converts sqlx/bun-scanned map to a proto object.

func ProjectExperiments

func ProjectExperiments(ctx context.Context, pID int) (experiments []*model.Experiment, err error)

ProjectExperiments returns a list of experiments within a project.

func ReadRPsAvailableToWorkspace

func ReadRPsAvailableToWorkspace(
	ctx context.Context,
	workspaceID int32,
	offset int32,
	limit int32,
	resourcePoolConfig []config.ResourcePoolConfig,
) ([]string, *apiv1.Pagination, error)

ReadRPsAvailableToWorkspace returns the names of resource pools available to a workspace.

func RecordTaskEndStats

func RecordTaskEndStats(ctx context.Context, stats *model.TaskStats) error

RecordTaskEndStats record end stats for tasks.

func RecordTaskStats

func RecordTaskStats(ctx context.Context, stats *model.TaskStats) error

RecordTaskStats record stats for tasks.

func RegisterModel

func RegisterModel(m interface{})

RegisterModel registers a model in Bun or, if theOneBun is not yet initialized, sets it up to be registered once initialized. It's generally best to pass a nil pointer of your model's type as argument m.

func RemoveOutdatedProjectHparams

func RemoveOutdatedProjectHparams(ctx context.Context, tx bun.Tx, projectID int) error

RemoveOutdatedProjectHparams removes outdated project hyperparams from provided project.

func RemoveRPWorkspaceBindings

func RemoveRPWorkspaceBindings(ctx context.Context,
	workspaceIds []int32, poolName string,
) error

RemoveRPWorkspaceBindings removes the bindings between workspaceIds and poolName.

func SetClause

func SetClause(fields []string) string

SetClause returns a SET subquery.

func SetClusterMessage

func SetClusterMessage(ctx context.Context, db *bun.DB, msg model.ClusterMessage) error

SetClusterMessage sets the cluster-wide message. Any existing message will be expired because only one cluster message is allowed at any time. Messages may be at most ClusterMessageMaxLength characters long. Returns a wrapped ErrInvalidInput when input is invalid.

func SetErrorState

func SetErrorState(taskID model.TaskID, endTime time.Time) error

SetErrorState sets given task to a ERROR state.

func SetPausedState

func SetPausedState(taskID model.TaskID, endTime time.Time) error

SetPausedState sets given task to a PAUSED state.

func SetTokenKeys

func SetTokenKeys(tk *model.AuthTokenKeypair)

SetTokenKeys sets tokenKeys.

func StartAllocationSession

func StartAllocationSession(
	ctx context.Context,
	allocationID model.AllocationID,
	owner *model.User,
) (string, error)

StartAllocationSession creates a row in the allocation_sessions table.

func StartNotebookSession

func StartNotebookSession(
	ctx context.Context,
	userID model.UserID,
	taskID model.TaskID,
) error

StartNotebookSession persists a new notebook session row into the database.

func TaskByID

func TaskByID(ctx context.Context, tID model.TaskID) (*model.Task, error)

TaskByID returns a task by its ID.

func TaskCompleted

func TaskCompleted(ctx context.Context, tID model.TaskID) (bool, error)

TaskCompleted checks if the end time exists for a task, if so, the task has completed.

func TopTrialsByMetric

func TopTrialsByMetric(
	ctx context.Context, experimentID int, maxTrials int, metric string, smallerIsBetter bool,
) ([]int32, error)

TopTrialsByMetric chooses the subset of trials from an experiment that recorded the best values for the specified metric at any point during the trial.

func TrialByExperimentAndRequestID

func TrialByExperimentAndRequestID(
	ctx context.Context, experimentID int, requestID model.RequestID,
) (*model.Trial, error)

TrialByExperimentAndRequestID looks up a trial, returning an error if none exists.

func TrialByID

func TrialByID(ctx context.Context, id int) (*model.Trial, error)

TrialByID looks up a trial by ID, returning an error if none exists.

func TrialByTaskID

func TrialByTaskID(ctx context.Context, taskID model.TaskID) (*model.Trial, error)

TrialByTaskID looks up a trial by taskID, returning an error if none exists. This errors if you called it with a non trial taskID.

func TrialIDsToWorkspaceIDs

func TrialIDsToWorkspaceIDs(ctx context.Context, trialIDs []int32) (
	[]model.AccessScopeID, error,
)

TrialIDsToWorkspaceIDs returns a slice of workspaces that the given trials belong to.

func TrialTaskIDsByTrialID

func TrialTaskIDsByTrialID(ctx context.Context, trialID int) ([]*model.RunTaskID, error)

TrialTaskIDsByTrialID returns trial id task ids by trial ID, sorted by start time.

func UpdateAllocationPorts

func UpdateAllocationPorts(ctx context.Context, a model.Allocation) error

UpdateAllocationPorts stores the latest task state and readiness.

func UpdateAllocationProxyAddress

func UpdateAllocationProxyAddress(ctx context.Context, a model.Allocation) error

UpdateAllocationProxyAddress stores the proxy address.

func UpdateAllocationStartTime

func UpdateAllocationStartTime(ctx context.Context, a model.Allocation) error

UpdateAllocationStartTime stores the latest start time.

func UpdateAllocationState

func UpdateAllocationState(ctx context.Context, a model.Allocation) error

UpdateAllocationState stores the latest task state and readiness.

func UpdateCheckpointSizeTx

func UpdateCheckpointSizeTx(ctx context.Context, idb bun.IDB, checkpoints []uuid.UUID) error

UpdateCheckpointSizeTx which updates checkpoint size and count to experiment and trial, is duplicated here. Remove from this file when bunifying. Original is in master/internal/checkpoints/postgres_checkpoints.go.

func UpdateRunMetadata

func UpdateRunMetadata(
	ctx context.Context,
	runID int,
	rawMetadata map[string]any,
	flatMetadata []model.RunMetadataIndex,
) (result map[string]any, err error)

UpdateRunMetadata updates the metadata of a run, including the metadata indexes.

func UpdateTrial

func UpdateTrial(ctx context.Context, id int, newState model.State) error

UpdateTrial updates the state of an existing trial. end_time is set if the trial moves to a terminal state.

func UpsertTrialByExternalIDTx

func UpsertTrialByExternalIDTx(
	ctx context.Context, tx bun.Tx, trial *model.Trial, taskID model.TaskID,
) error

UpsertTrialByExternalIDTx UPSERTs the trial with respect to the external_trial_id.

func ValidateDoubleFieldFilterComparison

func ValidateDoubleFieldFilterComparison(
	filter *commonv1.DoubleFieldFilter,
) error

ValidateDoubleFieldFilterComparison validates the min and max values in the range.

func ValidateInt32FieldFilterComparison

func ValidateInt32FieldFilterComparison(
	filter *commonv1.Int32FieldFilter,
) error

ValidateInt32FieldFilterComparison validates the min and max values in the range.

func ValidatePolymorphicFilter

func ValidatePolymorphicFilter(
	filter *commonv1.PolymorphicFilter,
) error

ValidatePolymorphicFilter ensures that a Polymorphic filter contains at most one valid range.

func ValidateTimeStampFieldFilterComparison

func ValidateTimeStampFieldFilterComparison(
	filter *commonv1.TimestampFieldFilter,
) error

ValidateTimeStampFieldFilterComparison validates the min and max timestamps in the range.

Types

type ClientStore

type ClientStore struct {
	// contains filtered or unexported fields
}

ClientStore is a store for OAuth clients. It is separate from PgDB so we can implement an interface of the external OAuth library without polluting PgDB's method set.

func (*ClientStore) Create

func (s *ClientStore) Create(c model.OAuthClient) error

Create adds a new client to the database.

func (*ClientStore) GetByID

func (s *ClientStore) GetByID(id string) (oauth2.ClientInfo, error)

GetByID returns a client given its ID, including the secret. It implements the gopkg.in/oauth2.v3#ClientStore interface, so it returns an external interface type; the returned object is always actually of type model.OAuthClient.

func (*ClientStore) List

func (s *ClientStore) List() ([]model.OAuthClient, error)

List returns all OAuth clients in the database. The secrets are not included.

func (*ClientStore) RemoveByID

func (s *ClientStore) RemoveByID(id string) error

RemoveByID removes the client with the given client ID.

type DB

type DB interface {
	Migrate(migrationURL, codeURL string, actions []string) error
	Close() error
	GetOrCreateClusterID(telemetryID string) (string, error)
	TrialExperimentAndRequestID(id int) (int, model.RequestID, error)
	AddExperiment(experiment *model.Experiment, modelDef []byte, activeConfig expconf.ExperimentConfig) error
	ExperimentIDByTrialID(trialID int) (int, error)
	NonTerminalExperiments() ([]*model.Experiment, error)
	TerminateExperimentInRestart(id int, state model.State) error
	SaveExperimentConfig(id int, config expconf.ExperimentConfig) error
	SaveExperimentState(experiment *model.Experiment) error
	SaveExperimentArchiveStatus(experiment *model.Experiment) error
	DeleteExperiments(ctx context.Context, ids []int) error
	ExperimentHasCheckpointsInRegistry(id int) (bool, error)
	SaveExperimentProgress(id int, progress *float64) error
	ActiveExperimentConfig(id int) (expconf.ExperimentConfig, error)
	ExperimentNumTrials(id int) (int64, error)
	ExperimentTrialIDs(expID int) ([]int, error)
	ExperimentModelDefinitionRaw(id int) ([]byte, error)
	UpdateTrialFields(id int, newRunnerMetadata *trialv1.TrialRunnerMetadata, newRunID, newRestarts int) error
	TrialRunIDAndRestarts(trialID int) (int, int, error)
	AddTrainingMetrics(ctx context.Context, m *trialv1.TrialMetrics) error
	AddValidationMetrics(
		ctx context.Context, m *trialv1.TrialMetrics,
	) error
	ValidationByTotalBatches(trialID, totalBatches int) (*model.TrialMetrics, error)
	CheckpointByTotalBatches(trialID, totalBatches int) (*model.Checkpoint, error)
	LatestCheckpointForTrial(trialID int) (*model.Checkpoint, error)
	PeriodicTelemetryInfo() ([]byte, error)
	TrialState(trialID int) (model.State, error)
	TrialStatus(trialID int) (model.State, *time.Time, error)
	Query(queryName string, v interface{}, params ...interface{}) error
	QueryF(
		queryName string, args []interface{}, v interface{}, params ...interface{}) error
	RawQuery(queryName string, params ...interface{}) ([]byte, error)
	UpdateResourceAllocationAggregation() error
	InsertTrialProfilerMetricsBatch(
		values []float32, batches []int32, timestamps []time.Time, labels []byte,
	) error
	GetTrialProfilerMetricsBatches(
		labels *trialv1.TrialProfilerMetricLabels, offset, limit int,
	) (model.TrialProfilerMetricsBatchBatch, error)
	ExperimentLabelUsage(projectID int32) (labelUsage map[string]int, err error)
	GetExperimentStatus(experimentID int) (state model.State, progress float64,
		err error)
	TrainingMetricBatches(experimentID int, metricName string, startTime time.Time) (
		batches []int32, endTime time.Time, err error)
	ValidationMetricBatches(experimentID int, metricName string, startTime time.Time) (
		batches []int32, endTime time.Time, err error)
	TrialsSnapshot(experimentID int, minBatches int, maxBatches int,
		metricName string, startTime time.Time, metricGroup model.MetricGroup) (
		trials []*apiv1.TrialsSnapshotResponse_Trial, endTime time.Time, err error)
	TopTrialsByTrainingLength(experimentID int, maxTrials int, metric string,
		smallerIsBetter bool) (trials []int32, err error)
	ExperimentSnapshot(experimentID int) ([]byte, int, error)
	SaveSnapshot(
		experimentID int, version int, experimentSnapshot []byte,
	) error
	DeleteSnapshotsForExperiment(experimentID int) error
	DeleteSnapshotsForTerminalExperiments() error
	QueryProto(queryName string, v interface{}, args ...interface{}) error
	QueryProtof(
		queryName string, args []interface{}, v interface{}, params ...interface{}) error
	TrialLogs(
		trialID, limit int, fs []api.Filter, order apiv1.OrderBy, followState interface{},
	) ([]*model.TrialLog, interface{}, error)
	DeleteTrialLogs(ids []int) error
	TrialLogsCount(trialID int, fs []api.Filter) (int, error)
	TrialLogsFields(trialID int) (*apiv1.TrialLogsFieldsResponse, error)
	RecordAgentStats(a *model.AgentStats) error
	EndAllAgentStats() error
	RecordInstanceStats(a *model.InstanceStats) error
	EndInstanceStats(a *model.InstanceStats) error
	EndAllInstanceStats() error
}

DB is an interface for _all_ the functionality packed into the DB.

type Dispatch

type Dispatch struct {
	bun.BaseModel `bun:"table:resourcemanagers_dispatcher_dispatches"`

	DispatchID       string             `bun:"dispatch_id"`
	ResourceID       sproto.ResourcesID `bun:"resource_id"`
	AllocationID     model.AllocationID `bun:"allocation_id"`
	ImpersonatedUser string             `bun:"impersonated_user"`
}

Dispatch is the Determined-persisted representation for dispatch existence.

func DispatchByID

func DispatchByID(
	ctx context.Context,
	id string,
) (*Dispatch, error)

DispatchByID retrieves a dispatch by its ID.

func ListAllDispatches

func ListAllDispatches(ctx context.Context) ([]*Dispatch, error)

ListAllDispatches lists all dispatches in the DB.

func ListDispatches

func ListDispatches(
	ctx context.Context,
	opts func(*bun.SelectQuery) (*bun.SelectQuery, error),
) ([]*Dispatch, error)

ListDispatches lists all dispatches according to the options provided.

func ListDispatchesByAllocationID

func ListDispatchesByAllocationID(
	ctx context.Context,
	id model.AllocationID,
) ([]*Dispatch, error)

ListDispatchesByAllocationID lists all dispatches for an allocation ID.

func ListDispatchesByJobID

func ListDispatchesByJobID(
	ctx context.Context,
	jobID string,
) ([]*Dispatch, error)

ListDispatchesByJobID returns a list of dispatches associated with the specified job.

type FilterComparison

type FilterComparison[T any] struct {
	Gt  *T
	Gte *T
	Lt  *T
	Lte *T
}

FilterComparison makes you wish for properties in generic structs/interfaces.

type MetricMeasurements

type MetricMeasurements struct {
	Values  map[string]interface{}
	Batches uint
	Time    time.Time
	Epoch   *float64 `json:"epoch,omitempty"`
	TrialID int32
}

MetricMeasurements represents a metric measured by all possible independent variables.

type MetricPartitionType

type MetricPartitionType string

MetricPartitionType denotes what type the metric is. This is planned to be deprecated once we upgrade to pg11 and can use DEFAULT partitioning.

const (
	// TrainingMetric designates metrics from training steps.
	TrainingMetric MetricPartitionType = "TRAINING"
	// ValidationMetric designates metrics from validation steps.
	ValidationMetric MetricPartitionType = "VALIDATION"
	// ProfilingMetric designates metrics from profiling steps.
	ProfilingMetric MetricPartitionType = "PROFILING"
	// GenericMetric designates metrics from other sources.
	GenericMetric MetricPartitionType = "GENERIC"
)

type PgDB

type PgDB struct {
	URL string
	// contains filtered or unexported fields
}

PgDB represents a Postgres database connection. The type definition is needed to define methods.

func Connect

func Connect(opts *config.DBConfig) (*PgDB, error)

Connect connects to the database, but doesn't run migrations & inits.

func ConnectPostgres

func ConnectPostgres(url string) (*PgDB, error)

ConnectPostgres connects to a Postgres database.

func Setup

func Setup(
	opts *config.DBConfig, postConnectHooks ...func(*PgDB) error,
) (db *PgDB, err error)

Setup connects to the database and run any necessary migrations.

func SingleDB

func SingleDB() *PgDB

SingleDB returns a singleton database client. Bun() should be preferred over this for all new queries.

func (*PgDB) ActiveExperimentConfig

func (db *PgDB) ActiveExperimentConfig(id int) (expconf.ExperimentConfig, error)

ActiveExperimentConfig returns the full config object for an experiment.

func (*PgDB) AddExperiment

func (db *PgDB) AddExperiment(
	experiment *model.Experiment, modelDef []byte, activeConfig expconf.ExperimentConfig,
) (err error)

AddExperiment adds the experiment to the database and sets its ID.

TODO(ilia): deprecate and use module function instead.

func (*PgDB) AddTaskLogs

func (db *PgDB) AddTaskLogs(logs []*model.TaskLog) error

AddTaskLogs bulk-inserts a list of *model.TaskLog objects to the database with automatic IDs.

func (*PgDB) AddTrainingMetrics

func (db *PgDB) AddTrainingMetrics(ctx context.Context, m *trialv1.TrialMetrics) error

AddTrainingMetrics [DEPRECATED] adds a completed step to the database with the given training metrics. If these training metrics occur before any others, a rollback is assumed and later training and validation metrics are cleaned up.

func (*PgDB) AddTrialMetrics

func (db *PgDB) AddTrialMetrics(
	ctx context.Context, m *trialv1.TrialMetrics, mGroup model.MetricGroup,
) error

AddTrialMetrics persists the given trial metrics to the database.

func (*PgDB) AddValidationMetrics

func (db *PgDB) AddValidationMetrics(
	ctx context.Context, m *trialv1.TrialMetrics,
) error

AddValidationMetrics [DEPRECATED] adds a completed validation to the database with the given validation metrics. If these validation metrics occur before any others, a rollback is assumed and later metrics are cleaned up from the database.

func (*PgDB) CheckpointByTotalBatches

func (db *PgDB) CheckpointByTotalBatches(trialID, totalBatches int) (*model.Checkpoint, error)

CheckpointByTotalBatches looks up a checkpoint by trial and total batch, returning nil if none exists.

func (*PgDB) ClientStore

func (db *PgDB) ClientStore() *ClientStore

ClientStore returns a store for OAuth clients backed by this database.

func (*PgDB) Close

func (db *PgDB) Close() error

Close closes the underlying pq connection.

func (*PgDB) DeleteExperiments

func (db *PgDB) DeleteExperiments(ctx context.Context, ids []int) error

DeleteExperiments deletes zero or more experiments.

func (*PgDB) DeleteSnapshotsForExperiment

func (db *PgDB) DeleteSnapshotsForExperiment(experimentID int) error

DeleteSnapshotsForExperiment deletes all snapshots for one given experiment.

func (*PgDB) DeleteSnapshotsForTerminalExperiments

func (db *PgDB) DeleteSnapshotsForTerminalExperiments() error

DeleteSnapshotsForTerminalExperiments deletes all snapshots for terminal state experiments from the database.

func (*PgDB) DeleteTaskLogs

func (db *PgDB) DeleteTaskLogs(ids []model.TaskID) error

DeleteTaskLogs deletes the logs for the given tasks.

func (*PgDB) DeleteTrialLogs

func (db *PgDB) DeleteTrialLogs(ids []int) error

DeleteTrialLogs deletes the logs for the given trial IDs.

func (*PgDB) EndAllAgentStats

func (db *PgDB) EndAllAgentStats() error

EndAllAgentStats called at master starts, in case master previously crushed If master stops, statistics would treat “live” agents as live until master restarts.

func (*PgDB) EndAllInstanceStats

func (db *PgDB) EndAllInstanceStats() error

EndAllInstanceStats called at master starts, in case master previously crushed If master stops, statistics would treat “live” instances as live until master restarts.

func (*PgDB) EndInstanceStats

func (db *PgDB) EndInstanceStats(a *model.InstanceStats) error

EndInstanceStats updates the end time of an instance.

func (*PgDB) ExperimentConfigRaw

func (db *PgDB) ExperimentConfigRaw(id int) ([]byte, error)

ExperimentConfigRaw returns the full config object for an experiment as a JSON string.

func (*PgDB) ExperimentHasCheckpointsInRegistry

func (db *PgDB) ExperimentHasCheckpointsInRegistry(id int) (bool, error)

ExperimentHasCheckpointsInRegistry checks if the experiment has any checkpoints in the registry.

func (*PgDB) ExperimentIDByTrialID

func (db *PgDB) ExperimentIDByTrialID(trialID int) (int, error)

ExperimentIDByTrialID looks up an experiment ID by a trial ID.

func (*PgDB) ExperimentLabelUsage

func (db *PgDB) ExperimentLabelUsage(projectID int32) (labelUsage map[string]int, err error)

ExperimentLabelUsage returns a flattened and deduplicated list of all the labels in use across all experiments.

func (*PgDB) ExperimentModelDefinitionRaw

func (db *PgDB) ExperimentModelDefinitionRaw(id int) ([]byte, error)

ExperimentModelDefinitionRaw returns the zipped model definition for an experiment as a byte array.

func (*PgDB) ExperimentNumTrials

func (db *PgDB) ExperimentNumTrials(id int) (int64, error)

ExperimentNumTrials returns the total number of trials for the experiment.

func (*PgDB) ExperimentSnapshot

func (db *PgDB) ExperimentSnapshot(experimentID int) ([]byte, int, error)

ExperimentSnapshot returns the snapshot for the specified experiment.

func (*PgDB) ExperimentTrialIDs

func (db *PgDB) ExperimentTrialIDs(expID int) ([]int, error)

ExperimentTrialIDs returns the trial IDs for the experiment.

func (*PgDB) FailDeletingExperiment

func (db *PgDB) FailDeletingExperiment() error

FailDeletingExperiment finds all experiments that were deleting when the master crashed and moves them to DELETE_FAILED.

func (*PgDB) GetExperimentStatus

func (db *PgDB) GetExperimentStatus(experimentID int) (state model.State, progress float64,
	err error,
)

GetExperimentStatus returns the current state of the experiment.

func (*PgDB) GetOrCreateClusterID

func (db *PgDB) GetOrCreateClusterID(telemetryID string) (string, error)

GetOrCreateClusterID queries the master uuid in the database, adding one if it doesn't exist. If a nonempty telemetryID is provided, it will be the one added, otherwise a uuid is generated.

func (*PgDB) GetTrialProfilerMetricsBatches

func (db *PgDB) GetTrialProfilerMetricsBatches(
	labels *trialv1.TrialProfilerMetricLabels, offset, limit int,
) (model.TrialProfilerMetricsBatchBatch, error)

GetTrialProfilerMetricsBatches gets a batch of profiler metric batches from the database. This method is for backwards compatibility and should be deprecated in the future in favor of generics metrics APIs.

Profiler metrics are stored in the metrics table as a nested JSON mapping of labels to values. All profiler metrics are associated with an agent ID, but certain metrics (i.e. gpu_util) may be associated with other labels. For example:

{
	"agent-ID-1": {
		"GPU-UUID-1": {
			"gpu_util": 0.12,
			"gpu_free_memory": 0.34,
		}
	}
}

func (*PgDB) InsertTrialProfilerMetricsBatch

func (db *PgDB) InsertTrialProfilerMetricsBatch(
	values []float32, batches []int32, timestamps []time.Time, labels []byte,
) error

InsertTrialProfilerMetricsBatch inserts a batch of metrics into the database.

func (*PgDB) LatestCheckpointForTrial

func (db *PgDB) LatestCheckpointForTrial(trialID int) (*model.Checkpoint, error)

LatestCheckpointForTrial finds the latest completed checkpoint for a trial, returning nil if none exists.

func (*PgDB) LegacyExperimentConfigByID

func (db *PgDB) LegacyExperimentConfigByID(
	id int,
) (expconf.LegacyConfig, error)

LegacyExperimentConfigByID parses very old configs, returning a LegacyConfig which exposes a select subset of fields in a type-safe way.

func (*PgDB) MaxTerminationDelay

func (db *PgDB) MaxTerminationDelay() time.Duration

MaxTerminationDelay is the max delay before a consumer can be sure all logs have been recevied. For Postgres, we don't need to wait very long at all; this was a hypothetical cap on fluent to DB latency prior to fluent's deprecation.

func (*PgDB) MetricNames

func (db *PgDB) MetricNames(ctx context.Context, experimentIDs []int) (
	map[model.MetricGroup][]string, error,
)

MetricNames returns a list of metric names for the given experiment IDs.

func (*PgDB) Migrate

func (db *PgDB) Migrate(
	migrationURL string, dbCodeDir string, actions []string,
) error

Migrate runs the migrations from the specified directory URL.

func (*PgDB) NonTerminalExperiments

func (db *PgDB) NonTerminalExperiments() ([]*model.Experiment, error)

NonTerminalExperiments finds all experiments in the database whose states are not terminal.

func (*PgDB) PeriodicTelemetryInfo

func (db *PgDB) PeriodicTelemetryInfo() ([]byte, error)

PeriodicTelemetryInfo returns anonymous information about the usage of the current Determined cluster.

func (*PgDB) Query

func (db *PgDB) Query(queryName string, v interface{}, params ...interface{}) error

Query returns the result of the query. Any placeholder parameters are replaced with supplied params.

func (*PgDB) QueryF

func (db *PgDB) QueryF(
	queryName string, args []interface{}, v interface{}, params ...interface{},
) error

QueryF returns the result of the formatted query. Any placeholder parameters are replaced with supplied params.

func (*PgDB) QueryProto

func (db *PgDB) QueryProto(queryName string, v interface{}, args ...interface{}) error

QueryProto returns the result of the query. Any placeholder parameters are replaced with supplied args. Enum values must be the full name of the enum.

func (*PgDB) QueryProtof

func (db *PgDB) QueryProtof(
	queryName string, args []interface{}, v interface{}, params ...interface{},
) error

QueryProtof returns the result of the formated query. Any placeholder parameters are replaced with supplied params.

func (*PgDB) RawQuery

func (db *PgDB) RawQuery(queryName string, params ...interface{}) ([]byte, error)

RawQuery returns the result of the query as a raw byte string. Any placeholder parameters are replaced with supplied params.

func (*PgDB) RecordAgentStats

func (db *PgDB) RecordAgentStats(a *model.AgentStats) error

RecordAgentStats insert a record of instance start time if instance has not been started or already ended.

func (*PgDB) RecordInstanceStats

func (db *PgDB) RecordInstanceStats(a *model.InstanceStats) error

RecordInstanceStats insert a record of instance start time if instance has not been started or already ended.

func (*PgDB) SaveExperimentArchiveStatus

func (db *PgDB) SaveExperimentArchiveStatus(experiment *model.Experiment) error

SaveExperimentArchiveStatus saves the current experiment archive status to the database.

func (*PgDB) SaveExperimentConfig

func (db *PgDB) SaveExperimentConfig(id int, config expconf.ExperimentConfig) error

SaveExperimentConfig saves the current experiment config to the database.

func (*PgDB) SaveExperimentProgress

func (db *PgDB) SaveExperimentProgress(id int, progress *float64) error

SaveExperimentProgress stores the progress for an experiment in the database.

func (*PgDB) SaveExperimentState

func (db *PgDB) SaveExperimentState(experiment *model.Experiment) error

SaveExperimentState saves the current experiment state to the database.

func (*PgDB) SaveSnapshot

func (db *PgDB) SaveSnapshot(
	experimentID int, version int, experimentSnapshot []byte,
) error

SaveSnapshot saves a searcher and trial snapshot together.

func (*PgDB) TaskLogs

func (db *PgDB) TaskLogs(
	taskID model.TaskID, limit int, fs []api.Filter, order apiv1.OrderBy, followState interface{},
) ([]*model.TaskLog, interface{}, error)

TaskLogs takes a task ID and log offset, limit and filters and returns matching logs.

func (*PgDB) TaskLogsCount

func (db *PgDB) TaskLogsCount(taskID model.TaskID, fs []api.Filter) (int, error)

TaskLogsCount returns the number of logs in postgres for the given task.

func (*PgDB) TaskLogsFields

func (db *PgDB) TaskLogsFields(taskID model.TaskID) (*apiv1.TaskLogsFieldsResponse, error)

TaskLogsFields returns the unique fields that can be filtered on for the given task.

func (*PgDB) TerminateExperimentInRestart

func (db *PgDB) TerminateExperimentInRestart(id int, state model.State) error

TerminateExperimentInRestart is used during master restart to properly terminate an experiment which was either in the process of stopping or which is not restorable for some reason, such as an invalid experiment config after a version upgrade.

func (*PgDB) TokenStore

func (db *PgDB) TokenStore() *TokenStore

TokenStore returns a store for OAuth tokens backed by this database.

func (*PgDB) TopTrialsByTrainingLength

func (db *PgDB) TopTrialsByTrainingLength(experimentID int, maxTrials int, metric string,
	smallerIsBetter bool,
) (trials []int32, err error)

TopTrialsByTrainingLength chooses the subset of trials that has been training for the highest number of batches, using the specified metric as a tie breaker.

func (*PgDB) TrainingMetricBatches

func (db *PgDB) TrainingMetricBatches(experimentID int, metricName string, startTime time.Time) (
	batches []int32, endTime time.Time, err error,
)

TrainingMetricBatches returns the milestones (in batches processed) at which a specific training metric was recorded.

func (*PgDB) TrialExperimentAndRequestID

func (db *PgDB) TrialExperimentAndRequestID(id int) (int, model.RequestID, error)

TrialExperimentAndRequestID returns the trial's experiment and request ID.

func (*PgDB) TrialLogs

func (db *PgDB) TrialLogs(
	trialID, limit int, fs []api.Filter, order apiv1.OrderBy, followState interface{},
) ([]*model.TrialLog, interface{}, error)

TrialLogs takes a trial ID and log offset, limit and filters and returns matching trial logs.

func (*PgDB) TrialLogsCount

func (db *PgDB) TrialLogsCount(trialID int, fs []api.Filter) (int, error)

TrialLogsCount returns the number of logs in postgres for the given trial.

func (*PgDB) TrialLogsFields

func (db *PgDB) TrialLogsFields(trialID int) (*apiv1.TrialLogsFieldsResponse, error)

TrialLogsFields returns the unique fields that can be filtered on for the given trial.

func (*PgDB) TrialRunIDAndRestarts

func (db *PgDB) TrialRunIDAndRestarts(trialID int) (runID int, restart int, err error)

TrialRunIDAndRestarts returns the run id and restart count for a trial.

func (*PgDB) TrialState

func (db *PgDB) TrialState(trialID int) (model.State, error)

TrialState returns the current state of the given trial.

func (*PgDB) TrialStatus

func (db *PgDB) TrialStatus(trialID int) (model.State, *time.Time, error)

TrialStatus returns the current status of the given trial, including the end time without returning all its hparams and other unneeded details. Called in paths hotter than TrialByID allows.

func (*PgDB) TrialsSnapshot

func (db *PgDB) TrialsSnapshot(experimentID int, minBatches int, maxBatches int,
	metricName string, startTime time.Time, metricGroup model.MetricGroup,
) (trials []*apiv1.TrialsSnapshotResponse_Trial, endTime time.Time, err error)

TrialsSnapshot returns metrics across each trial in an experiment at a specific point of progress, for metric groups other than training and validation.

func (*PgDB) TrySaveExperimentState

func (db *PgDB) TrySaveExperimentState(experiment *model.Experiment) error

TrySaveExperimentState saves the current experiment state to the database and returns if we successfully changed the state or not.

func (*PgDB) UpdateClusterHeartBeat

func (db *PgDB) UpdateClusterHeartBeat(currentClusterHeartbeat time.Time) error

UpdateClusterHeartBeat updates the clusterheartbeat column in the cluster_id table.

func (*PgDB) UpdateResourceAllocationAggregation

func (db *PgDB) UpdateResourceAllocationAggregation() error

UpdateResourceAllocationAggregation updates the aggregated resource allocation table.

func (*PgDB) UpdateTrialFields

func (db *PgDB) UpdateTrialFields(id int, newRunnerMetadata *trialv1.TrialRunnerMetadata, newRunID,
	newRestarts int,
) error

UpdateTrialFields updates the specified fields of trial with ID id. Fields that are nil or zero are not updated.

func (*PgDB) ValidationByTotalBatches

func (db *PgDB) ValidationByTotalBatches(trialID, totalBatches int) (*model.TrialMetrics, error)

ValidationByTotalBatches looks up a validation by trial and total batches, returning nil if none exists.

func (*PgDB) ValidationMetricBatches

func (db *PgDB) ValidationMetricBatches(experimentID int, metricName string, startTime time.Time) (
	batches []int32, endTime time.Time, err error,
)

ValidationMetricBatches returns the milestones (in batches processed) at which a specific validation metric was recorded.

type RPWorkspaceBinding

type RPWorkspaceBinding struct {
	bun.BaseModel `bun:"table:rp_workspace_bindings"`

	WorkspaceID int    `bun:"workspace_id"`
	PoolName    string `bun:"pool_name"`
	Valid       bool   `bun:"valid"`
}

RPWorkspaceBinding is a struct reflecting the db table rp_workspace_bindings.

func GetAllBindings

func GetAllBindings(
	ctx context.Context,
) ([]*RPWorkspaceBinding, error)

GetAllBindings gets all valid rp-workspace bindings.

func ReadWorkspacesBoundToRP

func ReadWorkspacesBoundToRP(
	ctx context.Context, poolName string, offset, limit int32,
	resourcePools []config.ResourcePoolConfig,
) ([]*RPWorkspaceBinding, *apiv1.Pagination, error)

ReadWorkspacesBoundToRP get the bindings between workspaceIds and the requested resource pool.

type SortDirection

type SortDirection string

SortDirection represents the order by in a query.

const (
	// SortDirectionAsc represents ordering by ascending.
	SortDirectionAsc SortDirection = "ASC"
	// SortDirectionDesc represents ordering by descending.
	SortDirectionDesc SortDirection = "DESC"
	// SortDirectionAscNullsFirst represents ordering by ascending with nulls first.
	SortDirectionAscNullsFirst SortDirection = "ASC NULLS FIRST"
	// SortDirectionDescNullsLast represents ordering by descending with nulls last.
	SortDirectionDescNullsLast SortDirection = "DESC NULLS LAST"
)

type StaticQueryMap

type StaticQueryMap struct {
	sync.Mutex
	// contains filtered or unexported fields
}

StaticQueryMap caches static sql files.

func (*StaticQueryMap) GetOrLoad

func (q *StaticQueryMap) GetOrLoad(queryName string) string

GetOrLoad fetches static sql from the cache or loads them from disk.

type TokenStore

type TokenStore struct {
	// contains filtered or unexported fields
}

TokenStore is a store for OAuth tokens. It is separate from PgDB so we can implement an interface of the external OAuth library without polluting PgDB's method set.

func (*TokenStore) Create

func (s *TokenStore) Create(info oauth2.TokenInfo) error

Create adds a new token to the database.

func (*TokenStore) GetByAccess

func (s *TokenStore) GetByAccess(access string) (oauth2.TokenInfo, error)

GetByAccess gets the token with the given access token value.

func (*TokenStore) GetByCode

func (s *TokenStore) GetByCode(code string) (oauth2.TokenInfo, error)

GetByCode gets the token with the given authorization code.

func (*TokenStore) GetByRefresh

func (s *TokenStore) GetByRefresh(refresh string) (oauth2.TokenInfo, error)

GetByRefresh gets the token with the given refresh token value.

func (*TokenStore) RemoveByAccess

func (s *TokenStore) RemoveByAccess(access string) error

RemoveByAccess deletes any tokens with the given access token value.

func (*TokenStore) RemoveByCode

func (s *TokenStore) RemoveByCode(code string) error

RemoveByCode deletes any tokens with the given authorization code.

func (*TokenStore) RemoveByRefresh

func (s *TokenStore) RemoveByRefresh(refresh string) error

RemoveByRefresh deletes any tokens with the given refresh token value.

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL