Documentation ¶
Index ¶
Constants ¶
This section is empty.
Variables ¶
View Source
var CreateCmd = &cobra.Command{ Use: "create <owner/model[:version]> [input=value] ... [flags]", Short: "Create a prediction", Args: cobra.MinimumNArgs(1), Aliases: []string{"new", "run"}, RunE: func(cmd *cobra.Command, args []string) error { id, err := identifier.ParseIdentifier(args[0]) if err != nil { return fmt.Errorf("invalid model specified: %s", args[0]) } s := spinner.New(spinner.CharSets[21], 100*time.Millisecond) s.FinalMSG = "" ctx := cmd.Context() r8, err := client.NewClient() if err != nil { return err } var version *replicate.ModelVersion if id.Version == "" { if model, err := r8.GetModel(ctx, id.Owner, id.Name); err == nil { version = model.LatestVersion } } else { if v, err := r8.GetModelVersion(ctx, id.Owner, id.Name, id.Version); err == nil { version = v } } stdin, err := util.GetPipedArgs() if err != nil { return fmt.Errorf("failed to get stdin info: %w", err) } separator := cmd.Flag("separator").Value.String() inputs, err := util.ParseInputs(ctx, r8, args[1:], stdin, separator) if err != nil { return fmt.Errorf("failed to parse inputs: %w", err) } var inputSchema *openapi3.Schema var outputSchema *openapi3.Schema if version != nil { inputSchema, outputSchema, err = util.GetSchemas(*version) if err != nil { return fmt.Errorf("failed to get input schema for version: %w", err) } } coercedInputs, err := util.CoerceTypes(inputs, inputSchema) if err != nil { return fmt.Errorf("failed to coerce inputs: %w", err) } shouldWait := (cmd.Flags().Changed("wait") || !cmd.Flags().Changed("no-wait")) canStream := (outputSchema != nil && outputSchema.Type.Is("array") && outputSchema.Items.Value.Type.Is("string") && outputSchema.Extensions["x-cog-array-type"] == "iterator" && outputSchema.Extensions["x-cog-array-display"] == "concatenate") shouldStream := canStream && !cmd.Flags().Changed("wait") && (cmd.Flags().Changed("stream") || !cmd.Flags().Changed("no-stream")) s.Start() var prediction *replicate.Prediction if id.Version == "" { prediction, err = r8.CreatePredictionWithModel(ctx, id.Owner, id.Name, coercedInputs, nil, shouldStream) if err != nil { if version != nil { prediction, err = r8.CreatePrediction(ctx, version.ID, coercedInputs, nil, shouldStream) } } } else { prediction, err = r8.CreatePrediction(ctx, id.Version, coercedInputs, nil, shouldStream) } if err != nil { return fmt.Errorf("failed to create prediction: %w", err) } s.Stop() hasStream := prediction.URLs["stream"] != "" if !util.IsTTY() || cmd.Flags().Changed("json") { if hasStream { events, _ := r8.StreamPrediction(ctx, prediction) if cmd.Flags().Changed("json") { fmt.Print("[") defer fmt.Print("]") prefix := "" for event := range events { if event.Type != replicate.SSETypeOutput { continue } if event.Data == "" { continue } b, err := json.Marshal(event.Data) if err != nil { return fmt.Errorf("failed to marshal event: %w", err) } fmt.Printf("%s%s", prefix, string(b)) prefix = ", " } } else { for event := range events { if event.Type != replicate.SSETypeOutput { continue } fmt.Print(event.Data) } fmt.Println("") } return nil } if shouldWait { err = r8.Wait(ctx, prediction) if err != nil { return fmt.Errorf("failed to wait for prediction: %w", err) } } b, err := json.Marshal(prediction) if err != nil { return fmt.Errorf("failed to marshal prediction: %w", err) } fmt.Println(string(b)) return nil } url := fmt.Sprintf("https://replicate.com/p/%s", prediction.ID) if !hasStream { fmt.Printf("Prediction created: %s\n", url) } if cmd.Flags().Changed("web") { if util.IsTTY() { fmt.Println("Opening in browser...") } err = browser.OpenURL(url) if err != nil { return fmt.Errorf("failed to open browser: %w", err) } return nil } if hasStream { sseChan, errChan := r8.StreamPrediction(ctx, prediction) tokens := []string{} for { select { case event, ok := <-sseChan: if !ok { return nil } switch event.Type { case replicate.SSETypeOutput: token := event.Data tokens = append(tokens, token) fmt.Print(token) case replicate.SSETypeLogs: case replicate.SSETypeDone: return nil default: } case err, ok := <-errChan: if !ok { return nil } return fmt.Errorf("streaming error: %w", err) } if cmd.Flags().Changed("save") { var dirname string if cmd.Flags().Changed("output-directory") { dirname = cmd.Flag("output-directory").Value.String() } else { dirname = fmt.Sprintf("./%s", prediction.ID) } dir, err := filepath.Abs(dirname) if err != nil { return fmt.Errorf("failed to create output directory: %w", err) } err = os.MkdirAll(dir, 0o755) if err != nil { return fmt.Errorf("failed to create directory: %w", err) } err = os.WriteFile(filepath.Join(dir, "output.txt"), []byte(strings.Join(tokens, "")), 0o644) if err != nil { return fmt.Errorf("failed to write output: %w", err) } } } } else if shouldWait { bar := progressbar.Default(100) bar.Describe("processing") predChan, errChan := r8.WaitAsync(ctx, prediction) for pred := range predChan { progress := pred.Progress() if progress != nil { bar.ChangeMax(progress.Total) _ = bar.Set(progress.Current) } if pred.Status.Terminated() { _ = bar.Finish() break } } if err := <-errChan; err != nil { return fmt.Errorf("failed to wait for prediction: %w", err) } switch prediction.Status { case replicate.Succeeded: fmt.Println("✅ Succeeded") bytes, err := json.MarshalIndent(prediction.Output, "", " ") if err != nil { return fmt.Errorf("failed to marshal output: %w", err) } fmt.Println(string(bytes)) case replicate.Failed: fmt.Println("❌ Failed") fmt.Println(*prediction.Logs) bytes, err := json.MarshalIndent(prediction.Error, "", " ") if err != nil { return fmt.Errorf("error: %v", prediction.Error) } fmt.Println(string(bytes)) case replicate.Canceled: fmt.Println("🚫 Canceled") fmt.Println(prediction.Logs) } if cmd.Flags().Changed("save") && prediction.Status == replicate.Succeeded { var dirname string if cmd.Flags().Changed("output-directory") { dirname = cmd.Flag("output-directory").Value.String() } else { dirname = fmt.Sprintf("./%s", prediction.ID) } dir, err := filepath.Abs(dirname) if err != nil { return fmt.Errorf("failed to create output directory: %w", err) } err = util.DownloadPrediction(ctx, *prediction, dir) if err != nil { return fmt.Errorf("failed to save output: %w", err) } } } return nil }, }
View Source
var RootCmd = &cobra.Command{ Use: "prediction [subcommand]", Short: "Interact with predictions", Aliases: []string{"predictions", "p"}, }
Functions ¶
func AddCreateFlags ¶
Types ¶
This section is empty.
Click to show internal directories.
Click to hide internal directories.