training

package
v0.8.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jun 20, 2024 License: Apache-2.0 Imports: 14 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

View Source
var CreateCmd = &cobra.Command{
	Use:     "create <owner/model[:version]> --destination <owner/model> [input=value] ... [flags]",
	Short:   "Create a training",
	Args:    cobra.MinimumNArgs(1),
	Aliases: []string{"new", "train"},
	RunE: func(cmd *cobra.Command, args []string) error {

		destination := cmd.Flag("destination").Value.String()
		if _, err := identifier.ParseIdentifier(destination); err != nil {
			return fmt.Errorf("invalid destination specified: %s", destination)
		}

		id, err := identifier.ParseIdentifier(args[0])
		if err != nil {
			return fmt.Errorf("invalid model specified: %s", args[0])
		}

		s := spinner.New(spinner.CharSets[21], 100*time.Millisecond)
		s.FinalMSG = ""

		ctx := cmd.Context()

		r8, err := client.NewClient()
		if err != nil {
			return err
		}

		var version *replicate.ModelVersion
		if id.Version == "" {
			model, err := r8.GetModel(ctx, id.Owner, id.Name)
			if err != nil {
				return fmt.Errorf("failed to get model: %w", err)
			}

			if model.LatestVersion == nil {
				return fmt.Errorf("no versions found for model %s", args[0])
			}

			version = model.LatestVersion
		} else {
			version, err = r8.GetModelVersion(ctx, id.Owner, id.Name, id.Version)
			if err != nil {
				return fmt.Errorf("failed to get model version: %w", err)
			}
		}

		stdin, err := util.GetPipedArgs()
		if err != nil {
			return fmt.Errorf("failed to get stdin info: %w", err)
		}

		separator := cmd.Flag("separator").Value.String()
		inputs, err := util.ParseInputs(ctx, r8, args[1:], stdin, separator)
		if err != nil {
			return fmt.Errorf("failed to parse inputs: %w", err)
		}

		coercedInputs, err := util.CoerceTypes(inputs, nil)
		if err != nil {
			return fmt.Errorf("failed to coerce inputs: %w", err)
		}

		s.Start()
		training, err := r8.CreateTraining(ctx, id.Owner, id.Name, version.ID, destination, coercedInputs, nil)
		if err != nil {
			return fmt.Errorf("failed to create training: %w", err)
		}
		s.Stop()

		url := fmt.Sprintf("https://replicate.com/p/%s", training.ID)
		fmt.Printf("Training created: %s\n", url)

		if cmd.Flags().Changed("web") {
			if util.IsTTY() {
				fmt.Println("Opening in browser...")
			}

			err = browser.OpenURL(url)
			if err != nil {
				return fmt.Errorf("failed to open browser: %w", err)
			}

			return nil
		}

		if cmd.Flags().Changed("json") || !util.IsTTY() {
			b, err := json.Marshal(training)
			if err != nil {
				return fmt.Errorf("failed to marshal training: %w", err)
			}

			fmt.Println(string(b))
			return nil
		}

		return nil
	},
}

CreateCmd represents the create command

View Source
var RootCmd = &cobra.Command{
	Use:     "training [subcommand]",
	Short:   "Interact with trainings",
	Aliases: []string{"trainings", "t"},
}

Functions

func AddCreateFlags

func AddCreateFlags(cmd *cobra.Command)

Types

This section is empty.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL