cohere

package
v0.0.0-...-85d02b3 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jun 24, 2024 License: Apache-2.0 Imports: 15 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

View Source
var (
	// Reference: https://platform.openai.com/docs/api-reference/chat/object
	CompleteReason  = "complete"
	MaxTokensReason = "max_tokens"
	FilteredReason  = "error_toxic"
)

Functions

This section is empty.

Types

type ChatCompletion

type ChatCompletion struct {
	Text          string                 `json:"text"`
	GenerationID  string                 `json:"generation_id"`
	ResponseID    string                 `json:"response_id"`
	TokenCount    TokenCount             `json:"token_count"`
	Citations     []Citation             `json:"citations"`
	Documents     []Documents            `json:"documents"`
	SearchQueries []SearchQuery          `json:"search_queries"`
	SearchResults []SearchResults        `json:"search_results"`
	Meta          Meta                   `json:"meta"`
	ToolInputs    map[string]interface{} `json:"tool_inputs"`
}

Cohere Chat Response

type ChatCompletionChunk

type ChatCompletionChunk struct {
	IsFinished   bool           `json:"is_finished"`
	EventType    string         `json:"event_type"`
	GenerationID *string        `json:"generation_id"`
	Text         string         `json:"text"`
	Response     *FinalResponse `json:"response,omitempty"`
	FinishReason *string        `json:"finish_reason,omitempty"`
}

ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming Ref: https://docs.cohere.com/reference/about

type ChatRequest

type ChatRequest struct {
	Model             string                `json:"model"`
	Message           string                `json:"message"`
	ChatHistory       []schemas.ChatMessage `json:"chat_history"`
	Temperature       float64               `json:"temperature,omitempty"`
	Preamble          string                `json:"preamble,omitempty"`
	PromptTruncation  *string               `json:"prompt_truncation,omitempty"`
	Connectors        []string              `json:"connectors,omitempty"`
	SearchQueriesOnly bool                  `json:"search_queries_only,omitempty"`
	Stream            bool                  `json:"stream,omitempty"`
	Seed              *int                  `json:"seed,omitempty"`
	MaxTokens         *int                  `json:"max_tokens,omitempty"`
	K                 int                   `json:"k"`
	P                 float32               `json:"p"`
	FrequencyPenalty  float32               `json:"frequency_penalty"`
	PresencePenalty   float32               `json:"presence_penalty"`
	StopSequences     []string              `json:"stop_sequences"`
}

ChatRequest is a request to complete a chat completion Ref: https://docs.cohere.com/reference/chat

func NewChatRequestFromConfig

func NewChatRequestFromConfig(cfg *Config) *ChatRequest

NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives

func (*ChatRequest) ApplyParams

func (r *ChatRequest) ApplyParams(params *schemas.ChatParams)

type ChatStream

type ChatStream struct {
	// contains filtered or unexported fields
}

ChatStream represents cohere chat stream for a specific request

func NewChatStream

func NewChatStream(
	tel *telemetry.Telemetry,
	client *http.Client,
	req *http.Request,
	modelName string,
	errMapper *ErrorMapper,
	finishReasonMapper *FinishReasonMapper,
) *ChatStream

func (*ChatStream) Close

func (s *ChatStream) Close() error

func (*ChatStream) Open

func (s *ChatStream) Open() error

func (*ChatStream) Recv

func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error)

type Citation

type Citation struct {
	Start      int      `json:"start"`
	End        int      `json:"end"`
	Text       string   `json:"text"`
	DocumentID []string `json:"document_id"`
}

type Client

type Client struct {
	// contains filtered or unexported fields
}

Client is a client for accessing Cohere API

func NewClient

func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error)

NewClient creates a new Cohere client for the Cohere API.

func (*Client) Chat

func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error)

Chat sends a chat request to the specified cohere model.

func (*Client) ChatStream

func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error)

func (*Client) ModelName

func (c *Client) ModelName() string

func (*Client) Provider

func (c *Client) Provider() string

func (*Client) SupportChatStream

func (c *Client) SupportChatStream() bool

type Config

type Config struct {
	BaseURL       string        `yaml:"base_url" json:"base_url" validate:"required,http_url"`
	ChatEndpoint  string        `yaml:"chat_endpoint" json:"chat_endpoint" validate:"required"`
	ModelName     string        `yaml:"model" json:"model" validate:"required"` // https://docs.cohere.com/docs/models#command
	APIKey        fields.Secret `yaml:"api_key" json:"-" validate:"required"`
	DefaultParams *Params       `yaml:"default_params,omitempty" json:"defaultParams"`
}

func DefaultConfig

func DefaultConfig() *Config

DefaultConfig for Cohere models

func (*Config) UnmarshalYAML

func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error

type Connectors

type Connectors struct {
	ID              string            `json:"id"`
	UserAccessToken string            `json:"user_access_token"`
	ContOnFail      string            `json:"continue_on_failure"`
	Options         map[string]string `json:"options"`
}

