diff --git a/README.md b/README.md index ad60a12..a482350 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,16 @@ # go-llm -Large Language Model API interface. This is a simple API interface for large language models +The module implements a simple API interface for large language models which run on [Ollama](https://github.com/ollama/ollama/blob/main/docs/api.md), -[Anthopic](https://docs.anthropic.com/en/api/getting-started) and [Mistral](https://docs.mistral.ai/) -(OpenAI might be added later). +[Anthopic](https://docs.anthropic.com/en/api/getting-started), [Mistral](https://docs.mistral.ai/) +and [OpenAI](https://platform.openai.com/docs/api-reference). The module implements the ability to: -The module includes the ability to utilize: - -* Maintaining a session of messages -* Tool calling support, including using your own tools (aka Tool plugins) -* Creating embedding vectors from text -* Streaming responses -* Multi-modal support (aka, Images and Attachments) +* Maintain a session of messages; +* Tool calling support, including using your own tools (aka Tool plugins); +* Create embedding vectors from text; +* Stream responses; +* Multi-modal support (aka, Images, Audio and Attachments); +* Text-to-speech (OpenAI only) for completions There is a command-line tool included in the module which can be used to interact with the API. If you have docker installed, you can use the following command to run the tool, without @@ -24,7 +23,8 @@ docker run ghcr.io/mutablelogic/go-llm:latest --help # Interact with Claude to retrieve news headlines, assuming # you have an API key for Anthropic and NewsAPI docker run \ - -e OLLAMA_URL -e MISTRAL_API_KEY -e NEWSAPI_KEY \ + -e OLLAMA_URL -e MISTRAL_API_KEY -e ANTHROPIC_API_KEY -e OPENAI_API_KEY \ + -e NEWSAPI_KEY \ ghcr.io/mutablelogic/go-llm:latest \ chat mistral-small-latest --prompt "What is the latest news?" --no-stream ``` @@ -35,7 +35,7 @@ install it if you have a `go` compiler). ## Programmatic Usage See the documentation [here](https://pkg.go.dev/github.com/mutablelogic/go-llm) -for integration into your own Go programs. +for integration into your own code. ### Agent Instantiation @@ -95,6 +95,24 @@ func main() { } ``` +Similarly for [OpenAI](https://pkg.go.dev/github.com/mutablelogic/go-llm/pkg/openai) +models, you can use: + +```go +import ( + "github.com/mutablelogic/go-llm/pkg/openai" +) + +func main() { + // Create a new agent + agent, err := openai.New(os.Getenv("OPENAI_API_KEY")) + if err != nil { + panic(err) + } + // ... +} +``` + You can append options to the agent creation to set the client/server communication options, such as user agent strings, timeouts, debugging, rate limiting, adding custom headers, etc. See [here](https://pkg.go.dev/github.com/mutablelogic/go-client#readme-basic-usage) for more information. @@ -111,6 +129,7 @@ func main() { agent, err := agent.New( agent.WithAnthropic(os.Getenv("ANTHROPIC_API_KEY")), agent.WithMistral(os.Getenv("MISTRAL_API_KEY")), + agent.WithOpenAI(os.Getenv("OPENAI_API_KEY")), agent.WithOllama(os.Getenv("OLLAMA_URL")), ) if err != nil { @@ -120,6 +139,30 @@ func main() { } ``` +### Completion + +You can generate a **completion** as follows, + +```go +import ( + "github.com/mutablelogic/go-llm" +) + +func completion(ctx context.Context, agent llm.Agent) (string, error) { + completion, err := agent. + Model(ctx, "claude-3-5-haiku-20241022"). + Completion((ctx, "Why is the sky blue?") + if err != nil { + return "", err + } else { + return completion.Text(0), nil + } +} +``` + +The zero index argument on `completion.Text(int)` indicates you want the text from the zero'th completion +choice, for providers who can generate serveral different choices simultaneously. + ### Chat Sessions You create a **chat session** with a model as follows, @@ -131,7 +174,9 @@ import ( func session(ctx context.Context, agent llm.Agent) error { // Create a new chat session - session := agent.Model(context.TODO(), "claude-3-5-haiku-20241022").Context() + session := agent. + Model(ctx, "claude-3-5-haiku-20241022"). + Context() // Repeat forever for { @@ -147,11 +192,11 @@ func session(ctx context.Context, agent llm.Agent) error { ``` The `Context` object will continue to store the current session and options, and will -ensure the session is maintained across multiple calls. +ensure the session is maintained across multiple completion calls. ### Embedding Generation -You can generate embedding vectors using an appropriate model with Ollama or Mistral models: +You can generate embedding vectors using an appropriate model with Ollama, OpenAI and Mistral models: ```go import ( @@ -159,8 +204,9 @@ import ( ) func embedding(ctx context.Context, agent llm.Agent) error { - // Create a new chat session - vector, err := agent.Model(ctx, "mistral-embed").Embedding(ctx, "hello") + vector, err := agent. + Model(ctx, "mistral-embed"). + Embedding(ctx, "hello") // ... } ``` @@ -182,21 +228,19 @@ func generate_image_caption(ctx context.Context, agent llm.Agent, path string) ( } defer f.Close() - // Describe an image - r, err := agent.Model("claude-3-5-sonnet-20241022").UserPrompt( - ctx, model.UserPrompt("Provide a short caption for this image", llm.WithAttachment(f)) - ) + completion, err := agent. + Model(ctx, "claude-3-5-sonnet-20241022"). + Completion((ctx, "Provide a short caption for this image", llm.WithAttachment(f)) if err != nil { - return "", err - } + return "", err + } - // Return success - return r.Text(0), err + return completion.Text(0), nil } ``` -To summarize a text or PDF docment is exactly the same using an Anthropic model, but maybe with a -different prompt. +To summarize a text or PDF document is exactly the same using an Anthropic model, but maybe +with a different prompt. ### Streaming @@ -210,16 +254,14 @@ import ( ) func generate_completion(ctx context.Context, agent llm.Agent, prompt string) (string, error) { - r, err := agent.Model("claude-3-5-sonnet-20241022").UserPrompt( - ctx, model.UserPrompt("What is the weather in London?"), - llm.WithStream(stream_callback), - ) + completion, err := agent. + Model(ctx, "claude-3-5-haiku-20241022"). + Completion((ctx, "Why is the sky blue?", llm.WithStream(stream_callback)) if err != nil { return "", err + } else { + return completion.Text(0), nil } - - // Return success - return r.Text(0), err } func stream_callback(completion llm.Completion) { @@ -231,30 +273,232 @@ func stream_callback(completion llm.Completion) { ### Tool Support -All providers support tools, but not all models. +All providers support tools, but not all models. Your own tools should implement the +following interface: -TODO +```go +package llm + +// Definition of a tool +type Tool interface { + Name() string // The name of the tool + Description() string // The description of the tool + Run(context.Context) (any, error) // Run the tool with a deadline and + // return the result +} +``` -## Options +For example, if you want to implement a tool which adds two numbers, + +```go +package addition + +type Adder struct { + A float64 `name:"a" help:"The first number" required:"true"` + B float64 `name:"b" help:"The second number" required:"true"` +} + +func (Adder) Name() string { + return "add_two_numbers" +} + +func (Adder) Description() string { + return "Add two numbers together and return the result" +} + +func (a Adder) Run(context.Context) (any, error) { + return a.A + a.B, nil +} +``` + +Then you can include your tool as part of the completion. It's possible that a +completion will continue to call additional tools, in which case you should +actually loop through completions until no tool calls are made. + +```go +import ( + "github.com/mutablelogic/go-llm" + "github.com/mutablelogic/go-llm/pkg/tool" +) + +func add_two_numbers(ctx context.Context, agent llm.Agent) (string, error) { + context := agent.Model(ctx, "claude-3-5-haiku-20241022").Context() + toolkit := tool.NewToolKit() + toolkit.Register(Adder{}) + + // Get the tool call + if err := context.FromUser(ctx, "What is five plus seven?", llm.WithToolKit(toolkit)); err != nil { + return "", err + } + + // Call tools + for { + calls := context.ToolCalls(0) + if len(calls) == 0 { + break + } + + // Print out any intermediate messages + if context.Text(0) != "" { + fmt.Println(context.Text(0)) + } + + // Get the results from the toolkit + results, err := toolkit.Run(ctx, calls...) + if err != nil { + return "", err + } + + // Get another tool call or a user response + if err := context.FromTool(ctx, results...); err != nil { + return "", err + } + } + + // Return the result + return context.Text(0) +} +``` + +Parameters are implemented as struct fields, with tags. The tags you can include are: + +* `name:""` - Set the name for the parameter +* `json:""` - If `name` is not used, then the name is set from the `json` tag +* `help:":` - Set the description for the parameter +* `required:""` - The parameter is required as part of the tool call +* `enum:"a,b,c"` - The parameter value should be one of these comma-separated options + +The transation of field types is as follows: + +* `string` - Translates as JSON `string` +* `uint`, `int` - Translates to JSON `integer` +* `float32`, `float64` - Translates to JSON `number` + +## Complete and Chat Options + +These are the options you can use with the `Completion` and `Chat` methods. + + + + + + + + + + + + + + + + + + + +
OllamaAnthropicMistralOpenAIGemini
+ llm.WithTemperature(float64) + What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.7 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. +
YesYesYesYesYes
+ +## Embedding Options + +These are the options you can include for the `Embedding`method. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
OllamaAnthropicMistralOpenAIGemini
+ ollama.WithKeepAlive(time.Duration) + Controls how long the model will stay loaded into memory following the request +
YesNoNoNoNo
+ ollama.WithTruncate() + Does not truncate the end of each input to fit within context length. Returns error if context length is exceeded. +
YesNoNoNoNo
+ ollama.WithOption(string, any) + Set model-specific option value. +
YesNoNoNoNo
+ openai.WithDimensions(uint64) + The number of dimensions the resulting output embeddings + should have. Only supported in text-embedding-3 and later models. +
NoNoNoYesNo
+ llm.WithFormat(string) + The format to return the embeddings in. Can be either . +
NoNo'float''float' or 'base64'No
+ +## Older Content You can add options to sessions, or to prompts. Different providers and models support different options. ```go +package llm + type Model interface { // Set session-wide options Context(...Opt) Context - // Add attachments (images, PDF's) to a user prompt for completion - UserPrompt(string, ...Opt) Context + // Create a completion from a text prompt + Completion(context.Context, string, ...Opt) (Completion, error) - // Create an embedding vector with embedding options + // Embedding vector generation Embedding(context.Context, string, ...Opt) ([]float64, error) } type Context interface { - // Add single-use options when calling the model, which override - // session options. You can attach files to a user prompt. + // Generate a response from a user prompt (with attachments and + // other options) FromUser(context.Context, string, ...Opt) error } ``` @@ -263,26 +507,38 @@ The options are as follows: | Option | Ollama | Anthropic | Mistral | OpenAI | Description | |--------|--------|-----------|---------|--------|-------------| -| `llm.WithTemperature(float64)` | Yes | Yes | Yes | - | What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.7 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. | -| `llm.WithTopP(float64)` | Yes | Yes | Yes | - | Nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. | -| `llm.WithTopK(uint64)` | Yes | Yes | No | - | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. | -| `llm.WithMaxTokens(uint64)` | No | Yes | Yes | - | The maximum number of tokens to generate in the response. | -| `llm.WithStream(func(llm.Completion))` | Can be enabled when tools are not used | Yes | Yes | - | Stream the response to a function. | -| `llm.WithToolChoice(string, string, ...)` | No | Use `auto`, `any` or a function name. Only the first argument is used. | Use `auto`, `any`, `none`, `required` or a function name. Only the first argument is used. | - | The tool to use for the model. | -| `llm.WithToolKit(llm.ToolKit)` | Cannot be combined with streaming | Yes | Yes | - | The set of tools to use. | -| `llm.WithStopSequence(string, string, ...)` | Yes | Yes | Yes | - | Stop generation if one of these tokens is detected. | -| `llm.WithSystemPrompt(string)` | No | Yes | Yes | - | Set the system prompt for the model. | -| `llm.WithSeed(uint64)` | Yes | No | Yes | - | The seed to use for random sampling. If set, different calls will generate deterministic results. | -| `llm.WithFormat(string)` | Use `json` | No | Use `json_format` or `text` | - | The format of the response. For Mistral, you must also instruct the model to produce JSON yourself with a system or a user message. | -| `llm.WithPresencePenalty(float64)` | Yes | No | Yes | - | Determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative. | -| `llm.WithFequencyPenalty(float64)` | Yes | No | Yes | - | Penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition. | -| `mistral.WithPrediction(string)` | No | No | Yes | - | Enable users to specify expected results, optimizing response times by leveraging known or predictable content. This approach is especially effective for updating text documents or code files with minimal changes, reducing latency while maintaining high-quality results. | -| `llm.WithSafePrompt()` | No | No | Yes | - | Whether to inject a safety prompt before all conversations. | -| `llm.WithNumCompletions(uint64)` | No | No | Yes | - | Number of completions to return for each request. | +| `llm.WithTemperature(float64)` | Yes | Yes | Yes | Yes | What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.7 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. | +| `llm.WithTopP(float64)` | Yes | Yes | Yes | Yes | Nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. | +| `llm.WithTopK(uint64)` | Yes | Yes | No | No | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. | +| `llm.WithMaxTokens(uint64)` | No | Yes | Yes | Yes | The maximum number of tokens to generate in the response. | +| `llm.WithStream(func(llm.Completion))` | Can be enabled when tools are not used | Yes | Yes | Yes | Stream the response to a function. | +| `llm.WithToolChoice(string, string, ...)` | No | Use `auto`, `any` or a function name. Only the first argument is used. | Use `auto`, `any`, `none`, `required` or a function name. Only the first argument is used. | Use `auto`, `none`, `required` or a function name. Only the first argument is used. | The tool to use for the model. | +| `llm.WithToolKit(llm.ToolKit)` | Cannot be combined with streaming | Yes | Yes | Yes | The set of tools to use. | +| `llm.WithStopSequence(string, string, ...)` | Yes | Yes | Yes | Yes | Stop generation if one of these tokens is detected. | +| `llm.WithSystemPrompt(string)` | No | Yes | Yes | Yes | Set the system prompt for the model. | +| `llm.WithSeed(uint64)` | Yes | No | Yes | Yes | The seed to use for random sampling. If set, different calls will generate deterministic results. | +| `llm.WithFormat(string)` | Use `json` | No | Use `json_format` or `text` | Use `json_format` or `text` | The format of the response. For Mistral, you must also instruct the model to produce JSON yourself with a system or a user message. | +| `llm.WithPresencePenalty(float64)` | Yes | No | Yes | Yes | Determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative. | +| `llm.WithFequencyPenalty(float64)` | Yes | No | Yes | Yes | Penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition. | +| `llm.WithPrediction(string)` | No | No | Yes | Yes | Enable users to specify expected results, optimizing response times by leveraging known or predictable content. This approach is especially effective for updating text documents or code files with minimal changes, reducing latency while maintaining high-quality results. | +| `llm.WithSafePrompt()` | No | No | Yes | No | Whether to inject a safety prompt before all conversations. | +| `llm.WithNumCompletions(uint64)` | No | No | Yes | Yes | Number of completions to return for each request. | | `llm.WithAttachment(io.Reader)` | Yes | Yes | Yes | - | Attach a file to a user prompt. It is the responsibility of the caller to close the reader. | +| `llm.WithUser(string)` | No | Yes | No | Yes | A unique identifier representing your end-user | | `antropic.WithEphemeral()` | No | Yes | No | - | Attachments should be cached server-side | | `antropic.WithCitations()` | No | Yes | No | - | Attachments should be used in citations | -| `antropic.WithUser(string)` | No | Yes | No | - | Indicate the user for the request, for debugging | +| `openai.WithStore(bool)` | No | No | No | Yes | Whether or not to store the output of this chat completion request | +| `openai.WithDimensions(uint64)` | No | No | No | Yes | The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models | +| `openai.WithReasoningEffort(string)` | No | No | No | Yes | The level of effort model should put into reasoning. | +| `openai.WithMetadata(string, string)` | No | No | No | Yes | Metadata to be logged with the completion. | +| `openai.WithLogitBias(uint64, int64)` | No | No | No | Yes | A token and their logit bias value. Call multiple times to add additional tokens | +| `openai.WithLogProbs()` | No | No | No | Yes | Include the log probabilities on the completion. | +| `openai.WithLogProbs()` | No | No | No | Yes | Include the log probabilities on the completion. | +| `openai.WithTopLogProbs(uint64)` | No | No | No | Yes | An integer between 0 and 20 specifying the number of most likely tokens to return at each token position. | +| `openai.WithAudio(string, string)` | No | No | No | Yes | Output audio (voice, format) for the completion. Can be used with certain models. | +| `openai.WithServiceTier(string)` | No | No | No | Yes | Specifies the latency tier to use for processing the request. | +| `openai.WithStreamOptions(func(llm.Completion), bool)` | No | No | No | Yes | Include usage information in the stream response | +| `openai.WithDisableParallelToolCalls()` | No | No | No | Yes | Call tools in serial, rather than in parallel | ## The Command Line Tool diff --git a/attachment.go b/attachment.go index 5987a9d..ab35210 100644 --- a/attachment.go +++ b/attachment.go @@ -13,15 +13,31 @@ import ( /////////////////////////////////////////////////////////////////////////////// // TYPES +type AttachmentMeta struct { + Id string `json:"id,omitempty"` + Filename string `json:"filename,omitempty"` + ExpiresAt uint64 `json:"expires_at,omitempty"` + Caption string `json:"transcript,omitempty"` + Data []byte `json:"data"` +} + // Attachment for messages type Attachment struct { - filename string - data []byte + meta AttachmentMeta } +const ( + defaultMimetype = "application/octet-stream" +) + //////////////////////////////////////////////////////////////////////////////// // LIFECYCLE +// NewAttachment creates a new, empty attachment +func NewAttachment() *Attachment { + return new(Attachment) +} + // ReadAttachment returns an attachment from a reader object. // It is the responsibility of the caller to close the reader. func ReadAttachment(r io.Reader) (*Attachment, error) { @@ -33,22 +49,43 @@ func ReadAttachment(r io.Reader) (*Attachment, error) { if f, ok := r.(*os.File); ok { filename = f.Name() } - return &Attachment{filename: filename, data: data}, nil + return &Attachment{ + meta: AttachmentMeta{ + Filename: filename, + Data: data, + }, + }, nil } //////////////////////////////////////////////////////////////////////////////// // STRINGIFY -func (a *Attachment) String() string { +// Convert JSON into an attachment +func (a *Attachment) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &a.meta) +} + +// Convert an attachment into JSON +func (a *Attachment) MarshalJSON() ([]byte, error) { + // Create a JSON representation var j struct { - Filename string `json:"filename"` + Id string `json:"id,omitempty"` + Filename string `json:"filename,omitempty"` Type string `json:"type"` Bytes uint64 `json:"bytes"` + Caption string `json:"transcript,omitempty"` } - j.Filename = a.filename + j.Id = a.meta.Id + j.Filename = a.meta.Filename j.Type = a.Type() - j.Bytes = uint64(len(a.data)) - data, err := json.MarshalIndent(j, "", " ") + j.Bytes = uint64(len(a.meta.Data)) + j.Caption = a.meta.Caption + return json.Marshal(j) +} + +// Stringify an attachment +func (a *Attachment) String() string { + data, err := json.MarshalIndent(a.meta, "", " ") if err != nil { return err.Error() } @@ -58,24 +95,68 @@ func (a *Attachment) String() string { //////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS +// Return the filename of an attachment func (a *Attachment) Filename() string { - return a.filename + return a.meta.Filename } +// Return the raw attachment data func (a *Attachment) Data() []byte { - return a.data + return a.meta.Data +} + +// Return the caption for the attachment +func (a *Attachment) Caption() string { + return a.meta.Caption } +// Return the mime media type for the attachment, based +// on the data and/or filename extension. Returns an empty string if +// there is no data or filename func (a *Attachment) Type() string { + // If there's no data or filename, return empty + if len(a.meta.Data) == 0 && a.meta.Filename == "" { + return "" + } + // Mimetype based on content - mimetype := http.DetectContentType(a.data) - if mimetype == "application/octet-stream" && a.filename != "" { + mimetype := defaultMimetype + if len(a.meta.Data) > 0 { + mimetype = http.DetectContentType(a.meta.Data) + if mimetype != defaultMimetype { + return mimetype + } + } + + // Mimetype based on filename + if a.meta.Filename != "" { // Detect mimetype from extension - mimetype = mime.TypeByExtension(filepath.Ext(a.filename)) + mimetype = mime.TypeByExtension(filepath.Ext(a.meta.Filename)) } + + // Return the default mimetype return mimetype } func (a *Attachment) Url() string { - return "data:" + a.Type() + ";base64," + base64.StdEncoding.EncodeToString(a.data) + return "data:" + a.Type() + ";base64," + base64.StdEncoding.EncodeToString(a.meta.Data) +} + +// Streaming includes the ability to append data +func (a *Attachment) Append(other *Attachment) { + if other.meta.Id != "" { + a.meta.Id = other.meta.Id + } + if other.meta.Filename != "" { + a.meta.Filename = other.meta.Filename + } + if other.meta.ExpiresAt != 0 { + a.meta.ExpiresAt = other.meta.ExpiresAt + } + if other.meta.Caption != "" { + a.meta.Caption += other.meta.Caption + } + if len(other.meta.Data) > 0 { + a.meta.Data = append(a.meta.Data, other.meta.Data...) + } } diff --git a/cmd/llm/chat.go b/cmd/llm/chat.go index bd14f6e..d072514 100644 --- a/cmd/llm/chat.go +++ b/cmd/llm/chat.go @@ -9,7 +9,6 @@ import ( // Packages llm "github.com/mutablelogic/go-llm" - agent "github.com/mutablelogic/go-llm/pkg/agent" ) //////////////////////////////////////////////////////////////////////////////// @@ -27,26 +26,17 @@ type ChatCmd struct { // PUBLIC METHODS func (cmd *ChatCmd) Run(globals *Globals) error { - return runagent(globals, func(ctx context.Context, client llm.Agent) error { - // Get the model - a, ok := client.(*agent.Agent) - if !ok { - return fmt.Errorf("No agents found") - } - model, err := a.GetModel(ctx, cmd.Model) - if err != nil { - return err - } + return run(globals, cmd.Model, func(ctx context.Context, model llm.Model) error { + // Current buffer + var buf string // Set the options opts := []llm.Opt{} if !cmd.NoStream { opts = append(opts, llm.WithStream(func(cc llm.Completion) { - if text := cc.Text(0); text != "" { - count := strings.Count(text, "\n") - fmt.Print(strings.Repeat("\033[F", count) + strings.Repeat(" ", count) + "\r") - fmt.Print(text) - } + text := cc.Text(0) + fmt.Print(strings.TrimPrefix(text, buf)) + buf = text })) } if cmd.System != "" { @@ -66,6 +56,7 @@ func (cmd *ChatCmd) Run(globals *Globals) error { input = cmd.Prompt cmd.Prompt = "" } else { + var err error input, err = globals.term.ReadLine(model.Name() + "> ") if errors.Is(err, io.EOF) { return nil @@ -91,6 +82,7 @@ func (cmd *ChatCmd) Run(globals *Globals) error { if len(calls) == 0 { break } + if session.Text(0) != "" { globals.term.Println(session.Text(0)) } else { @@ -100,6 +92,7 @@ func (cmd *ChatCmd) Run(globals *Globals) error { } globals.term.Println("Calling ", strings.Join(names, ", ")) } + if results, err := globals.toolkit.Run(ctx, calls...); err != nil { return err } else if err := session.FromTool(ctx, results...); err != nil { @@ -107,8 +100,12 @@ func (cmd *ChatCmd) Run(globals *Globals) error { } } - // Print the response - globals.term.Println("\n" + session.Text(0) + "\n") + // Print the response, if not streaming + if cmd.NoStream { + globals.term.Println("\n" + session.Text(0) + "\n") + } else { + globals.term.Println() + } } }) } diff --git a/cmd/llm/complete.go b/cmd/llm/complete.go new file mode 100644 index 0000000..f6de7fd --- /dev/null +++ b/cmd/llm/complete.go @@ -0,0 +1,119 @@ +package main + +import ( + "context" + "fmt" + "io" + "os" + "strings" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +type CompleteCmd struct { + Model string `arg:"" help:"Model name"` + Prompt string `arg:"" optional:"" help:"Prompt"` + File []string `type:"file" short:"f" help:"Files to attach"` + System string `flag:"system" help:"Set the system prompt"` + NoStream bool `flag:"no-stream" help:"Do not stream output"` + Format string `flag:"format" enum:"text,markdown,json" default:"text" help:"Output format"` + Temperature *float64 `flag:"temperature" short:"t" help:"Temperature for sampling"` +} + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (cmd *CompleteCmd) Run(globals *Globals) error { + return run(globals, cmd.Model, func(ctx context.Context, model llm.Model) error { + var prompt []byte + + // If we are pipeline content in via stdin + fileInfo, err := os.Stdin.Stat() + if err != nil { + return llm.ErrInternalServerError.Withf("Failed to get stdin stat: %v", err) + } + if (fileInfo.Mode() & os.ModeCharDevice) == 0 { + if data, err := io.ReadAll(os.Stdin); err != nil { + return err + } else if len(data) > 0 { + prompt = data + } + } + + // Append any further prompt + if len(cmd.Prompt) > 0 { + prompt = append(prompt, []byte("\n\n")...) + prompt = append(prompt, []byte(cmd.Prompt)...) + } + + opts := cmd.opts() + if !cmd.NoStream { + // Add streaming callback + var buf string + opts = append(opts, llm.WithStream(func(c llm.Completion) { + fmt.Print(strings.TrimPrefix(c.Text(0), buf)) + buf = c.Text(0) + })) + } + + // Add attachments + for _, file := range cmd.File { + f, err := os.Open(file) + if err != nil { + return err + } + defer f.Close() + opts = append(opts, llm.WithAttachment(f)) + } + + // Make the completion + completion, err := model.Completion(ctx, string(prompt), opts...) + if err != nil { + return err + } + + // Print the completion + if cmd.NoStream { + fmt.Println(completion.Text(0)) + } else { + fmt.Println("") + } + + // Return success + return nil + }) +} + +func (cmd *CompleteCmd) opts() []llm.Opt { + opts := []llm.Opt{} + + // Set system prompt + var system []string + if cmd.Format == "markdown" { + system = append(system, "Structure your output in markdown format.") + } else if cmd.Format == "json" { + system = append(system, "Structure your output in JSON format.") + } + if cmd.System != "" { + system = append(system, cmd.System) + } + if len(system) > 0 { + opts = append(opts, llm.WithSystemPrompt(strings.Join(system, "\n"))) + } + + // Set format + if cmd.Format == "json" { + opts = append(opts, llm.WithFormat("json")) + } + + // Set temperature + if cmd.Temperature != nil { + opts = append(opts, llm.WithTemperature(*cmd.Temperature)) + } + + return opts +} diff --git a/cmd/llm/embedding.go b/cmd/llm/embedding.go new file mode 100644 index 0000000..8f8b469 --- /dev/null +++ b/cmd/llm/embedding.go @@ -0,0 +1,36 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +type EmbeddingCmd struct { + Model string `arg:"" help:"Model name"` + Prompt string `arg:"" help:"Prompt"` +} + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (cmd *EmbeddingCmd) Run(globals *Globals) error { + return run(globals, cmd.Model, func(ctx context.Context, model llm.Model) error { + vector, err := model.Embedding(ctx, cmd.Prompt) + if err != nil { + return err + } + data, err := json.Marshal(vector) + if err != nil { + return err + } + fmt.Println(string(data)) + return nil + }) +} diff --git a/cmd/llm/main.go b/cmd/llm/main.go index ff82c92..ce6b532 100644 --- a/cmd/llm/main.go +++ b/cmd/llm/main.go @@ -6,6 +6,7 @@ import ( "os/signal" "path/filepath" "syscall" + "time" // Packages kong "github.com/alecthomas/kong" @@ -21,20 +22,23 @@ import ( type Globals struct { // Debugging - Debug bool `name:"debug" help:"Enable debug output"` - Verbose bool `name:"verbose" help:"Enable verbose output"` + Debug bool `name:"debug" help:"Enable debug output"` + Verbose bool `name:"verbose" short:"v" help:"Enable verbose output"` + Timeout time.Duration `name:"timeout" help:"Agent connection timeout"` // Agents Ollama `embed:"" help:"Ollama configuration"` Anthropic `embed:"" help:"Anthropic configuration"` Mistral `embed:"" help:"Mistral configuration"` + OpenAI `embed:"" help:"OpenAI configuration"` + Gemini `embed:"" help:"Gemini configuration"` // Tools NewsAPI `embed:"" help:"NewsAPI configuration"` // Context ctx context.Context - agent llm.Agent + agent *agent.Agent toolkit *tool.ToolKit term *Term } @@ -51,6 +55,14 @@ type Mistral struct { MistralKey string `env:"MISTRAL_API_KEY" help:"Mistral API Key"` } +type OpenAI struct { + OpenAIKey string `env:"OPENAI_API_KEY" help:"OpenAI API Key"` +} + +type Gemini struct { + GeminiKey string `env:"GEMINI_API_KEY" help:"Gemini API Key"` +} + type NewsAPI struct { NewsKey string `env:"NEWSAPI_KEY" help:"News API Key"` } @@ -64,8 +76,11 @@ type CLI struct { Tools ListToolsCmd `cmd:"" help:"Return a list of tools"` // Commands - Download DownloadModelCmd `cmd:"" help:"Download a model"` - Chat ChatCmd `cmd:"" help:"Start a chat session"` + Download DownloadModelCmd `cmd:"" help:"Download a model"` + Chat ChatCmd `cmd:"" help:"Start a chat session"` + Complete CompleteCmd `cmd:"" help:"Complete a prompt"` + Embedding EmbeddingCmd `cmd:"" help:"Generate an embedding"` + Version VersionCmd `cmd:"" help:"Print the version of this tool"` } //////////////////////////////////////////////////////////////////////////////// @@ -101,6 +116,9 @@ func main() { if cli.Debug || cli.Verbose { clientopts = append(clientopts, client.OptTrace(os.Stderr, cli.Verbose)) } + if cli.Timeout > 0 { + clientopts = append(clientopts, client.OptTimeout(cli.Timeout)) + } // Create an agent opts := []llm.Opt{} @@ -113,6 +131,12 @@ func main() { if cli.MistralKey != "" { opts = append(opts, agent.WithMistral(cli.MistralKey, clientopts...)) } + if cli.OpenAIKey != "" { + opts = append(opts, agent.WithOpenAI(cli.OpenAIKey, clientopts...)) + } + if cli.GeminiKey != "" { + opts = append(opts, agent.WithGemini(cli.GeminiKey, clientopts...)) + } // Make a toolkit toolkit := tool.NewToolKit() @@ -162,3 +186,16 @@ func clientOpts(cli *CLI) []client.ClientOpt { } return result } + +//////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func run(globals *Globals, name string, fn func(ctx context.Context, model llm.Model) error) error { + model, err := globals.agent.GetModel(globals.ctx, name) + if err != nil { + return err + } + + // Get the model + return fn(globals.ctx, model) +} diff --git a/cmd/llm/models.go b/cmd/llm/models.go index e304507..5cec47b 100644 --- a/cmd/llm/models.go +++ b/cmd/llm/models.go @@ -1,14 +1,16 @@ package main import ( - "context" "encoding/json" "fmt" + "os" + "sort" + "strings" // Packages - llm "github.com/mutablelogic/go-llm" + tablewriter "github.com/djthorpe/go-tablewriter" agent "github.com/mutablelogic/go-llm/pkg/agent" - ollama "github.com/mutablelogic/go-llm/pkg/ollama" + "github.com/mutablelogic/go-llm/pkg/ollama" ) //////////////////////////////////////////////////////////////////////////////// @@ -31,87 +33,90 @@ type DownloadModelCmd struct { // PUBLIC METHODS func (cmd *ListToolsCmd) Run(globals *Globals) error { - return runagent(globals, func(ctx context.Context, client llm.Agent) error { - tools := globals.toolkit.Tools(client) - fmt.Println(tools) - return nil - }) + tools := globals.toolkit.Tools(globals.agent) + fmt.Println(tools) + return nil } func (cmd *ListModelsCmd) Run(globals *Globals) error { - return runagent(globals, func(ctx context.Context, client llm.Agent) error { - agent, ok := client.(*agent.Agent) - if !ok { - return fmt.Errorf("No agents found") - } - models, err := agent.ListModels(ctx, cmd.Agent...) - if err != nil { - return err - } - fmt.Println(models) - return nil - }) -} - -func (*ListAgentsCmd) Run(globals *Globals) error { - return runagent(globals, func(ctx context.Context, client llm.Agent) error { - agent, ok := client.(*agent.Agent) - if !ok { - return fmt.Errorf("No agents found") - } + models, err := globals.agent.ListModels(globals.ctx, cmd.Agent...) + if err != nil { + return err + } + result := make(ModelList, 0, len(models)) + for _, model := range models { + result = append(result, Model{ + Agent: model.(*agent.Model).Agent, + Model: model.Name(), + Description: model.Description(), + Aliases: strings.Join(model.Aliases(), ", "), + }) + } - agents := make([]string, 0, len(agent.Agents())) - for _, agent := range agent.Agents() { - agents = append(agents, agent.Name()) - } + // Sort models by name + sort.Sort(result) - data, err := json.MarshalIndent(agents, "", " ") - if err != nil { - return err - } - fmt.Println(string(data)) + // Write out the models + return tablewriter.New(os.Stdout).Write(result, tablewriter.OptOutputText(), tablewriter.OptHeader()) +} - return nil - }) +func (*ListAgentsCmd) Run(globals *Globals) error { + agents := globals.agent.AgentNames() + data, err := json.MarshalIndent(agents, "", " ") + if err != nil { + return err + } + fmt.Println(string(data)) + return nil } func (cmd *DownloadModelCmd) Run(globals *Globals) error { - return runagent(globals, func(ctx context.Context, client llm.Agent) error { - agent := getagent(client, cmd.Agent) - if agent == nil { - return fmt.Errorf("No agents found with name %q", cmd.Agent) - } - // Download the model - switch agent.Name() { - case "ollama": - model, err := agent.(*ollama.Client).PullModel(ctx, cmd.Model) - if err != nil { - return err + agents := globals.agent.AgentsWithName(cmd.Agent) + if len(agents) == 0 { + return fmt.Errorf("No agents found with name %q", cmd.Agent) + } + switch agents[0].Name() { + case "ollama": + model, err := agents[0].(*ollama.Client).PullModel(globals.ctx, cmd.Model, ollama.WithPullStatus(func(status *ollama.PullStatus) { + var pct int64 + if status.TotalBytes > 0 { + pct = status.CompletedBytes * 100 / status.TotalBytes } - fmt.Println(model) - default: - return fmt.Errorf("Agent %q does not support model download", agent.Name()) + fmt.Print("\r", status.Status, " ", pct, "%") + if status.Status == "success" { + fmt.Println("") + } + })) + if err != nil { + return err } - return nil - }) + fmt.Println(model) + default: + return fmt.Errorf("Agent %q does not support model download", agents[0].Name()) + } + return nil } //////////////////////////////////////////////////////////////////////////////// -// PRIVATE METHODS +// MODEL LIST -func runagent(globals *Globals, fn func(ctx context.Context, agent llm.Agent) error) error { - return fn(globals.ctx, globals.agent) +type Model struct { + Agent string `json:"agent" writer:"Agent,width:10"` + Model string `json:"model" writer:"Model,wrap,width:40"` + Description string `json:"description" writer:"Description,wrap,width:60"` + Aliases string `json:"aliases" writer:"Aliases,wrap,width:30"` } -func getagent(client llm.Agent, name string) llm.Agent { - agent, ok := client.(*agent.Agent) - if !ok { - return nil - } - for _, agent := range agent.Agents() { - if agent.Name() == name { - return agent - } - } - return nil +type ModelList []Model + +func (models ModelList) Len() int { + return len(models) +} + +func (models ModelList) Less(a, b int) bool { + return strings.Compare(models[a].Model, models[b].Model) < 0 +} + +func (models ModelList) Swap(a, b int) { + models[a], models[b] = models[b], models[a] } diff --git a/cmd/llm/version.go b/cmd/llm/version.go new file mode 100644 index 0000000..35515a0 --- /dev/null +++ b/cmd/llm/version.go @@ -0,0 +1,34 @@ +package main + +import ( + "fmt" + "os" + "runtime" + + // Packages + "github.com/mutablelogic/go-llm/pkg/version" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type VersionCmd struct{} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Run the version command +func (cmd *VersionCmd) Run() error { + w := os.Stdout + if version.GitSource != "" { + fmt.Fprintf(w, "Source: https://%v\n", version.GitSource) + } + if version.GitTag != "" || version.GitBranch != "" { + fmt.Fprintf(w, "Version: %v (branch: %q hash:%q)\n", version.GitTag, version.GitBranch, version.GitHash) + } + if version.GoBuildTime != "" { + fmt.Fprintf(w, "Build: %v\n", version.GoBuildTime) + } + fmt.Fprintf(w, "Go: %v (%v/%v)\n", runtime.Version(), runtime.GOOS, runtime.GOARCH) + return nil +} diff --git a/context.go b/context.go index 0aad95d..db5dd12 100644 --- a/context.go +++ b/context.go @@ -5,27 +5,35 @@ import "context" ////////////////////////////////////////////////////////////////// // TYPES -// Completion is the content of the last context message +// Completion is the content of the last message type Completion interface { // Return the number of completions, which is ususally 1 unless - // WithNumCompletions was used when calling the model + // WithNumCompletions was used Num() int - // Return the current session role, which can be system, assistant, user, tool, tool_result, ... + // Return a specific completion + Choice(int) Completion + + // Return the current session role, which can be system, assistant, + // user, tool, tool_result, ... // If this is a completion, the role is usually 'assistant' Role() string // Return the text for the last completion, with the argument as the - // completion index (usually 0). If multiple completions are not - // supported, the argument is ignored. + // completion index (usually 0). Text(int) string + // Return audio for the last completion, with the argument as the + // completion index (usually 0). + Audio(int) *Attachment + // Return the current session tool calls given the completion index. // Will return nil if no tool calls were returned. ToolCalls(int) []ToolCall } -// Context is fed to the agent to generate a response +// Context is a context window fed to the agent to generate a response, +// with the ability to create the next completion type Context interface { Completion diff --git a/error.go b/error.go index 4bb0c9d..bd7a9c0 100644 --- a/error.go +++ b/error.go @@ -13,6 +13,7 @@ const ( ErrBadParameter ErrNotImplemented ErrConflict + ErrInternalServerError ) //////////////////////////////////////////////////////////////////////////////// @@ -36,6 +37,8 @@ func (e Err) Error() string { return "not implemented" case ErrConflict: return "conflict" + case ErrInternalServerError: + return "internal server error" } return fmt.Sprintf("error code %d", int(e)) } diff --git a/pkg/anthropic/testdata/LICENSE b/etc/testdata/LICENSE similarity index 100% rename from pkg/anthropic/testdata/LICENSE rename to etc/testdata/LICENSE diff --git a/pkg/anthropic/testdata/guggenheim.jpg b/etc/testdata/guggenheim.jpg similarity index 100% rename from pkg/anthropic/testdata/guggenheim.jpg rename to etc/testdata/guggenheim.jpg diff --git a/go.mod b/go.mod index 199ec9e..8446bb4 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,9 @@ require ( github.com/MichaelMure/go-term-text v0.3.1 github.com/alecthomas/kong v1.7.0 github.com/djthorpe/go-errors v1.0.3 + github.com/djthorpe/go-tablewriter v0.0.7 github.com/fatih/color v1.9.0 + github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 github.com/mutablelogic/go-client v1.0.10 github.com/stretchr/testify v1.10.0 golang.org/x/term v0.28.0 diff --git a/go.sum b/go.sum index e9cc101..48692ca 100644 --- a/go.sum +++ b/go.sum @@ -11,8 +11,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/djthorpe/go-errors v1.0.3 h1:GZeMPkC1mx2vteXLI/gvxZS0Ee9zxzwD1mcYyKU5jD0= github.com/djthorpe/go-errors v1.0.3/go.mod h1:HtfrZnMd6HsX75Mtbv9Qcnn0BqOrrFArvCaj3RMnZhY= +github.com/djthorpe/go-tablewriter v0.0.7 h1:jnNsJDjjLLCt0OAqB5DzGZN7V3beT1IpNMQ8GcOwZDU= +github.com/djthorpe/go-tablewriter v0.0.7/go.mod h1:NVBvytpL+6fHfCKn0+3lSi15/G3A1HWf2cLNeHg6YBg= github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= +github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc= +github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= diff --git a/model.go b/model.go index 221105b..f388fad 100644 --- a/model.go +++ b/model.go @@ -1,21 +1,28 @@ package llm -import "context" +import ( + "context" +) // An Model can be used to generate a response to a user prompt, -// which is passed to an agent. The interaction occurs through +// which is passed to an agent. A back-and-forth interaction occurs through // a session context object. type Model interface { // Return the name of the model Name() string - // Return am empty session context object for the model, - // setting session options + // Return the description of the model + Description() string + + // Return any model aliases + Aliases() []string + + // Return am empty session context object for the model, setting + // session options Context(...Opt) Context - // Convenience method to create a session context object - // with a user prompt - UserPrompt(string, ...Opt) Context + // Create a completion from a text prompt + Completion(context.Context, string, ...Opt) (Completion, error) // Embedding vector generation Embedding(context.Context, string, ...Opt) ([]float64, error) diff --git a/opt.go b/opt.go index a378a91..112a6c8 100644 --- a/opt.go +++ b/opt.go @@ -3,6 +3,7 @@ package llm import ( "encoding/json" "io" + "strings" "time" ) @@ -197,6 +198,10 @@ func WithToolKit(toolkit ToolKit) Opt { func WithStream(fn func(Completion)) Opt { return func(o *Opts) error { o.callback = fn + + // We include usage metrics in the streaming response for openai + o.Set("stream_options_include_usage", true) + return nil } } @@ -326,6 +331,12 @@ func WithSeed(v uint64) Opt { // Set format func WithFormat(v any) Opt { return func(o *Opts) error { + if v_, ok := v.(string); ok { + v_ = strings.TrimSpace(strings.ToLower(v_)) + if v_ == "json" { + v = "json_object" + } + } o.Set("format", v) return nil } @@ -357,3 +368,20 @@ func WithSafePrompt() Opt { return nil } } + +// Predicted output, which is most common when you are regenerating a file +// with only minor changes to most of the content. +func WithPrediction(v string) Opt { + return func(o *Opts) error { + o.Set("prediction", v) + return nil + } +} + +// A unique identifier representing your end-user +func WithUser(v string) Opt { + return func(o *Opts) error { + o.Set("user", v) + return nil + } +} diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index cd50c87..3766917 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -18,7 +18,7 @@ type Agent struct { *llm.Opts } -type model struct { +type Model struct { Agent string `json:"agent"` llm.Model `json:"model"` } @@ -44,7 +44,7 @@ func New(opts ...llm.Opt) (*Agent, error) { /////////////////////////////////////////////////////////////////////////////// // STRINGIFY -func (m model) String() string { +func (m Model) String() string { data, err := json.MarshalIndent(m, "", " ") if err != nil { return err.Error() @@ -168,13 +168,27 @@ func modelsForAgent(ctx context.Context, agent llm.Agent, names ...string) ([]ll return nil, err } + match_model := func(model llm.Model, names ...string) bool { + if len(names) == 0 { + return true + } + if slices.Contains(names, model.Name()) { + return true + } + for _, alias := range model.Aliases() { + if slices.Contains(names, alias) { + return true + } + } + return false + } + // Filter models result := make([]llm.Model, 0, len(models)) for _, agentmodel := range models { - if len(names) > 0 && !slices.Contains(names, agentmodel.Name()) { - continue + if match_model(agentmodel, names...) { + result = append(result, &Model{Agent: agent.Name(), Model: agentmodel}) } - result = append(result, &model{Agent: agent.Name(), Model: agentmodel}) } // Return success diff --git a/pkg/agent/opt.go b/pkg/agent/opt.go index 9e54647..ddc3350 100644 --- a/pkg/agent/opt.go +++ b/pkg/agent/opt.go @@ -4,9 +4,11 @@ import ( // Packages client "github.com/mutablelogic/go-client" llm "github.com/mutablelogic/go-llm" - "github.com/mutablelogic/go-llm/pkg/anthropic" + anthropic "github.com/mutablelogic/go-llm/pkg/anthropic" + gemini "github.com/mutablelogic/go-llm/pkg/gemini" mistral "github.com/mutablelogic/go-llm/pkg/mistral" ollama "github.com/mutablelogic/go-llm/pkg/ollama" + openai "github.com/mutablelogic/go-llm/pkg/openai" ) //////////////////////////////////////////////////////////////////////////////// @@ -44,3 +46,25 @@ func WithMistral(key string, opts ...client.ClientOpt) llm.Opt { } } } + +func WithOpenAI(key string, opts ...client.ClientOpt) llm.Opt { + return func(o *llm.Opts) error { + client, err := openai.New(key, opts...) + if err != nil { + return err + } else { + return llm.WithAgent(client)(o) + } + } +} + +func WithGemini(key string, opts ...client.ClientOpt) llm.Opt { + return func(o *llm.Opts) error { + client, err := gemini.New(key, opts...) + if err != nil { + return err + } else { + return llm.WithAgent(client)(o) + } + } +} diff --git a/pkg/anthropic/client.go b/pkg/anthropic/client.go index f446edc..a5ab813 100644 --- a/pkg/anthropic/client.go +++ b/pkg/anthropic/client.go @@ -1,14 +1,14 @@ /* -anthropic implements an API client for anthropic (https://docs.anthropic.com/en/api/getting-started) +anthropic implements an API client for anthropic +https://docs.anthropic.com/en/api/getting-started */ package anthropic import ( // Packages - "context" - client "github.com/mutablelogic/go-client" llm "github.com/mutablelogic/go-llm" + impl "github.com/mutablelogic/go-llm/pkg/internal/impl" ) /////////////////////////////////////////////////////////////////////////////// @@ -16,7 +16,7 @@ import ( type Client struct { *client.Client - cache map[string]llm.Model + *impl.ModelCache } var _ llm.Agent = (*Client)(nil) @@ -37,14 +37,15 @@ const ( func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) { // Create client opts = append(opts, client.OptEndpoint(endPoint)) - opts = append(opts, client.OptHeader("x-api-key", ApiKey), client.OptHeader("anthropic-version", defaultVersion)) + opts = append(opts, client.OptHeader("x-api-key", ApiKey)) + opts = append(opts, client.OptHeader("anthropic-version", defaultVersion)) client, err := client.New(opts...) if err != nil { return nil, err } // Return the client - return &Client{client, nil}, nil + return &Client{client, impl.NewModelCache()}, nil } /////////////////////////////////////////////////////////////////////////////// @@ -54,36 +55,3 @@ func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) { func (*Client) Name() string { return defaultName } - -// Return the models -func (anthropic *Client) Models(ctx context.Context) ([]llm.Model, error) { - // Cache models - if anthropic.cache == nil { - models, err := anthropic.ListModels(ctx) - if err != nil { - return nil, err - } - anthropic.cache = make(map[string]llm.Model, len(models)) - for _, model := range models { - anthropic.cache[model.Name()] = model - } - } - - // Return models - result := make([]llm.Model, 0, len(anthropic.cache)) - for _, model := range anthropic.cache { - result = append(result, model) - } - return result, nil -} - -// Return a model by name, or nil if not found. -// Panics on error. -func (anthropic *Client) Model(ctx context.Context, name string) llm.Model { - if anthropic.cache == nil { - if _, err := anthropic.Models(ctx); err != nil { - panic(err) - } - } - return anthropic.cache[name] -} diff --git a/pkg/anthropic/completion.go b/pkg/anthropic/completion.go index 6c25a11..f9e490c 100644 --- a/pkg/anthropic/completion.go +++ b/pkg/anthropic/completion.go @@ -21,7 +21,7 @@ type Response struct { Reason string `json:"stop_reason,omitempty"` StopSequence *string `json:"stop_sequence,omitempty"` Message - Metrics `json:"usage,omitempty"` + *Metrics `json:"usage,omitempty"` } // Metrics @@ -43,25 +43,43 @@ func (r Response) String() string { return string(data) } +func (m Metrics) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS type reqMessages struct { - Model string `json:"model"` - MaxTokens uint64 `json:"max_tokens,omitempty"` - Metadata *optmetadata `json:"metadata,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Stream bool `json:"stream,omitempty"` - System string `json:"system,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK uint64 `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Messages []*Message `json:"messages"` - Tools []llm.Tool `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` + Model string `json:"model"` + MaxTokens uint64 `json:"max_tokens,omitempty"` + Metadata *optmetadata `json:"metadata,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK uint64 `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Tools []llm.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Messages []llm.Completion `json:"messages"` } -func (anthropic *Client) Messages(ctx context.Context, context llm.Context, opts ...llm.Opt) (*Response, error) { +// Send a completion request with a single prompt, and return the next completion +func (model *model) Completion(ctx context.Context, prompt string, opts ...llm.Opt) (llm.Completion, error) { + message, err := messagefactory{}.UserPrompt(prompt, opts...) + if err != nil { + return nil, err + } + return model.Chat(ctx, []llm.Completion{message}, opts...) +} + +// Send a completion request with multiple completions, and return the next completion +func (model *model) Chat(ctx context.Context, completions []llm.Completion, opts ...llm.Opt) (llm.Completion, error) { // Apply options opt, err := llm.ApplyOpts(opts...) if err != nil { @@ -70,8 +88,8 @@ func (anthropic *Client) Messages(ctx context.Context, context llm.Context, opts // Request req, err := client.NewJSONRequest(reqMessages{ - Model: context.(*session).model.Name(), - MaxTokens: optMaxTokens(context.(*session).model, opt), + Model: model.Name(), + MaxTokens: optMaxTokens(model, opt), Metadata: optMetadata(opt), StopSequences: optStopSequences(opt), Stream: optStream(opt), @@ -79,19 +97,21 @@ func (anthropic *Client) Messages(ctx context.Context, context llm.Context, opts Temperature: optTemperature(opt), TopK: optTopK(opt), TopP: optTopP(opt), - Messages: context.(*session).seq, - Tools: optTools(anthropic, opt), + Tools: optTools(model.Client, opt), ToolChoice: optToolChoice(opt), + Messages: completions, }) if err != nil { return nil, err } - // Stream + // Response options var response Response reqopts := []client.RequestOpt{ client.OptPath("messages"), } + + // Streaming if optStream(opt) { reqopts = append(reqopts, client.OptTextStreamCallback(func(evt client.TextStreamEvent) error { if err := streamEvent(&response, evt); err != nil { @@ -105,7 +125,7 @@ func (anthropic *Client) Messages(ctx context.Context, context llm.Context, opts } // Response - if err := anthropic.DoWithContext(ctx, req, &response, reqopts...); err != nil { + if err := model.DoWithContext(ctx, req, &response, reqopts...); err != nil { return nil, err } diff --git a/pkg/anthropic/completion_test.go b/pkg/anthropic/completion_test.go_old similarity index 97% rename from pkg/anthropic/completion_test.go rename to pkg/anthropic/completion_test.go_old index 2726c7e..a5db6a4 100644 --- a/pkg/anthropic/completion_test.go +++ b/pkg/anthropic/completion_test.go_old @@ -10,7 +10,6 @@ import ( // Packages llm "github.com/mutablelogic/go-llm" - anthropic "github.com/mutablelogic/go-llm/pkg/anthropic" "github.com/mutablelogic/go-llm/pkg/tool" assert "github.com/stretchr/testify/assert" ) @@ -101,7 +100,7 @@ func Test_chat_002(t *testing.T) { } }) t.Run("User", func(t *testing.T) { - r, err := client.Messages(context.TODO(), model.UserPrompt("What is the temperature in London?"), anthropic.WithUser("username")) + r, err := client.Messages(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithUser("username")) if assert.NoError(err) { assert.Equal("assistant", r.Role()) assert.Equal(1, r.Num()) diff --git a/pkg/anthropic/content.go b/pkg/anthropic/content.go new file mode 100644 index 0000000..4f79eb4 --- /dev/null +++ b/pkg/anthropic/content.go @@ -0,0 +1,160 @@ +package anthropic + +import ( + "encoding/json" + "strings" + + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Content struct { + Type string `json:"type"` // image, document, text, tool_use + ContentText + ContentAttachment + *ContentTool + ContentToolResult + CacheControl *cachecontrol `json:"cache_control,omitempty"` // ephemeral +} + +type ContentText struct { + Text string `json:"text,omitempty"` // text content +} + +type ContentTool struct { + Id string `json:"id,omitempty"` // tool id + Name string `json:"name,omitempty"` // tool name + Input map[string]any `json:"input"` // tool input + InputJson string `json:"partial_json,omitempty"` // partial json input (for streaming) +} + +type ContentAttachment struct { + Title string `json:"title,omitempty"` // title of the document + Context string `json:"context,omitempty"` // context of the document + Citations *contentcitation `json:"citations,omitempty"` // citations of the document + Source *contentsource `json:"source,omitempty"` // image or document content +} + +type ContentToolResult struct { + Id string `json:"tool_use_id,omitempty"` // tool id + Content any `json:"content,omitempty"` +} + +type contentsource struct { + Type string `json:"type"` // base64 or text + MediaType string `json:"media_type"` // image/jpeg, image/png, image/gif, image/webp, application/pdf, text/plain + Data any `json:"data"` // ...base64 or text encoded data +} + +type cachecontrol struct { + Type string `json:"type"` // ephemeral +} + +type contentcitation struct { + Enabled bool `json:"enabled"` // true +} + +/////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +var ( + supportedAttachments = map[string]string{ + "image/jpeg": "image", + "image/png": "image", + "image/gif": "image", + "image/webp": "image", + "application/pdf": "document", + "text/plain": "text", + } +) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Return a content object with text content +func NewTextContent(v string) *Content { + return &Content{ + Type: "text", + ContentText: ContentText{ + Text: v, + }, + } +} + +// Return a content object with tool result +func NewToolResultContent(v llm.ToolResult) *Content { + content := new(Content) + content.Type = "tool_result" + content.ContentToolResult.Id = v.Call().Id() + // content.ContentToolResult.Name = v.Call().Name() + + // We only support JSON encoding for the moment + data, err := json.Marshal(v.Value()) + if err != nil { + content.ContentToolResult.Content = err.Error() + } else { + content.ContentToolResult.Content = string(data) + } + + return content +} + +// Make attachment content +func NewAttachment(attachment *llm.Attachment, ephemeral, citations bool) (*Content, error) { + // Detect mimetype + mimetype := attachment.Type() + if strings.HasPrefix(mimetype, "text/") { + // Switch to text/plain - TODO: charsets? + mimetype = "text/plain" + } + + // Check supported mimetype + typ, exists := supportedAttachments[mimetype] + if !exists { + return nil, llm.ErrBadParameter.Withf("unsupported or undetected mimetype %q", mimetype) + } + + // Create attachment + content := new(Content) + content.Type = typ + if ephemeral { + content.CacheControl = &cachecontrol{Type: "ephemeral"} + } + + // Handle by type + switch typ { + case "text": + content.Type = "document" + content.Title = attachment.Filename() + content.Source = &contentsource{ + Type: "text", + MediaType: mimetype, + Data: string(attachment.Data()), + } + if citations { + content.Citations = &contentcitation{Enabled: true} + } + case "document": + content.Source = &contentsource{ + Type: "base64", + MediaType: mimetype, + Data: attachment.Data(), + } + if citations { + content.Citations = &contentcitation{Enabled: true} + } + case "image": + content.Source = &contentsource{ + Type: "base64", + MediaType: mimetype, + Data: attachment.Data(), + } + default: + return nil, llm.ErrBadParameter.Withf("unsupported attachment type %q", typ) + } + + // Return success + return content, nil +} diff --git a/pkg/anthropic/message.go b/pkg/anthropic/message.go index 33aecb3..c5b8b59 100644 --- a/pkg/anthropic/message.go +++ b/pkg/anthropic/message.go @@ -6,7 +6,7 @@ import ( // Packages llm "github.com/mutablelogic/go-llm" - "github.com/mutablelogic/go-llm/pkg/tool" + tool "github.com/mutablelogic/go-llm/pkg/tool" ) /////////////////////////////////////////////////////////////////////////////// @@ -17,162 +17,13 @@ type Message struct { RoleContent } +var _ llm.Completion = (*Message)(nil) + type RoleContent struct { Role string `json:"role"` Content []*Content `json:"content,omitempty"` } -var _ llm.Completion = (*Message)(nil) - -type Content struct { - Type string `json:"type"` // image, document, text, tool_use - ContentText - ContentAttachment - *ContentTool - ContentToolResult - CacheControl *cachecontrol `json:"cache_control,omitempty"` // ephemeral -} - -type ContentText struct { - Text string `json:"text,omitempty"` // text content -} - -type ContentTool struct { - Id string `json:"id,omitempty"` // tool id - Name string `json:"name,omitempty"` // tool name - Input map[string]any `json:"input"` // tool input - InputJson string `json:"partial_json,omitempty"` // partial json input (for streaming) -} - -type ContentAttachment struct { - Title string `json:"title,omitempty"` // title of the document - Context string `json:"context,omitempty"` // context of the document - Citations *contentcitation `json:"citations,omitempty"` // citations of the document - Source *contentsource `json:"source,omitempty"` // image or document content -} - -type ContentToolResult struct { - Id string `json:"tool_use_id,omitempty"` // tool id - Content any `json:"content,omitempty"` -} - -type contentsource struct { - Type string `json:"type"` // base64 or text - MediaType string `json:"media_type"` // image/jpeg, image/png, image/gif, image/webp, application/pdf, text/plain - Data any `json:"data"` // ...base64 or text encoded data -} - -type cachecontrol struct { - Type string `json:"type"` // ephemeral -} - -type contentcitation struct { - Enabled bool `json:"enabled"` // true -} - -/////////////////////////////////////////////////////////////////////////////// -// GLOBALS - -var ( - supportedAttachments = map[string]string{ - "image/jpeg": "image", - "image/png": "image", - "image/gif": "image", - "image/webp": "image", - "application/pdf": "document", - "text/plain": "text", - } -) - -/////////////////////////////////////////////////////////////////////////////// -// LIFECYCLE - -// Return a content object with text content -func NewTextContent(v string) *Content { - return &Content{ - Type: "text", - ContentText: ContentText{ - Text: v, - }, - } -} - -// Return a content object with tool result -func NewToolResultContent(v llm.ToolResult) *Content { - content := new(Content) - content.Type = "tool_result" - content.ContentToolResult.Id = v.Call().Id() - // content.ContentToolResult.Name = v.Call().Name() - - // We only support JSON encoding for the moment - data, err := json.Marshal(v.Value()) - if err != nil { - content.ContentToolResult.Content = err.Error() - } else { - content.ContentToolResult.Content = string(data) - } - - return content -} - -// Make attachment content -func NewAttachment(attachment *llm.Attachment, ephemeral, citations bool) (*Content, error) { - // Detect mimetype - mimetype := attachment.Type() - if strings.HasPrefix(mimetype, "text/") { - // Switch to text/plain - TODO: charsets? - mimetype = "text/plain" - } - - // Check supported mimetype - typ, exists := supportedAttachments[mimetype] - if !exists { - return nil, llm.ErrBadParameter.Withf("unsupported or undetected mimetype %q", mimetype) - } - - // Create attachment - content := new(Content) - content.Type = typ - if ephemeral { - content.CacheControl = &cachecontrol{Type: "ephemeral"} - } - - // Handle by type - switch typ { - case "text": - content.Type = "document" - content.Title = attachment.Filename() - content.Source = &contentsource{ - Type: "text", - MediaType: mimetype, - Data: string(attachment.Data()), - } - if citations { - content.Citations = &contentcitation{Enabled: true} - } - case "document": - content.Source = &contentsource{ - Type: "base64", - MediaType: mimetype, - Data: attachment.Data(), - } - if citations { - content.Citations = &contentcitation{Enabled: true} - } - case "image": - content.Source = &contentsource{ - Type: "base64", - MediaType: mimetype, - Data: attachment.Data(), - } - default: - return nil, llm.ErrBadParameter.Withf("unsupported attachment type %q", typ) - } - - // Return success - return content, nil -} - /////////////////////////////////////////////////////////////////////////////// // STRINGIFY @@ -187,20 +38,31 @@ func (m Message) String() string { /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS - MESSAGE -func (m Message) Num() int { +// Return the number of completions +func (Message) Num() int { return 1 } -func (m Message) Role() string { - return m.RoleContent.Role +// Return the current session role +func (message Message) Role() string { + return message.RoleContent.Role } -func (m Message) Text(index int) string { +// Return the completion +func (message Message) Choice(index int) llm.Completion { + if index != 0 { + return nil + } + return message +} + +// Return the text for the last completion +func (message Message) Text(index int) string { if index != 0 { return "" } var text []string - for _, content := range m.RoleContent.Content { + for _, content := range message.RoleContent.Content { if content.Type == "text" { text = append(text, content.ContentText.Text) } @@ -208,14 +70,19 @@ func (m Message) Text(index int) string { return strings.Join(text, "\n") } -func (m Message) ToolCalls(index int) []llm.ToolCall { +// Return the audio - unsupported +func (Message) Audio(index int) *llm.Attachment { + return nil +} + +func (message Message) ToolCalls(index int) []llm.ToolCall { if index != 0 { return nil } // Gather tool calls var result []llm.ToolCall - for _, content := range m.Content { + for _, content := range message.Content { if content.Type == "tool_use" { result = append(result, tool.NewCall(content.ContentTool.Id, content.ContentTool.Name, content.ContentTool.Input)) } diff --git a/pkg/anthropic/messagefactory.go b/pkg/anthropic/messagefactory.go new file mode 100644 index 0000000..1280ca5 --- /dev/null +++ b/pkg/anthropic/messagefactory.go @@ -0,0 +1,73 @@ +package anthropic + +import ( + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type messagefactory struct{} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - MESSAGE FACTORY + +func (messagefactory) SystemPrompt(prompt string) llm.Completion { + return &Message{ + RoleContent: RoleContent{ + Role: "system", + Content: []*Content{NewTextContent(prompt)}, + }, + } +} + +func (messagefactory) UserPrompt(prompt string, opts ...llm.Opt) (llm.Completion, error) { + // Get attachments + opt, err := llm.ApplyPromptOpts(opts...) + if err != nil { + return nil, err + } + + // Get attachments, allocate content + attachments := opt.Attachments() + contents := make([]*Content, 1, len(attachments)+1) + + // Append the text and the attachments + contents[0] = NewTextContent(prompt) + for _, attachment := range attachments { + if content, err := NewAttachment(attachment, optEphemeral(opt), optCitations(opt)); err != nil { + return nil, err + } else { + contents = append(contents, content) + } + } + + // Return success + return &Message{ + RoleContent: RoleContent{ + Role: "user", + Content: contents, + }, + }, nil +} + +func (messagefactory) ToolResults(results ...llm.ToolResult) ([]llm.Completion, error) { + // Check for no results + if len(results) == 0 { + return nil, llm.ErrBadParameter.Withf("No tool results") + } + + // Create user message + message := Message{ + RoleContent{ + Role: "user", + Content: make([]*Content, 0, len(results)), + }, + } + for _, result := range results { + message.RoleContent.Content = append(message.RoleContent.Content, NewToolResultContent(result)) + } + + // Return success + return []llm.Completion{message}, nil +} diff --git a/pkg/anthropic/model.go b/pkg/anthropic/model.go index 2b50ee6..6f86d8f 100644 --- a/pkg/anthropic/model.go +++ b/pkg/anthropic/model.go @@ -9,6 +9,7 @@ import ( // Packages client "github.com/mutablelogic/go-client" llm "github.com/mutablelogic/go-llm" + impl "github.com/mutablelogic/go-llm/pkg/internal/impl" ) /////////////////////////////////////////////////////////////////////////////// @@ -44,21 +45,68 @@ func (m model) String() string { } /////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - API +// PUBLIC METHODS - llm.Agent -// Get a model by name -func (anthropic *Client) GetModel(ctx context.Context, name string) (llm.Model, error) { - var response Model - if err := anthropic.DoWithContext(ctx, nil, &response, client.OptPath("models", name)); err != nil { +// Return the models +func (anthropic *Client) Models(ctx context.Context) ([]llm.Model, error) { + return anthropic.ModelCache.Load(func() ([]llm.Model, error) { + return anthropic.loadmodels(ctx) + }) +} + +// Return a model by name, or nil if not found. +// Panics on error. +func (anthropic *Client) Model(ctx context.Context, name string) llm.Model { + model, err := anthropic.ModelCache.Get(func() ([]llm.Model, error) { + return anthropic.loadmodels(ctx) + }, name) + if err != nil { + panic(err) + } + return model +} + +// Function called to load models +func (anthropic *Client) loadmodels(ctx context.Context) ([]llm.Model, error) { + if models, err := anthropic.ListModels(ctx); err != nil { return nil, err + } else { + result := make([]llm.Model, len(models)) + for i, meta := range models { + result[i] = &model{anthropic, meta} + } + return result, nil } +} - // Return success - return &model{anthropic, response}, nil +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - llm.Model + +// Return the name of a model +func (model *model) Name() string { + return model.meta.Name } +// Return model description +func (model model) Description() string { + return model.meta.Description +} + +// Return model aliases +func (model) Aliases() []string { + return nil +} + +// Return a new empty session +func (model *model) Context(opts ...llm.Opt) llm.Context { + return impl.NewSession(model, &messagefactory{}, opts...) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - API + // List models -func (anthropic *Client) ListModels(ctx context.Context) ([]llm.Model, error) { +func (anthropic *Client) ListModels(ctx context.Context) ([]Model, error) { var response struct { Body []Model `json:"data"` HasMore bool `json:"has_more"` @@ -68,7 +116,7 @@ func (anthropic *Client) ListModels(ctx context.Context) ([]llm.Model, error) { // Request request := url.Values{} - result := make([]llm.Model, 0, 100) + result := make([]Model, 0, 100) for { if err := anthropic.DoWithContext(ctx, nil, &response, client.OptPath("models"), client.OptQuery(request)); err != nil { return nil, err @@ -76,7 +124,7 @@ func (anthropic *Client) ListModels(ctx context.Context) ([]llm.Model, error) { // Convert to llm.Model for _, meta := range response.Body { - result = append(result, &model{anthropic, meta}) + result = append(result, meta) } // If there are no more models, return @@ -91,9 +139,15 @@ func (anthropic *Client) ListModels(ctx context.Context) ([]llm.Model, error) { return result, nil } -// Return the name of a model -func (model *model) Name() string { - return model.meta.Name +// Get a model by name +func (anthropic *Client) GetModel(ctx context.Context, name string) (*Model, error) { + var response Model + if err := anthropic.DoWithContext(ctx, nil, &response, client.OptPath("models", name)); err != nil { + return nil, err + } + + // Return success + return &response, nil } // Embedding vector generation - not supported on Anthropic diff --git a/pkg/anthropic/opt.go b/pkg/anthropic/opt.go index 862721b..b77fcdd 100644 --- a/pkg/anthropic/opt.go +++ b/pkg/anthropic/opt.go @@ -17,13 +17,6 @@ type optmetadata struct { //////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS -func WithUser(v string) llm.Opt { - return func(o *llm.Opts) error { - o.Set("user", v) - return nil - } -} - func WithEphemeral() llm.Opt { return func(o *llm.Opts) error { o.Set("ephemeral", true) diff --git a/pkg/anthropic/session.go b/pkg/anthropic/session.go_old similarity index 100% rename from pkg/anthropic/session.go rename to pkg/anthropic/session.go_old diff --git a/pkg/anthropic/session_test.go b/pkg/anthropic/session_test.go_old similarity index 100% rename from pkg/anthropic/session_test.go rename to pkg/anthropic/session_test.go_old diff --git a/pkg/gemini/client.go b/pkg/gemini/client.go new file mode 100644 index 0000000..508ebef --- /dev/null +++ b/pkg/gemini/client.go @@ -0,0 +1,60 @@ +/* +gemini implements an API client for Google's Gemini LLM (https://ai.google.dev/gemini-api/docs) +*/ +package gemini + +import ( + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Client struct { + *client.Client + cache map[string]llm.Model +} + +var _ llm.Agent = (*Client)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + endPoint = "https://generativelanguage.googleapis.com/v1beta" + defaultName = "gemini" +) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Create a new client +func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) { + // Create client + opts = append(opts, client.OptEndpoint(endPointWithKey(endPoint, ApiKey))) + client, err := client.New(opts...) + if err != nil { + return nil, err + } + + // Return the client + return &Client{client, nil}, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the name of the agent +func (Client) Name() string { + return defaultName +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func endPointWithKey(endpoint, key string) string { + return endpoint + "?key=" + key +} diff --git a/pkg/gemini/client_test.go b/pkg/gemini/client_test.go new file mode 100644 index 0000000..f55f396 --- /dev/null +++ b/pkg/gemini/client_test.go @@ -0,0 +1,58 @@ +package gemini_test + +import ( + "flag" + "log" + "os" + "strconv" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + gemini "github.com/mutablelogic/go-llm/pkg/gemini" + assert "github.com/stretchr/testify/assert" +) + +/////////////////////////////////////////////////////////////////////////////// +// TEST SET-UP + +var ( + client *gemini.Client +) + +func TestMain(m *testing.M) { + var verbose bool + + // Verbose output + flag.Parse() + if f := flag.Lookup("test.v"); f != nil { + if v, err := strconv.ParseBool(f.Value.String()); err == nil { + verbose = v + } + } + + // API KEY + api_key := os.Getenv("GEMINI_API_KEY") + if api_key == "" { + log.Print("GEMINI_API_KEY not set") + os.Exit(0) + } + + // Create client + var err error + client, err = gemini.New(api_key, opts.OptTrace(os.Stderr, verbose)) + if err != nil { + log.Println(err) + os.Exit(-1) + } + os.Exit(m.Run()) +} + +/////////////////////////////////////////////////////////////////////////////// +// TESTS + +func Test_client_001(t *testing.T) { + assert := assert.New(t) + assert.NotNil(client) + t.Log(client) +} diff --git a/pkg/gemini/model.go b/pkg/gemini/model.go new file mode 100644 index 0000000..435436e --- /dev/null +++ b/pkg/gemini/model.go @@ -0,0 +1,140 @@ +package gemini + +import ( + "context" + "encoding/json" + + // Packages + "github.com/mutablelogic/go-client" + "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type model struct { + *Client `json:"-"` + meta Model +} + +var _ llm.Model = (*model)(nil) + +type Model struct { + Name string `json:"name"` + Version string `json:"version"` + DisplayName string `json:"displayName"` + Description string `json:"description"` + InputTokenLimit uint64 `json:"inputTokenLimit"` + OutputTokenLimit uint64 `json:"outputTokenLimit"` + SupportedGenerationMethods []string `json:"supportedGenerationMethods"` + Temperature float64 `json:"temperature"` + TopP float64 `json:"topP"` + TopK uint64 `json:"topK"` +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (m model) MarshalJSON() ([]byte, error) { + return json.Marshal(m.meta) +} + +func (m model) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - llm.Model implementation + +// Return model name +func (m model) Name() string { + return m.meta.Name +} + +// Return model aliases +func (model model) Aliases() []string { + return nil +} + +// Return model description +func (model model) Description() string { + return model.meta.Description +} + +// Return the models +func (gemini *Client) Models(ctx context.Context) ([]llm.Model, error) { + // Cache models + if gemini.cache == nil { + models, err := gemini.ListModels(ctx) + if err != nil { + return nil, err + } + gemini.cache = make(map[string]llm.Model, len(models)) + for _, m := range models { + gemini.cache[m.Name] = &model{gemini, m} + } + } + + // Return models + result := make([]llm.Model, 0, len(gemini.cache)) + for _, model := range gemini.cache { + result = append(result, model) + } + return result, nil +} + +// Return a model by name, or nil if not found. +// Panics on error. +func (gemini *Client) Model(ctx context.Context, name string) llm.Model { + if gemini.cache == nil { + if _, err := gemini.Models(ctx); err != nil { + panic(err) + } + } + return gemini.cache[name] +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - API + +// ListModels returns all the models +func (c *Client) ListModels(ctx context.Context) ([]Model, error) { + // Response + var response struct { + Data []Model `json:"models"` + } + if err := c.DoWithContext(ctx, nil, &response, client.OptPath("models")); err != nil { + return nil, err + } + + // Return success + return response.Data, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - MODEL + +// Return am empty session context object for the model, +// setting session options +func (m model) Context(...llm.Opt) llm.Context { + return nil +} + +// Create a completion from a text prompt +func (m model) Completion(context.Context, string, ...llm.Opt) (llm.Completion, error) { + return nil, llm.ErrNotImplemented +} + +// Create a completion from a chat session +func (m model) Chat(context.Context, []llm.Completion, ...llm.Opt) (llm.Completion, error) { + return nil, llm.ErrNotImplemented +} + +// Embedding vector generation +func (m model) Embedding(context.Context, string, ...llm.Opt) ([]float64, error) { + return nil, llm.ErrNotImplemented +} diff --git a/pkg/gemini/model_test.go b/pkg/gemini/model_test.go new file mode 100644 index 0000000..3d18877 --- /dev/null +++ b/pkg/gemini/model_test.go @@ -0,0 +1,22 @@ +package gemini_test + +import ( + "context" + "encoding/json" + "testing" + + // Packages + assert "github.com/stretchr/testify/assert" +) + +func Test_models_001(t *testing.T) { + assert := assert.New(t) + + response, err := client.ListModels(context.TODO()) + assert.NoError(err) + assert.NotEmpty(response) + + data, err := json.MarshalIndent(response, "", " ") + assert.NoError(err) + t.Log(string(data)) +} diff --git a/pkg/internal/impl/modelcache.go b/pkg/internal/impl/modelcache.go new file mode 100644 index 0000000..1c91376 --- /dev/null +++ b/pkg/internal/impl/modelcache.go @@ -0,0 +1,66 @@ +package impl + +import ( + // Packages + "sync" + + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type ModelCache struct { + sync.RWMutex + cache map[string]llm.Model +} + +type ModelLoadFunc func() ([]llm.Model, error) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func NewModelCache() *ModelCache { + cache := new(ModelCache) + cache.cache = make(map[string]llm.Model, 20) + return cache +} + +/////////////////////////////////////////////////////////////////////////////// +// METHODS + +// Load models and return them +func (c *ModelCache) Load(fn ModelLoadFunc) ([]llm.Model, error) { + c.Lock() + defer c.Unlock() + + // Load models + if len(c.cache) == 0 { + if models, err := fn(); err != nil { + return nil, err + } else { + for _, m := range models { + c.cache[m.Name()] = m + } + } + } + + // Return models + result := make([]llm.Model, 0, len(c.cache)) + for _, model := range c.cache { + result = append(result, model) + } + return result, nil +} + +// Return a model by name +func (c *ModelCache) Get(fn ModelLoadFunc, name string) (llm.Model, error) { + if len(c.cache) == 0 { + if _, err := c.Load(fn); err != nil { + return nil, err + } + } + c.RLock() + defer c.RUnlock() + return c.cache[name], nil +} diff --git a/pkg/internal/impl/session.go b/pkg/internal/impl/session.go new file mode 100644 index 0000000..d56846e --- /dev/null +++ b/pkg/internal/impl/session.go @@ -0,0 +1,186 @@ +package impl + +import ( + "context" + "encoding/json" + + // Packages + "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// INTERFACE + +// Abstract interface for a message factory +type MessageFactory interface { + // Generate a system prompt + SystemPrompt(prompt string) llm.Completion + + // Generate a user prompt, with attachments and other options + UserPrompt(string, ...llm.Opt) (llm.Completion, error) + + // Generate an array of results from calling tools + ToolResults(...llm.ToolResult) ([]llm.Completion, error) +} + +type Model interface { + // Additional method for a context object + Chat(ctx context.Context, completions []llm.Completion, opts ...llm.Opt) (llm.Completion, error) +} + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// A chat session with history +type session struct { + model Model // The model used for the session + opts []llm.Opt // Options to apply to the session + seq []llm.Completion // Sequence of messages + factory MessageFactory // Factory for generating messages +} + +var _ llm.Context = (*session)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Create a new empty session to store a context window +func NewSession(model llm.Model, factory MessageFactory, opts ...llm.Opt) *session { + chatmodel, ok := model.(Model) + if !ok || model == nil { + panic("Model does not implement the session.Model interface") + } + return &session{ + model: chatmodel, + opts: opts, + seq: make([]llm.Completion, 0, 10), + factory: factory, + } +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (session session) MarshalJSON() ([]byte, error) { + return json.Marshal(session.seq) +} + +func (session session) String() string { + data, err := json.MarshalIndent(session, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return an array of messages in the session with system prompt. If the +// prompt is empty, no system prompt is prepended +func (session *session) WithSystemPrompt(prompt string) []llm.Completion { + messages := make([]llm.Completion, 0, len(session.seq)+1) + if prompt != "" { + messages = append(messages, session.factory.SystemPrompt(prompt)) + } + return append(messages, session.seq...) +} + +// Append a message to the session +// TODO: Trim the context window to a maximum size +func (session *session) Append(messages ...llm.Completion) { + session.seq = append(session.seq, messages...) +} + +// Generate a response from a user prompt (with attachments and other options) +func (session *session) FromUser(ctx context.Context, prompt string, opts ...llm.Opt) error { + // Append the user prompt + message, err := session.factory.UserPrompt(prompt, opts...) + if err != nil { + return err + } else { + return session.chat(ctx, message) + } +} + +// Generate a response from a tool, passing the results from the tool call +func (session *session) FromTool(ctx context.Context, results ...llm.ToolResult) error { + // Append the tool results + if results, err := session.factory.ToolResults(results...); err != nil { + return err + } else { + return session.chat(ctx, results...) + } +} + +func (session *session) chat(ctx context.Context, messages ...llm.Completion) error { + // Append the messages to the chat + session.Append(messages...) + + // Generate the completion, and append the first choice of the completion + // TODO: Use Opts to select which completion choice to use + completion, err := session.model.Chat(ctx, session.seq, session.opts...) + if err != nil { + return err + } else if completion.Num() == 0 { + return llm.ErrBadParameter.With("No completion choices returned") + } + + // Append the first choice + session.Append(completion.Choice(0)) + + // Success + return nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - COMPLETION + +// Return the number of completions +func (session *session) Num() int { + if len(session.seq) == 0 { + return 0 + } + return session.seq[len(session.seq)-1].Num() +} + +// Return the current session role +func (session *session) Role() string { + if len(session.seq) == 0 { + return "" + } + return session.seq[len(session.seq)-1].Role() +} + +// Return the current session choice +func (session *session) Choice(index int) llm.Completion { + if len(session.seq) == 0 { + return nil + } + return session.seq[len(session.seq)-1].Choice(index) +} + +// Return the text for the last completion +func (session *session) Text(index int) string { + if len(session.seq) == 0 { + return "" + } + return session.seq[len(session.seq)-1].Text(index) +} + +// Return audio for the last completion +func (session *session) Audio(index int) *llm.Attachment { + if len(session.seq) == 0 { + return nil + } + return session.seq[len(session.seq)-1].Audio(index) +} + +// Return the current session tool calls given the completion index. +// Will return nil if no tool calls were returned. +func (session *session) ToolCalls(index int) []llm.ToolCall { + if len(session.seq) == 0 { + return nil + } + return session.seq[len(session.seq)-1].ToolCalls(index) +} diff --git a/pkg/mistral/client.go b/pkg/mistral/client.go index f70f8bf..ac7a523 100644 --- a/pkg/mistral/client.go +++ b/pkg/mistral/client.go @@ -1,14 +1,14 @@ /* -mistral implements an API client for mistral (https://docs.mistral.ai/api/) +mistral implements an API client for mistral +https://docs.mistral.ai/api/ */ package mistral import ( - "context" - // Packages client "github.com/mutablelogic/go-client" llm "github.com/mutablelogic/go-llm" + impl "github.com/mutablelogic/go-llm/pkg/internal/impl" ) /////////////////////////////////////////////////////////////////////////////// @@ -16,7 +16,7 @@ import ( type Client struct { *client.Client - cache map[string]llm.Model + *impl.ModelCache } var _ llm.Agent = (*Client)(nil) @@ -46,7 +46,7 @@ func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) { } // Return the client - return &Client{client, nil}, nil + return &Client{client, impl.NewModelCache()}, nil } /////////////////////////////////////////////////////////////////////////////// @@ -56,36 +56,3 @@ func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) { func (Client) Name() string { return defaultName } - -// Return the models -func (c *Client) Models(ctx context.Context) ([]llm.Model, error) { - // Cache models - if c.cache == nil { - models, err := c.ListModels(ctx) - if err != nil { - return nil, err - } - c.cache = make(map[string]llm.Model, len(models)) - for _, model := range models { - c.cache[model.Name()] = model - } - } - - // Return models - result := make([]llm.Model, 0, len(c.cache)) - for _, model := range c.cache { - result = append(result, model) - } - return result, nil -} - -// Return a model by name, or nil if not found. -// Panics on error. -func (c *Client) Model(ctx context.Context, name string) llm.Model { - if c.cache == nil { - if _, err := c.Models(ctx); err != nil { - panic(err) - } - } - return c.cache[name] -} diff --git a/pkg/mistral/completion.go b/pkg/mistral/completion.go index b0e4bb0..75aab09 100644 --- a/pkg/mistral/completion.go +++ b/pkg/mistral/completion.go @@ -5,6 +5,7 @@ import ( "encoding/json" "strings" + // Packages "github.com/mutablelogic/go-client" "github.com/mutablelogic/go-llm" ) @@ -19,7 +20,18 @@ type Response struct { Created uint64 `json:"created"` Model string `json:"model"` Completions `json:"choices"` - Metrics `json:"usage,omitempty"` + *Metrics `json:"usage,omitempty"` +} + +// Possible completions +type Completions []Completion + +// Completion Variation +type Completion struct { + Index uint64 `json:"index"` + Message *Message `json:"message"` + Delta *Message `json:"delta,omitempty"` // For streaming + Reason string `json:"finish_reason,omitempty"` } // Metrics @@ -42,73 +54,98 @@ func (r Response) String() string { return string(data) } +func (c Completion) String() string { + data, err := json.MarshalIndent(c, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +func (m Metrics) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS type reqChatCompletion struct { - Model string `json:"model"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens uint64 `json:"max_tokens,omitempty"` - Stream bool `json:"stream,omitempty"` - StopSequences []string `json:"stop,omitempty"` - Seed uint64 `json:"random_seed,omitempty"` - Messages []*Message `json:"messages"` - Format any `json:"response_format,omitempty"` - Tools []llm.Tool `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - NumChoices uint64 `json:"n,omitempty"` - Prediction *Content `json:"prediction,omitempty"` - SafePrompt bool `json:"safe_prompt,omitempty"` -} - -func (mistral *Client) ChatCompletion(ctx context.Context, context llm.Context, opts ...llm.Opt) (*Response, error) { + Model string `json:"model"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens uint64 `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + StopSequences []string `json:"stop,omitempty"` + Seed uint64 `json:"random_seed,omitempty"` + Format any `json:"response_format,omitempty"` + Tools []llm.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + NumCompletions uint64 `json:"n,omitempty"` + Prediction *Content `json:"prediction,omitempty"` + SafePrompt bool `json:"safe_prompt,omitempty"` + Messages []llm.Completion `json:"messages"` +} + +// Send a completion request with a single prompt, and return the next completion +func (model *model) Completion(ctx context.Context, prompt string, opts ...llm.Opt) (llm.Completion, error) { + message, err := messagefactory{}.UserPrompt(prompt, opts...) + if err != nil { + return nil, err + } + return model.Chat(ctx, []llm.Completion{message}, opts...) +} + +// Send a completion request with multiple completions, and return the next completion +func (model *model) Chat(ctx context.Context, completions []llm.Completion, opts ...llm.Opt) (llm.Completion, error) { // Apply options opt, err := llm.ApplyOpts(opts...) if err != nil { return nil, err } - // Append the system prompt at the beginning - messages := make([]*Message, 0, len(context.(*session).seq)+1) + // Create the completions including the system prompt + messages := make([]llm.Completion, 0, len(completions)+1) if system := opt.SystemPrompt(); system != "" { - messages = append(messages, systemPrompt(system)) - } - - // Always append the first message of each completion - for _, message := range context.(*session).seq { - messages = append(messages, message) + messages = append(messages, messagefactory{}.SystemPrompt(system)) } + messages = append(messages, completions...) // Request req, err := client.NewJSONRequest(reqChatCompletion{ - Model: context.(*session).model.Name(), + Model: model.Name(), Temperature: optTemperature(opt), TopP: optTopP(opt), MaxTokens: optMaxTokens(opt), Stream: optStream(opt), StopSequences: optStopSequences(opt), Seed: optSeed(opt), - Messages: messages, Format: optFormat(opt), - Tools: optTools(mistral, opt), + Tools: optTools(model.Client, opt), ToolChoice: optToolChoice(opt), PresencePenalty: optPresencePenalty(opt), FrequencyPenalty: optFrequencyPenalty(opt), - NumChoices: optNumCompletions(opt), + NumCompletions: optNumCompletions(opt), Prediction: optPrediction(opt), SafePrompt: optSafePrompt(opt), + Messages: messages, }) if err != nil { return nil, err } + // Response options var response Response reqopts := []client.RequestOpt{ client.OptPath("chat", "completions"), } + + // Streaming if optStream(opt) { reqopts = append(reqopts, client.OptTextStreamCallback(func(evt client.TextStreamEvent) error { if err := streamEvent(&response, evt); err != nil { @@ -122,7 +159,7 @@ func (mistral *Client) ChatCompletion(ctx context.Context, context llm.Context, } // Response - if err := mistral.DoWithContext(ctx, req, &response, reqopts...); err != nil { + if err := model.DoWithContext(ctx, req, &response, reqopts...); err != nil { return nil, err } @@ -131,7 +168,7 @@ func (mistral *Client) ChatCompletion(ctx context.Context, context llm.Context, } /////////////////////////////////////////////////////////////////////////////// -// PRIVATE METHODS +// PRIVATE METHODS - STREAMING func streamEvent(response *Response, evt client.TextStreamEvent) error { var delta Response @@ -147,28 +184,32 @@ func streamEvent(response *Response, evt client.TextStreamEvent) error { if delta.Id != "" { response.Id = delta.Id } + if delta.Type != "" { + response.Type = delta.Type + } if delta.Created != 0 { response.Created = delta.Created } if delta.Model != "" { response.Model = delta.Model } + + // Append the delta to the response for _, completion := range delta.Completions { - appendCompletion(response, &completion) - } - if delta.Metrics.InputTokens > 0 { - response.Metrics.InputTokens += delta.Metrics.InputTokens - } - if delta.Metrics.OutputTokens > 0 { - response.Metrics.OutputTokens += delta.Metrics.OutputTokens + if err := appendCompletion(response, &completion); err != nil { + return err + } } - if delta.Metrics.TotalTokens > 0 { - response.Metrics.TotalTokens += delta.Metrics.TotalTokens + + // Apend the metrics to the response + if delta.Metrics != nil { + response.Metrics = delta.Metrics } return nil } -func appendCompletion(response *Response, c *Completion) { +func appendCompletion(response *Response, c *Completion) error { + // Append a new completion for { if c.Index < uint64(len(response.Completions)) { break @@ -183,18 +224,111 @@ func appendCompletion(response *Response, c *Completion) { }, }) } - // Add the completion delta + + // Add the reason if c.Reason != "" { response.Completions[c.Index].Reason = c.Reason } + + // Get the completion + message := response.Completions[c.Index].Message + if message == nil { + return llm.ErrBadParameter + } + + // Add the role if role := c.Delta.Role(); role != "" { - response.Completions[c.Index].Message.RoleContent.Role = role + message.RoleContent.Role = role + } + + // We only allow deltas which are strings at the moment + if c.Delta.Content != nil { + if str, ok := c.Delta.Content.(string); ok { + if text, ok := message.Content.(string); ok { + message.Content = text + str + } else { + message.Content = str + } + } else { + return llm.ErrNotImplemented.Withf("appendCompletion not implemented: %T", c.Delta.Content) + } } - // TODO: We only allow deltas which are strings at the moment... - if str, ok := c.Delta.Content.(string); ok && str != "" { - if text, ok := response.Completions[c.Index].Message.Content.(string); ok { - response.Completions[c.Index].Message.Content = text + str + // Append tool calls + for i := range c.Delta.Calls { + if i >= len(message.Calls) { + message.Calls = append(message.Calls, toolcall{}) + } + } + + for i, call := range c.Delta.Calls { + if call.meta.Id != "" { + message.Calls[i].meta.Id = call.meta.Id + } + if call.meta.Index != 0 { + message.Calls[i].meta.Index = call.meta.Index + } + if call.meta.Type != "" { + message.Calls[i].meta.Type = call.meta.Type } + if call.meta.Function.Name != "" { + message.Calls[i].meta.Function.Name = call.meta.Function.Name + } + if call.meta.Function.Arguments != "" { + message.Calls[i].meta.Function.Arguments += call.meta.Function.Arguments + } + } + + // Return success + return nil +} + +/////////////////////////////////////////////////////////////////////////////// +// COMPLETIONS + +// Return the number of completions +func (c Completions) Num() int { + return len(c) +} + +// Return message for a specific completion +func (c Completions) Choice(index int) llm.Completion { + if index < 0 || index >= len(c) { + return nil + } + return c[index].Message +} + +// Return the role of the completion +func (c Completions) Role() string { + // The role should be the same for all completions, let's use the first one + if len(c) == 0 { + return "" + } + return c[0].Message.Role() +} + +// Return the text content for a specific completion +func (c Completions) Text(index int) string { + if index < 0 || index >= len(c) { + return "" + } + return c[index].Message.Text(0) +} + +// Return audio content for a specific completion +func (c Completions) Audio(index int) *llm.Attachment { + if index < 0 || index >= len(c) { + return nil + } + return c[index].Message.Audio(0) +} + +// Return the current session tool calls given the completion index. +// Will return nil if no tool calls were returned. +func (c Completions) ToolCalls(index int) []llm.ToolCall { + if index < 0 || index >= len(c) { + return nil } + return c[index].Message.ToolCalls(0) } diff --git a/pkg/mistral/completion_test.go b/pkg/mistral/completion_test.go index 899d364..d5fd83c 100644 --- a/pkg/mistral/completion_test.go +++ b/pkg/mistral/completion_test.go @@ -4,37 +4,110 @@ import ( "context" "fmt" "os" - "strings" "testing" // Packages llm "github.com/mutablelogic/go-llm" - mistral "github.com/mutablelogic/go-llm/pkg/mistral" tool "github.com/mutablelogic/go-llm/pkg/tool" assert "github.com/stretchr/testify/assert" ) -func Test_chat_001(t *testing.T) { +func Test_completion_001(t *testing.T) { assert := assert.New(t) model := client.Model(context.TODO(), "mistral-small-latest") + if !assert.NotNil(model) { + t.FailNow() + } - if assert.NotNil(model) { - response, err := client.ChatCompletion(context.TODO(), model.UserPrompt("Hello, how are you?")) - assert.NoError(err) + response, err := model.Completion(context.TODO(), "Hello, how are you?") + if assert.NoError(err) { assert.NotEmpty(response) t.Log(response) } } +func Test_completion_004(t *testing.T) { + assert := assert.New(t) + + model := client.Model(context.TODO(), "mistral-small-latest") + if !assert.NotNil(model) { + t.FailNow() + } + + // Test tool support + t.Run("Toolkit", func(t *testing.T) { + toolkit := tool.NewToolKit() + toolkit.Register(&weather{}) + session := model.Context(llm.WithToolKit(toolkit)) + + assert.NoError(session.FromUser(context.TODO(), "What is the weather in the capital city of Germany?")) + + assert.Equal("assistant", session.Role()) + assert.Equal(1, session.Num()) + assert.NotEmpty(session.ToolCalls(0)) + + results, err := toolkit.Run(context.TODO(), session.ToolCalls(0)...) + assert.NoError(err) + assert.NotEmpty(results) + + assert.NoError(session.FromTool(context.TODO(), results...)) + + }) +} + +type weather struct { + City string `json:"city" help:"The city to get the weather for" required:"true"` +} + +func (weather) Name() string { + return "weather_in_city" +} + +func (weather) Description() string { + return "Get the weather for a city" +} + +func (w weather) Run(ctx context.Context) (any, error) { + return fmt.Sprintf("The weather in %q is sunny and warm", w.City), nil +} + +func Test_completion_005(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "pixtral-12b-2409") + if !assert.NotNil(model) { + t.FailNow() + } + + // Test image captioning + t.Run("ImageCaption", func(t *testing.T) { + f, err := os.Open("../../etc/testdata/guggenheim.jpg") + if !assert.NoError(err) { + t.FailNow() + } + defer f.Close() + + r, err := model.Completion( + context.TODO(), + "Describe this picture", + llm.WithAttachment(f), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + } + }) +} + +/* func Test_chat_002(t *testing.T) { assert := assert.New(t) - model := client.Model(context.TODO(), "mistral-large-latest") + model := client.Model(context.TODO(), "mistral-small-latest") if !assert.NotNil(model) { t.FailNow() } t.Run("Temperature", func(t *testing.T) { - r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithTemperature(0.5)) + r, err := client.Completion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithTemperature(0.5)) if assert.NoError(err) { assert.Equal("assistant", r.Role()) assert.Equal(1, r.Num()) @@ -236,3 +309,4 @@ func (weather) Description() string { func (w weather) Run(ctx context.Context) (any, error) { return fmt.Sprintf("The weather in %q is sunny and warm", w.City), nil } +*/ diff --git a/pkg/mistral/content.go b/pkg/mistral/content.go new file mode 100644 index 0000000..994ddd0 --- /dev/null +++ b/pkg/mistral/content.go @@ -0,0 +1,47 @@ +package mistral + +import ( + "net/url" + + "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Content struct { + Type string `json:"type"` // text or content + *Text `json:"text,omitempty"` // text content + *Prediction `json:"content,omitempty"` // prediction + *Image `json:"image_url,omitempty"` // image_url +} + +// text content +type Text string + +// text content +type Prediction string + +// either a URL or "data:image/png;base64," followed by the base64 encoded image +type Image string + +/////////////////////////////////////////////////////////////////////////////// +// LICECYCLE + +func NewPrediction(content Prediction) *Content { + return &Content{Type: "content", Prediction: &content} +} + +func NewTextContext(text Text) *Content { + return &Content{Type: "text", Text: &text} +} + +func NewImageData(image *llm.Attachment) *Content { + url := Image(image.Url()) + return &Content{Type: "image_url", Image: &url} +} + +func NewImageUrl(u *url.URL) *Content { + url := Image(u.String()) + return &Content{Type: "image_url", Image: &url} +} diff --git a/pkg/mistral/message.go b/pkg/mistral/message.go index 6300b9e..8af898a 100644 --- a/pkg/mistral/message.go +++ b/pkg/mistral/message.go @@ -1,175 +1,72 @@ package mistral import ( - "encoding/json" - // Packages "github.com/mutablelogic/go-llm" - "github.com/mutablelogic/go-llm/pkg/tool" ) /////////////////////////////////////////////////////////////////////////////// // TYPES -// Possible completions -type Completions []Completion - -var _ llm.Completion = Completions{} - // Message with text or object content type Message struct { RoleContent - ToolCallArray `json:"tool_calls,omitempty"` + Calls ToolCalls `json:"tool_calls,omitempty"` } type RoleContent struct { Role string `json:"role,omitempty"` // assistant, user, tool, system Content any `json:"content,omitempty"` // string or array of text, reference, image_url - Id string `json:"tool_call_id,omitempty"` // tool call - when role is tool Name string `json:"name,omitempty"` // function name - when role is tool -} - -// Completion Variation -type Completion struct { - Index uint64 `json:"index"` - Message *Message `json:"message"` - Delta *Message `json:"delta,omitempty"` // For streaming - Reason string `json:"finish_reason,omitempty"` + Id string `json:"tool_call_id,omitempty"` // tool call - when role is tool } var _ llm.Completion = (*Message)(nil) -type Content struct { - Type string `json:"type,omitempty"` // text, reference, image_url - *Text `json:"text,omitempty"` // text content - *Prediction `json:"content,omitempty"` // prediction - *Image `json:"image_url,omitempty"` // image_url -} - -// A set of tool calls -type ToolCallArray []ToolCall - -// text content -type Text string - -// text content -type Prediction string - -// either a URL or "data:image/png;base64," followed by the base64 encoded image -type Image string - -/////////////////////////////////////////////////////////////////////////////// -// LIFECYCLE - -// Return a Content object with text content (either in "text" or "prediction" field) -func NewContent(t, v, p string) *Content { - content := new(Content) - content.Type = t - if v != "" { - content.Text = (*Text)(&v) - } - if p != "" { - content.Prediction = (*Prediction)(&p) - } - return content -} - -// Return a Content object with text content -func NewTextContent(v string) *Content { - return NewContent("text", v, "") -} - -// Return an image attachment -func NewImageAttachment(a *llm.Attachment) *Content { - content := new(Content) - image := a.Url() - content.Type = "image_url" - content.Image = (*Image)(&image) - return content -} - /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS - MESSAGE -func (m Message) Num() int { +func (Message) Num() int { return 1 } -func (m Message) Role() string { - return m.RoleContent.Role +func (message *Message) Role() string { + return message.RoleContent.Role +} + +// Return the completion +func (message *Message) Choice(index int) llm.Completion { + if index != 0 { + return nil + } + return message } -func (m Message) Text(index int) string { +func (message *Message) Text(index int) string { if index != 0 { return "" } // If content is text, return it - if text, ok := m.Content.(string); ok { + if text, ok := message.Content.(string); ok { return text } // For other kinds, return empty string for the moment return "" } -func (m Message) ToolCalls(index int) []llm.ToolCall { - if index != 0 { - return nil - } - - // Make the tool calls - calls := make([]llm.ToolCall, 0, len(m.ToolCallArray)) - for _, call := range m.ToolCallArray { - var args map[string]any - if call.Function.Arguments != "" { - if err := json.Unmarshal([]byte(call.Function.Arguments), &args); err != nil { - return nil - } - } - calls = append(calls, tool.NewCall(call.Id, call.Function.Name, args)) - } - - // Return success - return calls -} - -/////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - COMPLETIONS - -// Return the number of completions -func (c Completions) Num() int { - return len(c) +// Unsupported +func (message *Message) Audio(index int) *llm.Attachment { + return nil } -// Return message for a specific completion -func (c Completions) Message(index int) *Message { - if index < 0 || index >= len(c) { +// Return all the tool calls +func (message *Message) ToolCalls(index int) []llm.ToolCall { + if index != 0 { return nil } - return c[index].Message -} - -// Return the role of the completion -func (c Completions) Role() string { - // The role should be the same for all completions, let's use the first one - if len(c) == 0 { - return "" + calls := make([]llm.ToolCall, 0, len(message.Calls)) + for _, call := range message.Calls { + calls = append(calls, call) } - return c[0].Message.Role() -} - -// Return the text content for a specific completion -func (c Completions) Text(index int) string { - if index < 0 || index >= len(c) { - return "" - } - return c[index].Message.Text(0) -} - -// Return the current session tool calls given the completion index. -// Will return nil if no tool calls were returned. -func (c Completions) ToolCalls(index int) []llm.ToolCall { - if index < 0 || index >= len(c) { - return nil - } - return c[index].Message.ToolCalls(0) + return calls } diff --git a/pkg/mistral/messagefactory.go b/pkg/mistral/messagefactory.go new file mode 100644 index 0000000..07b8f10 --- /dev/null +++ b/pkg/mistral/messagefactory.go @@ -0,0 +1,78 @@ +package mistral + +import ( + "encoding/json" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type messagefactory struct{} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - MESSAGE FACTORY + +func (messagefactory) SystemPrompt(prompt string) llm.Completion { + return &Message{ + RoleContent: RoleContent{ + Role: "system", + Content: prompt, + }, + } +} + +func (messagefactory) UserPrompt(prompt string, opts ...llm.Opt) (llm.Completion, error) { + // Get attachments + opt, err := llm.ApplyPromptOpts(opts...) + if err != nil { + return nil, err + } + + // Get attachments, allocate content + attachments := opt.Attachments() + content := make([]*Content, 1, len(attachments)+1) + + // Append the text and the attachments + content[0] = NewTextContext(Text(prompt)) + for _, attachment := range attachments { + content = append(content, NewImageData(attachment)) + } + + // Return success + return &Message{ + RoleContent: RoleContent{ + Role: "user", + Content: content, + }, + }, nil +} + +func (messagefactory) ToolResults(results ...llm.ToolResult) ([]llm.Completion, error) { + // Check for no results + if len(results) == 0 { + return nil, llm.ErrBadParameter.Withf("No tool results") + } + + // Create results + messages := make([]llm.Completion, 0, len(results)) + for _, result := range results { + value, err := json.Marshal(result.Value()) + if err != nil { + return nil, err + } + messages = append(messages, &Message{ + RoleContent: RoleContent{ + Role: "tool", + Name: result.Call().Name(), + Content: string(value), + Id: result.Call().Id(), + }, + }) + } + + // Return success + return messages, nil +} diff --git a/pkg/mistral/model.go b/pkg/mistral/model.go index ab7374b..06d8320 100644 --- a/pkg/mistral/model.go +++ b/pkg/mistral/model.go @@ -4,8 +4,10 @@ import ( "context" "encoding/json" - "github.com/mutablelogic/go-client" - "github.com/mutablelogic/go-llm" + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" + impl "github.com/mutablelogic/go-llm/pkg/internal/impl" ) /////////////////////////////////////////////////////////////////////////////// @@ -52,33 +54,94 @@ func (m model) String() string { return string(data) } +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - llm.Agent + +// Return the models +func (mistral *Client) Models(ctx context.Context) ([]llm.Model, error) { + return mistral.ModelCache.Load(func() ([]llm.Model, error) { + return mistral.loadmodels(ctx) + }) +} + +// Return a model by name, or nil if not found. +// Panics on error. +func (mistral *Client) Model(ctx context.Context, name string) llm.Model { + model, err := mistral.ModelCache.Get(func() ([]llm.Model, error) { + return mistral.loadmodels(ctx) + }, name) + if err != nil { + panic(err) + } + return model +} + +// Function called to load models +func (mistral *Client) loadmodels(ctx context.Context) ([]llm.Model, error) { + if models, err := mistral.ListModels(ctx); err != nil { + return nil, err + } else { + result := make([]llm.Model, len(models)) + for i, meta := range models { + result[i] = &model{mistral, meta} + } + return result, nil + } +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - llm.Model + +// Return model name +func (model model) Name() string { + return model.meta.Name +} + +// Return model aliases +func (model model) Aliases() []string { + return model.meta.Aliases +} + +// Return model description +func (model model) Description() string { + return model.meta.Description +} + +// Return a new empty session +func (model *model) Context(opts ...llm.Opt) llm.Context { + return impl.NewSession(model, &messagefactory{}, opts...) +} + /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS - API // ListModels returns all the models -func (c *Client) ListModels(ctx context.Context) ([]llm.Model, error) { +func (mistral *Client) ListModels(ctx context.Context) ([]Model, error) { // Response var response struct { Data []Model `json:"data"` } - if err := c.DoWithContext(ctx, nil, &response, client.OptPath("models")); err != nil { + if err := mistral.DoWithContext(ctx, nil, &response, client.OptPath("models")); err != nil { return nil, err } - // Make models - result := make([]llm.Model, 0, len(response.Data)) - for _, meta := range response.Data { - result = append(result, &model{c, meta}) + // Return success + return response.Data, nil +} + +// GetModel returns one model +func (mistral *Client) GetModel(ctx context.Context, model string) (*Model, error) { + // Return the response + var response Model + if err := mistral.DoWithContext(ctx, nil, &response, client.OptPath("models", model)); err != nil { + return nil, err } - // Return models - return result, nil + // Return success + return &response, nil } -/////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - MODEL - -// Return the name of the model -func (m model) Name() string { - return m.meta.Name +// Delete a fine-tuned model +func (mistral *Client) DeleteModel(ctx context.Context, model string) error { + return mistral.DoWithContext(ctx, client.MethodDelete, nil, client.OptPath("models", model)) } diff --git a/pkg/mistral/opt.go b/pkg/mistral/opt.go index 71e82da..ceb4e74 100644 --- a/pkg/mistral/opt.go +++ b/pkg/mistral/opt.go @@ -3,19 +3,10 @@ package mistral import ( "strings" + // Packages "github.com/mutablelogic/go-llm" ) -/////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - -func WithPrediction(v string) llm.Opt { - return func(o *llm.Opts) error { - o.Set("prediction", v) - return nil - } -} - /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS @@ -112,7 +103,7 @@ func optPrediction(opts *llm.Opts) *Content { if prediction == "" { return nil } - return NewContent("content", "", prediction) + return NewPrediction(Prediction(prediction)) } func optSafePrompt(opts *llm.Opts) bool { diff --git a/pkg/mistral/session.go b/pkg/mistral/session.go deleted file mode 100644 index 3b50539..0000000 --- a/pkg/mistral/session.go +++ /dev/null @@ -1,219 +0,0 @@ -package mistral - -import ( - "context" - "encoding/json" - - // Packages - llm "github.com/mutablelogic/go-llm" -) - -////////////////////////////////////////////////////////////////// -// TYPES - -type session struct { - model *model // The model used for the session - opts []llm.Opt // Options to apply to the session - seq []*Message // Sequence of messages -} - -var _ llm.Context = (*session)(nil) - -/////////////////////////////////////////////////////////////////////////////// -// LIFECYCLE - -// Return an empty session context object for the model, setting session options -func (model *model) Context(opts ...llm.Opt) llm.Context { - return &session{ - model: model, - opts: opts, - seq: make([]*Message, 0, 10), - } -} - -// Convenience method to create a session context object with a user prompt, which -// panics on error -func (model *model) UserPrompt(prompt string, opts ...llm.Opt) llm.Context { - context := model.Context(opts...) - - // Create a user prompt - message, err := userPrompt(prompt, opts...) - if err != nil { - panic(err) - } - - // Add to the sequence - context.(*session).seq = append(context.(*session).seq, message) - - // Return success - return context -} - -/////////////////////////////////////////////////////////////////////////////// -// STRINGIFY - -func (session session) String() string { - var data []byte - var err error - if len(session.seq) == 1 { - data, err = json.MarshalIndent(session.seq[0], "", " ") - } else { - data, err = json.MarshalIndent(session.seq, "", " ") - } - if err != nil { - return err.Error() - } - return string(data) -} - -////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - -// Return the number of completions -func (session *session) Num() int { - if len(session.seq) == 0 { - return 0 - } - return 1 -} - -// Return the role of the last message -func (session *session) Role() string { - if len(session.seq) == 0 { - return "" - } - return session.seq[len(session.seq)-1].Role() -} - -// Return the text of the last message -func (session *session) Text(index int) string { - if len(session.seq) == 0 { - return "" - } - return session.seq[len(session.seq)-1].Text(index) -} - -// Return tool calls for the last message -func (session *session) ToolCalls(index int) []llm.ToolCall { - if len(session.seq) == 0 { - return nil - } - return session.seq[len(session.seq)-1].ToolCalls(index) -} - -// Generate a response from a user prompt (with attachments and -// other options) -func (session *session) FromUser(ctx context.Context, prompt string, opts ...llm.Opt) error { - message, err := userPrompt(prompt, opts...) - if err != nil { - return err - } - - // Append the user prompt to the sequence - session.seq = append(session.seq, message) - - // The options come from the session options and the user options - chatopts := make([]llm.Opt, 0, len(session.opts)+len(opts)) - chatopts = append(chatopts, session.opts...) - chatopts = append(chatopts, opts...) - - // Call the 'chat completion' method - r, err := session.model.ChatCompletion(ctx, session, chatopts...) - if err != nil { - return err - } - - // Append the first message from the set of completions - session.seq = append(session.seq, r.Completions.Message(0)) - - // Return success - return nil -} - -// Generate a response from a tool, passing the results from the tool call -func (session *session) FromTool(ctx context.Context, results ...llm.ToolResult) error { - messages, err := toolResults(results...) - if err != nil { - return err - } - - // Append the tool results to the sequence - session.seq = append(session.seq, messages...) - - // Call the 'chat' method - r, err := session.model.ChatCompletion(ctx, session, session.opts...) - if err != nil { - return err - } - - // Append the first message from the set of completions - session.seq = append(session.seq, r.Completions.Message(0)) - - // Return success - return nil -} - -/////////////////////////////////////////////////////////////////////////////// -// PRIVATE METHODS - -func systemPrompt(prompt string) *Message { - return &Message{ - RoleContent: RoleContent{ - Role: "system", - Content: prompt, - }, - } -} - -func userPrompt(prompt string, opts ...llm.Opt) (*Message, error) { - // Get attachments - opt, err := llm.ApplyPromptOpts(opts...) - if err != nil { - return nil, err - } - - // Get attachments, allocate content - attachments := opt.Attachments() - content := make([]*Content, 1, len(attachments)+1) - - // Append the text and the attachments - content[0] = NewTextContent(prompt) - for _, attachment := range attachments { - content = append(content, NewImageAttachment(attachment)) - } - - // Return success - return &Message{ - RoleContent: RoleContent{ - Role: "user", - Content: content, - }, - }, nil -} - -func toolResults(results ...llm.ToolResult) ([]*Message, error) { - // Check for no results - if len(results) == 0 { - return nil, llm.ErrBadParameter.Withf("No tool results") - } - - // Create results - messages := make([]*Message, 0, len(results)) - for _, result := range results { - value, err := json.Marshal(result.Value()) - if err != nil { - return nil, err - } - messages = append(messages, &Message{ - RoleContent: RoleContent{ - Role: "tool", - Id: result.Call().Id(), - Name: result.Call().Name(), - Content: string(value), - }, - }) - } - - // Return success - return messages, nil -} diff --git a/pkg/mistral/tool.go b/pkg/mistral/tool.go index 255146e..0eecbf9 100644 --- a/pkg/mistral/tool.go +++ b/pkg/mistral/tool.go @@ -7,8 +7,11 @@ import ( /////////////////////////////////////////////////////////////////////////////// // TYPES +type ToolCalls []toolcall + type ToolCall struct { Id string `json:"id,omitempty"` // tool id + Type string `json:"type,omitempty"` // tool type (function) Index uint64 `json:"index,omitempty"` // tool index Function struct { Name string `json:"name,omitempty"` // tool name @@ -23,6 +26,10 @@ type toolcall struct { /////////////////////////////////////////////////////////////////////////////// // STRINGIFY +func (t *toolcall) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &t.meta) +} + func (t toolcall) MarshalJSON() ([]byte, error) { return json.Marshal(t.meta) } @@ -34,3 +41,21 @@ func (t toolcall) String() string { } return string(data) } + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// The tool name +func (t toolcall) Name() string { + return t.meta.Function.Name +} + +// The tool identifier +func (t toolcall) Id() string { + return t.meta.Id +} + +// Decode the calling parameters +func (t toolcall) Decode(v any) error { + return json.Unmarshal([]byte(t.meta.Function.Arguments), v) +} diff --git a/pkg/ollama/chat.go b/pkg/ollama/chat.go deleted file mode 100644 index 175e0b5..0000000 --- a/pkg/ollama/chat.go +++ /dev/null @@ -1,148 +0,0 @@ -package ollama - -import ( - "context" - "encoding/json" - "time" - - // Packages - client "github.com/mutablelogic/go-client" - llm "github.com/mutablelogic/go-llm" -) - -/////////////////////////////////////////////////////////////////////////////// -// TYPES - -// Chat Completion Response -type Response struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Done bool `json:"done"` - Reason string `json:"done_reason,omitempty"` - Message `json:"message"` - Metrics -} - -// Metrics -type Metrics struct { - TotalDuration time.Duration `json:"total_duration,omitempty"` - LoadDuration time.Duration `json:"load_duration,omitempty"` - PromptEvalCount int `json:"prompt_eval_count,omitempty"` - PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` - EvalCount int `json:"eval_count,omitempty"` - EvalDuration time.Duration `json:"eval_duration,omitempty"` -} - -var _ llm.Completion = (*Response)(nil) - -/////////////////////////////////////////////////////////////////////////////// -// STRINGIFY - -func (r Response) String() string { - data, err := json.MarshalIndent(r, "", " ") - if err != nil { - return err.Error() - } - return string(data) -} - -/////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - -type reqChat struct { - Model string `json:"model"` - Messages []*Message `json:"messages"` - Tools []llm.Tool `json:"tools,omitempty"` - Format string `json:"format,omitempty"` - Options map[string]interface{} `json:"options,omitempty"` - Stream bool `json:"stream"` - KeepAlive *time.Duration `json:"keep_alive,omitempty"` -} - -func (ollama *Client) Chat(ctx context.Context, context llm.Context, opts ...llm.Opt) (*Response, error) { - // Apply options - opt, err := llm.ApplyOpts(opts...) - if err != nil { - return nil, err - } - - // Append the system prompt at the beginning - messages := make([]*Message, 0, len(context.(*session).seq)+1) - if system := opt.SystemPrompt(); system != "" { - messages = append(messages, systemPrompt(system)) - } - - // Always append the first message of each completion - for _, message := range context.(*session).seq { - messages = append(messages, message) - } - - // Request - req, err := client.NewJSONRequest(reqChat{ - Model: context.(*session).model.Name(), - Messages: messages, - Tools: optTools(ollama, opt), - Format: optFormat(opt), - Options: optOptions(opt), - Stream: optStream(ollama, opt), - KeepAlive: optKeepAlive(opt), - }) - if err != nil { - return nil, err - } - - // Response - var response, delta Response - reqopts := []client.RequestOpt{ - client.OptPath("chat"), - } - if optStream(ollama, opt) { - reqopts = append(reqopts, client.OptJsonStreamCallback(func(v any) error { - if v, ok := v.(*Response); !ok || v == nil { - return llm.ErrConflict.Withf("Invalid stream response: %v", v) - } else if err := streamEvent(&response, v); err != nil { - return err - } - if fn := opt.StreamFn(); fn != nil { - fn(&response) - } - return nil - })) - } - - // Response - if err := ollama.DoWithContext(ctx, req, &delta, reqopts...); err != nil { - return nil, err - } - - // Return success - if optStream(ollama, opt) { - return &response, nil - } else { - return &delta, nil - } -} - -/////////////////////////////////////////////////////////////////////////////// -// PRIVATE METHODS - -func streamEvent(response, delta *Response) error { - if delta.Model != "" { - response.Model = delta.Model - } - if !delta.CreatedAt.IsZero() { - response.CreatedAt = delta.CreatedAt - } - if delta.Message.RoleContent.Role != "" { - response.Message.RoleContent.Role = delta.Message.RoleContent.Role - } - if delta.Message.RoleContent.Content != "" { - response.Message.RoleContent.Content += delta.Message.RoleContent.Content - } - if delta.Done { - response.Done = delta.Done - response.Metrics = delta.Metrics - response.Reason = delta.Reason - } - return nil -} diff --git a/pkg/ollama/chat_test.go b/pkg/ollama/chat_test.go deleted file mode 100644 index b0ea189..0000000 --- a/pkg/ollama/chat_test.go +++ /dev/null @@ -1,189 +0,0 @@ -package ollama_test - -import ( - "context" - "fmt" - "os" - "strings" - "testing" - - // Packages - - llm "github.com/mutablelogic/go-llm" - ollama "github.com/mutablelogic/go-llm/pkg/ollama" - tool "github.com/mutablelogic/go-llm/pkg/tool" - assert "github.com/stretchr/testify/assert" -) - -func Test_chat_001(t *testing.T) { - // Pull the model - model, err := client.PullModel(context.TODO(), "qwen:0.5b", ollama.WithPullStatus(func(status *ollama.PullStatus) { - t.Log(status) - })) - if err != nil { - t.FailNow() - } - - t.Run("Temperature", func(t *testing.T) { - assert := assert.New(t) - response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithTemperature(0.5)) - if !assert.NoError(err) { - t.FailNow() - } - t.Log(response) - }) - - t.Run("TopP", func(t *testing.T) { - assert := assert.New(t) - response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithTopP(0.5)) - if !assert.NoError(err) { - t.FailNow() - } - t.Log(response) - }) - t.Run("TopK", func(t *testing.T) { - assert := assert.New(t) - response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithTopK(50)) - if !assert.NoError(err) { - t.FailNow() - } - t.Log(response) - }) - - t.Run("Stream", func(t *testing.T) { - assert := assert.New(t) - response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithStream(func(stream llm.Completion) { - t.Log(stream) - })) - if !assert.NoError(err) { - t.FailNow() - } - t.Log(response) - }) - - t.Run("Stop", func(t *testing.T) { - assert := assert.New(t) - response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithStopSequence("sky")) - if !assert.NoError(err) { - t.FailNow() - } - t.Log(response) - }) - - t.Run("System", func(t *testing.T) { - assert := assert.New(t) - response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithSystemPrompt("reply as if you are shakespeare")) - if !assert.NoError(err) { - t.FailNow() - } - t.Log(response) - }) - - t.Run("Seed", func(t *testing.T) { - assert := assert.New(t) - response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithSeed(1234)) - if !assert.NoError(err) { - t.FailNow() - } - t.Log(response) - }) - - t.Run("Format", func(t *testing.T) { - assert := assert.New(t) - response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue? Reply in JSON format"), llm.WithFormat("json")) - if !assert.NoError(err) { - t.FailNow() - } - t.Log(response) - }) - - t.Run("PresencePenalty", func(t *testing.T) { - assert := assert.New(t) - response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?t"), llm.WithPresencePenalty(-1.0)) - if !assert.NoError(err) { - t.FailNow() - } - t.Log(response) - }) - - t.Run("FrequencyPenalty", func(t *testing.T) { - assert := assert.New(t) - response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?t"), llm.WithFrequencyPenalty(1.0)) - if !assert.NoError(err) { - t.FailNow() - } - t.Log(response) - }) -} - -func Test_chat_002(t *testing.T) { - assert := assert.New(t) - model, err := client.PullModel(context.TODO(), "llava:7b") - if !assert.NoError(err) { - t.FailNow() - } - assert.NotNil(model) - - f, err := os.Open("testdata/guggenheim.jpg") - if !assert.NoError(err) { - t.FailNow() - } - defer f.Close() - - // Describe an image - r, err := client.Chat(context.TODO(), model.UserPrompt("Provide a short caption for this image", llm.WithAttachment(f))) - if assert.NoError(err) { - assert.Equal("assistant", r.Role()) - assert.Equal(1, r.Num()) - assert.NotEmpty(r.Text(0)) - t.Log(r.Text(0)) - } -} - -func Test_chat_003(t *testing.T) { - assert := assert.New(t) - model, err := client.PullModel(context.TODO(), "llama3.2") - if !assert.NoError(err) { - t.FailNow() - } - assert.NotNil(model) - - toolkit := tool.NewToolKit() - toolkit.Register(&weather{}) - - // Get the weather for a city - r, err := client.Chat(context.TODO(), model.UserPrompt("What is the weather in the capital city of germany?"), llm.WithToolKit(toolkit)) - if assert.NoError(err) { - assert.Equal("assistant", r.Role()) - assert.Equal(1, r.Num()) - - calls := r.ToolCalls(0) - assert.NotEmpty(calls) - - var w weather - assert.NoError(calls[0].Decode(&w)) - assert.Equal("berlin", strings.ToLower(w.City)) - } -} - -type weather struct { - City string `json:"city" help:"The city to get the weather for"` -} - -func (weather) Name() string { - return "weather_in_city" -} - -func (weather) Description() string { - return "Get the weather for a city" -} - -func (w weather) Run(ctx context.Context) (any, error) { - var result struct { - City string `json:"city"` - Weather string `json:"weather"` - } - result.City = w.City - result.Weather = fmt.Sprintf("The weather in %q is sunny and warm", w.City) - return result, nil -} diff --git a/pkg/ollama/client.go b/pkg/ollama/client.go index 56d9c62..33b0b8b 100644 --- a/pkg/ollama/client.go +++ b/pkg/ollama/client.go @@ -1,7 +1,10 @@ +/* +ollama implements an API client for ollama +https://github.com/ollama/ollama/blob/main/docs/api.md +*/ package ollama import ( - // Packages client "github.com/mutablelogic/go-client" llm "github.com/mutablelogic/go-llm" @@ -14,7 +17,6 @@ type Client struct { *client.Client } -// Ensure it satisfies the agent.Agent interface var _ llm.Agent = (*Client)(nil) /////////////////////////////////////////////////////////////////////////////// diff --git a/pkg/ollama/client_test.go b/pkg/ollama/client_test.go index e1987c2..a7f3452 100644 --- a/pkg/ollama/client_test.go +++ b/pkg/ollama/client_test.go @@ -6,6 +6,7 @@ import ( "os" "strconv" "testing" + "time" // Packages opts "github.com/mutablelogic/go-client" @@ -40,7 +41,7 @@ func TestMain(m *testing.M) { // Create client var err error - client, err = ollama.New(endpoint_url, opts.OptTrace(os.Stderr, verbose)) + client, err = ollama.New(endpoint_url, opts.OptTrace(os.Stderr, verbose), opts.OptTimeout(5*time.Minute)) if err != nil { log.Println(err) os.Exit(-1) diff --git a/pkg/ollama/completion.go b/pkg/ollama/completion.go new file mode 100644 index 0000000..6d13a9e --- /dev/null +++ b/pkg/ollama/completion.go @@ -0,0 +1,242 @@ +package ollama + +import ( + "context" + "encoding/json" + "strings" + "time" + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Chat Response +type Response struct { + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Done bool `json:"done"` + Reason string `json:"done_reason,omitempty"` + Response *string `json:"response,omitempty"` // For completion + Message `json:"message"` // For chat + Metrics +} + +var _ llm.Completion = (*Response)(nil) + +// Metrics +type Metrics struct { + TotalDuration time.Duration `json:"total_duration,omitempty"` + LoadDuration time.Duration `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration time.Duration `json:"eval_duration,omitempty"` +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (r Response) String() string { + data, err := json.MarshalIndent(r, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// https://github.com/ollama/ollama/blob/main/api/types.go +type reqCompletion struct { + // Model name + Model string `json:"model"` + + // Prompt is the textual prompt to send to the model. + Prompt string `json:"prompt"` + + // Suffix is the text that comes after the inserted text. + Suffix string `json:"suffix,omitempty"` + + // System overrides the model's default system message/prompt. + System string `json:"system,omitempty"` + + // Template overrides the model's default prompt template. + Template string `json:"template,omitempty"` + + // Stream specifies whether the response is streaming; it is true by default. + Stream *bool `json:"stream,omitempty"` + + // Raw set to true means that no formatting will be applied to the prompt. + Raw bool `json:"raw,omitempty"` + + // Format specifies the format to return a response in. + Format json.RawMessage `json:"format,omitempty"` + + // KeepAlive controls how long the model will stay loaded in memory following + // this request. + KeepAlive *time.Duration `json:"keep_alive,omitempty"` + + // Images is an optional list of base64-encoded images accompanying this + // request, for multimodal models. + Images []ImageData `json:"images,omitempty"` + + // Options lists model-specific options. For example, temperature can be + // set through this field, if the model supports it. + Options map[string]any `json:"options,omitempty"` +} + +// Create a completion from a prompt +func (model *model) Completion(ctx context.Context, prompt string, opts ...llm.Opt) (llm.Completion, error) { + // Apply options - including prompt options + opt, err := llm.ApplyPromptOpts(opts...) + if err != nil { + return nil, err + } + + // Make images + images := make([]ImageData, 0, len(opt.Attachments())) + for _, attachment := range opt.Attachments() { + if !strings.HasPrefix(attachment.Type(), "image/") { + return nil, llm.ErrBadParameter.Withf("Attachment is not an image: %v", attachment.Filename()) + } + images = append(images, attachment.Data()) + } + + // Request + req, err := client.NewJSONRequest(reqCompletion{ + Model: model.Name(), + Prompt: prompt, + System: opt.SystemPrompt(), + Stream: optStream(model.Client, opt), + Format: json.RawMessage(optFormat(opt)), + KeepAlive: optKeepAlive(opt), + Images: images, + Options: optOptions(opt), + }) + if err != nil { + return nil, err + } + + // Make the request + return model.request(ctx, req, opt.StreamFn(), client.OptPath("generate")) +} + +type reqChat struct { + Model string `json:"model"` + Messages []llm.Completion `json:"messages"` + Tools []llm.Tool `json:"tools,omitempty"` + Format string `json:"format,omitempty"` + Options map[string]any `json:"options,omitempty"` + Stream *bool `json:"stream"` + KeepAlive *time.Duration `json:"keep_alive,omitempty"` +} + +// Create a completion from a chat session +func (model *model) Chat(ctx context.Context, completions []llm.Completion, opts ...llm.Opt) (llm.Completion, error) { + // Apply options + opt, err := llm.ApplyOpts(opts...) + if err != nil { + return nil, err + } + + // Create the completions including the system prompt + messages := make([]llm.Completion, 0, len(completions)+1) + if system := opt.SystemPrompt(); system != "" { + messages = append(messages, messagefactory{}.SystemPrompt(system)) + } + messages = append(messages, completions...) + + // Request + req, err := client.NewJSONRequest(reqChat{ + Model: model.Name(), + Messages: messages, + Tools: optTools(model.Client, opt), + Format: optFormat(opt), + Options: optOptions(opt), + Stream: optStream(model.Client, opt), + KeepAlive: optKeepAlive(opt), + }) + if err != nil { + return nil, err + } + + // Make the request + return model.request(ctx, req, opt.StreamFn(), client.OptPath("chat")) +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func (model *model) request(ctx context.Context, req client.Payload, streamfn func(llm.Completion), opts ...client.RequestOpt) (*Response, error) { + var delta, response Response + if streamfn != nil { + opts = append(opts, client.OptJsonStreamCallback(func(v any) error { + if v, ok := v.(*Response); !ok || v == nil { + return llm.ErrConflict.Withf("Invalid stream response: %v", delta) + } else if err := streamEvent(&response, v); err != nil { + return err + } + if fn := streamfn; fn != nil { + fn(&response) + } + return nil + })) + } + + // Response + if err := model.DoWithContext(ctx, req, &delta, opts...); err != nil { + return nil, err + } + + // Return success + if streamfn != nil { + return &response, nil + } else if delta.Response != nil { + delta.Message = Message{ + RoleContent: RoleContent{ + Role: "user", + Content: *delta.Response, + }, + } + } + return &delta, nil +} + +func streamEvent(response, delta *Response) error { + // Completion instead of chat + if delta.Response != nil { + delta.Message = Message{ + RoleContent: RoleContent{ + Role: "user", + Content: *delta.Response, + }, + } + } + + // Update response from the delta + if delta.Model != "" { + response.Model = delta.Model + } + if !delta.CreatedAt.IsZero() { + response.CreatedAt = delta.CreatedAt + } + if delta.Message.RoleContent.Role != "" { + response.Message.RoleContent.Role = delta.Message.RoleContent.Role + } + if delta.Message.RoleContent.Content != "" { + response.Message.RoleContent.Content += delta.Message.RoleContent.Content + } + if delta.Done { + response.Done = delta.Done + response.Metrics = delta.Metrics + response.Reason = delta.Reason + } + + // Return success + return nil +} diff --git a/pkg/ollama/completion_test.go b/pkg/ollama/completion_test.go new file mode 100644 index 0000000..0ae884d --- /dev/null +++ b/pkg/ollama/completion_test.go @@ -0,0 +1,187 @@ +package ollama_test + +import ( + "context" + "fmt" + "os" + "testing" + + // Packages + + llm "github.com/mutablelogic/go-llm" + ollama "github.com/mutablelogic/go-llm/pkg/ollama" + tool "github.com/mutablelogic/go-llm/pkg/tool" + assert "github.com/stretchr/testify/assert" +) + +func Test_completion_001(t *testing.T) { + assert := assert.New(t) + + // Pull the model + model, err := client.PullModel(context.TODO(), "qwen:0.5b", ollama.WithPullStatus(func(status *ollama.PullStatus) { + t.Log(status) + })) + if err != nil { + t.FailNow() + } + + // Get a completion + response, err := model.Completion(context.TODO(), "Hello, how are you?") + if assert.NoError(err) { + assert.NotEmpty(response) + } +} + +func Test_completion_002(t *testing.T) { + assert := assert.New(t) + + // Pull the model + model, err := client.PullModel(context.TODO(), "qwen:0.5b", ollama.WithPullStatus(func(status *ollama.PullStatus) { + t.Log(status) + })) + if err != nil { + t.FailNow() + } + + t.Run("Temperature", func(t *testing.T) { + response, err := model.Completion(context.TODO(), "Tell me in less than five words why the sky is blue?", llm.WithTemperature(0.5)) + if assert.NoError(err) { + t.Log(response) + } + }) + + t.Run("TopP", func(t *testing.T) { + response, err := model.Completion(context.TODO(), "Tell me in less than five words why the sky is blue?", llm.WithTopP(0.5)) + if assert.NoError(err) { + t.Log(response) + } + }) + + t.Run("TopK", func(t *testing.T) { + response, err := model.Completion(context.TODO(), "Tell me in less than five words why the sky is blue?", llm.WithTopK(50)) + if assert.NoError(err) { + t.Log(response) + } + }) + + t.Run("Stop", func(t *testing.T) { + response, err := model.Completion(context.TODO(), "Tell me in less than five words why the sky is blue?", llm.WithStopSequence("sky")) + if assert.NoError(err) { + t.Log(response) + } + }) + + t.Run("System", func(t *testing.T) { + response, err := model.Completion(context.TODO(), "Tell me in less than five words why the sky is blue?", llm.WithSystemPrompt("reply as if you are shakespeare")) + if assert.NoError(err) { + t.Log(response) + } + }) + + t.Run("Seed", func(t *testing.T) { + response, err := model.Completion(context.TODO(), "Tell me in less than five words why the sky is blue?", llm.WithSeed(123)) + if assert.NoError(err) { + t.Log(response) + } + }) + + t.Run("Format", func(t *testing.T) { + response, err := model.Completion(context.TODO(), "Why the sky is blue? Reply in JSON format", llm.WithFormat("json")) + if assert.NoError(err) { + t.Log(response) + } + }) + + t.Run("FrequencyPenalty", func(t *testing.T) { + response, err := model.Completion(context.TODO(), "Why the sky is blue?", llm.WithFrequencyPenalty(1.0)) + if assert.NoError(err) { + t.Log(response) + } + }) +} + +func Test_completion_003(t *testing.T) { + assert := assert.New(t) + + // Pull the model + model, err := client.PullModel(context.TODO(), "llama3.2-vision", ollama.WithPullStatus(func(status *ollama.PullStatus) { + t.Log(status) + })) + if err != nil { + t.FailNow() + } + + t.Run("Vision", func(t *testing.T) { + f, err := os.Open("testdata/guggenheim.jpg") + if !assert.NoError(err) { + t.FailNow() + } + defer f.Close() + response, err := model.Completion(context.TODO(), "Describe this image", llm.WithAttachment(f)) + if assert.NoError(err) { + t.Log(response) + } + }) +} + +func Test_completion_004(t *testing.T) { + assert := assert.New(t) + + // Pull the model + model, err := client.PullModel(context.TODO(), "mistral", ollama.WithPullStatus(func(status *ollama.PullStatus) { + t.Log(status) + })) + if err != nil { + t.FailNow() + } + + // Test tool support + t.Run("Toolkit", func(t *testing.T) { + toolkit := tool.NewToolKit() + toolkit.Register(&weather{}) + + session := model.Context(llm.WithToolKit(toolkit)) + err := session.FromUser(context.TODO(), + "What is the weather in the capital city of Germany?", + ) + if !assert.NoError(err) { + t.FailNow() + } + + assert.Equal("assistant", session.Role()) + assert.Greater(session.Num(), 0) + assert.NotEmpty(session.ToolCalls(0)) + + toolcalls := session.ToolCalls(0) + assert.NotEmpty(toolcalls) + assert.Equal("weather_in_city", toolcalls[0].Name()) + + results, err := toolkit.Run(context.TODO(), toolcalls...) + if !assert.NoError(err) { + t.FailNow() + } + + assert.Len(results, len(toolcalls)) + + err = session.FromTool(context.TODO(), results...) + if !assert.NoError(err) { + t.FailNow() + } + }) +} + +type weather struct { + City string `json:"city" help:"The city to get the weather for" required:"true"` +} + +func (weather) Name() string { + return "weather_in_city" +} + +func (weather) Description() string { + return "Get the weather for a city" +} + +func (w weather) Run(ctx context.Context) (any, error) { + return fmt.Sprintf("The weather in %q is sunny and warm", w.City), nil +} diff --git a/pkg/ollama/doc.go b/pkg/ollama/doc.go deleted file mode 100644 index c652fb9..0000000 --- a/pkg/ollama/doc.go +++ /dev/null @@ -1,5 +0,0 @@ -/* -ollama implements an API client for ollama -https://github.com/ollama/ollama/blob/main/docs/api.md -*/ -package ollama diff --git a/pkg/ollama/embedding.go b/pkg/ollama/embedding.go index ceae604..235ee15 100644 --- a/pkg/ollama/embedding.go +++ b/pkg/ollama/embedding.go @@ -90,6 +90,13 @@ func (ollama *Client) GenerateEmbedding(ctx context.Context, name string, prompt } // Embedding vector generation -func (model *model) Embedding(context.Context, string, ...llm.Opt) ([]float64, error) { - return nil, llm.ErrNotImplemented +func (model *model) Embedding(ctx context.Context, prompt string, opts ...llm.Opt) ([]float64, error) { + embedding, err := model.GenerateEmbedding(ctx, model.Name(), []string{prompt}, opts...) + if err != nil { + return nil, err + } + if len(embedding.Embeddings) > 0 { + return embedding.Embeddings[0], nil + } + return nil, llm.ErrNotFound.With("no embeddings returned") } diff --git a/pkg/ollama/message.go b/pkg/ollama/message.go index d00c6b8..4ef8432 100644 --- a/pkg/ollama/message.go +++ b/pkg/ollama/message.go @@ -1,6 +1,7 @@ package ollama import ( + "encoding/json" "fmt" // Packages @@ -11,21 +12,28 @@ import ( /////////////////////////////////////////////////////////////////////////////// // TYPES +type messagefactory struct{} + // Message with text or object content type Message struct { RoleContent - ToolCallArray `json:"tool_calls,omitempty"` + Images []ImageData `json:"images,omitempty"` + Calls ToolCalls `json:"tool_calls,omitempty"` + *ToolResults } +var _ llm.Completion = (*Message)(nil) + type RoleContent struct { Role string `json:"role,omitempty"` // assistant, user, tool, system Content string `json:"content,omitempty"` // string or array of text, reference, image_url - Images []Data `json:"images,omitempty"` // Image attachments - ToolResult } -// A set of tool calls -type ToolCallArray []ToolCall +type ToolCalls []ToolCall + +type ToolResults struct { + Name string `json:"name,omitempty"` // function name - when role is tool +} type ToolCall struct { Type string `json:"type"` // function @@ -39,24 +47,94 @@ type ToolCallFunction struct { } // Data represents the raw binary data of an image file. -type Data []byte +type ImageData []byte -// ToolResult -type ToolResult struct { - Name string `json:"name,omitempty"` // function name - when role is tool +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - MESSAGE FACTORY + +func (messagefactory) SystemPrompt(prompt string) llm.Completion { + return &Message{ + RoleContent: RoleContent{ + Role: "system", + Content: prompt, + }, + } +} + +func (messagefactory) UserPrompt(prompt string, opts ...llm.Opt) (llm.Completion, error) { + // Get attachments + opt, err := llm.ApplyPromptOpts(opts...) + if err != nil { + return nil, err + } + + // Append image attachments + attachments := opt.Attachments() + images := make([]ImageData, 0, len(attachments)) + for _, attachment := range attachments { + images = append(images, attachment.Data()) + } + + // Return success + return &Message{ + RoleContent: RoleContent{ + Role: "user", + Content: prompt, + }, + Images: images, + }, nil +} + +func (messagefactory) ToolResults(results ...llm.ToolResult) ([]llm.Completion, error) { + // Check for no results + if len(results) == 0 { + return nil, llm.ErrBadParameter.Withf("No tool results") + } + + // Create results + messages := make([]llm.Completion, 0, len(results)) + for _, result := range results { + value, err := json.Marshal(result.Value()) + if err != nil { + return nil, err + } + messages = append(messages, &Message{ + RoleContent: RoleContent{ + Role: "tool", + Content: string(value), + }, + ToolResults: &ToolResults{ + Name: result.Call().Name(), + }, + }) + } + + // Return success + return messages, nil } /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS - MESSAGE +// Return the number of completions func (m Message) Num() int { return 1 } +// Return the current session role func (m Message) Role() string { return m.RoleContent.Role } +// Return the completion +func (message *Message) Choice(index int) llm.Completion { + if index != 0 { + return nil + } + return message +} + +// Return the text func (m Message) Text(index int) string { if index != 0 { return "" @@ -64,14 +142,21 @@ func (m Message) Text(index int) string { return m.Content } +// Return the audio - not supported on ollama +func (message *Message) Audio(index int) *llm.Attachment { + return nil +} + +// Return the current session tool calls given the completion index. +// Will return nil if no tool calls were returned. func (m Message) ToolCalls(index int) []llm.ToolCall { if index != 0 { return nil } // Make the tool calls - calls := make([]llm.ToolCall, 0, len(m.ToolCallArray)) - for _, call := range m.ToolCallArray { + calls := make([]llm.ToolCall, 0, len(m.Calls)) + for _, call := range m.Calls { calls = append(calls, tool.NewCall(fmt.Sprint(call.Function.Index), call.Function.Name, call.Function.Arguments)) } diff --git a/pkg/ollama/model.go b/pkg/ollama/model.go index 94fb7d4..9e62452 100644 --- a/pkg/ollama/model.go +++ b/pkg/ollama/model.go @@ -4,11 +4,13 @@ import ( "context" "encoding/json" "net/http" + "strings" "time" // Packages client "github.com/mutablelogic/go-client" llm "github.com/mutablelogic/go-llm" + impl "github.com/mutablelogic/go-llm/pkg/internal/impl" ) /////////////////////////////////////////////////////////////////////////////// @@ -81,29 +83,51 @@ func (m PullStatus) String() string { } /////////////////////////////////////////////////////////////////////////////// -// INTERFACE IMPLEMENTATION +// PUBLIC METHODS - llm.Model implementation func (m model) Name() string { return m.ModelMeta.Name } -/////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS +// Return model name +func (model) Aliases() []string { + return nil +} + +// Return model description +func (model model) Description() string { + return strings.Join(model.ModelMeta.Details.Families, ", ") +} // Agent interface func (ollama *Client) Models(ctx context.Context) ([]llm.Model, error) { + // We don't explicitly cache models return ollama.ListModels(ctx) } -// Agent interface +// Return the a model by name func (ollama *Client) Model(ctx context.Context, name string) llm.Model { model, err := ollama.GetModel(ctx, name) if err != nil { panic(err) } + + // In the ollama version, we attempt to load the model into + // memory here, so that we can use it immediately + ollama.LoadModel(ctx, name) + + // Return the model return model } +// Return a new empty session +func (model *model) Context(opts ...llm.Opt) llm.Context { + return impl.NewSession(model, &messagefactory{}, opts...) +} + +/////////////////////////////////////////////////////////////////////////////// +// API CALLS + // List models func (ollama *Client) ListModels(ctx context.Context) ([]llm.Model, error) { type respListModel struct { diff --git a/pkg/ollama/opt.go b/pkg/ollama/opt.go index 2b1afd4..769da41 100644 --- a/pkg/ollama/opt.go +++ b/pkg/ollama/opt.go @@ -1,6 +1,7 @@ package ollama import ( + "strings" "time" // Packages @@ -19,9 +20,9 @@ func WithInsecure() llm.Opt { } // Embeddings: Does not truncate the end of each input to fit within context length. Returns error if context length is exceeded. -func WithTruncate(v bool) llm.Opt { +func WithTruncate() llm.Opt { return func(o *llm.Opts) error { - o.Set("truncate", v) + o.Set("truncate", true) return nil } } @@ -88,7 +89,14 @@ func optTools(agent *Client, opts *llm.Opts) []llm.Tool { } func optFormat(opts *llm.Opts) string { - return opts.GetString("format") + format := strings.ToLower(opts.GetString("format")) + if format == "" { + return "" + } + if format == "json_format" { + return "json" + } + return format } func optStopSequence(opts *llm.Opts) []string { @@ -135,15 +143,24 @@ func optOptions(opts *llm.Opts) map[string]any { return result } -func optStream(agent *Client, opts *llm.Opts) bool { +func optStream(agent *Client, opts *llm.Opts) *bool { + var stream bool + + // Based on stream function + if opts.StreamFn() != nil { + stream = true + } + // Streaming only if there is a stream function and no tools toolkit := opts.ToolKit() if toolkit != nil { if tools := toolkit.Tools(agent); len(tools) > 0 { - return false + stream = false } } - return opts.StreamFn() != nil + + // Return the value + return &stream } func optKeepAlive(opts *llm.Opts) *time.Duration { diff --git a/pkg/ollama/session.go b/pkg/ollama/session.go deleted file mode 100644 index 50c702c..0000000 --- a/pkg/ollama/session.go +++ /dev/null @@ -1,221 +0,0 @@ -package ollama - -import ( - "context" - "encoding/json" - - // Packages - llm "github.com/mutablelogic/go-llm" -) - -/////////////////////////////////////////////////////////////////////////////// -// TYPES - -// Implementation of a message session, which is a sequence of messages -type session struct { - model *model // The model used for the session - opts []llm.Opt // Options to apply to the session - seq []*Message // Sequence of messages -} - -var _ llm.Context = (*session)(nil) - -/////////////////////////////////////////////////////////////////////////////// -// LIFECYCLE - -// Return an empty session context object for the model, setting session options -func (model *model) Context(opts ...llm.Opt) llm.Context { - return &session{ - model: model, - opts: opts, - seq: make([]*Message, 0, 10), - } -} - -// Convenience method to create a session context object with a user prompt, which -// panics on error -func (model *model) UserPrompt(prompt string, opts ...llm.Opt) llm.Context { - context := model.Context(opts...) - - // Create a user prompt - message, err := userPrompt(prompt, opts...) - if err != nil { - panic(err) - } - - // Add to the sequence - context.(*session).seq = append(context.(*session).seq, message) - - // Return success - return context -} - -/////////////////////////////////////////////////////////////////////////////// -// STRINGIFY - -func (session session) String() string { - var data []byte - var err error - if len(session.seq) == 1 { - data, err = json.MarshalIndent(session.seq[0], "", " ") - } else { - data, err = json.MarshalIndent(session.seq, "", " ") - } - if err != nil { - return err.Error() - } - return string(data) -} - -////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - -// Return the number of completions -func (session *session) Num() int { - if len(session.seq) == 0 { - return 0 - } - return 1 -} - -// Return the role of the last message -func (session *session) Role() string { - if len(session.seq) == 0 { - return "" - } - return session.seq[len(session.seq)-1].Role() -} - -// Return the text of the last message -func (session *session) Text(index int) string { - if len(session.seq) == 0 { - return "" - } - return session.seq[len(session.seq)-1].Text(index) -} - -// Return tool calls for the last message -func (session *session) ToolCalls(index int) []llm.ToolCall { - if len(session.seq) == 0 { - return nil - } - return session.seq[len(session.seq)-1].ToolCalls(index) -} - -/////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - -// Generate a response from a user prompt (with attachments) -func (session *session) FromUser(ctx context.Context, prompt string, opts ...llm.Opt) error { - message, err := userPrompt(prompt, opts...) - if err != nil { - return err - } - - // Append the user prompt to the sequence - session.seq = append(session.seq, message) - - // The options come from the session options and the user options - chatopts := make([]llm.Opt, 0, len(session.opts)+len(opts)) - chatopts = append(chatopts, session.opts...) - chatopts = append(chatopts, opts...) - - // Call the 'chat' method - r, err := session.model.Chat(ctx, session, chatopts...) - if err != nil { - return err - } - - // Append the message to the sequence - session.seq = append(session.seq, &r.Message) - - // Return success - return nil -} - -// Generate a response from a tool calling result -func (session *session) FromTool(ctx context.Context, results ...llm.ToolResult) error { - messages, err := toolResults(results...) - if err != nil { - return err - } - - // Append the tool results to the sequence - session.seq = append(session.seq, messages...) - - // Call the 'chat' method - r, err := session.model.Chat(ctx, session, session.opts...) - if err != nil { - return err - } - - // Append the first message from the set of completions - session.seq = append(session.seq, &r.Message) - - // Return success - return nil -} - -/////////////////////////////////////////////////////////////////////////////// -// PRIVATE METHODS - -func systemPrompt(prompt string) *Message { - return &Message{ - RoleContent: RoleContent{ - Role: "system", - Content: prompt, - }, - } -} - -func userPrompt(prompt string, opts ...llm.Opt) (*Message, error) { - // Get attachments - opt, err := llm.ApplyPromptOpts(opts...) - if err != nil { - return nil, err - } - - // Get attachments, allocate content - attachments := opt.Attachments() - data := make([]Data, 0, len(attachments)) - for _, attachment := range attachments { - data = append(data, attachment.Data()) - } - - // Return success - return &Message{ - RoleContent: RoleContent{ - Role: "user", - Content: prompt, - Images: data, - }, - }, nil -} - -func toolResults(results ...llm.ToolResult) ([]*Message, error) { - // Check for no results - if len(results) == 0 { - return nil, llm.ErrBadParameter.Withf("No tool results") - } - - // Create results - messages := make([]*Message, 0, len(results)) - for _, result := range results { - value, err := json.Marshal(result.Value()) - if err != nil { - return nil, err - } - messages = append(messages, &Message{ - RoleContent: RoleContent{ - Role: "tool", - ToolResult: ToolResult{ - Name: result.Call().Name(), - }, - Content: string(value), - }, - }) - } - - // Return success - return messages, nil -} diff --git a/pkg/ollama/session_test.go b/pkg/ollama/session_test.go deleted file mode 100644 index e343eff..0000000 --- a/pkg/ollama/session_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package ollama_test - -import ( - "context" - "testing" - - // Packages - llm "github.com/mutablelogic/go-llm" - tool "github.com/mutablelogic/go-llm/pkg/tool" - assert "github.com/stretchr/testify/assert" -) - -func Test_session_001(t *testing.T) { - assert := assert.New(t) - model, err := client.PullModel(context.TODO(), "llama3.2") - if !assert.NoError(err) { - t.FailNow() - } - assert.NotNil(model) - - session := model.Context() - if assert.NotNil(session) { - err := session.FromUser(context.TODO(), "Hello, how are you?") - assert.NoError(err) - t.Log(session) - } -} - -func Test_session_002(t *testing.T) { - assert := assert.New(t) - model, err := client.PullModel(context.TODO(), "llama3.2") - if !assert.NoError(err) { - t.FailNow() - } - assert.NotNil(model) - - toolkit := tool.NewToolKit() - toolkit.Register(&weather{}) - - session := model.Context(llm.WithToolKit(toolkit)) - if !assert.NotNil(session) { - t.FailNow() - } - - assert.NoError(session.FromUser(context.TODO(), "What is the weather like in London today?")) - calls := session.ToolCalls(0) - if assert.Len(calls, 1) { - assert.Equal("weather_in_city", calls[0].Name()) - - result, err := toolkit.Run(context.TODO(), calls...) - assert.NoError(err) - assert.Len(result, 1) - - assert.NoError(session.FromTool(context.TODO(), result...)) - } - - t.Log(session) -} diff --git a/pkg/openai/client.go b/pkg/openai/client.go new file mode 100644 index 0000000..23e9086 --- /dev/null +++ b/pkg/openai/client.go @@ -0,0 +1,58 @@ +/* +openai implements an API client for OpenAI +https://platform.openai.com/docs/api-reference +*/ +package openai + +import ( + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" + impl "github.com/mutablelogic/go-llm/pkg/internal/impl" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Client struct { + *client.Client + *impl.ModelCache +} + +var _ llm.Agent = (*Client)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + endPoint = "https://api.openai.com/v1" + defaultName = "openai" +) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Create a new client +func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) { + // Create client + opts = append(opts, client.OptEndpoint(endPoint)) + opts = append(opts, client.OptReqToken(client.Token{ + Scheme: client.Bearer, + Value: ApiKey, + })) + client, err := client.New(opts...) + if err != nil { + return nil, err + } + + // Return the client + return &Client{client, impl.NewModelCache()}, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the name of the agent +func (*Client) Name() string { + return defaultName +} diff --git a/pkg/openai/client_test.go b/pkg/openai/client_test.go new file mode 100644 index 0000000..6b3b890 --- /dev/null +++ b/pkg/openai/client_test.go @@ -0,0 +1,58 @@ +package openai_test + +import ( + "flag" + "log" + "os" + "strconv" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + openai "github.com/mutablelogic/go-llm/pkg/openai" + assert "github.com/stretchr/testify/assert" +) + +/////////////////////////////////////////////////////////////////////////////// +// TEST SET-UP + +var ( + client *openai.Client +) + +func TestMain(m *testing.M) { + var verbose bool + + // Verbose output + flag.Parse() + if f := flag.Lookup("test.v"); f != nil { + if v, err := strconv.ParseBool(f.Value.String()); err == nil { + verbose = v + } + } + + // API KEY + api_key := os.Getenv("OPENAI_API_KEY") + if api_key == "" { + log.Print("OPENAI_API_KEY not set") + os.Exit(0) + } + + // Create client + var err error + client, err = openai.New(api_key, opts.OptTrace(os.Stderr, verbose)) + if err != nil { + log.Println(err) + os.Exit(-1) + } + os.Exit(m.Run()) +} + +/////////////////////////////////////////////////////////////////////////////// +// TESTS + +func Test_client_001(t *testing.T) { + assert := assert.New(t) + assert.NotNil(client) + t.Log(client) +} diff --git a/pkg/openai/completion.go b/pkg/openai/completion.go new file mode 100644 index 0000000..a4f4315 --- /dev/null +++ b/pkg/openai/completion.go @@ -0,0 +1,381 @@ +package openai + +import ( + "context" + "encoding/json" + "strings" + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Completion Response +type Response struct { + Id string `json:"id"` + Type string `json:"object"` + Created uint64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + ServiceTier string `json:"service_tier"` + Completions `json:"choices"` + *Metrics `json:"usage,omitempty"` +} + +// Completion choices +type Completions []Completion + +// Completion Variation +type Completion struct { + Index uint64 `json:"index"` + Message *Message `json:"message"` + Delta *Message `json:"delta,omitempty"` // For streaming + Reason string `json:"finish_reason,omitempty"` +} + +// Metrics +type Metrics struct { + PromptTokens uint64 `json:"prompt_tokens,omitempty"` + CompletionTokens uint64 `json:"completion_tokens,omitempty"` + TotalTokens uint64 `json:"total_tokens,omitempty"` + PromptTokenDetails struct { + CachedTokens uint64 `json:"cached_tokens,omitempty"` + AudioTokens uint64 `json:"audio_tokens,omitempty"` + } `json:"prompt_tokens_details,omitempty"` + CompletionTokenDetails struct { + ReasoningTokens uint64 `json:"reasoning_tokens,omitempty"` + AcceptedPredictionTokens uint64 `json:"accepted_prediction_tokens,omitempty"` + RejectedPredictionTokens uint64 `json:"rejected_prediction_tokens,omitempty"` + } `json:"completion_tokens_details,omitempty"` +} + +var _ llm.Completion = (*Response)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (r Response) String() string { + data, err := json.MarshalIndent(r, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +func (c Completion) String() string { + data, err := json.MarshalIndent(c, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +func (m Metrics) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +type reqCompletion struct { + Model string `json:"model"` + Store *bool `json:"store,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + LogitBias map[uint64]int64 `json:"logit_bias,omitempty"` + LogProbs bool `json:"logprobs,omitempty"` + TopLogProbs uint64 `json:"top_logprobs,omitempty"` + MaxTokens uint64 `json:"max_completion_tokens,omitempty"` + NumCompletions uint64 `json:"n,omitempty"` + Modalties []string `json:"modalities,omitempty"` + Prediction *Content `json:"prediction,omitempty"` + Audio *Audio `json:"audio,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat *Format `json:"response_format,omitempty"` + Seed uint64 `json:"seed,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + StopSequences []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Tools []llm.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + User string `json:"user,omitempty"` + Messages []llm.Completion `json:"messages"` +} + +// Send a completion request with a single prompt, and return the next completion +func (model *model) Completion(ctx context.Context, prompt string, opts ...llm.Opt) (llm.Completion, error) { + message, err := messagefactory{}.UserPrompt(prompt, opts...) + if err != nil { + return nil, err + } + return model.Chat(ctx, []llm.Completion{message}, opts...) +} + +// Send a completion request with multiple completions, and return the next completion +func (model *model) Chat(ctx context.Context, completions []llm.Completion, opts ...llm.Opt) (llm.Completion, error) { + // Apply options + opt, err := llm.ApplyOpts(opts...) + if err != nil { + return nil, err + } + + // Create the completions including the system prompt + messages := make([]llm.Completion, 0, len(completions)+1) + if system := opt.SystemPrompt(); system != "" { + messages = append(messages, messagefactory{}.SystemPrompt(system)) + } + messages = append(messages, completions...) + + // Request + req, err := client.NewJSONRequest(reqCompletion{ + Model: model.Name(), + Store: optStore(opt), + ReasoningEffort: optReasoningEffort(opt), + Metadata: optMetadata(opt), + FrequencyPenalty: optFrequencyPenalty(opt), + LogitBias: optLogitBias(opt), + LogProbs: optLogProbs(opt), + TopLogProbs: optTopLogProbs(opt), + MaxTokens: optMaxTokens(opt), + NumCompletions: optNumCompletions(opt), + Modalties: optModalities(opt), + Prediction: optPrediction(opt), + Audio: optAudio(opt), + PresencePenalty: optPresencePenalty(opt), + ResponseFormat: optResponseFormat(opt), + Seed: optSeed(opt), + ServiceTier: optServiceTier(opt), + StreamOptions: optStreamOptions(opt), + Temperature: optTemperature(opt), + TopP: optTopP(opt), + Stream: optStream(opt), + StopSequences: optStopSequences(opt), + Tools: optTools(model, opt), + ToolChoice: optToolChoice(opt), + ParallelToolCalls: optParallelToolCalls(opt), + User: optUser(opt), + Messages: messages, + }) + if err != nil { + return nil, err + } + + // Response options + var response Response + reqopts := []client.RequestOpt{ + client.OptPath("chat", "completions"), + } + + // Streaming + if optStream(opt) { + reqopts = append(reqopts, client.OptTextStreamCallback(func(evt client.TextStreamEvent) error { + if err := streamEvent(&response, evt); err != nil { + return err + } + if fn := opt.StreamFn(); fn != nil { + fn(&response) + } + return nil + })) + } + + // Response + if err := model.DoWithContext(ctx, req, &response, reqopts...); err != nil { + return nil, err + } + + // Return success + return &response, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS - STREAMING + +func streamEvent(response *Response, evt client.TextStreamEvent) error { + var delta Response + // If we are done, ignore + if strings.TrimSpace(evt.Data) == "[DONE]" { + return nil + } + // Decode the event + if err := evt.Json(&delta); err != nil { + return err + } + // Append the delta to the response + if delta.Id != "" { + response.Id = delta.Id + } + if delta.Type != "" { + response.Type = delta.Type + } + if delta.Created != 0 { + response.Created = delta.Created + } + if delta.Model != "" { + response.Model = delta.Model + } + if delta.SystemFingerprint != "" { + response.SystemFingerprint = delta.SystemFingerprint + } + if delta.ServiceTier != "" { + response.ServiceTier = delta.ServiceTier + } + + // Append the delta to the response + for _, completion := range delta.Completions { + if err := appendCompletion(response, &completion); err != nil { + return err + } + } + + // Apend the metrics to the response + if delta.Metrics != nil { + response.Metrics = delta.Metrics + } + return nil +} + +func appendCompletion(response *Response, c *Completion) error { + // Append a new completion + for { + if c.Index < uint64(len(response.Completions)) { + break + } + response.Completions = append(response.Completions, Completion{ + Index: c.Index, + Message: &Message{ + RoleContent: RoleContent{ + Role: c.Delta.Role(), + Content: "", + }, + }, + }) + } + + // Add the reason + if c.Reason != "" { + response.Completions[c.Index].Reason = c.Reason + } + + // Get the completion + message := response.Completions[c.Index].Message + if message == nil { + return llm.ErrBadParameter + } + + // Add the role + if role := c.Delta.Role(); role != "" { + message.RoleContent.Role = role + } + + // We only allow deltas which are strings at the moment + if c.Delta.Content != nil { + if str, ok := c.Delta.Content.(string); ok { + if text, ok := message.Content.(string); ok { + message.Content = text + str + } else { + message.Content = str + } + } else { + return llm.ErrNotImplemented.Withf("appendCompletion not implemented: %T", c.Delta.Content) + } + } + + // Append audio data + if c.Delta.Media != nil { + if message.Media == nil { + message.Media = llm.NewAttachment() + } + message.Media.Append(c.Delta.Media) + } + + // Append tool calls + for i := range c.Delta.Calls { + if i >= len(message.Calls) { + message.Calls = append(message.Calls, toolcall{}) + } + } + + for i, call := range c.Delta.Calls { + if call.meta.Id != "" { + message.Calls[i].meta.Id = call.meta.Id + } + if call.meta.Index != 0 { + message.Calls[i].meta.Index = call.meta.Index + } + if call.meta.Type != "" { + message.Calls[i].meta.Type = call.meta.Type + } + if call.meta.Function.Name != "" { + message.Calls[i].meta.Function.Name = call.meta.Function.Name + } + if call.meta.Function.Arguments != "" { + message.Calls[i].meta.Function.Arguments += call.meta.Function.Arguments + } + } + + // Return success + return nil +} + +/////////////////////////////////////////////////////////////////////////////// +// COMPLETIONS + +// Return the number of completions +func (c Completions) Num() int { + return len(c) +} + +// Return message for a specific completion +func (c Completions) Choice(index int) llm.Completion { + if index < 0 || index >= len(c) { + return nil + } + return c[index].Message +} + +// Return the role of the completion +func (c Completions) Role() string { + // The role should be the same for all completions, let's use the first one + if len(c) == 0 { + return "" + } + return c[0].Message.Role() +} + +// Return the text content for a specific completion +func (c Completions) Text(index int) string { + if index < 0 || index >= len(c) { + return "" + } + return c[index].Message.Text(0) +} + +// Return audio content for a specific completion +func (c Completions) Audio(index int) *llm.Attachment { + if index < 0 || index >= len(c) { + return nil + } + return c[index].Message.Audio(0) +} + +// Return the current session tool calls given the completion index. +// Will return nil if no tool calls were returned. +func (c Completions) ToolCalls(index int) []llm.ToolCall { + if index < 0 || index >= len(c) { + return nil + } + return c[index].Message.ToolCalls(0) +} diff --git a/pkg/openai/completion_test.go b/pkg/openai/completion_test.go new file mode 100644 index 0000000..fc5a3c6 --- /dev/null +++ b/pkg/openai/completion_test.go @@ -0,0 +1,426 @@ +package openai_test + +import ( + "context" + "fmt" + "os" + "testing" + + llm "github.com/mutablelogic/go-llm" + openai "github.com/mutablelogic/go-llm/pkg/openai" + "github.com/mutablelogic/go-llm/pkg/tool" + assert "github.com/stretchr/testify/assert" +) + +func Test_completion_001(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "gpt-4o-mini") + if !assert.NotNil(model) { + t.FailNow() + } + + response, err := model.Completion(context.TODO(), "Hello, how are you?") + if assert.NoError(err) { + assert.NotEmpty(response) + t.Log(response) + } +} + +func Test_completion_002(t *testing.T) { + assert := assert.New(t) + + // Test options + model := client.Model(context.TODO(), "gpt-4o-mini") + if !assert.NotNil(model) { + t.FailNow() + } + + o3_model := client.Model(context.TODO(), "o3-mini") + if !assert.NotNil(o3_model) { + t.FailNow() + } + + audio_model := client.Model(context.TODO(), "gpt-4o-audio-preview") + if !assert.NotNil(audio_model) { + t.FailNow() + } + + t.Run("Store", func(t *testing.T) { + r, err := model.Completion(context.TODO(), "What is the temperature in London?", openai.WithStore(true)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("ReasoningEffort", func(t *testing.T) { + r, err := o3_model.Completion(context.TODO(), "What is the temperature in London?", openai.WithReasoningEffort("low")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("Metadata", func(t *testing.T) { + r, err := model.Completion(context.TODO(), "What is the temperature in London?", openai.WithMetadata("a", "b")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("FrequencyPenalty", func(t *testing.T) { + r, err := model.Completion(context.TODO(), "What is the temperature in London?", llm.WithFrequencyPenalty(-0.5)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("LogitBias", func(t *testing.T) { + r, err := model.Completion(context.TODO(), "What is the temperature in London?", openai.WithLogitBias(56, 22)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("LogProbs", func(t *testing.T) { + r, err := model.Completion(context.TODO(), "What is the temperature in London?", openai.WithLogProbs()) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("TopLogProbs", func(t *testing.T) { + r, err := model.Completion(context.TODO(), "What is the temperature in London?", openai.WithTopLogProbs(3)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("MaxTokens", func(t *testing.T) { + r, err := model.Completion(context.TODO(), "What is the temperature in London?", llm.WithMaxTokens(20)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("Completions", func(t *testing.T) { + r, err := model.Completion(context.TODO(), "What is the temperature in London?", llm.WithNumCompletions(3)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(3, r.Num()) + assert.NotEmpty(r.Text(0)) + assert.NotEmpty(r.Text(1)) + assert.NotEmpty(r.Text(2)) + t.Log(r) + } + }) + + t.Run("Modalties", func(t *testing.T) { + r, err := model.Completion(context.TODO(), "What is the temperature in London?", openai.WithModalities("text")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("Prediction", func(t *testing.T) { + r, err := model.Completion(context.TODO(), "Why is the sky blue", llm.WithPrediction("The sky is blue due to Rayleigh scattering")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("Audio", func(t *testing.T) { + r, err := audio_model.Completion( + context.TODO(), + "Tell me in no more than ten words why is the sky blue", + openai.WithAudio("ash", "mp3"), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) // Returns the audio transcript + assert.NotEmpty(r.Audio(0)) + t.Log(r) + } + }) + + t.Run("PresencePenalty", func(t *testing.T) { + r, err := model.Completion( + context.TODO(), + "Tell me in no more than ten words why is the sky blue", + llm.WithPresencePenalty(1.0), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("ResponseFormat", func(t *testing.T) { + r, err := model.Completion( + context.TODO(), + "Tell me in no more than ten words why is the sky blue, and response in JSON format", + llm.WithFormat("json_object"), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("Seed", func(t *testing.T) { + r, err := model.Completion( + context.TODO(), + "Tell me in no more than ten words why is the sky blue", + llm.WithSeed(123), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("ServiceTier", func(t *testing.T) { + r, err := model.Completion( + context.TODO(), + "Tell me in no more than ten words why is the sky blue", + openai.WithServiceTier("default"), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("Stop", func(t *testing.T) { + r, err := model.Completion( + context.TODO(), + "Tell me in no more than ten words why is the sky blue", + llm.WithStopSequence("sky", "blue"), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("TopP", func(t *testing.T) { + r, err := model.Completion( + context.TODO(), + "Tell me in no more than ten words why is the sky blue", + llm.WithTopP(0.1), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("User", func(t *testing.T) { + r, err := model.Completion( + context.TODO(), + "Tell me in no more than ten words why is the sky blue", + llm.WithUser("test_user"), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + +} + +func Test_completion_003(t *testing.T) { + assert := assert.New(t) + + model := client.Model(context.TODO(), "gpt-4o-mini") + if !assert.NotNil(model) { + t.FailNow() + } + + audio_model := client.Model(context.TODO(), "gpt-4o-audio-preview") + if !assert.NotNil(audio_model) { + t.FailNow() + } + + // Test streaming + t.Run("Streaming", func(t *testing.T) { + r, err := model.Completion( + context.TODO(), + "Tell me in no more than ten words why is the sky blue", + llm.WithStream(func(message llm.Completion) { + // TODO + }), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + } + }) + + t.Run("StreamingCompletions", func(t *testing.T) { + r, err := model.Completion( + context.TODO(), + "Tell me in no more than ten words why is the sky blue", + llm.WithNumCompletions(2), + llm.WithStream(func(message llm.Completion) { + // TODO + }), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(2, r.Num()) + assert.NotEmpty(r.Text(0)) + assert.NotEmpty(r.Text(1)) + } + }) + + t.Run("StreamingUsage", func(t *testing.T) { + r, err := model.Completion( + context.TODO(), + "Tell me in no more than ten words why is the sky blue", + openai.WithStreamOptions(func(message llm.Completion) { + // TODO + }, true), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + } + }) + + t.Run("StreamingAudio", func(t *testing.T) { + r, err := audio_model.Completion( + context.TODO(), + "Tell me in exactly three words why is the sky blue", + openai.WithStreamOptions(func(message llm.Completion) { + // TODO + }, true), + openai.WithAudio("ash", "pcm16"), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + assert.NotEmpty(r.Audio(0)) + } + }) + +} + +func Test_completion_004(t *testing.T) { + assert := assert.New(t) + + model := client.Model(context.TODO(), "gpt-4o-mini") + if !assert.NotNil(model) { + t.FailNow() + } + + // Test tool support + t.Run("Toolkit", func(t *testing.T) { + toolkit := tool.NewToolKit() + toolkit.Register(weather{}) + + r, err := model.Completion( + context.TODO(), + "What is the weather in the capital city of Germany?", + llm.WithToolKit(toolkit), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.ToolCalls(0)) + + toolcalls := r.ToolCalls(0) + assert.Len(toolcalls, 1) + assert.Equal("weather_in_city", toolcalls[0].Name()) + } + }) +} + +type weather struct { + City string `json:"city" help:"The city to get the weather for" required:"true"` +} + +func (weather) Name() string { + return "weather_in_city" +} + +func (weather) Description() string { + return "Get the weather for a city" +} + +func (w weather) Run(ctx context.Context) (any, error) { + return fmt.Sprintf("The weather in %q is sunny and warm", w.City), nil +} + +func Test_completion_005(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "gpt-4o-mini") + if !assert.NotNil(model) { + t.FailNow() + } + + // Test image captioning + t.Run("ImageCaption", func(t *testing.T) { + f, err := os.Open("testdata/guggenheim.jpg") + if !assert.NoError(err) { + t.FailNow() + } + defer f.Close() + + r, err := model.Completion( + context.TODO(), + "Describe this picture", + llm.WithAttachment(f), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + } + }) +} diff --git a/pkg/openai/content.go b/pkg/openai/content.go new file mode 100644 index 0000000..2b71a6d --- /dev/null +++ b/pkg/openai/content.go @@ -0,0 +1,51 @@ +package openai + +import ( + "net/url" + + "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Content struct { + Type string `json:"type"` // text or content + Content string `json:"content,omitempty"` // content content ;-) + Text string `json:"text,omitempty"` // text content + Audio *llm.Attachment `json:"audio,omitempty"` // audio content + Image *Image `json:"image_url,omitempty"` // image content +} + +// A set of tool calls +type ToolCallArray []ToolCall + +// text content +type Text string + +// text content +type Prediction string + +// either a URL or "data:image/png;base64," followed by the base64 encoded image +type Image struct { + Url string `json:"url,omitempty"` +} + +/////////////////////////////////////////////////////////////////////////////// +// LICECYCLE + +func NewContentString(typ, content string) *Content { + return &Content{Type: typ, Content: content} +} + +func NewTextContext(content string) *Content { + return &Content{Type: "text", Text: content} +} + +func NewImageData(image *llm.Attachment) *Content { + return &Content{Type: "image_url", Image: &Image{Url: image.Url()}} +} + +func NewImageUrl(url *url.URL) *Content { + return &Content{Type: "image_url", Image: &Image{Url: url.String()}} +} diff --git a/pkg/openai/embeddings.go b/pkg/openai/embeddings.go new file mode 100644 index 0000000..8d85ad3 --- /dev/null +++ b/pkg/openai/embeddings.go @@ -0,0 +1,109 @@ +package openai + +import ( + "context" + "encoding/json" + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// embeddings is the implementation of the llm.Embedding interface +type embeddings struct { + Embeddings +} + +// Embeddings is the metadata for a generated embedding vector +type Embeddings struct { + Type string `json:"object"` + Model string `json:"model"` + Data []Embedding `json:"data"` + Metrics +} + +// Embedding is a single vector +type Embedding struct { + Type string `json:"object"` + Index uint64 `json:"index"` + Vector []float64 `json:"embedding"` +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (m Embedding) MarshalJSON() ([]byte, error) { + return json.Marshal(m.Vector) +} + +func (m embeddings) MarshalJSON() ([]byte, error) { + return json.Marshal(m.Embeddings) +} + +func (m embeddings) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +type reqEmbedding struct { + Model string `json:"model"` + Input []string `json:"input"` + Format string `json:"encoding_format,omitempty"` + Dimensions uint64 `json:"dimensions,omitempty"` + User string `json:"user,omitempty"` +} + +func (openai *Client) GenerateEmbedding(ctx context.Context, model string, prompt []string, opts ...llm.Opt) (*embeddings, error) { + // Bail out is no prompt + if len(prompt) == 0 { + return nil, llm.ErrBadParameter.With("missing prompt") + } + + // Apply options + opt, err := llm.ApplyOpts(opts...) + if err != nil { + return nil, err + } + + // Request + req, err := client.NewJSONRequest(reqEmbedding{ + Model: model, + Input: prompt, + Format: optFormat(opt), + Dimensions: optDimensions(opt), + User: optUser(opt), + }) + if err != nil { + return nil, err + } + + // Response + var response embeddings + if err := openai.DoWithContext(ctx, req, &response, client.OptPath("embeddings")); err != nil { + return nil, err + } + + // Return success + return &response, nil +} + +// Generate one vector +func (model *model) Embedding(ctx context.Context, prompt string, opts ...llm.Opt) ([]float64, error) { + response, err := model.GenerateEmbedding(ctx, model.Name(), []string{prompt}, opts...) + if err != nil { + return nil, err + } + if len(response.Embeddings.Data) == 0 { + return nil, llm.ErrNotFound.With("no embeddings returned") + } + return response.Embeddings.Data[0].Vector, nil +} diff --git a/pkg/openai/embeddings_test.go b/pkg/openai/embeddings_test.go new file mode 100644 index 0000000..d0cd29a --- /dev/null +++ b/pkg/openai/embeddings_test.go @@ -0,0 +1,20 @@ +package openai_test + +import ( + "context" + "testing" + + // Packages + assert "github.com/stretchr/testify/assert" +) + +func Test_embeddings_001(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "text-embedding-ada-002") + if assert.NotNil(model) { + response, err := model.Embedding(context.TODO(), "Hello, how are you?") + assert.NoError(err) + assert.NotEmpty(response) + t.Log(response) + } +} diff --git a/pkg/openai/message.go b/pkg/openai/message.go new file mode 100644 index 0000000..37128f2 --- /dev/null +++ b/pkg/openai/message.go @@ -0,0 +1,84 @@ +package openai + +import ( + // Packages + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Message with text or object content +type Message struct { + RoleContent + Media *llm.Attachment `json:"audio,omitempty"` + Calls ToolCalls `json:"tool_calls,omitempty"` + *ToolResults +} + +var _ llm.Completion = (*Message)(nil) + +type RoleContent struct { + Role string `json:"role,omitempty"` // assistant, user, tool, system + Content any `json:"content,omitempty"` // string or array of text, reference, image_url +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the number of completions +func (Message) Num() int { + return 1 +} + +// Return the current session role +func (message *Message) Role() string { + return message.RoleContent.Role +} + +// Return the completion +func (message *Message) Choice(index int) llm.Completion { + if index != 0 { + return nil + } + return message +} + +// Return the text for the last completion +func (message *Message) Text(index int) string { + if index != 0 { + return "" + } + // If content is text, return it + if text, ok := message.Content.(string); ok && text != "" { + return text + } + // If content is audio, and there is a caption, return it + if audio := message.Audio(0); audio != nil && audio.Caption() != "" { + return audio.Caption() + } + + // For other kinds, return empty string for the moment + return "" +} + +// Return the audio +func (message *Message) Audio(index int) *llm.Attachment { + if index != 0 { + return nil + } + return message.Media +} + +// Return the current session tool calls given the completion index. +// Will return nil if no tool calls were returned. +func (message *Message) ToolCalls(index int) []llm.ToolCall { + if index != 0 { + return nil + } + calls := make([]llm.ToolCall, 0, len(message.Calls)) + for _, call := range message.Calls { + calls = append(calls, call) + } + return calls +} diff --git a/pkg/openai/messagefactory.go b/pkg/openai/messagefactory.go new file mode 100644 index 0000000..9d05210 --- /dev/null +++ b/pkg/openai/messagefactory.go @@ -0,0 +1,79 @@ +package openai + +import ( + "encoding/json" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type messagefactory struct{} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - MESSAGE FACTORY + +func (messagefactory) SystemPrompt(prompt string) llm.Completion { + return &Message{ + RoleContent: RoleContent{ + Role: "system", + Content: prompt, + }, + } +} + +func (messagefactory) UserPrompt(prompt string, opts ...llm.Opt) (llm.Completion, error) { + // Get attachments + opt, err := llm.ApplyPromptOpts(opts...) + if err != nil { + return nil, err + } + + // Get attachments, allocate content + attachments := opt.Attachments() + content := make([]*Content, 1, len(attachments)+1) + + // Append the text and the attachments + content[0] = NewTextContext(prompt) + for _, attachment := range attachments { + content = append(content, NewImageData(attachment)) + } + + // Return success + return &Message{ + RoleContent: RoleContent{ + Role: "user", + Content: content, + }, + }, nil +} + +func (messagefactory) ToolResults(results ...llm.ToolResult) ([]llm.Completion, error) { + // Check for no results + if len(results) == 0 { + return nil, llm.ErrBadParameter.Withf("No tool results") + } + + // Create results + messages := make([]llm.Completion, 0, len(results)) + for _, result := range results { + value, err := json.Marshal(result.Value()) + if err != nil { + return nil, err + } + messages = append(messages, &Message{ + RoleContent: RoleContent{ + Role: "tool", + Content: string(value), + }, + ToolResults: &ToolResults{ + Id: result.Call().Id(), + }, + }) + } + + // Return success + return messages, nil +} diff --git a/pkg/openai/model.go b/pkg/openai/model.go new file mode 100644 index 0000000..ae8ba89 --- /dev/null +++ b/pkg/openai/model.go @@ -0,0 +1,137 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" + impl "github.com/mutablelogic/go-llm/pkg/internal/impl" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type model struct { + *Client `json:"-"` + meta Model +} + +var _ llm.Model = (*model)(nil) + +type Model struct { + Name string `json:"id"` + Type string `json:"object,omitempty"` + CreatedAt uint64 `json:"created,omitempty"` + OwnedBy string `json:"owned_by,omitempty"` +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (m model) MarshalJSON() ([]byte, error) { + return json.Marshal(m.meta) +} + +func (m model) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - llm.Agent + +// Return the models +func (openai *Client) Models(ctx context.Context) ([]llm.Model, error) { + return openai.ModelCache.Load(func() ([]llm.Model, error) { + return openai.loadmodels(ctx) + }) +} + +// Return a model by name, or nil if not found. +// Panics on error. +func (openai *Client) Model(ctx context.Context, name string) llm.Model { + model, err := openai.ModelCache.Get(func() ([]llm.Model, error) { + return openai.loadmodels(ctx) + }, name) + if err != nil { + panic(err) + } + return model +} + +// Function called to load models +func (openai *Client) loadmodels(ctx context.Context) ([]llm.Model, error) { + if models, err := openai.ListModels(ctx); err != nil { + return nil, err + } else { + result := make([]llm.Model, len(models)) + for i, meta := range models { + result[i] = &model{openai, meta} + } + return result, nil + } +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - llm.Model + +// Return model name +func (model model) Name() string { + return model.meta.Name +} + +// Return model description +func (model model) Description() string { + return fmt.Sprintf("Owner: %q", model.meta.OwnedBy) +} + +// Return model aliases +func (model) Aliases() []string { + return nil +} + +// Return a new empty session +func (model *model) Context(opts ...llm.Opt) llm.Context { + return impl.NewSession(model, &messagefactory{}, opts...) +} + +/////////////////////////////////////////////////////////////////////////////// +// API CALLS + +// ListModels returns all the models +func (openai *Client) ListModels(ctx context.Context) ([]Model, error) { + // Return the response + var response struct { + Data []Model `json:"data"` + } + if err := openai.DoWithContext(ctx, nil, &response, client.OptPath("models")); err != nil { + return nil, err + } + + // Return success + return response.Data, nil +} + +// GetModel returns one model +func (openai *Client) GetModel(ctx context.Context, model string) (*Model, error) { + // Return the response + var response Model + if err := openai.DoWithContext(ctx, nil, &response, client.OptPath("models", model)); err != nil { + return nil, err + } + + // Return success + return &response, nil +} + +// Delete a fine-tuned model. You must have the Owner role in your organization +// to delete a model. +func (openai *Client) DeleteModel(ctx context.Context, model string) error { + return openai.DoWithContext(ctx, client.MethodDelete, nil, client.OptPath("models", model)) +} diff --git a/pkg/openai/model_test.go b/pkg/openai/model_test.go new file mode 100644 index 0000000..c4b6425 --- /dev/null +++ b/pkg/openai/model_test.go @@ -0,0 +1,43 @@ +package openai_test + +import ( + "context" + "testing" + + // Packages + assert "github.com/stretchr/testify/assert" +) + +func Test_models_001(t *testing.T) { + assert := assert.New(t) + + response, err := client.ListModels(context.TODO()) + assert.NoError(err) + assert.NotEmpty(response) + + t.Run("models", func(t *testing.T) { + for _, model := range response { + model_, err := client.GetModel(context.TODO(), model.Name) + if assert.NoError(err) { + assert.NotNil(model_) + assert.Equal(*model_, model) + } + } + }) +} + +func Test_models_002(t *testing.T) { + assert := assert.New(t) + + response, err := client.Models(context.TODO()) + assert.NoError(err) + assert.NotEmpty(response) + + t.Run("models", func(t *testing.T) { + for _, model := range response { + model_ := client.Model(context.TODO(), model.Name()) + assert.NotNil(model_) + assert.Equal(model_, model) + } + }) +} diff --git a/pkg/openai/opt.go b/pkg/openai/opt.go new file mode 100644 index 0000000..73efa36 --- /dev/null +++ b/pkg/openai/opt.go @@ -0,0 +1,330 @@ +package openai + +import ( + // Packages + "slices" + "strings" + + "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Embeddings: The number of dimensions the resulting output embeddings +// should have. Only supported in text-embedding-3 and later models. +func WithDimensions(v uint64) llm.Opt { + return func(o *llm.Opts) error { + o.Set("dimensions", v) + return nil + } +} + +// Whether or not to store the output of this chat completion request for use in +// model distillation or evals products. +func WithStore(v bool) llm.Opt { + return func(o *llm.Opts) error { + o.Set("store", v) + return nil + } +} + +// Constrains effort on reasoning for reasoning models. Currently supported values are +// low, medium, and high. Reducing reasoning effort can result in faster responses +// and fewer tokens used on reasoning in a response. +func WithReasoningEffort(v string) llm.Opt { + return func(o *llm.Opts) error { + o.Set("reasoning_effort", v) + return nil + } +} + +// Key-value pair that can be attached to an object. This can be useful for storing +// additional information about the object in a structured format, and querying for objects +// via API or the dashboard. +func WithMetadata(k, v string) llm.Opt { + return func(o *llm.Opts) error { + // Set store to true + if err := WithStore(true)(o); err != nil { + return err + } + + // Add metadata + metadata, ok := o.Get("metadata").(map[string]string) + if !ok { + metadata = make(map[string]string, 16) + } + metadata[k] = v + o.Set("metadata", metadata) + return nil + } +} + +// Tokens (specified by their token ID in the tokenizer) to an associated bias +// value from -100 to 100. Mathematically, the bias is added to the logits +// generated by the model prior to sampling. The exact effect will vary per model, +// but values between -1 and 1 should decrease or increase likelihood of selection; +// values like -100 or 100 should result in a ban or exclusive selection of the +// relevant token. +func WithLogitBias(token uint64, bias int64) llm.Opt { + return func(o *llm.Opts) error { + logit_bias, ok := o.Get("logit_bias").(map[uint64]int64) + if !ok { + logit_bias = make(map[uint64]int64, 16) + } + logit_bias[token] = bias + o.Set("logit_bias", logit_bias) + return nil + } +} + +// Whether to return log probabilities of the output tokens or not. +func WithLogProbs() llm.Opt { + return func(o *llm.Opts) error { + o.Set("logprobs", true) + return nil + } +} + +// An integer between 0 and 20 specifying the number of most likely tokens +// to return at each token position, each with an associated log probability. +func WithTopLogProbs(v uint64) llm.Opt { + return func(o *llm.Opts) error { + if v > 20 { + return llm.ErrBadParameter.With("top_logprobs") + } + o.Set("logprobs", true) + o.Set("top_logprobs", v) + return nil + } +} + +// Output types that you would like the model to generate for this request. +// Supported values are: "text", "audio" +func WithModalities(v ...string) llm.Opt { + return func(o *llm.Opts) error { + arr, ok := o.Get("modalities").([]string) + if !ok { + arr = make([]string, 0, 16) + } + for _, v := range v { + v = strings.ToLower(strings.TrimSpace(v)) + if !slices.Contains(arr, v) { + arr = append(arr, v) + } + } + o.Set("modalities", arr) + return nil + } +} + +// Parameters for audio output +func WithAudio(voice, format string) llm.Opt { + return func(o *llm.Opts) error { + if err := WithModalities("text", "audio")(o); err != nil { + return err + } + if audio := NewAudio(voice, format); audio != nil { + o.Set("audio", audio) + } else { + return llm.ErrBadParameter.With("audio") + } + return nil + } +} + +// Specifies the latency tier to use for processing the request. Values +// can be auto or default +func WithServiceTier(v string) llm.Opt { + return func(o *llm.Opts) error { + o.Set("service_tier", v) + return nil + } +} + +// Enable streaming and include usage information in the streaming response +func WithStreamOptions(fn func(llm.Completion), include_usage bool) llm.Opt { + return func(o *llm.Opts) error { + if err := llm.WithStream(fn)(o); err != nil { + return err + } + o.Set("stream_options_include_usage", include_usage) + return nil + } +} + +// Disable parallel tool calling +func WithDisableParallelToolCalls() llm.Opt { + return func(o *llm.Opts) error { + o.Set("parallel_tool_calls", false) + return nil + } +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +// For embedding +func optFormat(opts *llm.Opts) string { + return opts.GetString("format") +} + +// For embedding +func optDimensions(opts *llm.Opts) uint64 { + return opts.GetUint64("dimensions") +} + +// For embedding and completions +func optUser(opts *llm.Opts) string { + return opts.GetString("user") +} + +func optStore(opts *llm.Opts) *bool { + if v, ok := opts.Get("store").(bool); ok { + return &v + } + return nil +} + +func optReasoningEffort(opts *llm.Opts) string { + return opts.GetString("reasoning_effort") +} + +func optMetadata(opts *llm.Opts) map[string]string { + if metadata, ok := opts.Get("metadata").(map[string]string); ok { + return metadata + } + return nil +} + +func optFrequencyPenalty(opts *llm.Opts) float64 { + return opts.GetFloat64("frequency_penalty") +} + +func optLogitBias(opts *llm.Opts) map[uint64]int64 { + if logit_bias, ok := opts.Get("logit_bias").(map[uint64]int64); ok { + return logit_bias + } + return nil +} + +func optLogProbs(opts *llm.Opts) bool { + return opts.GetBool("logprobs") +} + +func optTopLogProbs(opts *llm.Opts) uint64 { + return opts.GetUint64("top_logprobs") +} + +func optMaxTokens(opts *llm.Opts) uint64 { + return opts.GetUint64("max_tokens") +} + +func optNumCompletions(opts *llm.Opts) uint64 { + return opts.GetUint64("num_completions") +} + +func optModalities(opts *llm.Opts) []string { + if v, ok := opts.Get("modalities").([]string); ok { + return v + } + return nil +} + +func optPrediction(opts *llm.Opts) *Content { + v := strings.TrimSpace(opts.GetString("prediction")) + if v != "" { + return NewContentString("content", v) + } + return nil +} + +func optAudio(opts *llm.Opts) *Audio { + if v, ok := opts.Get("audio").(*Audio); ok { + return v + } + return nil +} + +func optPresencePenalty(opts *llm.Opts) float64 { + return opts.GetFloat64("presence_penalty") +} + +func optResponseFormat(opts *llm.Opts) *Format { + if format := NewFormat(optFormat(opts)); format != nil { + return format + } else { + return nil + } +} + +func optSeed(opts *llm.Opts) uint64 { + return opts.GetUint64("seed") +} + +func optServiceTier(opts *llm.Opts) string { + return opts.GetString("service_tier") +} + +func optStreamOptions(opts *llm.Opts) *StreamOptions { + if opts.Has("stream_options_include_usage") { + return NewStreamOptions(opts.GetBool("stream_options_include_usage")) + } else { + return nil + } +} + +func optStream(opts *llm.Opts) bool { + return opts.StreamFn() != nil +} + +func optTemperature(opts *llm.Opts) float64 { + return opts.GetFloat64("temperature") +} + +func optTopP(opts *llm.Opts) float64 { + return opts.GetFloat64("top_p") +} + +func optStopSequences(opts *llm.Opts) []string { + if opts.Has("stop") { + if stop, ok := opts.Get("stop").([]string); ok { + return stop + } + } + return nil +} + +func optTools(agent llm.Agent, opts *llm.Opts) []llm.Tool { + toolkit := opts.ToolKit() + if toolkit == nil { + return nil + } + return toolkit.Tools(agent) +} + +func optToolChoice(opts *llm.Opts) any { + choices, ok := opts.Get("tool_choice").([]string) + if !ok || len(choices) == 0 { + return nil + } + + // We only support one choice + choice := strings.TrimSpace(strings.ToLower(choices[0])) + switch choice { + case "auto", "none", "required": + return choice + case "": + return nil + default: + return NewToolChoice(choice) + } +} + +func optParallelToolCalls(opts *llm.Opts) *bool { + if opts.Has("parallel_tool_calls") { + v := opts.GetBool("parallel_tool_calls") + return &v + } + return nil +} diff --git a/pkg/openai/opt_audio.go b/pkg/openai/opt_audio.go new file mode 100644 index 0000000..ed230bc --- /dev/null +++ b/pkg/openai/opt_audio.go @@ -0,0 +1,26 @@ +package openai + +import "strings" + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Audio struct { + // Supported voices include ash, ballad, coral, sage, and verse + Voice string `json:"voice"` + + // Supported formats: wav, mp3, flac, opus, or pcm16 + Format string `json:"format"` +} + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func NewAudio(voice, format string) *Audio { + voice = strings.TrimSpace(strings.ToLower(voice)) + format = strings.TrimSpace(strings.ToLower(format)) + if voice == "" || format == "" { + return nil + } + return &Audio{Voice: voice, Format: format} +} diff --git a/pkg/openai/opt_format.go b/pkg/openai/opt_format.go new file mode 100644 index 0000000..34e4623 --- /dev/null +++ b/pkg/openai/opt_format.go @@ -0,0 +1,25 @@ +package openai + +import "strings" + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Format struct { + // Supported response format types are text, json_object or json_schema + Type string `json:"type"` +} + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func NewFormat(format string) *Format { + format = strings.TrimSpace(strings.ToLower(format)) + switch format { + case "text", "json_object": + return &Format{Type: format} + default: + // json_schema is not yet supported + return nil + } +} diff --git a/pkg/openai/opt_stream.go b/pkg/openai/opt_stream.go new file mode 100644 index 0000000..c82da5e --- /dev/null +++ b/pkg/openai/opt_stream.go @@ -0,0 +1,15 @@ +package openai + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type StreamOptions struct { + IncludeUsage bool `json:"include_usage"` +} + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func NewStreamOptions(include_usage bool) *StreamOptions { + return &StreamOptions{IncludeUsage: include_usage} +} diff --git a/pkg/openai/opt_toolchoice.go b/pkg/openai/opt_toolchoice.go new file mode 100644 index 0000000..c89c5c9 --- /dev/null +++ b/pkg/openai/opt_toolchoice.go @@ -0,0 +1,23 @@ +package openai + +import "strings" + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type ToolChoice struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + } `json:"function"` +} + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func NewToolChoice(function string) *ToolChoice { + choice := new(ToolChoice) + choice.Type = "function" + choice.Function.Name = strings.TrimSpace(strings.ToLower(function)) + return choice +} diff --git a/pkg/mistral/session_test.go b/pkg/openai/session_test.go similarity index 88% rename from pkg/mistral/session_test.go rename to pkg/openai/session_test.go index 7fbcaa3..0f1004c 100644 --- a/pkg/mistral/session_test.go +++ b/pkg/openai/session_test.go @@ -1,4 +1,4 @@ -package mistral_test +package openai_test import ( "context" @@ -12,7 +12,7 @@ import ( func Test_session_001(t *testing.T) { assert := assert.New(t) - model := client.Model(context.TODO(), "mistral-small-latest") + model := client.Model(context.TODO(), "gpt-4o-mini") if !assert.NotNil(model) { t.FailNow() } @@ -27,7 +27,7 @@ func Test_session_001(t *testing.T) { func Test_session_002(t *testing.T) { assert := assert.New(t) - model := client.Model(context.TODO(), "mistral-small-latest") + model := client.Model(context.TODO(), "gpt-4o-mini") if !assert.NotNil(model) { t.FailNow() } diff --git a/pkg/mistral/testdata/LICENSE b/pkg/openai/testdata/LICENSE similarity index 100% rename from pkg/mistral/testdata/LICENSE rename to pkg/openai/testdata/LICENSE diff --git a/pkg/mistral/testdata/guggenheim.jpg b/pkg/openai/testdata/guggenheim.jpg similarity index 100% rename from pkg/mistral/testdata/guggenheim.jpg rename to pkg/openai/testdata/guggenheim.jpg diff --git a/pkg/openai/tool.go b/pkg/openai/tool.go new file mode 100644 index 0000000..68865a3 --- /dev/null +++ b/pkg/openai/tool.go @@ -0,0 +1,65 @@ +package openai + +import ( + "encoding/json" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type ToolCall struct { + Id string `json:"id,omitempty"` // tool id + Type string `json:"type,omitempty"` // tool type (function) + Index uint64 `json:"index,omitempty"` // tool index + Function struct { + Name string `json:"name,omitempty"` // tool name + Arguments string `json:"arguments,omitempty"` // tool arguments + } `json:"function"` +} + +type toolcall struct { + meta ToolCall +} + +type ToolCalls []toolcall + +type ToolResults struct { + Id string `json:"tool_call_id,omitempty"` +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (t *toolcall) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &t.meta) +} + +func (t toolcall) MarshalJSON() ([]byte, error) { + return json.Marshal(t.meta) +} + +func (t toolcall) String() string { + data, err := json.MarshalIndent(t, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// The tool name +func (t toolcall) Name() string { + return t.meta.Function.Name +} + +// The tool identifier +func (t toolcall) Id() string { + return t.meta.Id +} + +// Decode the calling parameters +func (t toolcall) Decode(v any) error { + return json.Unmarshal([]byte(t.meta.Function.Arguments), v) +} diff --git a/pkg/tool/toolkit.go b/pkg/tool/toolkit.go index 79eb86a..5c9efd3 100644 --- a/pkg/tool/toolkit.go +++ b/pkg/tool/toolkit.go @@ -37,15 +37,15 @@ func (kit *ToolKit) Tools(agent llm.Agent) []llm.Tool { result := make([]llm.Tool, 0, len(kit.functions)) for _, t := range kit.functions { switch agent.Name() { - case "ollama", "mistral": + case "anthropic": + t.Parameters = nil + result = append(result, t) + default: t.InputSchema = nil result = append(result, ToolFunction{ Type: "function", Tool: t, }) - default: - t.Parameters = nil - result = append(result, t) } } return result diff --git a/pkg/ui/telegram/telegram.go b/pkg/ui/telegram/telegram.go new file mode 100644 index 0000000..052161a --- /dev/null +++ b/pkg/ui/telegram/telegram.go @@ -0,0 +1,61 @@ +package telegram + +import ( + "context" + "fmt" + + // Packages + telegram "github.com/go-telegram-bot-api/telegram-bot-api/v5" +) + +///////////////////////////////////////////////////////////////////// +// TYPES + +type t struct { + *telegram.BotAPI +} + +///////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func NewTelegram(token string) (*t, error) { + bot, err := telegram.NewBotAPI(token) + if err != nil { + return nil, err + } + + // Create a new telegram instance + telegram := &t{bot} + + // Return the instance + return telegram, nil +} + +///////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (t *t) Run(ctx context.Context) error { + updates := t.GetUpdatesChan(telegram.NewUpdate(0)) +FOR_LOOP: + for { + select { + case <-ctx.Done(): + break FOR_LOOP + case evt := <-updates: + if evt.Message != nil && !evt.Message.IsCommand() { + t.handleMessage(evt.Message) + } + } + } + + // Return success + return nil +} + +///////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func (t *t) handleMessage(update *telegram.Message) { + fmt.Println("Received message from", update.From.UserName) + fmt.Println(" => ", update.Text) +} diff --git a/pkg/ui/term/term.go b/pkg/ui/term/term.go new file mode 100644 index 0000000..f7bb5da --- /dev/null +++ b/pkg/ui/term/term.go @@ -0,0 +1,89 @@ +package term + +import ( + "fmt" + "io" + "os" + + // Packages + format "github.com/MichaelMure/go-term-text" + color "github.com/fatih/color" + term "golang.org/x/term" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Term struct { + r io.Reader + fd int + *term.Terminal +} + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func NewTerm(r io.Reader) (*Term, error) { + t := new(Term) + t.r = r + + // Set file descriptor + if osf, ok := r.(*os.File); ok { + t.fd = int(osf.Fd()) + if term.IsTerminal(t.fd) { + t.Terminal = term.NewTerminal(osf, "") + } + } + + // Return success + return t, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Returns the width and height of the terminal, or (0,0) if we aren't in +// a terminal +func (t *Term) Size() (int, int) { + if t.Terminal != nil { + if w, h, err := term.GetSize(t.fd); err == nil { + return w, h + } + } + // Unable to get the size + return 0, 0 +} + +func (t *Term) Println(v ...any) { + text := fmt.Sprint(v...) + w, _ := t.Size() + if w > 0 { + text, _ = format.Wrap(text, w) + } + fmt.Fprintln(os.Stdout, text) +} + +func (t *Term) ReadLine(prompt string) (string, error) { + // Set terminal raw mode + if t.Terminal != nil { + state, err := term.MakeRaw(t.fd) + if err != nil { + return "", err + } + defer term.Restore(t.fd, state) + } + + // Set the prompt with color + if t.Terminal != nil { + prompt = color.New(color.Bold).Sprint(prompt) + t.Terminal.SetPrompt(prompt) + } + + // Read the line + if t.Terminal != nil { + return t.Terminal.ReadLine() + } else { + // Don't support non-terminal input yet + return "", io.EOF + } +} diff --git a/pkg/ui/ui.go b/pkg/ui/ui.go new file mode 100644 index 0000000..25ef44e --- /dev/null +++ b/pkg/ui/ui.go @@ -0,0 +1,14 @@ +package ui + +import "context" + +////////////////////////////////////////////////////////////////////////////// +// TYPES + +type UI interface { + // Run the runloop for the UI + Run(ctx context.Context) error + + // Send a system message + SysPrint(format string, args ...interface{}) error +} diff --git a/pkg/version/version.go b/pkg/version/version.go new file mode 100644 index 0000000..da44dc0 --- /dev/null +++ b/pkg/version/version.go @@ -0,0 +1,12 @@ +package version + +/////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +var ( + GitSource string + GitTag string + GitBranch string + GitHash string + GoBuildTime string +)