diff --git a/README.md b/README.md index f8de370..2cea2bd 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,21 @@ # go-llm Large Language Model API interface. This is a simple API interface for large language models -which run on [Ollama](https://github.com/ollama/ollama/blob/main/docs/api.md) -and [Anthopic](https://docs.anthropic.com/en/api/getting-started). +which run on [Ollama](https://github.com/ollama/ollama/blob/main/docs/api.md), +[Anthopic](https://docs.anthropic.com/en/api/getting-started) and [Mistral](https://docs.mistral.ai/) +(OpenAI might be added later). The module includes the ability to utilize: * Maintaining a session of messages -* Tool calling support +* Tool calling support, including using your own tools (aka Tool plugins) +* Creating embedding vectors from text * Streaming responses +* Multi-modal support (aka, Images and Attachments) There is a command-line tool included in the module which can be used to interact with the API. -For example, +If you have docker installed, you can use the following command to run the tool, without +installation: ```bash # Display help @@ -20,15 +24,23 @@ docker run ghcr.io/mutablelogic/go-llm:latest --help # Interact with Claude to retrieve news headlines, assuming # you have an API key for Anthropic and NewsAPI docker run \ - --interactive -e ANTHROPIC_API_KEY -e NEWSAPI_KEY \ + --interactive -e MISTRAL_API_KEY -e NEWSAPI_KEY \ ghcr.io/mutablelogic/go-llm:latest \ - chat claude-3-5-haiku-20241022 + chat claude-3-5-haiku-20241022 --prompt "What is the latest news?" ``` +See below for more information on how to use the command-line tool (or how to install it +if you have a `go` compiler). + ## Programmatic Usage See the documentation [here](https://pkg.go.dev/github.com/mutablelogic/go-llm) -for integration into your own Go programs. To create an +for integration into your own Go programs. + +### Agent Instantiation + +For each LLM provider, you create an agent which can be used to interact with the API. +To create an [Ollama](https://pkg.go.dev/github.com/mutablelogic/go-llm/pkg/anthropic) agent, @@ -38,7 +50,7 @@ import ( ) func main() { - // Create a new agent + // Create a new agent - replace the URL with the one to your Ollama instance agent, err := ollama.New("https://ollama.com/api/v1/") if err != nil { panic(err) @@ -49,7 +61,7 @@ func main() { To create an [Anthropic](https://pkg.go.dev/github.com/mutablelogic/go-llm/pkg/anthropic) -agent, +agent with an API key stored as an environment variable, ```go import ( @@ -58,7 +70,49 @@ import ( func main() { // Create a new agent - agent, err := anthropic.New(os.Getev("ANTHROPIC_API_KEY")) + agent, err := anthropic.New(os.Getenv("ANTHROPIC_API_KEY")) + if err != nil { + panic(err) + } + // ... +} +``` + +For [Mistral](https://pkg.go.dev/github.com/mutablelogic/go-llm/pkg/mistral) models, you can use: + +```go +import ( + "github.com/mutablelogic/go-llm/pkg/mistral" +) + +func main() { + // Create a new agent + agent, err := mistral.New(os.Getenv("MISTRAL_API_KEY")) + if err != nil { + panic(err) + } + // ... +} +``` + +You can append options to the agent creation to set the client/server communication options, +such as user agent strings, timeouts, debugging, rate limiting, adding custom headers, etc. See [here](https://pkg.go.dev/github.com/mutablelogic/go-client#readme-basic-usage) for more information. + +There is also an _aggregated_ agent which can be used to interact with multiple providers at once. This is useful if you want +to use models from different providers simultaneously. + +```go +import ( + "github.com/mutablelogic/go-llm/pkg/agent" +) + +func main() { + // Create a new agent which aggregates multiple providers + agent, err := agent.New( + agent.WithAnthropic(os.Getenv("ANTHROPIC_API_KEY")), + agent.WithMistral(os.Getenv("MISTRAL_API_KEY")), + agent.WithOllama(os.Getenv("OLLAMA_URL")), + ) if err != nil { panic(err) } @@ -66,6 +120,8 @@ func main() { } ``` +### Chat Sessions + You create a **chat session** with a model as follows, ```go @@ -75,7 +131,7 @@ import ( func session(ctx context.Context, agent llm.Agent) error { // Create a new chat session - session := agent.Model("claude-3-5-haiku-20241022").Context() + session := agent.Model(context.TODO(), "claude-3-5-haiku-20241022").Context() // Repeat forever for { @@ -84,12 +140,114 @@ func session(ctx context.Context, agent llm.Agent) error { return err } - // Print the response - fmt.Println(session.Text()) + // Print the response for the zero'th completion + fmt.Println(session.Text(0)) } } ``` +The `Context` object will continue to store the current session and options, and will +ensure the session is maintained across multiple calls. + +### Embedding Generation + +TODO + +### Attachments & Image Caption Generation + +TODO + +### Streaming + +TODO + +### Tool Support + +TODO + +## Options + +You can add options to sessions, or to prompts. Different providers and models support +different options. + +```go +type Model interface { + // Set session-wide options + Context(...Opt) Context + + // Add attachments (images, PDF's) to a user prompt for completion + UserPrompt(string, ...Opt) Context + + // Create an embedding vector with embedding options + Embedding(context.Context, string, ...Opt) ([]float64, error) +} + +type Context interface { + // Add single-use options when calling the model, which override + // session options. You can attach files to a user prompt. + FromUser(context.Context, string, ...Opt) error +} +``` + +The options are as follows: + +| Option | Ollama | Anthropic | Mistral | OpenAI | Description | +|--------|--------|-----------|---------|--------|-------------| +| `llm.WithTemperature(float64)` | Yes | Yes | Yes | - | What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.7 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. | +| `llm.WithTopP(float64)` | Yes | Yes | Yes | - | Nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. | +| `llm.WithTopK(uint64)` | Yes | Yes | No | - | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. | +| `llm.WithMaxTokens(uint64)` | No | Yes | Yes | - | The maximum number of tokens to generate in the response. | +| `llm.WithStream(func(llm.Completion))` | Can be enabled when tools are not used | Yes | Yes | - | Stream the response to a function. | +| `llm.WithToolChoice(string, string, ...)` | No | Yes | Use `auto`, `any`, `none`, `required` or a function name. Only the first argument is used. | - | The tool to use for the model. | +| `llm.WithToolKit(llm.ToolKit)` | Cannot be combined with streaming | Yes | Yes | - | The set of tools to use. | +| `llm.WithStopSequence(string, string, ...)` | Yes | Yes | Yes | - | Stop generation if one of these tokens is detected. | +| `llm.WithSystemPrompt(string)` | No | Yes | Yes | - | Set the system prompt for the model. | +| `llm.WithSeed(uint64)` | Yes | Yes | Yes | - | The seed to use for random sampling. If set, different calls will generate deterministic results. | +| `llm.WithFormat(string)` | Use `json` | Yes | Use `json_format` or `text` | - | The format of the response. For Mistral, you must also instruct the model to produce JSON yourself with a system or a user message. | +| `llm.WithPresencePenalty(float64)` | Yes | No | Yes | - | Determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative. | +| `llm.WithFequencyPenalty(float64)` | Yes | No | Yes | - | Penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition. | +| `mistral.WithPrediction(string)` | No | No | Yes | - | Enable users to specify expected results, optimizing response times by leveraging known or predictable content. This approach is especially effective for updating text documents or code files with minimal changes, reducing latency while maintaining high-quality results. | +| `llm.WithSafePrompt()` | No | No | Yes | - | Whether to inject a safety prompt before all conversations. | +| `llm.WithNumCompletions(uint64)` | No | No | Yes | - | Number of completions to return for each request. | +| `llm.WithAttachment(io.Reader)` | Yes | Yes | Yes | - | Attach a file to a user prompt. It is the responsibility of the caller to close the reader. | +| `antropic.WithEphemeral()` | No | Yes | No | - | Attachments should be cached server-side | +| `antropic.WithCitations()` | No | Yes | No | - | Attachments should be used in citations | +| `antropic.WithUser(string)` | No | Yes | No | - | Indicate the user name for the request, for debugging | + +## The Command Line Tool + +You can use the command-line tool to interact with the API. To build the tool, you can use the following command: + +```bash +go install github.com/mutablelogic/go-llm/cmd/llm@latest +llm --help +``` + +The output is something like: + +```text +Usage: llm [flags] + +LLM agent command line interface + +Flags: + -h, --help Show context-sensitive help. + --debug Enable debug output + --verbose Enable verbose output + --ollama-endpoint=STRING Ollama endpoint ($OLLAMA_URL) + --anthropic-key=STRING Anthropic API Key ($ANTHROPIC_API_KEY) + --news-key=STRING News API Key ($NEWSAPI_KEY) + +Commands: + agents Return a list of agents + models Return a list of models + tools Return a list of tools + download Download a model + chat Start a chat session + +Run "llm --help" for more information on a command. +``` + ## Contributing & Distribution *This module is currently in development and subject to change*. Please do file diff --git a/agent.go b/agent.go index b7658dd..1e75eea 100644 --- a/agent.go +++ b/agent.go @@ -11,4 +11,8 @@ type Agent interface { // Return the models Models(context.Context) ([]Model, error) + + // Return a model by name, or nil if not found. + // Panics on error. + Model(context.Context, string) Model } diff --git a/attachment.go b/attachment.go index c7733c4..5987a9d 100644 --- a/attachment.go +++ b/attachment.go @@ -1,8 +1,13 @@ package llm import ( + "encoding/base64" + "encoding/json" "io" + "mime" + "net/http" "os" + "path/filepath" ) /////////////////////////////////////////////////////////////////////////////// @@ -31,6 +36,25 @@ func ReadAttachment(r io.Reader) (*Attachment, error) { return &Attachment{filename: filename, data: data}, nil } +//////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (a *Attachment) String() string { + var j struct { + Filename string `json:"filename"` + Type string `json:"type"` + Bytes uint64 `json:"bytes"` + } + j.Filename = a.filename + j.Type = a.Type() + j.Bytes = uint64(len(a.data)) + data, err := json.MarshalIndent(j, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + //////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS @@ -41,3 +65,17 @@ func (a *Attachment) Filename() string { func (a *Attachment) Data() []byte { return a.data } + +func (a *Attachment) Type() string { + // Mimetype based on content + mimetype := http.DetectContentType(a.data) + if mimetype == "application/octet-stream" && a.filename != "" { + // Detect mimetype from extension + mimetype = mime.TypeByExtension(filepath.Ext(a.filename)) + } + return mimetype +} + +func (a *Attachment) Url() string { + return "data:" + a.Type() + ";base64," + base64.StdEncoding.EncodeToString(a.data) +} diff --git a/cmd/agent/chat.go b/cmd/llm/chat.go similarity index 69% rename from cmd/agent/chat.go rename to cmd/llm/chat.go index 4067d75..bd14f6e 100644 --- a/cmd/agent/chat.go +++ b/cmd/llm/chat.go @@ -18,6 +18,8 @@ import ( type ChatCmd struct { Model string `arg:"" help:"Model name"` NoStream bool `flag:"nostream" help:"Disable streaming"` + NoTools bool `flag:"nostream" help:"Disable tool calling"` + Prompt string `flag:"prompt" help:"Set the initial user prompt"` System string `flag:"system" help:"Set the system prompt"` } @@ -39,16 +41,18 @@ func (cmd *ChatCmd) Run(globals *Globals) error { // Set the options opts := []llm.Opt{} if !cmd.NoStream { - opts = append(opts, llm.WithStream(func(cc llm.ContextContent) { - if text := cc.Text(); text != "" { - fmt.Println(text) + opts = append(opts, llm.WithStream(func(cc llm.Completion) { + if text := cc.Text(0); text != "" { + count := strings.Count(text, "\n") + fmt.Print(strings.Repeat("\033[F", count) + strings.Repeat(" ", count) + "\r") + fmt.Print(text) } })) } if cmd.System != "" { opts = append(opts, llm.WithSystemPrompt(cmd.System)) } - if globals.toolkit != nil { + if globals.toolkit != nil && !cmd.NoTools { opts = append(opts, llm.WithToolKit(globals.toolkit)) } @@ -57,11 +61,17 @@ func (cmd *ChatCmd) Run(globals *Globals) error { // Continue looping until end of input for { - input, err := globals.term.ReadLine(model.Name() + "> ") - if errors.Is(err, io.EOF) { - return nil - } else if err != nil { - return err + var input string + if cmd.Prompt != "" { + input = cmd.Prompt + cmd.Prompt = "" + } else { + input, err = globals.term.ReadLine(model.Name() + "> ") + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return err + } } // Ignore empty input @@ -77,12 +87,12 @@ func (cmd *ChatCmd) Run(globals *Globals) error { // Repeat call tools until no more calls are made for { - calls := session.ToolCalls() + calls := session.ToolCalls(0) if len(calls) == 0 { break } - if session.Text() != "" { - globals.term.Println(session.Text()) + if session.Text(0) != "" { + globals.term.Println(session.Text(0)) } else { var names []string for _, call := range calls { @@ -98,7 +108,7 @@ func (cmd *ChatCmd) Run(globals *Globals) error { } // Print the response - globals.term.Println("\n" + session.Text() + "\n") + globals.term.Println("\n" + session.Text(0) + "\n") } }) } diff --git a/cmd/agent/main.go b/cmd/llm/main.go similarity index 88% rename from cmd/agent/main.go rename to cmd/llm/main.go index 8d0970c..205ad86 100644 --- a/cmd/agent/main.go +++ b/cmd/llm/main.go @@ -12,8 +12,8 @@ import ( client "github.com/mutablelogic/go-client" llm "github.com/mutablelogic/go-llm" agent "github.com/mutablelogic/go-llm/pkg/agent" - "github.com/mutablelogic/go-llm/pkg/newsapi" - "github.com/mutablelogic/go-llm/pkg/tool" + newsapi "github.com/mutablelogic/go-llm/pkg/newsapi" + tool "github.com/mutablelogic/go-llm/pkg/tool" ) //////////////////////////////////////////////////////////////////////////////// @@ -27,6 +27,7 @@ type Globals struct { // Agents Ollama `embed:"" help:"Ollama configuration"` Anthropic `embed:"" help:"Anthropic configuration"` + Mistral `embed:"" help:"Mistral configuration"` // Tools NewsAPI `embed:"" help:"NewsAPI configuration"` @@ -46,6 +47,10 @@ type Anthropic struct { AnthropicKey string `env:"ANTHROPIC_API_KEY" help:"Anthropic API Key"` } +type Mistral struct { + MistralKey string `env:"MISTRAL_API_KEY" help:"Mistral API Key"` +} + type NewsAPI struct { NewsKey string `env:"NEWSAPI_KEY" help:"News API Key"` } @@ -102,8 +107,13 @@ func main() { if cli.OllamaEndpoint != "" { opts = append(opts, agent.WithOllama(cli.OllamaEndpoint, clientopts...)) } - if cli.AnthropicKey != "" { - opts = append(opts, agent.WithAnthropic(cli.AnthropicKey, clientopts...)) + /* + if cli.AnthropicKey != "" { + opts = append(opts, agent.WithAnthropic(cli.AnthropicKey, clientopts...)) + } + */ + if cli.MistralKey != "" { + opts = append(opts, agent.WithMistral(cli.MistralKey, clientopts...)) } // Make a toolkit diff --git a/cmd/agent/models.go b/cmd/llm/models.go similarity index 98% rename from cmd/agent/models.go rename to cmd/llm/models.go index 1bb96ee..e304507 100644 --- a/cmd/agent/models.go +++ b/cmd/llm/models.go @@ -60,7 +60,7 @@ func (*ListAgentsCmd) Run(globals *Globals) error { return fmt.Errorf("No agents found") } - var agents []string + agents := make([]string, 0, len(agent.Agents())) for _, agent := range agent.Agents() { agents = append(agents, agent.Name()) } diff --git a/cmd/agent/term.go b/cmd/llm/term.go similarity index 100% rename from cmd/agent/term.go rename to cmd/llm/term.go diff --git a/context.go b/context.go index cc8e83c..0aad95d 100644 --- a/context.go +++ b/context.go @@ -5,21 +5,29 @@ import "context" ////////////////////////////////////////////////////////////////// // TYPES -// ContextContent is the content of the last context message -type ContextContent interface { +// Completion is the content of the last context message +type Completion interface { + // Return the number of completions, which is ususally 1 unless + // WithNumCompletions was used when calling the model + Num() int + // Return the current session role, which can be system, assistant, user, tool, tool_result, ... + // If this is a completion, the role is usually 'assistant' Role() string - // Return the current session text, or empty string if no text was returned - Text() string + // Return the text for the last completion, with the argument as the + // completion index (usually 0). If multiple completions are not + // supported, the argument is ignored. + Text(int) string - // Return the current session tool calls, or empty if no tool calls were made - ToolCalls() []ToolCall + // Return the current session tool calls given the completion index. + // Will return nil if no tool calls were returned. + ToolCalls(int) []ToolCall } // Context is fed to the agent to generate a response type Context interface { - ContextContent + Completion // Generate a response from a user prompt (with attachments and // other options) diff --git a/etc/docker/Dockerfile b/etc/docker/Dockerfile index b612cda..3704fec 100644 --- a/etc/docker/Dockerfile +++ b/etc/docker/Dockerfile @@ -25,4 +25,4 @@ RUN apt update -y && apt install -y ca-certificates LABEL org.opencontainers.image.source=https://${SOURCE} # Entrypoint when running the server -ENTRYPOINT [ "/usr/local/bin/agent" ] +ENTRYPOINT [ "/usr/local/bin/llm" ] diff --git a/opt.go b/opt.go index df91705..a378a91 100644 --- a/opt.go +++ b/opt.go @@ -1,6 +1,7 @@ package llm import ( + "encoding/json" "io" "time" ) @@ -13,12 +14,13 @@ type Opt func(*Opts) error // set of options type Opts struct { - agents map[string]Agent // Set of agents - toolkit ToolKit // Toolkit for tools - callback func(ContextContent) // Streaming callback - attachments []*Attachment // Attachments - system string // System prompt - options map[string]any // Additional options + prompt bool + agents map[string]Agent // Set of agents + toolkit ToolKit // Toolkit for tools + callback func(Completion) // Streaming callback + attachments []*Attachment // Attachments + system string // System prompt + options map[string]any // Additional options } //////////////////////////////////////////////////////////////////////////////// @@ -26,7 +28,22 @@ type Opts struct { // ApplyOpts returns a structure of options func ApplyOpts(opts ...Opt) (*Opts, error) { + return applyOpts(false, opts...) +} + +// ApplyPromptOpts returns a structure of options for a prompt +func ApplyPromptOpts(opts ...Opt) (*Opts, error) { + if opt, err := applyOpts(true, opts...); err != nil { + return nil, err + } else { + return opt, nil + } +} + +// ApplySessionOpts returns a structure of options +func applyOpts(prompt bool, opts ...Opt) (*Opts, error) { o := new(Opts) + o.prompt = prompt o.agents = make(map[string]Agent) o.options = make(map[string]any) for _, opt := range opts { @@ -37,6 +54,33 @@ func ApplyOpts(opts ...Opt) (*Opts, error) { return o, nil } +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (o Opts) MarshalJSON() ([]byte, error) { + var j struct { + ToolKit ToolKit `json:"toolkit,omitempty"` + Agents map[string]Agent `json:"agents,omitempty"` + System string `json:"system,omitempty"` + Attachments []*Attachment `json:"attachments,omitempty"` + Options map[string]any `json:"options,omitempty"` + } + j.ToolKit = o.toolkit + j.Agents = o.agents + j.Attachments = o.attachments + j.System = o.system + j.Options = o.options + return json.Marshal(j) +} + +func (o Opts) String() string { + data, err := json.Marshal(o) + if err != nil { + return err.Error() + } + return string(data) +} + /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS - PROPERTIES @@ -46,7 +90,7 @@ func (o *Opts) ToolKit() ToolKit { } // Return the stream function -func (o *Opts) StreamFn() func(ContextContent) { +func (o *Opts) StreamFn() func(Completion) { return o.callback } @@ -150,7 +194,7 @@ func WithToolKit(toolkit ToolKit) Opt { } // Set chat streaming function -func WithStream(fn func(ContextContent)) Opt { +func WithStream(fn func(Completion)) Opt { return func(o *Opts) error { o.callback = fn return nil @@ -182,6 +226,10 @@ func WithAgent(agent Agent) Opt { // Create an attachment func WithAttachment(r io.Reader) Opt { return func(o *Opts) error { + // Only attach if prompt is set + if !o.prompt { + return nil + } if attachment, err := ReadAttachment(r); err != nil { return err } else { @@ -216,13 +264,41 @@ func WithTopP(v float64) Opt { // Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more // diverse answers, while a lower value (e.g. 10) will be more conservative. -func WithTopK(v uint) Opt { +func WithTopK(v uint64) Opt { return func(o *Opts) error { o.Set("top_k", v) return nil } } +func WithPresencePenalty(v float64) Opt { + return func(o *Opts) error { + if v < -2 || v > 2 { + return ErrBadParameter.With("presence_penalty") + } + o.Set("presence_penalty", v) + return nil + } +} + +func WithFrequencyPenalty(v float64) Opt { + return func(o *Opts) error { + if v < -2 || v > 2 { + return ErrBadParameter.With("frequency_penalty") + } + o.Set("frequency_penalty", v) + return nil + } +} + +// The maximum number of tokens to generate in the completion. +func WithMaxTokens(v uint64) Opt { + return func(o *Opts) error { + o.Set("max_tokens", v) + return nil + } +} + // Set system prompt func WithSystemPrompt(v string) Opt { return func(o *Opts) error { @@ -230,3 +306,54 @@ func WithSystemPrompt(v string) Opt { return nil } } + +// Set stop sequence +func WithStopSequence(v ...string) Opt { + return func(o *Opts) error { + o.Set("stop", v) + return nil + } +} + +// Set random seed for deterministic behavior +func WithSeed(v uint64) Opt { + return func(o *Opts) error { + o.Set("seed", v) + return nil + } +} + +// Set format +func WithFormat(v any) Opt { + return func(o *Opts) error { + o.Set("format", v) + return nil + } +} + +// Set tool choices: can be auto, none, required, any or a list of tool names +func WithToolChoice(v ...string) Opt { + return func(o *Opts) error { + o.Set("tool_choice", v) + return nil + } +} + +// Number of completions to return for each request +func WithNumCompletions(v uint64) Opt { + return func(o *Opts) error { + if v < 1 || v > 8 { + return ErrBadParameter.With("num_completions must be between 1 and 8") + } + o.Set("num_completions", v) + return nil + } +} + +// Inject a safety prompt before all conversations. +func WithSafePrompt() Opt { + return func(o *Opts) error { + o.Set("safe_prompt", true) + return nil + } +} diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 312c327..cd50c87 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -105,6 +105,15 @@ func (a *Agent) Models(ctx context.Context) ([]llm.Model, error) { return a.ListModels(ctx) } +// Return a model +func (a *Agent) Model(ctx context.Context, name string) llm.Model { + model, err := a.GetModel(ctx, name) + if err != nil { + panic(err) + } + return model +} + // Return the models from list of agents func (a *Agent) ListModels(ctx context.Context, names ...string) ([]llm.Model, error) { var result error diff --git a/pkg/agent/opt.go b/pkg/agent/opt.go index a316881..36fe662 100644 --- a/pkg/agent/opt.go +++ b/pkg/agent/opt.go @@ -4,7 +4,7 @@ import ( // Packages client "github.com/mutablelogic/go-client" llm "github.com/mutablelogic/go-llm" - anthropic "github.com/mutablelogic/go-llm/pkg/anthropic" + mistral "github.com/mutablelogic/go-llm/pkg/mistral" ollama "github.com/mutablelogic/go-llm/pkg/ollama" ) @@ -22,9 +22,21 @@ func WithOllama(endpoint string, opts ...client.ClientOpt) llm.Opt { } } -func WithAnthropic(key string, opts ...client.ClientOpt) llm.Opt { +/* + func WithAnthropic(key string, opts ...client.ClientOpt) llm.Opt { + return func(o *llm.Opts) error { + client, err := anthropic.New(key, opts...) + if err != nil { + return err + } else { + return llm.WithAgent(client)(o) + } + } + } +*/ +func WithMistral(key string, opts ...client.ClientOpt) llm.Opt { return func(o *llm.Opts) error { - client, err := anthropic.New(key, opts...) + client, err := mistral.New(key, opts...) if err != nil { return err } else { diff --git a/pkg/anthropic/client.go b/pkg/anthropic/client.go index 6f4a13b..8bb617a 100644 --- a/pkg/anthropic/client.go +++ b/pkg/anthropic/client.go @@ -14,6 +14,7 @@ import ( type Client struct { *client.Client + cache map[string]llm.Model } var _ llm.Agent = (*Client)(nil) @@ -41,7 +42,10 @@ func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) { } // Return the client - return &Client{client}, nil + return &Client{ + Client: client, + cache: make(map[string]llm.Model), + }, nil } /////////////////////////////////////////////////////////////////////////////// diff --git a/pkg/anthropic/model.go b/pkg/anthropic/model.go index 288cb47..04baa49 100644 --- a/pkg/anthropic/model.go +++ b/pkg/anthropic/model.go @@ -34,7 +34,38 @@ type ModelMeta struct { // Agent interface func (anthropic *Client) Models(ctx context.Context) ([]llm.Model, error) { - return anthropic.ListModels(ctx) + // Cache models + if len(anthropic.cache) == 0 { + models, err := anthropic.ListModels(ctx) + if err != nil { + return nil, err + } + for _, model := range models { + name := model.Name() + anthropic.cache[name] = model + } + } + + // Return models + result := make([]llm.Model, 0, len(anthropic.cache)) + for _, model := range anthropic.cache { + result = append(result, model) + } + return result, nil +} + +// Agent interface +func (anthropic *Client) Model(ctx context.Context, model string) llm.Model { + // Cache models + if len(anthropic.cache) == 0 { + _, err := anthropic.Models(ctx) + if err != nil { + panic(err) + } + } + + // Return model + return anthropic.cache[model] } // Get a model by name diff --git a/pkg/anthropic/opt.go b/pkg/anthropic/opt.go index 3f64a42..5461b59 100644 --- a/pkg/anthropic/opt.go +++ b/pkg/anthropic/opt.go @@ -17,13 +17,6 @@ type optmetadata struct { //////////////////////////////////////////////////////////////////////////////// // OPTIONS -func WithMaxTokens(v uint) llm.Opt { - return func(o *llm.Opts) error { - o.Set("max_tokens", v) - return nil - } -} - func WithUser(v string) llm.Opt { return func(o *llm.Opts) error { o.Set("user", v) @@ -31,13 +24,6 @@ func WithUser(v string) llm.Opt { } } -func WithStopSequences(v ...string) llm.Opt { - return func(o *llm.Opts) error { - o.Set("stop", v) - return nil - } -} - func WithEphemeral() llm.Opt { return func(o *llm.Opts) error { o.Set("ephemeral", true) diff --git a/pkg/anthropic/session_test.go b/pkg/anthropic/session_test.go index 78c01de..e27c078 100644 --- a/pkg/anthropic/session_test.go +++ b/pkg/anthropic/session_test.go @@ -83,9 +83,10 @@ func Test_session_002(t *testing.T) { t.FailNow() } - err := toolkit.Run(context.TODO(), session.ToolCalls()...) + result, err := toolkit.Run(context.TODO(), session.ToolCalls()...) if !assert.NoError(err) { t.FailNow() } + assert.NotEmpty(result) }) } diff --git a/pkg/mistral/chat_completion.go b/pkg/mistral/chat_completion.go new file mode 100644 index 0000000..b0e4bb0 --- /dev/null +++ b/pkg/mistral/chat_completion.go @@ -0,0 +1,200 @@ +package mistral + +import ( + "context" + "encoding/json" + "strings" + + "github.com/mutablelogic/go-client" + "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Chat Completion Response +type Response struct { + Id string `json:"id"` + Type string `json:"object"` + Created uint64 `json:"created"` + Model string `json:"model"` + Completions `json:"choices"` + Metrics `json:"usage,omitempty"` +} + +// Metrics +type Metrics struct { + InputTokens uint64 `json:"prompt_tokens,omitempty"` + OutputTokens uint `json:"completion_tokens,omitempty"` + TotalTokens uint `json:"total_tokens,omitempty"` +} + +var _ llm.Completion = (*Response)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (r Response) String() string { + data, err := json.MarshalIndent(r, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +type reqChatCompletion struct { + Model string `json:"model"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens uint64 `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + StopSequences []string `json:"stop,omitempty"` + Seed uint64 `json:"random_seed,omitempty"` + Messages []*Message `json:"messages"` + Format any `json:"response_format,omitempty"` + Tools []llm.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + NumChoices uint64 `json:"n,omitempty"` + Prediction *Content `json:"prediction,omitempty"` + SafePrompt bool `json:"safe_prompt,omitempty"` +} + +func (mistral *Client) ChatCompletion(ctx context.Context, context llm.Context, opts ...llm.Opt) (*Response, error) { + // Apply options + opt, err := llm.ApplyOpts(opts...) + if err != nil { + return nil, err + } + + // Append the system prompt at the beginning + messages := make([]*Message, 0, len(context.(*session).seq)+1) + if system := opt.SystemPrompt(); system != "" { + messages = append(messages, systemPrompt(system)) + } + + // Always append the first message of each completion + for _, message := range context.(*session).seq { + messages = append(messages, message) + } + + // Request + req, err := client.NewJSONRequest(reqChatCompletion{ + Model: context.(*session).model.Name(), + Temperature: optTemperature(opt), + TopP: optTopP(opt), + MaxTokens: optMaxTokens(opt), + Stream: optStream(opt), + StopSequences: optStopSequences(opt), + Seed: optSeed(opt), + Messages: messages, + Format: optFormat(opt), + Tools: optTools(mistral, opt), + ToolChoice: optToolChoice(opt), + PresencePenalty: optPresencePenalty(opt), + FrequencyPenalty: optFrequencyPenalty(opt), + NumChoices: optNumCompletions(opt), + Prediction: optPrediction(opt), + SafePrompt: optSafePrompt(opt), + }) + if err != nil { + return nil, err + } + + var response Response + reqopts := []client.RequestOpt{ + client.OptPath("chat", "completions"), + } + if optStream(opt) { + reqopts = append(reqopts, client.OptTextStreamCallback(func(evt client.TextStreamEvent) error { + if err := streamEvent(&response, evt); err != nil { + return err + } + if fn := opt.StreamFn(); fn != nil { + fn(&response) + } + return nil + })) + } + + // Response + if err := mistral.DoWithContext(ctx, req, &response, reqopts...); err != nil { + return nil, err + } + + // Return success + return &response, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func streamEvent(response *Response, evt client.TextStreamEvent) error { + var delta Response + // If we are done, ignore + if strings.TrimSpace(evt.Data) == "[DONE]" { + return nil + } + // Decode the event + if err := evt.Json(&delta); err != nil { + return err + } + // Append the delta to the response + if delta.Id != "" { + response.Id = delta.Id + } + if delta.Created != 0 { + response.Created = delta.Created + } + if delta.Model != "" { + response.Model = delta.Model + } + for _, completion := range delta.Completions { + appendCompletion(response, &completion) + } + if delta.Metrics.InputTokens > 0 { + response.Metrics.InputTokens += delta.Metrics.InputTokens + } + if delta.Metrics.OutputTokens > 0 { + response.Metrics.OutputTokens += delta.Metrics.OutputTokens + } + if delta.Metrics.TotalTokens > 0 { + response.Metrics.TotalTokens += delta.Metrics.TotalTokens + } + return nil +} + +func appendCompletion(response *Response, c *Completion) { + for { + if c.Index < uint64(len(response.Completions)) { + break + } + response.Completions = append(response.Completions, Completion{ + Index: c.Index, + Message: &Message{ + RoleContent: RoleContent{ + Role: c.Delta.Role(), + Content: "", + }, + }, + }) + } + // Add the completion delta + if c.Reason != "" { + response.Completions[c.Index].Reason = c.Reason + } + if role := c.Delta.Role(); role != "" { + response.Completions[c.Index].Message.RoleContent.Role = role + } + + // TODO: We only allow deltas which are strings at the moment... + if str, ok := c.Delta.Content.(string); ok && str != "" { + if text, ok := response.Completions[c.Index].Message.Content.(string); ok { + response.Completions[c.Index].Message.Content = text + str + } + } +} diff --git a/pkg/mistral/chat_completion_test.go b/pkg/mistral/chat_completion_test.go new file mode 100644 index 0000000..41afb88 --- /dev/null +++ b/pkg/mistral/chat_completion_test.go @@ -0,0 +1,239 @@ +package mistral_test + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + + // Packages + + "github.com/mutablelogic/go-llm" + mistral "github.com/mutablelogic/go-llm/pkg/mistral" + "github.com/mutablelogic/go-llm/pkg/tool" + assert "github.com/stretchr/testify/assert" +) + +func Test_chat_001(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "mistral-small-latest") + + if assert.NotNil(model) { + response, err := client.ChatCompletion(context.TODO(), model.UserPrompt("Hello, how are you?")) + assert.NoError(err) + assert.NotEmpty(response) + t.Log(response) + } +} + +func Test_chat_002(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "mistral-large-latest") + if !assert.NotNil(model) { + t.FailNow() + } + + t.Run("Temperature", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithTemperature(0.5)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("TopP", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithTopP(0.5)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("MaxTokens", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithMaxTokens(10)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("Stream", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithNumCompletions(2), llm.WithStream(func(r llm.Completion) { + t.Log(r.Role(), "=>", r.Text(0)) + })) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(2, r.Num()) + assert.NotEmpty(r.Text(0)) + assert.NotEmpty(r.Text(1)) + t.Log(r) + } + }) + t.Run("Stop", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithStopSequence("STOP")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("System", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithSystemPrompt("You are shakespearian")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("Seed", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithSeed(123)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("Format", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithFormat("json_object"), llm.WithSystemPrompt("Return a JSON object")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("ToolChoiceAuto", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithToolChoice("auto")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("ToolChoiceFunc", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithToolChoice("get_weather")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("PresencePenalty", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), mistral.WithPresencePenalty(-2)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("FrequencyPenalty", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), mistral.WithFrequencyPenalty(-2)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("NumChoices", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithNumCompletions(3)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(3, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("Prediction", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), mistral.WithPrediction("The temperature in London today is")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("SafePrompt", func(t *testing.T) { + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithSafePrompt()) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) +} + +func Test_chat_003(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "pixtral-12b-2409") + if !assert.NotNil(model) { + t.FailNow() + } + + f, err := os.Open("testdata/guggenheim.jpg") + if !assert.NoError(err) { + t.FailNow() + } + defer f.Close() + + // Describe an image + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("Provide a short caption for this image", llm.WithAttachment(f))) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r.Text(0)) + } +} + +func Test_chat_004(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "mistral-small-latest") + if !assert.NotNil(model) { + t.FailNow() + } + + toolkit := tool.NewToolKit() + toolkit.Register(&weather{}) + + // Get the weather for a city + r, err := client.ChatCompletion(context.TODO(), model.UserPrompt("What is the weather in the capital city of germany?"), llm.WithToolKit(toolkit)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + + calls := r.ToolCalls(0) + assert.NotEmpty(calls) + + var w weather + assert.NoError(calls[0].Decode(&w)) + assert.Equal("berlin", strings.ToLower(w.City)) + } +} + +type weather struct { + City string `json:"city" help:"The city to get the weather for"` +} + +func (weather) Name() string { + return "weather_in_city" +} + +func (weather) Description() string { + return "Get the weather for a city" +} + +func (w weather) Run(ctx context.Context) (any, error) { + return fmt.Sprintf("The weather in %q is sunny and warm", w.City), nil +} diff --git a/pkg/mistral/client.go b/pkg/mistral/client.go new file mode 100644 index 0000000..7e643dc --- /dev/null +++ b/pkg/mistral/client.go @@ -0,0 +1,91 @@ +/* +mistral implements an API client for mistral (https://docs.mistral.ai/api/) +*/ +package mistral + +import ( + // Packages + "context" + + "github.com/mutablelogic/go-client" + "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Client struct { + *client.Client + cache map[string]llm.Model +} + +var _ llm.Agent = (*Client)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + endPoint = "https://api.mistral.ai/v1" + defaultName = "mistral" +) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Create a new client +func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) { + // Create client + opts = append(opts, client.OptEndpoint(endPoint)) + opts = append(opts, client.OptReqToken(client.Token{ + Scheme: client.Bearer, + Value: ApiKey, + })) + client, err := client.New(opts...) + if err != nil { + return nil, err + } + + // Return the client + return &Client{client, nil}, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the name of the agent +func (Client) Name() string { + return defaultName +} + +// Return the models +func (c *Client) Models(ctx context.Context) ([]llm.Model, error) { + // Cache models + if c.cache == nil { + models, err := c.ListModels(ctx) + if err != nil { + return nil, err + } + c.cache = make(map[string]llm.Model, len(models)) + for _, model := range models { + c.cache[model.Name()] = model + } + } + + // Return models + result := make([]llm.Model, 0, len(c.cache)) + for _, model := range c.cache { + result = append(result, model) + } + return result, nil +} + +// Return a model by name, or nil if not found. +// Panics on error. +func (c *Client) Model(ctx context.Context, name string) llm.Model { + if c.cache == nil { + if _, err := c.Models(ctx); err != nil { + panic(err) + } + } + return c.cache[name] +} diff --git a/pkg/mistral/client_test.go b/pkg/mistral/client_test.go new file mode 100644 index 0000000..93fed56 --- /dev/null +++ b/pkg/mistral/client_test.go @@ -0,0 +1,58 @@ +package mistral_test + +import ( + "flag" + "log" + "os" + "strconv" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + mistral "github.com/mutablelogic/go-llm/pkg/mistral" + assert "github.com/stretchr/testify/assert" +) + +/////////////////////////////////////////////////////////////////////////////// +// TEST SET-UP + +var ( + client *mistral.Client +) + +func TestMain(m *testing.M) { + var verbose bool + + // Verbose output + flag.Parse() + if f := flag.Lookup("test.v"); f != nil { + if v, err := strconv.ParseBool(f.Value.String()); err == nil { + verbose = v + } + } + + // API KEY + api_key := os.Getenv("MISTRAL_API_KEY") + if api_key == "" { + log.Print("MISTRAL_API_KEY not set") + os.Exit(0) + } + + // Create client + var err error + client, err = mistral.New(api_key, opts.OptTrace(os.Stderr, verbose)) + if err != nil { + log.Println(err) + os.Exit(-1) + } + os.Exit(m.Run()) +} + +/////////////////////////////////////////////////////////////////////////////// +// TESTS + +func Test_client_001(t *testing.T) { + assert := assert.New(t) + assert.NotNil(client) + t.Log(client) +} diff --git a/pkg/mistral/embeddings.go b/pkg/mistral/embeddings.go new file mode 100644 index 0000000..48555f3 --- /dev/null +++ b/pkg/mistral/embeddings.go @@ -0,0 +1,101 @@ +package mistral + +import ( + "context" + "encoding/json" + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// embeddings is the implementation of the llm.Embedding interface +type embeddings struct { + Embeddings +} + +// Embeddings is the metadata for a generated embedding vector +type Embeddings struct { + Id string `json:"id"` + Type string `json:"object"` + Model string `json:"model"` + Data []Embedding `json:"data"` + Metrics +} + +// Embedding is a single vector +type Embedding struct { + Type string `json:"object"` + Index uint64 `json:"index"` + Vector []float64 `json:"embedding"` +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (m Embedding) MarshalJSON() ([]byte, error) { + return json.Marshal(m.Vector) +} + +func (m embeddings) MarshalJSON() ([]byte, error) { + return json.Marshal(m.Embeddings) +} + +func (m embeddings) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +type reqEmbedding struct { + Model string `json:"model"` + Input []string `json:"input"` + Format string `json:"encoding_format,omitempty"` +} + +func (mistral *Client) GenerateEmbedding(ctx context.Context, name string, prompt []string, _ ...llm.Opt) (*embeddings, error) { + // Options are currently ignored + + // Bail out is no prompt + if len(prompt) == 0 { + return nil, llm.ErrBadParameter.With("missing prompt") + } + + // Request + req, err := client.NewJSONRequest(reqEmbedding{ + Model: name, + Input: prompt, + }) + if err != nil { + return nil, err + } + + // Response + var response embeddings + if err := mistral.DoWithContext(ctx, req, &response, client.OptPath("embeddings")); err != nil { + return nil, err + } + + // Return success + return &response, nil +} + +// Generate one vector +func (model *model) Embedding(ctx context.Context, prompt string, opts ...llm.Opt) ([]float64, error) { + response, err := model.GenerateEmbedding(ctx, model.Name(), []string{prompt}, opts...) + if err != nil { + return nil, err + } + if len(response.Embeddings.Data) == 0 { + return nil, llm.ErrNotFound.With("no embeddings returned") + } + return response.Embeddings.Data[0].Vector, nil +} diff --git a/pkg/mistral/embeddings_test.go b/pkg/mistral/embeddings_test.go new file mode 100644 index 0000000..bc09454 --- /dev/null +++ b/pkg/mistral/embeddings_test.go @@ -0,0 +1,20 @@ +package mistral_test + +import ( + "context" + "testing" + + // Packages + assert "github.com/stretchr/testify/assert" +) + +func Test_embeddings_001(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "mistral-embed") + if assert.NotNil(model) { + response, err := model.Embedding(context.TODO(), "Hello, how are you?") + assert.NoError(err) + assert.NotEmpty(response) + t.Log(response) + } +} diff --git a/pkg/mistral/message.go b/pkg/mistral/message.go new file mode 100644 index 0000000..6300b9e --- /dev/null +++ b/pkg/mistral/message.go @@ -0,0 +1,175 @@ +package mistral + +import ( + "encoding/json" + + // Packages + "github.com/mutablelogic/go-llm" + "github.com/mutablelogic/go-llm/pkg/tool" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Possible completions +type Completions []Completion + +var _ llm.Completion = Completions{} + +// Message with text or object content +type Message struct { + RoleContent + ToolCallArray `json:"tool_calls,omitempty"` +} + +type RoleContent struct { + Role string `json:"role,omitempty"` // assistant, user, tool, system + Content any `json:"content,omitempty"` // string or array of text, reference, image_url + Id string `json:"tool_call_id,omitempty"` // tool call - when role is tool + Name string `json:"name,omitempty"` // function name - when role is tool +} + +// Completion Variation +type Completion struct { + Index uint64 `json:"index"` + Message *Message `json:"message"` + Delta *Message `json:"delta,omitempty"` // For streaming + Reason string `json:"finish_reason,omitempty"` +} + +var _ llm.Completion = (*Message)(nil) + +type Content struct { + Type string `json:"type,omitempty"` // text, reference, image_url + *Text `json:"text,omitempty"` // text content + *Prediction `json:"content,omitempty"` // prediction + *Image `json:"image_url,omitempty"` // image_url +} + +// A set of tool calls +type ToolCallArray []ToolCall + +// text content +type Text string + +// text content +type Prediction string + +// either a URL or "data:image/png;base64," followed by the base64 encoded image +type Image string + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Return a Content object with text content (either in "text" or "prediction" field) +func NewContent(t, v, p string) *Content { + content := new(Content) + content.Type = t + if v != "" { + content.Text = (*Text)(&v) + } + if p != "" { + content.Prediction = (*Prediction)(&p) + } + return content +} + +// Return a Content object with text content +func NewTextContent(v string) *Content { + return NewContent("text", v, "") +} + +// Return an image attachment +func NewImageAttachment(a *llm.Attachment) *Content { + content := new(Content) + image := a.Url() + content.Type = "image_url" + content.Image = (*Image)(&image) + return content +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - MESSAGE + +func (m Message) Num() int { + return 1 +} + +func (m Message) Role() string { + return m.RoleContent.Role +} + +func (m Message) Text(index int) string { + if index != 0 { + return "" + } + // If content is text, return it + if text, ok := m.Content.(string); ok { + return text + } + // For other kinds, return empty string for the moment + return "" +} + +func (m Message) ToolCalls(index int) []llm.ToolCall { + if index != 0 { + return nil + } + + // Make the tool calls + calls := make([]llm.ToolCall, 0, len(m.ToolCallArray)) + for _, call := range m.ToolCallArray { + var args map[string]any + if call.Function.Arguments != "" { + if err := json.Unmarshal([]byte(call.Function.Arguments), &args); err != nil { + return nil + } + } + calls = append(calls, tool.NewCall(call.Id, call.Function.Name, args)) + } + + // Return success + return calls +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - COMPLETIONS + +// Return the number of completions +func (c Completions) Num() int { + return len(c) +} + +// Return message for a specific completion +func (c Completions) Message(index int) *Message { + if index < 0 || index >= len(c) { + return nil + } + return c[index].Message +} + +// Return the role of the completion +func (c Completions) Role() string { + // The role should be the same for all completions, let's use the first one + if len(c) == 0 { + return "" + } + return c[0].Message.Role() +} + +// Return the text content for a specific completion +func (c Completions) Text(index int) string { + if index < 0 || index >= len(c) { + return "" + } + return c[index].Message.Text(0) +} + +// Return the current session tool calls given the completion index. +// Will return nil if no tool calls were returned. +func (c Completions) ToolCalls(index int) []llm.ToolCall { + if index < 0 || index >= len(c) { + return nil + } + return c[index].Message.ToolCalls(0) +} diff --git a/pkg/mistral/model.go b/pkg/mistral/model.go new file mode 100644 index 0000000..24420b6 --- /dev/null +++ b/pkg/mistral/model.go @@ -0,0 +1,82 @@ +package mistral + +import ( + "context" + "encoding/json" + + "github.com/mutablelogic/go-client" + "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type model struct { + *Client `json:"-"` + meta Model +} + +type Model struct { + Name string `json:"id"` + Description string `json:"description,omitempty"` + Type string `json:"type,omitempty"` + CreatedAt *uint64 `json:"created,omitempty"` + OwnedBy string `json:"owned_by,omitempty"` + MaxContextLength uint64 `json:"max_context_length,omitempty"` + Aliases []string `json:"aliases,omitempty"` + Deprecation *string `json:"deprecation,omitempty"` + DefaultModelTemperature *float64 `json:"default_model_temperature,omitempty"` + Capabilities struct { + CompletionChat bool `json:"completion_chat,omitempty"` + CompletionFim bool `json:"completion_fim,omitempty"` + FunctionCalling bool `json:"function_calling,omitempty"` + FineTuning bool `json:"fine_tuning,omitempty"` + Vision bool `json:"vision,omitempty"` + } `json:"capabilities,omitempty"` +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (m model) MarshalJSON() ([]byte, error) { + return json.Marshal(m.meta) +} + +func (m model) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - API + +// ListModels returns all the models +func (c *Client) ListModels(ctx context.Context) ([]llm.Model, error) { + // Response + var response struct { + Data []Model `json:"data"` + } + if err := c.DoWithContext(ctx, nil, &response, client.OptPath("models")); err != nil { + return nil, err + } + + // Make models + result := make([]llm.Model, 0, len(response.Data)) + for _, meta := range response.Data { + result = append(result, &model{c, meta}) + } + + // Return models + return result, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - MODEL + +// Return the name of the model +func (m model) Name() string { + return m.meta.Name +} diff --git a/pkg/mistral/model_test.go b/pkg/mistral/model_test.go new file mode 100644 index 0000000..812be6e --- /dev/null +++ b/pkg/mistral/model_test.go @@ -0,0 +1,23 @@ +package mistral_test + +import ( + "context" + "encoding/json" + "testing" + + // Packages + + assert "github.com/stretchr/testify/assert" +) + +func Test_models_001(t *testing.T) { + assert := assert.New(t) + + response, err := client.ListModels(context.TODO()) + assert.NoError(err) + assert.NotEmpty(response) + + data, err := json.MarshalIndent(response, "", " ") + assert.NoError(err) + t.Log(string(data)) +} diff --git a/pkg/mistral/opt.go b/pkg/mistral/opt.go new file mode 100644 index 0000000..71e82da --- /dev/null +++ b/pkg/mistral/opt.go @@ -0,0 +1,120 @@ +package mistral + +import ( + "strings" + + "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func WithPrediction(v string) llm.Opt { + return func(o *llm.Opts) error { + o.Set("prediction", v) + return nil + } +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func optTemperature(opts *llm.Opts) float64 { + return opts.GetFloat64("temperature") +} + +func optTopP(opts *llm.Opts) float64 { + return opts.GetFloat64("top_p") +} + +func optMaxTokens(opts *llm.Opts) uint64 { + return opts.GetUint64("max_tokens") +} + +func optStream(opts *llm.Opts) bool { + return opts.StreamFn() != nil +} + +func optStopSequences(opts *llm.Opts) []string { + if opts.Has("stop") { + if stop, ok := opts.Get("stop").([]string); ok { + return stop + } + } + return nil +} + +func optSeed(opts *llm.Opts) uint64 { + return opts.GetUint64("seed") +} + +func optFormat(opts *llm.Opts) any { + var fmt struct { + Type string `json:"type"` + } + format := opts.GetString("format") + if format == "" { + return nil + } else { + fmt.Type = format + } + return fmt +} + +func optTools(agent llm.Agent, opts *llm.Opts) []llm.Tool { + toolkit := opts.ToolKit() + if toolkit == nil { + return nil + } + return toolkit.Tools(agent) +} + +func optToolChoice(opts *llm.Opts) any { + choices, ok := opts.Get("tool_choice").([]string) + if !ok || len(choices) == 0 { + return nil + } + + // We only support one choice + choice := strings.TrimSpace(strings.ToLower(choices[0])) + switch choice { + case "auto", "none", "any", "required": + return choice + case "": + return nil + default: + var fn struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + } `json:"function"` + } + fn.Type = "function" + fn.Function.Name = choice + return fn + } +} + +func optPresencePenalty(opts *llm.Opts) float64 { + return opts.GetFloat64("presence_penalty") +} + +func optFrequencyPenalty(opts *llm.Opts) float64 { + return opts.GetFloat64("frequency_penalty") +} + +func optNumCompletions(opts *llm.Opts) uint64 { + return opts.GetUint64("num_completions") +} + +func optPrediction(opts *llm.Opts) *Content { + prediction := strings.TrimSpace(opts.GetString("prediction")) + if prediction == "" { + return nil + } + return NewContent("content", "", prediction) +} + +func optSafePrompt(opts *llm.Opts) bool { + return opts.GetBool("safe_prompt") +} diff --git a/pkg/mistral/session.go b/pkg/mistral/session.go new file mode 100644 index 0000000..3b50539 --- /dev/null +++ b/pkg/mistral/session.go @@ -0,0 +1,219 @@ +package mistral + +import ( + "context" + "encoding/json" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +////////////////////////////////////////////////////////////////// +// TYPES + +type session struct { + model *model // The model used for the session + opts []llm.Opt // Options to apply to the session + seq []*Message // Sequence of messages +} + +var _ llm.Context = (*session)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Return an empty session context object for the model, setting session options +func (model *model) Context(opts ...llm.Opt) llm.Context { + return &session{ + model: model, + opts: opts, + seq: make([]*Message, 0, 10), + } +} + +// Convenience method to create a session context object with a user prompt, which +// panics on error +func (model *model) UserPrompt(prompt string, opts ...llm.Opt) llm.Context { + context := model.Context(opts...) + + // Create a user prompt + message, err := userPrompt(prompt, opts...) + if err != nil { + panic(err) + } + + // Add to the sequence + context.(*session).seq = append(context.(*session).seq, message) + + // Return success + return context +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (session session) String() string { + var data []byte + var err error + if len(session.seq) == 1 { + data, err = json.MarshalIndent(session.seq[0], "", " ") + } else { + data, err = json.MarshalIndent(session.seq, "", " ") + } + if err != nil { + return err.Error() + } + return string(data) +} + +////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the number of completions +func (session *session) Num() int { + if len(session.seq) == 0 { + return 0 + } + return 1 +} + +// Return the role of the last message +func (session *session) Role() string { + if len(session.seq) == 0 { + return "" + } + return session.seq[len(session.seq)-1].Role() +} + +// Return the text of the last message +func (session *session) Text(index int) string { + if len(session.seq) == 0 { + return "" + } + return session.seq[len(session.seq)-1].Text(index) +} + +// Return tool calls for the last message +func (session *session) ToolCalls(index int) []llm.ToolCall { + if len(session.seq) == 0 { + return nil + } + return session.seq[len(session.seq)-1].ToolCalls(index) +} + +// Generate a response from a user prompt (with attachments and +// other options) +func (session *session) FromUser(ctx context.Context, prompt string, opts ...llm.Opt) error { + message, err := userPrompt(prompt, opts...) + if err != nil { + return err + } + + // Append the user prompt to the sequence + session.seq = append(session.seq, message) + + // The options come from the session options and the user options + chatopts := make([]llm.Opt, 0, len(session.opts)+len(opts)) + chatopts = append(chatopts, session.opts...) + chatopts = append(chatopts, opts...) + + // Call the 'chat completion' method + r, err := session.model.ChatCompletion(ctx, session, chatopts...) + if err != nil { + return err + } + + // Append the first message from the set of completions + session.seq = append(session.seq, r.Completions.Message(0)) + + // Return success + return nil +} + +// Generate a response from a tool, passing the results from the tool call +func (session *session) FromTool(ctx context.Context, results ...llm.ToolResult) error { + messages, err := toolResults(results...) + if err != nil { + return err + } + + // Append the tool results to the sequence + session.seq = append(session.seq, messages...) + + // Call the 'chat' method + r, err := session.model.ChatCompletion(ctx, session, session.opts...) + if err != nil { + return err + } + + // Append the first message from the set of completions + session.seq = append(session.seq, r.Completions.Message(0)) + + // Return success + return nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func systemPrompt(prompt string) *Message { + return &Message{ + RoleContent: RoleContent{ + Role: "system", + Content: prompt, + }, + } +} + +func userPrompt(prompt string, opts ...llm.Opt) (*Message, error) { + // Get attachments + opt, err := llm.ApplyPromptOpts(opts...) + if err != nil { + return nil, err + } + + // Get attachments, allocate content + attachments := opt.Attachments() + content := make([]*Content, 1, len(attachments)+1) + + // Append the text and the attachments + content[0] = NewTextContent(prompt) + for _, attachment := range attachments { + content = append(content, NewImageAttachment(attachment)) + } + + // Return success + return &Message{ + RoleContent: RoleContent{ + Role: "user", + Content: content, + }, + }, nil +} + +func toolResults(results ...llm.ToolResult) ([]*Message, error) { + // Check for no results + if len(results) == 0 { + return nil, llm.ErrBadParameter.Withf("No tool results") + } + + // Create results + messages := make([]*Message, 0, len(results)) + for _, result := range results { + value, err := json.Marshal(result.Value()) + if err != nil { + return nil, err + } + messages = append(messages, &Message{ + RoleContent: RoleContent{ + Role: "tool", + Id: result.Call().Id(), + Name: result.Call().Name(), + Content: string(value), + }, + }) + } + + // Return success + return messages, nil +} diff --git a/pkg/mistral/session_test.go b/pkg/mistral/session_test.go new file mode 100644 index 0000000..7fbcaa3 --- /dev/null +++ b/pkg/mistral/session_test.go @@ -0,0 +1,56 @@ +package mistral_test + +import ( + "context" + "testing" + + // Packages + llm "github.com/mutablelogic/go-llm" + tool "github.com/mutablelogic/go-llm/pkg/tool" + assert "github.com/stretchr/testify/assert" +) + +func Test_session_001(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "mistral-small-latest") + if !assert.NotNil(model) { + t.FailNow() + } + + session := model.Context() + if assert.NotNil(session) { + err := session.FromUser(context.TODO(), "Hello, how are you?") + assert.NoError(err) + t.Log(session) + } +} + +func Test_session_002(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "mistral-small-latest") + if !assert.NotNil(model) { + t.FailNow() + } + + toolkit := tool.NewToolKit() + toolkit.Register(&weather{}) + + session := model.Context(llm.WithToolKit(toolkit)) + if !assert.NotNil(session) { + t.FailNow() + } + + assert.NoError(session.FromUser(context.TODO(), "What is the weather like in London today?")) + calls := session.ToolCalls(0) + if assert.Len(calls, 1) { + assert.Equal("weather_in_city", calls[0].Name()) + + result, err := toolkit.Run(context.TODO(), calls...) + assert.NoError(err) + assert.Len(result, 1) + + assert.NoError(session.FromTool(context.TODO(), result...)) + } + + t.Log(session) +} diff --git a/pkg/mistral/testdata/LICENSE b/pkg/mistral/testdata/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/pkg/mistral/testdata/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/pkg/mistral/testdata/guggenheim.jpg b/pkg/mistral/testdata/guggenheim.jpg new file mode 100644 index 0000000..7e16517 Binary files /dev/null and b/pkg/mistral/testdata/guggenheim.jpg differ diff --git a/pkg/mistral/tool.go b/pkg/mistral/tool.go new file mode 100644 index 0000000..255146e --- /dev/null +++ b/pkg/mistral/tool.go @@ -0,0 +1,36 @@ +package mistral + +import ( + "encoding/json" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type ToolCall struct { + Id string `json:"id,omitempty"` // tool id + Index uint64 `json:"index,omitempty"` // tool index + Function struct { + Name string `json:"name,omitempty"` // tool name + Arguments string `json:"arguments,omitempty"` // tool arguments + } `json:"function"` +} + +type toolcall struct { + meta ToolCall +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (t toolcall) MarshalJSON() ([]byte, error) { + return json.Marshal(t.meta) +} + +func (t toolcall) String() string { + data, err := json.MarshalIndent(t, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} diff --git a/pkg/ollama/chat.go b/pkg/ollama/chat.go index 14bf5f5..175e0b5 100644 --- a/pkg/ollama/chat.go +++ b/pkg/ollama/chat.go @@ -13,13 +13,13 @@ import ( /////////////////////////////////////////////////////////////////////////////// // TYPES -// Chat Response +// Chat Completion Response type Response struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Message MessageMeta `json:"message"` - Done bool `json:"done"` - Reason string `json:"done_reason,omitempty"` + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Done bool `json:"done"` + Reason string `json:"done_reason,omitempty"` + Message `json:"message"` Metrics } @@ -33,6 +33,8 @@ type Metrics struct { EvalDuration time.Duration `json:"eval_duration,omitempty"` } +var _ llm.Completion = (*Response)(nil) + /////////////////////////////////////////////////////////////////////////////// // STRINGIFY @@ -49,34 +51,36 @@ func (r Response) String() string { type reqChat struct { Model string `json:"model"` - Messages []*MessageMeta `json:"messages"` - Tools []ToolFunction `json:"tools,omitempty"` + Messages []*Message `json:"messages"` + Tools []llm.Tool `json:"tools,omitempty"` Format string `json:"format,omitempty"` Options map[string]interface{} `json:"options,omitempty"` Stream bool `json:"stream"` KeepAlive *time.Duration `json:"keep_alive,omitempty"` } -func (ollama *Client) Chat(ctx context.Context, prompt llm.Context, opts ...llm.Opt) (*Response, error) { +func (ollama *Client) Chat(ctx context.Context, context llm.Context, opts ...llm.Opt) (*Response, error) { + // Apply options opt, err := llm.ApplyOpts(opts...) if err != nil { return nil, err } // Append the system prompt at the beginning - seq := make([]*MessageMeta, 0, len(prompt.(*session).seq)+1) + messages := make([]*Message, 0, len(context.(*session).seq)+1) if system := opt.SystemPrompt(); system != "" { - seq = append(seq, &MessageMeta{ - Role: "system", - Content: opt.SystemPrompt(), - }) + messages = append(messages, systemPrompt(system)) + } + + // Always append the first message of each completion + for _, message := range context.(*session).seq { + messages = append(messages, message) } - seq = append(seq, prompt.(*session).seq...) // Request req, err := client.NewJSONRequest(reqChat{ - Model: prompt.(*session).model.Name(), - Messages: seq, + Model: context.(*session).model.Name(), + Messages: messages, Tools: optTools(ollama, opt), Format: optFormat(opt), Options: optOptions(opt), @@ -89,33 +93,29 @@ func (ollama *Client) Chat(ctx context.Context, prompt llm.Context, opts ...llm. // Response var response, delta Response - if err := ollama.DoWithContext(ctx, req, &delta, client.OptPath("chat"), client.OptJsonStreamCallback(func(v any) error { - if v, ok := v.(*Response); !ok || v == nil { - return llm.ErrConflict.Withf("Invalid stream response: %v", v) - } else { - response.Model = v.Model - response.CreatedAt = v.CreatedAt - response.Message.Role = v.Message.Role - response.Message.Content += v.Message.Content - if v.Done { - response.Done = v.Done - response.Metrics = v.Metrics - response.Reason = v.Reason + reqopts := []client.RequestOpt{ + client.OptPath("chat"), + } + if optStream(ollama, opt) { + reqopts = append(reqopts, client.OptJsonStreamCallback(func(v any) error { + if v, ok := v.(*Response); !ok || v == nil { + return llm.ErrConflict.Withf("Invalid stream response: %v", v) + } else if err := streamEvent(&response, v); err != nil { + return err } - } - - //Call the chat callback - if optStream(ollama, opt) { if fn := opt.StreamFn(); fn != nil { fn(&response) } - } - return nil - })); err != nil { + return nil + })) + } + + // Response + if err := ollama.DoWithContext(ctx, req, &delta, reqopts...); err != nil { return nil, err } - // We return the delta or the response + // Return success if optStream(ollama, opt) { return &response, nil } else { @@ -124,16 +124,25 @@ func (ollama *Client) Chat(ctx context.Context, prompt llm.Context, opts ...llm. } /////////////////////////////////////////////////////////////////////////////// -// INTERFACE - CONTEXT CONTENT +// PRIVATE METHODS -func (response Response) Role() string { - return response.Message.Role -} - -func (response Response) Text() string { - return response.Message.Content -} - -func (response Response) ToolCalls() []llm.ToolCall { +func streamEvent(response, delta *Response) error { + if delta.Model != "" { + response.Model = delta.Model + } + if !delta.CreatedAt.IsZero() { + response.CreatedAt = delta.CreatedAt + } + if delta.Message.RoleContent.Role != "" { + response.Message.RoleContent.Role = delta.Message.RoleContent.Role + } + if delta.Message.RoleContent.Content != "" { + response.Message.RoleContent.Content += delta.Message.RoleContent.Content + } + if delta.Done { + response.Done = delta.Done + response.Metrics = delta.Metrics + response.Reason = delta.Reason + } return nil } diff --git a/pkg/ollama/chat_test.go b/pkg/ollama/chat_test.go index cbc77c5..b0ea189 100644 --- a/pkg/ollama/chat_test.go +++ b/pkg/ollama/chat_test.go @@ -2,13 +2,13 @@ package ollama_test import ( "context" - "encoding/json" - "log" + "fmt" "os" + "strings" "testing" // Packages - opts "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" ollama "github.com/mutablelogic/go-llm/pkg/ollama" tool "github.com/mutablelogic/go-llm/pkg/tool" @@ -16,11 +16,6 @@ import ( ) func Test_chat_001(t *testing.T) { - client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) - if err != nil { - t.FailNow() - } - // Pull the model model, err := client.PullModel(context.TODO(), "qwen:0.5b", ollama.WithPullStatus(func(status *ollama.PullStatus) { t.Log(status) @@ -29,87 +24,91 @@ func Test_chat_001(t *testing.T) { t.FailNow() } - t.Run("ChatStream", func(t *testing.T) { + t.Run("Temperature", func(t *testing.T) { assert := assert.New(t) - response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithStream(func(stream llm.Context) { - t.Log(stream) - })) + response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithTemperature(0.5)) if !assert.NoError(err) { t.FailNow() } t.Log(response) }) - t.Run("ChatNoStream", func(t *testing.T) { + t.Run("TopP", func(t *testing.T) { assert := assert.New(t) - response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky green?")) + response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithTopP(0.5)) + if !assert.NoError(err) { + t.FailNow() + } + t.Log(response) + }) + t.Run("TopK", func(t *testing.T) { + assert := assert.New(t) + response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithTopK(50)) if !assert.NoError(err) { t.FailNow() } t.Log(response) }) -} - -func Test_chat_002(t *testing.T) { - client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) - if err != nil { - t.FailNow() - } - - // Pull the model - model, err := client.PullModel(context.TODO(), "llama3.2:1b", ollama.WithPullStatus(func(status *ollama.PullStatus) { - t.Log(status) - })) - if err != nil { - t.FailNow() - } - // Make a toolkit - toolkit := tool.NewToolKit() - if err := toolkit.Register(new(weather)); err != nil { - t.FailNow() - } + t.Run("Stream", func(t *testing.T) { + assert := assert.New(t) + response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithStream(func(stream llm.Completion) { + t.Log(stream) + })) + if !assert.NoError(err) { + t.FailNow() + } + t.Log(response) + }) - t.Run("Tools", func(t *testing.T) { + t.Run("Stop", func(t *testing.T) { assert := assert.New(t) - response, err := client.Chat(context.TODO(), - model.UserPrompt("what is the weather in berlin?"), - llm.WithToolKit(toolkit), - ) + response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithStopSequence("sky")) if !assert.NoError(err) { t.FailNow() } t.Log(response) }) -} -func Test_chat_003(t *testing.T) { - client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, false)) - if err != nil { - t.FailNow() - } + t.Run("System", func(t *testing.T) { + assert := assert.New(t) + response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithSystemPrompt("reply as if you are shakespeare")) + if !assert.NoError(err) { + t.FailNow() + } + t.Log(response) + }) - // Pull the model - model, err := client.PullModel(context.TODO(), "llava", ollama.WithPullStatus(func(status *ollama.PullStatus) { - t.Log(status) - })) - if err != nil { - t.FailNow() - } + t.Run("Seed", func(t *testing.T) { + assert := assert.New(t) + response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithSeed(1234)) + if !assert.NoError(err) { + t.FailNow() + } + t.Log(response) + }) - // Explain the content of an image - t.Run("Image", func(t *testing.T) { + t.Run("Format", func(t *testing.T) { assert := assert.New(t) + response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue? Reply in JSON format"), llm.WithFormat("json")) + if !assert.NoError(err) { + t.FailNow() + } + t.Log(response) + }) - f, err := os.Open("testdata/guggenheim.jpg") + t.Run("PresencePenalty", func(t *testing.T) { + assert := assert.New(t) + response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?t"), llm.WithPresencePenalty(-1.0)) if !assert.NoError(err) { t.FailNow() } - defer f.Close() + t.Log(response) + }) - response, err := client.Chat(context.TODO(), - model.UserPrompt("describe this photo to me", llm.WithAttachment(f)), - ) + t.Run("FrequencyPenalty", func(t *testing.T) { + assert := assert.New(t) + response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?t"), llm.WithFrequencyPenalty(1.0)) if !assert.NoError(err) { t.FailNow() } @@ -117,30 +116,74 @@ func Test_chat_003(t *testing.T) { }) } -//////////////////////////////////////////////////////////////////////////////// -// TOOLS +func Test_chat_002(t *testing.T) { + assert := assert.New(t) + model, err := client.PullModel(context.TODO(), "llava:7b") + if !assert.NoError(err) { + t.FailNow() + } + assert.NotNil(model) -type weather struct { - Location string `json:"location" name:"location" help:"The location to get the weather for" required:"true"` + f, err := os.Open("testdata/guggenheim.jpg") + if !assert.NoError(err) { + t.FailNow() + } + defer f.Close() + + // Describe an image + r, err := client.Chat(context.TODO(), model.UserPrompt("Provide a short caption for this image", llm.WithAttachment(f))) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r.Text(0)) + } +} + +func Test_chat_003(t *testing.T) { + assert := assert.New(t) + model, err := client.PullModel(context.TODO(), "llama3.2") + if !assert.NoError(err) { + t.FailNow() + } + assert.NotNil(model) + + toolkit := tool.NewToolKit() + toolkit.Register(&weather{}) + + // Get the weather for a city + r, err := client.Chat(context.TODO(), model.UserPrompt("What is the weather in the capital city of germany?"), llm.WithToolKit(toolkit)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + + calls := r.ToolCalls(0) + assert.NotEmpty(calls) + + var w weather + assert.NoError(calls[0].Decode(&w)) + assert.Equal("berlin", strings.ToLower(w.City)) + } } -func (*weather) Name() string { - return "weather_in_location" +type weather struct { + City string `json:"city" help:"The city to get the weather for"` } -func (*weather) Description() string { - return "Get the weather in a location" +func (weather) Name() string { + return "weather_in_city" } -func (weather *weather) String() string { - data, err := json.MarshalIndent(weather, "", " ") - if err != nil { - return err.Error() - } - return string(data) +func (weather) Description() string { + return "Get the weather for a city" } -func (weather *weather) Run(ctx context.Context) (any, error) { - log.Println("weather_in_location", "=>", weather) - return "very sunny today", nil +func (w weather) Run(ctx context.Context) (any, error) { + var result struct { + City string `json:"city"` + Weather string `json:"weather"` + } + result.City = w.City + result.Weather = fmt.Sprintf("The weather in %q is sunny and warm", w.City) + return result, nil } diff --git a/pkg/ollama/client_test.go b/pkg/ollama/client_test.go index 851b98b..e1987c2 100644 --- a/pkg/ollama/client_test.go +++ b/pkg/ollama/client_test.go @@ -1,7 +1,10 @@ package ollama_test import ( + "flag" + "log" "os" + "strconv" "testing" // Packages @@ -10,23 +13,46 @@ import ( assert "github.com/stretchr/testify/assert" ) -func Test_client_001(t *testing.T) { - assert := assert.New(t) - client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) - if assert.NoError(err) { - assert.NotNil(client) - t.Log(client) +/////////////////////////////////////////////////////////////////////////////// +// TEST SET-UP + +var ( + client *ollama.Client +) + +func TestMain(m *testing.M) { + var verbose bool + + // Verbose output + flag.Parse() + if f := flag.Lookup("test.v"); f != nil { + if v, err := strconv.ParseBool(f.Value.String()); err == nil { + verbose = v + } } + + // Endpoint + endpoint_url := os.Getenv("OLLAMA_URL") + if endpoint_url == "" { + log.Print("OLLAMA_URL not set") + os.Exit(0) + } + + // Create client + var err error + client, err = ollama.New(endpoint_url, opts.OptTrace(os.Stderr, verbose)) + if err != nil { + log.Println(err) + os.Exit(-1) + } + os.Exit(m.Run()) } /////////////////////////////////////////////////////////////////////////////// -// ENVIRONMENT +// TESTS -func GetEndpoint(t *testing.T) string { - key := os.Getenv("OLLAMA_URL") - if key == "" { - t.Skip("OLLAMA_URL not set, skipping tests") - t.SkipNow() - } - return key +func Test_client_001(t *testing.T) { + assert := assert.New(t) + assert.NotNil(client) + t.Log(client) } diff --git a/pkg/ollama/embedding_test.go b/pkg/ollama/embedding_test.go index 77c854f..2b0fc2c 100644 --- a/pkg/ollama/embedding_test.go +++ b/pkg/ollama/embedding_test.go @@ -2,27 +2,29 @@ package ollama_test import ( "context" - "os" "testing" // Packages - opts "github.com/mutablelogic/go-client" - ollama "github.com/mutablelogic/go-llm/pkg/ollama" + assert "github.com/stretchr/testify/assert" ) -func Test_embed_001(t *testing.T) { - client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) - if err != nil { - t.FailNow() - } +func Test_embeddings_001(t *testing.T) { + t.Run("Embedding1", func(t *testing.T) { + assert := assert.New(t) + embedding, err := client.GenerateEmbedding(context.TODO(), "qwen:0.5b", []string{"hello, world"}) + if !assert.NoError(err) { + t.FailNow() + } + assert.Equal(1, len(embedding.Embeddings)) + }) - t.Run("Embedding", func(t *testing.T) { + t.Run("Embedding2", func(t *testing.T) { assert := assert.New(t) - embedding, err := client.GenerateEmbedding(context.TODO(), "qwen:0.5b", []string{"world"}) + embedding, err := client.GenerateEmbedding(context.TODO(), "qwen:0.5b", []string{"hello, world", "goodbye cruel world"}) if !assert.NoError(err) { t.FailNow() } - t.Log(embedding) + assert.Equal(2, len(embedding.Embeddings)) }) } diff --git a/pkg/ollama/message.go b/pkg/ollama/message.go index 53efe76..d00c6b8 100644 --- a/pkg/ollama/message.go +++ b/pkg/ollama/message.go @@ -1,36 +1,80 @@ package ollama import ( + "fmt" + + // Packages llm "github.com/mutablelogic/go-llm" + tool "github.com/mutablelogic/go-llm/pkg/tool" ) /////////////////////////////////////////////////////////////////////////////// // TYPES -// Chat Message -type MessageMeta struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - FunctionName string `json:"name,omitempty"` // Function name for a tool result - Images []Data `json:"images,omitempty"` // Image attachments - ToolCalls []ToolCall `json:"tool_calls,omitempty"` // Tool calls from the assistant +// Message with text or object content +type Message struct { + RoleContent + ToolCallArray `json:"tool_calls,omitempty"` +} + +type RoleContent struct { + Role string `json:"role,omitempty"` // assistant, user, tool, system + Content string `json:"content,omitempty"` // string or array of text, reference, image_url + Images []Data `json:"images,omitempty"` // Image attachments + ToolResult } +// A set of tool calls +type ToolCallArray []ToolCall + type ToolCall struct { + Type string `json:"type"` // function Function ToolCallFunction `json:"function"` } type ToolCallFunction struct { Index int `json:"index,omitempty"` Name string `json:"name"` - Arguments map[string]any `json:"arguments"` + Arguments map[string]any `json:"arguments,omitempty"` } // Data represents the raw binary data of an image file. type Data []byte -// ToolFunction -type ToolFunction struct { - Type string `json:"type"` // function - Function llm.Tool `json:"function"` +// ToolResult +type ToolResult struct { + Name string `json:"name,omitempty"` // function name - when role is tool +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - MESSAGE + +func (m Message) Num() int { + return 1 +} + +func (m Message) Role() string { + return m.RoleContent.Role +} + +func (m Message) Text(index int) string { + if index != 0 { + return "" + } + return m.Content +} + +func (m Message) ToolCalls(index int) []llm.ToolCall { + if index != 0 { + return nil + } + + // Make the tool calls + calls := make([]llm.ToolCall, 0, len(m.ToolCallArray)) + for _, call := range m.ToolCallArray { + calls = append(calls, tool.NewCall(fmt.Sprint(call.Function.Index), call.Function.Name, call.Function.Arguments)) + } + + // Return success + return calls } diff --git a/pkg/ollama/model.go b/pkg/ollama/model.go index 246d68d..94fb7d4 100644 --- a/pkg/ollama/model.go +++ b/pkg/ollama/model.go @@ -16,7 +16,7 @@ import ( // model is the implementation of the llm.Model interface type model struct { - client *Client + *Client `json:"-"` ModelMeta } @@ -60,8 +60,12 @@ type PullStatus struct { /////////////////////////////////////////////////////////////////////////////// // STRINGIFY +func (m model) MarshalJSON() ([]byte, error) { + return json.Marshal(m.ModelMeta) +} + func (m model) String() string { - data, err := json.MarshalIndent(m.ModelMeta, "", " ") + data, err := json.MarshalIndent(m, "", " ") if err != nil { return err.Error() } @@ -91,10 +95,19 @@ func (ollama *Client) Models(ctx context.Context) ([]llm.Model, error) { return ollama.ListModels(ctx) } +// Agent interface +func (ollama *Client) Model(ctx context.Context, name string) llm.Model { + model, err := ollama.GetModel(ctx, name) + if err != nil { + panic(err) + } + return model +} + // List models func (ollama *Client) ListModels(ctx context.Context) ([]llm.Model, error) { type respListModel struct { - Models []*model `json:"models"` + Models []ModelMeta `json:"models"` } // Send the request @@ -105,9 +118,8 @@ func (ollama *Client) ListModels(ctx context.Context) ([]llm.Model, error) { // Convert to llm.Model result := make([]llm.Model, 0, len(response.Models)) - for _, model := range response.Models { - model.client = ollama - result = append(result, model) + for _, meta := range response.Models { + result = append(result, &model{ollama, meta}) } // Return models @@ -117,7 +129,7 @@ func (ollama *Client) ListModels(ctx context.Context) ([]llm.Model, error) { // List running models func (ollama *Client) ListRunningModels(ctx context.Context) ([]llm.Model, error) { type respListModel struct { - Models []*model `json:"models"` + Models []ModelMeta `json:"models"` } // Send the request @@ -128,9 +140,8 @@ func (ollama *Client) ListRunningModels(ctx context.Context) ([]llm.Model, error // Convert to llm.Model result := make([]llm.Model, 0, len(response.Models)) - for _, model := range response.Models { - model.client = ollama - result = append(result, model) + for _, meta := range response.Models { + result = append(result, &model{ollama, meta}) } // Return models @@ -152,16 +163,15 @@ func (ollama *Client) GetModel(ctx context.Context, name string) (llm.Model, err } // Response - var response model + var response ModelMeta if err := ollama.DoWithContext(ctx, req, &response, client.OptPath("show")); err != nil { return nil, err } else { - response.client = ollama - response.ModelMeta.Name = name + response.Name = name } // Return success - return &response, nil + return &model{ollama, response}, nil } // Copy a local model by name @@ -242,3 +252,46 @@ func (ollama *Client) PullModel(ctx context.Context, name string, opts ...llm.Op // Return success return ollama.GetModel(ctx, name) } + +// Load a model into memory +func (ollama *Client) LoadModel(ctx context.Context, name string) (llm.Model, error) { + type reqLoadModel struct { + Model string `json:"model"` + } + + // Request + req, err := client.NewJSONRequest(reqLoadModel{ + Model: name, + }) + if err != nil { + return nil, err + } + + // Response + if err := ollama.DoWithContext(ctx, req, nil, client.OptPath("generate")); err != nil { + return nil, err + } + + // Return success + return ollama.GetModel(ctx, name) +} + +// Unload a model from memory +func (ollama *Client) UnloadModel(ctx context.Context, name string) error { + type reqLoadModel struct { + Model string `json:"model"` + KeepAlive uint `json:"keepalive"` + } + + // Request + req, err := client.NewJSONRequest(reqLoadModel{ + Model: name, + KeepAlive: 0, + }) + if err != nil { + return err + } + + // Response + return ollama.DoWithContext(ctx, req, nil, client.OptPath("generate")) +} diff --git a/pkg/ollama/model_test.go b/pkg/ollama/model_test.go index db14c9d..5a02d7f 100644 --- a/pkg/ollama/model_test.go +++ b/pkg/ollama/model_test.go @@ -2,23 +2,19 @@ package ollama_test import ( "context" - "os" "testing" // Packages - opts "github.com/mutablelogic/go-client" + ollama "github.com/mutablelogic/go-llm/pkg/ollama" assert "github.com/stretchr/testify/assert" ) func Test_model_001(t *testing.T) { - client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) - if err != nil { - t.FailNow() - } - var names []string + t.Run("Models", func(t *testing.T) { + // Get all models assert := assert.New(t) models, err := client.Models(context.TODO()) if !assert.NoError(err) { @@ -29,19 +25,19 @@ func Test_model_001(t *testing.T) { names = append(names, model.Name()) } }) - t.Run("Model", func(t *testing.T) { + // Get models one by one assert := assert.New(t) for _, name := range names { model, err := client.GetModel(context.TODO(), name) if !assert.NoError(err) { t.FailNow() } - t.Log(model) + assert.NotNil(model) } }) - t.Run("PullModel", func(t *testing.T) { + // Pull a model assert := assert.New(t) model, err := client.PullModel(context.TODO(), "qwen:0.5b", ollama.WithPullStatus(func(status *ollama.PullStatus) { t.Log(status) @@ -53,6 +49,7 @@ func Test_model_001(t *testing.T) { }) t.Run("CopyModel", func(t2 *testing.T) { + // Copy a model assert := assert.New(t) err := client.CopyModel(context.TODO(), "qwen:0.5b", t.Name()) if !assert.NoError(err) { @@ -60,15 +57,28 @@ func Test_model_001(t *testing.T) { } }) + t.Run("LoadModel", func(t2 *testing.T) { + // Load model into memory + assert := assert.New(t) + _, err := client.LoadModel(context.TODO(), t.Name()) + assert.NoError(err) + }) + + t.Run("UnloadModel", func(t2 *testing.T) { + // Unload model from memory + assert := assert.New(t) + err := client.UnloadModel(context.TODO(), t.Name()) + assert.NoError(err) + }) + t.Run("DeleteModel", func(t2 *testing.T) { + // Delete a model assert := assert.New(t) - _, err = client.GetModel(context.TODO(), t.Name()) - if !assert.NoError(err) { - t.FailNow() - } - err := client.DeleteModel(context.TODO(), t.Name()) - if !assert.NoError(err) { - t.FailNow() + _, err := client.GetModel(context.TODO(), t.Name()) + if assert.NoError(err) { + err = client.DeleteModel(context.TODO(), t.Name()) + assert.NoError(err) } }) + } diff --git a/pkg/ollama/opt.go b/pkg/ollama/opt.go index f5a28d0..2b1afd4 100644 --- a/pkg/ollama/opt.go +++ b/pkg/ollama/opt.go @@ -79,30 +79,27 @@ func optPullStatus(opts *llm.Opts) func(*PullStatus) { return nil } -func optSystemPrompt(opts *llm.Opts) string { - return opts.SystemPrompt() -} - -func optTools(agent *Client, opts *llm.Opts) []ToolFunction { +func optTools(agent *Client, opts *llm.Opts) []llm.Tool { toolkit := opts.ToolKit() if toolkit == nil { return nil } - tools := toolkit.Tools(agent) - result := make([]ToolFunction, 0, len(tools)) - for _, tool := range tools { - result = append(result, ToolFunction{ - Type: "function", - Function: tool, - }) - } - return result + return toolkit.Tools(agent) } func optFormat(opts *llm.Opts) string { return opts.GetString("format") } +func optStopSequence(opts *llm.Opts) []string { + if opts.Has("stop") { + if stop, ok := opts.Get("stop").([]string); ok { + return stop + } + } + return nil +} + func optOptions(opts *llm.Opts) map[string]any { result := make(map[string]any) if o, ok := opts.Get("options").(map[string]any); ok { @@ -113,13 +110,25 @@ func optOptions(opts *llm.Opts) map[string]any { // copy across temperature, top_p and top_k if opts.Has("temperature") { - result["temperature"] = opts.Get("temperature") + result["temperature"] = opts.Get("temperature").(float64) } if opts.Has("top_p") { - result["top_p"] = opts.Get("top_p") + result["top_p"] = opts.GetFloat64("top_p") } if opts.Has("top_k") { - result["top_k"] = opts.Get("top_k") + result["top_k"] = opts.GetUint64("top_k") + } + if opts.Has("stop") { + result["stop"] = opts.Get("stop").([]string) + } + if opts.Has("seed") { + result["seed"] = opts.GetUint64("seed") + } + if opts.Has("presence_penalty") { + result["presence_penalty"] = opts.GetFloat64("presence_penalty") + } + if opts.Has("frequency_penalty") { + result["frequency_penalty"] = opts.GetFloat64("frequency_penalty") } // Return result diff --git a/pkg/ollama/session.go b/pkg/ollama/session.go index 867f60a..50c702c 100644 --- a/pkg/ollama/session.go +++ b/pkg/ollama/session.go @@ -3,11 +3,9 @@ package ollama import ( "context" "encoding/json" - "fmt" // Packages llm "github.com/mutablelogic/go-llm" - "github.com/mutablelogic/go-llm/pkg/tool" ) /////////////////////////////////////////////////////////////////////////////// @@ -15,9 +13,9 @@ import ( // Implementation of a message session, which is a sequence of messages type session struct { - opts []llm.Opt - model *model - seq []*MessageMeta + model *model // The model used for the session + opts []llm.Opt // Options to apply to the session + seq []*Message // Sequence of messages } var _ llm.Context = (*session)(nil) @@ -25,21 +23,30 @@ var _ llm.Context = (*session)(nil) /////////////////////////////////////////////////////////////////////////////// // LIFECYCLE -// Create a new empty context +// Return an empty session context object for the model, setting session options func (model *model) Context(opts ...llm.Opt) llm.Context { return &session{ model: model, opts: opts, + seq: make([]*Message, 0, 10), } } -// Create a new context with a user prompt +// Convenience method to create a session context object with a user prompt, which +// panics on error func (model *model) UserPrompt(prompt string, opts ...llm.Opt) llm.Context { context := model.Context(opts...) - context.(*session).seq = append(context.(*session).seq, &MessageMeta{ - Role: "user", - Content: prompt, - }) + + // Create a user prompt + message, err := userPrompt(prompt, opts...) + if err != nil { + panic(err) + } + + // Add to the sequence + context.(*session).seq = append(context.(*session).seq, message) + + // Return success return context } @@ -60,137 +67,155 @@ func (session session) String() string { return string(data) } +////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the number of completions +func (session *session) Num() int { + if len(session.seq) == 0 { + return 0 + } + return 1 +} + +// Return the role of the last message +func (session *session) Role() string { + if len(session.seq) == 0 { + return "" + } + return session.seq[len(session.seq)-1].Role() +} + +// Return the text of the last message +func (session *session) Text(index int) string { + if len(session.seq) == 0 { + return "" + } + return session.seq[len(session.seq)-1].Text(index) +} + +// Return tool calls for the last message +func (session *session) ToolCalls(index int) []llm.ToolCall { + if len(session.seq) == 0 { + return nil + } + return session.seq[len(session.seq)-1].ToolCalls(index) +} + /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS // Generate a response from a user prompt (with attachments) -func (s *session) FromUser(ctx context.Context, prompt string, opts ...llm.Opt) error { - // Append the user prompt - if user, err := userPrompt(prompt, opts...); err != nil { +func (session *session) FromUser(ctx context.Context, prompt string, opts ...llm.Opt) error { + message, err := userPrompt(prompt, opts...) + if err != nil { return err - } else { - s.seq = append(s.seq, user) } + // Append the user prompt to the sequence + session.seq = append(session.seq, message) + // The options come from the session options and the user options - chatopts := make([]llm.Opt, 0, len(s.opts)+len(opts)) - chatopts = append(chatopts, s.opts...) + chatopts := make([]llm.Opt, 0, len(session.opts)+len(opts)) + chatopts = append(chatopts, session.opts...) chatopts = append(chatopts, opts...) // Call the 'chat' method - client := s.model.client - r, err := client.Chat(ctx, s, chatopts...) + r, err := session.model.Chat(ctx, session, chatopts...) if err != nil { return err - } else { - s.seq = append(s.seq, &r.Message) } + // Append the message to the sequence + session.seq = append(session.seq, &r.Message) + // Return success return nil } // Generate a response from a tool calling result -func (s *session) FromTool(ctx context.Context, results ...llm.ToolResult) error { - if len(results) == 0 { - return llm.ErrConflict.Withf("No tool results") +func (session *session) FromTool(ctx context.Context, results ...llm.ToolResult) error { + messages, err := toolResults(results...) + if err != nil { + return err } - // Append the tool results - for _, result := range results { - if message, err := toolResult(result); err != nil { - return err - } else { - s.seq = append(s.seq, message) - } - } + // Append the tool results to the sequence + session.seq = append(session.seq, messages...) // Call the 'chat' method - r, err := s.model.client.Chat(ctx, s, s.opts...) + r, err := session.model.Chat(ctx, session, session.opts...) if err != nil { return err - } else { - s.seq = append(s.seq, &r.Message) } + // Append the first message from the set of completions + session.seq = append(session.seq, &r.Message) + // Return success return nil } -// Return the role of the last message -func (session *session) Role() string { - if len(session.seq) == 0 { - return "" - } - return session.seq[len(session.seq)-1].Role -} - -// Return the text of the last message -func (session *session) Text() string { - if len(session.seq) == 0 { - return "" - } - return session.seq[len(session.seq)-1].Content -} - -// Return the tool calls of the last message -func (session *session) ToolCalls() []llm.ToolCall { - // Sanity check for tool call - if len(session.seq) == 0 { - return nil - } - meta := session.seq[len(session.seq)-1] - if meta.Role != "assistant" { - return nil - } +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS - // Gather tool calls - var result []llm.ToolCall - for _, call := range meta.ToolCalls { - result = append(result, tool.NewCall(fmt.Sprint(call.Function.Index), call.Function.Name, call.Function.Arguments)) +func systemPrompt(prompt string) *Message { + return &Message{ + RoleContent: RoleContent{ + Role: "system", + Content: prompt, + }, } - return result } -/////////////////////////////////////////////////////////////////////////////// -// PRIVATE METHODS - -func userPrompt(prompt string, opts ...llm.Opt) (*MessageMeta, error) { - // Apply options for attachments - opt, err := llm.ApplyOpts(opts...) +func userPrompt(prompt string, opts ...llm.Opt) (*Message, error) { + // Get attachments + opt, err := llm.ApplyPromptOpts(opts...) if err != nil { return nil, err } - // Create a new message - var meta MessageMeta - meta.Role = "user" - meta.Content = prompt - - if attachments := opt.Attachments(); len(attachments) > 0 { - meta.Images = make([]Data, len(attachments)) - for i, attachment := range attachments { - meta.Images[i] = attachment.Data() - } + // Get attachments, allocate content + attachments := opt.Attachments() + data := make([]Data, 0, len(attachments)) + for _, attachment := range attachments { + data = append(data, attachment.Data()) } // Return success - return &meta, nil + return &Message{ + RoleContent: RoleContent{ + Role: "user", + Content: prompt, + Images: data, + }, + }, nil } -func toolResult(result llm.ToolResult) (*MessageMeta, error) { - // Turn result into JSON - data, err := json.Marshal(result.Value()) - if err != nil { - return nil, err +func toolResults(results ...llm.ToolResult) ([]*Message, error) { + // Check for no results + if len(results) == 0 { + return nil, llm.ErrBadParameter.Withf("No tool results") } - // Create a new message - var meta MessageMeta - meta.Role = "tool" - meta.FunctionName = result.Call().Name() - meta.Content = string(data) + // Create results + messages := make([]*Message, 0, len(results)) + for _, result := range results { + value, err := json.Marshal(result.Value()) + if err != nil { + return nil, err + } + messages = append(messages, &Message{ + RoleContent: RoleContent{ + Role: "tool", + ToolResult: ToolResult{ + Name: result.Call().Name(), + }, + Content: string(value), + }, + }) + } // Return success - return &meta, nil + return messages, nil } diff --git a/pkg/ollama/session_test.go b/pkg/ollama/session_test.go index e4df9d4..e343eff 100644 --- a/pkg/ollama/session_test.go +++ b/pkg/ollama/session_test.go @@ -2,89 +2,57 @@ package ollama_test import ( "context" - "os" "testing" // Packages - opts "github.com/mutablelogic/go-client" llm "github.com/mutablelogic/go-llm" - ollama "github.com/mutablelogic/go-llm/pkg/ollama" - "github.com/mutablelogic/go-llm/pkg/tool" + tool "github.com/mutablelogic/go-llm/pkg/tool" assert "github.com/stretchr/testify/assert" ) func Test_session_001(t *testing.T) { - client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) - if err != nil { + assert := assert.New(t) + model, err := client.PullModel(context.TODO(), "llama3.2") + if !assert.NoError(err) { t.FailNow() } + assert.NotNil(model) - // Pull the model - model, err := client.PullModel(context.TODO(), "qwen:0.5b") - if err != nil { - t.FailNow() + session := model.Context() + if assert.NotNil(session) { + err := session.FromUser(context.TODO(), "Hello, how are you?") + assert.NoError(err) + t.Log(session) } - - // Session with a single user prompt - streaming - t.Run("stream", func(t *testing.T) { - assert := assert.New(t) - session := model.Context(llm.WithStream(func(stream llm.Context) { - t.Log("SESSION DELTA", stream) - })) - assert.NotNil(session) - - err := session.FromUser(context.TODO(), "Why is the grass green?") - if !assert.NoError(err) { - t.FailNow() - } - assert.Equal("assistant", session.Role()) - assert.NotEmpty(session.Text()) - }) - - // Session with a single user prompt - not streaming - t.Run("nostream", func(t *testing.T) { - assert := assert.New(t) - session := model.Context() - assert.NotNil(session) - - err := session.FromUser(context.TODO(), "Why is the sky blue?") - if !assert.NoError(err) { - t.FailNow() - } - assert.Equal("assistant", session.Role()) - assert.NotEmpty(session.Text()) - }) } func Test_session_002(t *testing.T) { - client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) - if err != nil { - t.FailNow() - } - - // Pull the model + assert := assert.New(t) model, err := client.PullModel(context.TODO(), "llama3.2") - if err != nil { + if !assert.NoError(err) { t.FailNow() } + assert.NotNil(model) - // Make a toolkit toolkit := tool.NewToolKit() - if err := toolkit.Register(new(weather)); err != nil { + toolkit.Register(&weather{}) + + session := model.Context(llm.WithToolKit(toolkit)) + if !assert.NotNil(session) { t.FailNow() } - // Session with a tool call - t.Run("toolcall", func(t *testing.T) { - assert := assert.New(t) + assert.NoError(session.FromUser(context.TODO(), "What is the weather like in London today?")) + calls := session.ToolCalls(0) + if assert.Len(calls, 1) { + assert.Equal("weather_in_city", calls[0].Name()) - session := model.Context(llm.WithToolKit(toolkit)) - assert.NotNil(session) + result, err := toolkit.Run(context.TODO(), calls...) + assert.NoError(err) + assert.Len(result, 1) - err = session.FromUser(context.TODO(), "What is today's weather in Berlin?", llm.WithTemperature(0.5)) - if !assert.NoError(err) { - t.FailNow() - } - t.Log(session) - }) + assert.NoError(session.FromTool(context.TODO(), result...)) + } + + t.Log(session) } diff --git a/pkg/tool/old/tool.go_old b/pkg/tool/old/tool.go_old deleted file mode 100644 index 2da2bee..0000000 --- a/pkg/tool/old/tool.go_old +++ /dev/null @@ -1,220 +0,0 @@ -package ollama - -import ( - "encoding/json" - "errors" - "fmt" - "reflect" - "strings" - - // Packages - llm "github.com/mutablelogic/go-llm" -) - -/////////////////////////////////////////////////////////////////////////////// -// TYPES - -type Tool struct { - Type string `json:"type"` - Function ToolFunction `json:"function"` -} - -type ToolFunction struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters struct { - Type string `json:"type,omitempty"` - Required []string `json:"required,omitempty"` - Properties map[string]ToolParameter `json:"properties,omitempty"` - } `json:"parameters"` - proto reflect.Type // Prototype for parameter return -} - -type ToolParameter struct { - Name string `json:"-"` - Type string `json:"type"` - Description string `json:"description,omitempty"` - Enum []string `json:"enum,omitempty"` - required bool - index []int // Field index into prototype for setting a field -} - -/////////////////////////////////////////////////////////////////////////////// -// LIFECYCLE - -// Return a tool, or panic if there is an error -func MustTool(name, description string, params any) *Tool { - tool, err := NewTool(name, description, params) - if err != nil { - panic(err) - } - return tool -} - -// Return a new tool definition -func NewTool(name, description string, params any) (*Tool, error) { - tool := Tool{ - Type: "function", - Function: ToolFunction{Name: name, Description: description, proto: reflect.TypeOf(params)}, - } - - // Add parameters - tool.Function.Parameters.Type = "object" - if params, err := paramsFor(params); err != nil { - return nil, err - } else { - tool.Function.Parameters.Required = make([]string, 0, len(params)) - tool.Function.Parameters.Properties = make(map[string]ToolParameter, len(params)) - for _, param := range params { - if _, exists := tool.Function.Parameters.Properties[param.Name]; exists { - return nil, llm.ErrConflict.Withf("parameter %q already exists", param.Name) - } else { - tool.Function.Parameters.Properties[param.Name] = param - } - if param.required { - tool.Function.Parameters.Required = append(tool.Function.Parameters.Required, param.Name) - } - } - } - - // Return success - return &tool, nil -} - -// Return a new tool call -func NewToolCall(v ToolCall) *ToolCallFunction { - return &v.Function -} - -/////////////////////////////////////////////////////////////////////////////// -// STRINGIFY - -func (t Tool) String() string { - data, err := json.MarshalIndent(t, "", " ") - if err != nil { - return err.Error() - } - return string(data) -} - -/////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - -func (t *Tool) Params(call ToolCall) (any, error) { - if call.Function.Name != t.Function.Name { - return nil, llm.ErrBadParameter.Withf("invalid function %q, expected %q", call.Function.Name, t.Function.Name) - } - - // Create parameters - params := reflect.New(t.Function.proto).Elem() - - // Iterate over arguments - var result error - for name, value := range call.Function.Arguments { - param, exists := t.Function.Parameters.Properties[name] - if !exists { - return nil, llm.ErrBadParameter.Withf("invalid argument %q", name) - } - result = errors.Join(result, paramSet(params.FieldByIndex(param.index), value)) - } - - // Return any errors - return params.Interface(), result -} - -/////////////////////////////////////////////////////////////////////////////// -// PRIVATE METHODS - -// Return tool parameters from a struct -func paramsFor(params any) ([]ToolParameter, error) { - if params == nil { - return []ToolParameter{}, nil - } - rt := reflect.TypeOf(params) - if rt.Kind() == reflect.Ptr { - rt = rt.Elem() - } - if rt.Kind() != reflect.Struct { - return nil, llm.ErrBadParameter.With("params must be a struct") - } - - // Iterate over fields - fields := reflect.VisibleFields(rt) - result := make([]ToolParameter, 0, len(fields)) - for _, field := range fields { - if param, err := paramFor(field); err != nil { - return nil, err - } else { - result = append(result, param) - } - } - - // Return success - return result, nil -} - -// Return tool parameters from a struct field -func paramFor(field reflect.StructField) (ToolParameter, error) { - // Name - name := field.Tag.Get("name") - if name == "" { - name = field.Name - } - - // Type - typ, err := paramType(field) - if err != nil { - return ToolParameter{}, err - } - - // Required - _, required := field.Tag.Lookup("required") - - // Enum - enum := []string{} - if enum_ := field.Tag.Get("enum"); enum_ != "" { - enum = strings.Split(enum_, ",") - } - - // Return success - return ToolParameter{ - Name: field.Name, - Type: typ, - Description: field.Tag.Get("help"), - Enum: enum, - required: required, - index: field.Index, - }, nil -} - -var ( - typeString = reflect.TypeOf("") - typeUint = reflect.TypeOf(uint(0)) - typeInt = reflect.TypeOf(int(0)) - typeFloat64 = reflect.TypeOf(float64(0)) - typeFloat32 = reflect.TypeOf(float32(0)) -) - -// Return parameter type from a struct field -func paramType(field reflect.StructField) (string, error) { - t := field.Type - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - switch field.Type { - case typeString: - return "string", nil - case typeUint, typeInt: - return "integer", nil - case typeFloat64, typeFloat32: - return "number", nil - default: - return "", llm.ErrBadParameter.Withf("unsupported type %v for field %q", field.Type, field.Name) - } -} - -// Set a field parameter -func paramSet(field reflect.Value, v any) error { - fmt.Println("TODO", field, "=>", v) - return nil -} diff --git a/pkg/tool/old/tool.go_old_old b/pkg/tool/old/tool.go_old_old deleted file mode 100644 index 2d0a24b..0000000 --- a/pkg/tool/old/tool.go_old_old +++ /dev/null @@ -1,216 +0,0 @@ -package anthropic - -import ( - "encoding/json" - "reflect" - "strings" - - // Packages - llm "github.com/mutablelogic/go-llm" -) - -/////////////////////////////////////////////////////////////////////////////// -// TYPES - -type Tool struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters struct { - Type string `json:"type,omitempty"` - Required []string `json:"required,omitempty"` - Properties map[string]ToolParameter `json:"properties,omitempty"` - } `json:"input_schema"` - proto reflect.Type // Prototype for parameter return -} - -type ToolParameter struct { - Name string `json:"-"` - Type string `json:"type"` - Description string `json:"description,omitempty"` - Enum []string `json:"enum,omitempty"` - required bool - index []int // Field index into prototype for setting a field -} - -type toolcall struct { - ContentTool -} - -/////////////////////////////////////////////////////////////////////////////// -// LIFECYCLE - -// Return a tool, or panic if there is an error -func MustTool(name, description string, params any) *Tool { - tool, err := NewTool(name, description, params) - if err != nil { - panic(err) - } - return tool -} - -// Return a new tool definition -func NewTool(name, description string, params any) (*Tool, error) { - tool := Tool{ - Name: name, - Description: description, - proto: reflect.TypeOf(params), - } - - // Add parameters - tool.Parameters.Type = "object" - toolparams, err := paramsFor(params) - if err != nil { - return nil, err - } - - // Set parameters - tool.Parameters.Required = make([]string, 0, len(toolparams)) - tool.Parameters.Properties = make(map[string]ToolParameter, len(toolparams)) - for _, param := range toolparams { - if _, exists := tool.Parameters.Properties[param.Name]; exists { - return nil, llm.ErrConflict.Withf("parameter %q already exists", param.Name) - } else { - tool.Parameters.Properties[param.Name] = param - } - if param.required { - tool.Parameters.Required = append(tool.Parameters.Required, param.Name) - } - } - - // Return success - return &tool, nil -} - -// Return a new tool call from a content parameter -func NewToolCall(content *Content) *toolcall { - if content == nil || content.ContentTool.Id == "" || content.ContentTool.Name == "" { - return nil - } - return &toolcall{content.ContentTool} -} - -/////////////////////////////////////////////////////////////////////////////// -// STRINGIFY - -func (t Tool) String() string { - data, err := json.MarshalIndent(t, "", " ") - if err != nil { - return err.Error() - } - return string(data) -} - -func (t toolcall) String() string { - data, err := json.MarshalIndent(t, "", " ") - if err != nil { - return err.Error() - } - return string(data) -} - -/////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - -func (t *toolcall) Name() string { - return t.ContentTool.Name -} - -func (t *toolcall) Id() string { - return t.ContentTool.Id -} - -func (t *toolcall) Params() any { - // TODO: Convert - return t.ContentTool.Input -} - -/////////////////////////////////////////////////////////////////////////////// -// PRIVATE METHODS - -// Return tool parameters from a struct -func paramsFor(params any) ([]ToolParameter, error) { - if params == nil { - return []ToolParameter{}, nil - } - rt := reflect.TypeOf(params) - if rt.Kind() == reflect.Ptr { - rt = rt.Elem() - } - if rt.Kind() != reflect.Struct { - return nil, llm.ErrBadParameter.With("params must be a struct") - } - - // Iterate over fields - fields := reflect.VisibleFields(rt) - result := make([]ToolParameter, 0, len(fields)) - for _, field := range fields { - if param, err := paramFor(field); err != nil { - return nil, err - } else { - result = append(result, param) - } - } - - // Return success - return result, nil -} - -// Return tool parameters from a struct field -func paramFor(field reflect.StructField) (ToolParameter, error) { - // Name - name := field.Tag.Get("name") - if name == "" { - name = field.Name - } - - // Type - typ, err := paramType(field) - if err != nil { - return ToolParameter{}, err - } - - // Required - _, required := field.Tag.Lookup("required") - - // Enum - enum := []string{} - if enum_ := field.Tag.Get("enum"); enum_ != "" { - enum = strings.Split(enum_, ",") - } - - // Return success - return ToolParameter{ - Name: field.Name, - Type: typ, - Description: field.Tag.Get("help"), - Enum: enum, - required: required, - index: field.Index, - }, nil -} - -var ( - typeString = reflect.TypeOf("") - typeUint = reflect.TypeOf(uint(0)) - typeInt = reflect.TypeOf(int(0)) - typeFloat64 = reflect.TypeOf(float64(0)) - typeFloat32 = reflect.TypeOf(float32(0)) -) - -// Return parameter type from a struct field -func paramType(field reflect.StructField) (string, error) { - t := field.Type - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - switch field.Type { - case typeString: - return "string", nil - case typeUint, typeInt: - return "integer", nil - case typeFloat64, typeFloat32: - return "number", nil - default: - return "", llm.ErrBadParameter.Withf("unsupported type %v for field %q", field.Type, field.Name) - } -} diff --git a/pkg/tool/old/tool_test.go_old b/pkg/tool/old/tool_test.go_old deleted file mode 100644 index f4d1d30..0000000 --- a/pkg/tool/old/tool_test.go_old +++ /dev/null @@ -1,29 +0,0 @@ -package ollama_test - -import ( - "testing" - - // Packagees - - ollama "github.com/mutablelogic/go-llm/pkg/ollama" -) - -func Test_tool_001(t *testing.T) { - tool, err := ollama.NewTool("test", "test_tool", struct{}{}) - if err != nil { - t.FailNow() - } - t.Log(tool) -} - -func Test_tool_002(t *testing.T) { - tool, err := ollama.NewTool("test", "test_tool", struct { - A string `help:"A string"` - B int `help:"An integer"` - C float64 `help:"A float" required:""` - }{}) - if err != nil { - t.FailNow() - } - t.Log(tool) -} diff --git a/pkg/tool/tool.go b/pkg/tool/tool.go index 5b22920..23515d5 100644 --- a/pkg/tool/tool.go +++ b/pkg/tool/tool.go @@ -43,6 +43,11 @@ type ToolParameter struct { index []int // Field index into prototype for setting a field } +type ToolFunction struct { + Type string `json:"type"` // function + llm.Tool `json:"function"` +} + //////////////////////////////////////////////////////////////////////////////// // STRINGIFY @@ -54,6 +59,14 @@ func (t tool) String() string { return string(data) } +func (t ToolFunction) String() string { + data, err := json.MarshalIndent(t, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS diff --git a/pkg/tool/toolkit.go b/pkg/tool/toolkit.go index 6d9fd34..79eb86a 100644 --- a/pkg/tool/toolkit.go +++ b/pkg/tool/toolkit.go @@ -37,9 +37,12 @@ func (kit *ToolKit) Tools(agent llm.Agent) []llm.Tool { result := make([]llm.Tool, 0, len(kit.functions)) for _, t := range kit.functions { switch agent.Name() { - case "ollama": + case "ollama", "mistral": t.InputSchema = nil - result = append(result, t) + result = append(result, ToolFunction{ + Type: "function", + Tool: t, + }) default: t.Parameters = nil result = append(result, t) diff --git a/toolkit.go b/toolkit.go index e61d98d..8b180e2 100644 --- a/toolkit.go +++ b/toolkit.go @@ -28,7 +28,6 @@ type Tool interface { Description() string // Run the tool with a deadline and return the result - // TODO: Change 'any' to ToolResult Run(context.Context) (any, error) }