type ConnectorsResponse

type ConnectorsResponse struct {
	ID              string            `json:"id"`
	UserAccessToken string            `json:"user_access_token"`
	ContOnFail      string            `json:"continue_on_failure"`
	Options         map[string]string `json:"options"`
}

type Documents

type Documents struct {
	ID   string            `json:"id"`
	Data map[string]string `json:"data"` // TODO: This needs to be updated
}

type ErrorMapper

type ErrorMapper struct {
	// contains filtered or unexported fields
}

func NewErrorMapper

func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper

func (*ErrorMapper) Map

func (m *ErrorMapper) Map(resp *http.Response) error

type FinalResponse

type FinalResponse struct {
	ResponseID   string     `json:"response_id"`
	Text         string     `json:"text"`
	GenerationID string     `json:"generation_id"`
	TokenCount   TokenCount `json:"token_count"`
	Meta         Meta       `json:"meta"`
}

type FinishReasonMapper

type FinishReasonMapper struct {
	// contains filtered or unexported fields
}

func NewFinishReasonMapper

func NewFinishReasonMapper(tel *telemetry.Telemetry) *FinishReasonMapper

func (*FinishReasonMapper) Map

func (m *FinishReasonMapper) Map(finishReason *string) *schemas.FinishReason

type Meta

type Meta struct {
	APIVersion struct {
		Version string `json:"version"`
	} `json:"api_version"`
	BilledUnits struct {
		InputTokens  int `json:"input_tokens"`
		OutputTokens int `json:"output_tokens"`
	} `json:"billed_units"`
}

type Params

type Params struct {
	Seed              *int     `yaml:"seed,omitempty" json:"seed,omitempty" validate:"omitempty,number"`
	Temperature       float64  `yaml:"temperature,omitempty" json:"temperature" validate:"required,number"`
	MaxTokens         *int     `yaml:"max_tokens,omitempty" json:"max_tokens,omitempty" validate:"omitempty,number"`
	K                 int      `yaml:"k,omitempty" json:"k" validate:"number,gte=0,lte=500"`
	P                 float32  `yaml:"p,omitempty" json:"p" validate:"number,gte=0.01,lte=0.99"`
	FrequencyPenalty  float32  `yaml:"frequency_penalty,omitempty" json:"frequency_penalty" validate:"gte=0.0,lte=1.0"`
	PresencePenalty   float32  `yaml:"presence_penalty,omitempty" json:"presence_penalty" validate:"gte=0.0,lte=1.0"`
	Preamble          string   `yaml:"preamble,omitempty" json:"preamble,omitempty"`
	StopSequences     []string `yaml:"stop_sequences,omitempty" json:"stop_sequences" validate:"max=5"`
	PromptTruncation  *string  `yaml:"prompt_truncation,omitempty" json:"prompt_truncation,omitempty"`
	Connectors        []string `yaml:"connectors,omitempty" json:"connectors,omitempty"`
	SearchQueriesOnly bool     `yaml:"search_queries_only,omitempty" json:"search_queries_only,omitempty"`
}

Params defines Cohere-specific model params with the specific validation of values TODO: Add validations

func DefaultParams

func DefaultParams() Params

func (*Params) UnmarshalYAML

func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error

type SearchQuery

type SearchQuery struct {
	Text         string `json:"text"`
	GenerationID string `json:"generation_id"`
}

type SearchQueryObject

type SearchQueryObject struct {
	Text         string `json:"text"`
	GenerationID string `json:"generationId"`
}

type SearchResults

type SearchResults struct {
	SearchQuery []SearchQueryObject  `json:"search_query"`
	Connectors  []ConnectorsResponse `json:"connectors"`
	DocumentID  []string             `json:"documentId"`
}

type StreamReader

type StreamReader struct {
	// contains filtered or unexported fields
}

StreamReader reads Cohere streaming chat chunks that are formated as serializer chunk json per line (a.k.a. application/stream+json)

func NewStreamReader

func NewStreamReader(stream io.Reader, maxBufferSize int) *StreamReader

NewStreamReader creates an instance of StreamReader

func (*StreamReader) ReadEvent

func (r *StreamReader) ReadEvent() ([]byte, error)

ReadEvent scans the EventStream for events.

type SupportedEventType

type SupportedEventType = string

SupportedEventType Cohere has other types too: Ref: https://docs.cohere.com/reference/chat (see Chat -> Responses -> StreamedChatResponse)

var (
	StreamStartEvent SupportedEventType = "stream-start"
	TextGenEvent     SupportedEventType = "text-generation"
	StreamEndEvent   SupportedEventType = "stream-end"
)

type TokenCount

type TokenCount struct {
	PromptTokens   int `json:"prompt_tokens"`
	ResponseTokens int `json:"response_tokens"`
	TotalTokens    int `json:"total_tokens"`
	BilledTokens   int `json:"billed_tokens"`
}

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL