diff --git a/README.md b/README.md index 99b2c64..71bb145 100644 --- a/README.md +++ b/README.md @@ -151,18 +151,88 @@ ensure the session is maintained across multiple calls. ### Embedding Generation -TODO +You can generate embedding vectors using an appropriate model with Ollama or Mistral models: + +```go +import ( + "github.com/mutablelogic/go-llm" +) + +func embedding(ctx context.Context, agent llm.Agent) error { + // Create a new chat session + vector, err := agent.Model(ctx, "mistral-embed").Embedding(ctx, "hello") + // ... +} +``` ### Attachments & Image Caption Generation -TODO +Some models have `vision` capability and others can also summarize text. For example, to +generate captions for an image, + +```go +import ( + "github.com/mutablelogic/go-llm" +) + +func generate_image_caption(ctx context.Context, agent llm.Agent, path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + + // Describe an image + r, err := agent.Model("claude-3-5-sonnet-20241022").UserPrompt( + ctx, model.UserPrompt("Provide a short caption for this image", llm.WithAttachment(f)) + ) + if err != nil { + return "", err + } + + // Return success + return r.Text(0), err +} +``` + +To summarize a text or PDF docment is exactly the same using an Anthropic model, but maybe with a +different prompt. ### Streaming -TODO +Streaming is supported with all providers, but Ollama cannot be used with streaming and tools +simultaneously. You provide a callback function of signature `func(llm.Completion)` which will +be called as a completion is received. + +```go +import ( + "github.com/mutablelogic/go-llm" +) + +func generate_completion(ctx context.Context, agent llm.Agent, prompt string) (string, error) { + r, err := agent.Model("claude-3-5-sonnet-20241022").UserPrompt( + ctx, model.UserPrompt("What is the weather in London?"), + llm.WithStream(stream_callback), + ) + if err != nil { + return "", err + } + + // Return success + return r.Text(0), err +} + +func stream_callback(completion llm.Completion) { + // Print out the completion text on each call + fmt.Println(completion.Text(0)) +} + +``` ### Tool Support +All providers support tools, but not all models. + TODO ## Options diff --git a/cmd/llm/main.go b/cmd/llm/main.go index 205ad86..ff82c92 100644 --- a/cmd/llm/main.go +++ b/cmd/llm/main.go @@ -107,11 +107,9 @@ func main() { if cli.OllamaEndpoint != "" { opts = append(opts, agent.WithOllama(cli.OllamaEndpoint, clientopts...)) } - /* - if cli.AnthropicKey != "" { - opts = append(opts, agent.WithAnthropic(cli.AnthropicKey, clientopts...)) - } - */ + if cli.AnthropicKey != "" { + opts = append(opts, agent.WithAnthropic(cli.AnthropicKey, clientopts...)) + } if cli.MistralKey != "" { opts = append(opts, agent.WithMistral(cli.MistralKey, clientopts...)) } diff --git a/pkg/agent/opt.go b/pkg/agent/opt.go index 36fe662..9e54647 100644 --- a/pkg/agent/opt.go +++ b/pkg/agent/opt.go @@ -4,6 +4,7 @@ import ( // Packages client "github.com/mutablelogic/go-client" llm "github.com/mutablelogic/go-llm" + "github.com/mutablelogic/go-llm/pkg/anthropic" mistral "github.com/mutablelogic/go-llm/pkg/mistral" ollama "github.com/mutablelogic/go-llm/pkg/ollama" ) @@ -22,18 +23,17 @@ func WithOllama(endpoint string, opts ...client.ClientOpt) llm.Opt { } } -/* - func WithAnthropic(key string, opts ...client.ClientOpt) llm.Opt { - return func(o *llm.Opts) error { - client, err := anthropic.New(key, opts...) - if err != nil { - return err - } else { - return llm.WithAgent(client)(o) - } +func WithAnthropic(key string, opts ...client.ClientOpt) llm.Opt { + return func(o *llm.Opts) error { + client, err := anthropic.New(key, opts...) + if err != nil { + return err + } else { + return llm.WithAgent(client)(o) } } -*/ +} + func WithMistral(key string, opts ...client.ClientOpt) llm.Opt { return func(o *llm.Opts) error { client, err := mistral.New(key, opts...) diff --git a/pkg/anthropic/client.go b/pkg/anthropic/client.go index 8bb617a..f446edc 100644 --- a/pkg/anthropic/client.go +++ b/pkg/anthropic/client.go @@ -5,6 +5,8 @@ package anthropic import ( // Packages + "context" + client "github.com/mutablelogic/go-client" llm "github.com/mutablelogic/go-llm" ) @@ -42,10 +44,7 @@ func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) { } // Return the client - return &Client{ - Client: client, - cache: make(map[string]llm.Model), - }, nil + return &Client{client, nil}, nil } /////////////////////////////////////////////////////////////////////////////// @@ -55,3 +54,36 @@ func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) { func (*Client) Name() string { return defaultName } + +// Return the models +func (anthropic *Client) Models(ctx context.Context) ([]llm.Model, error) { + // Cache models + if anthropic.cache == nil { + models, err := anthropic.ListModels(ctx) + if err != nil { + return nil, err + } + anthropic.cache = make(map[string]llm.Model, len(models)) + for _, model := range models { + anthropic.cache[model.Name()] = model + } + } + + // Return models + result := make([]llm.Model, 0, len(anthropic.cache)) + for _, model := range anthropic.cache { + result = append(result, model) + } + return result, nil +} + +// Return a model by name, or nil if not found. +// Panics on error. +func (anthropic *Client) Model(ctx context.Context, name string) llm.Model { + if anthropic.cache == nil { + if _, err := anthropic.Models(ctx); err != nil { + panic(err) + } + } + return anthropic.cache[name] +} diff --git a/pkg/anthropic/client_test.go b/pkg/anthropic/client_test.go index ea7e2ee..0ecabcd 100644 --- a/pkg/anthropic/client_test.go +++ b/pkg/anthropic/client_test.go @@ -1,7 +1,10 @@ package anthropic_test import ( + "flag" + "log" "os" + "strconv" "testing" // Packages @@ -10,23 +13,46 @@ import ( assert "github.com/stretchr/testify/assert" ) -func Test_client_001(t *testing.T) { - assert := assert.New(t) - client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) - if assert.NoError(err) { - assert.NotNil(client) - t.Log(client) +/////////////////////////////////////////////////////////////////////////////// +// TEST SET-UP + +var ( + client *anthropic.Client +) + +func TestMain(m *testing.M) { + var verbose bool + + // Verbose output + flag.Parse() + if f := flag.Lookup("test.v"); f != nil { + if v, err := strconv.ParseBool(f.Value.String()); err == nil { + verbose = v + } } + + // API KEY + api_key := os.Getenv("ANTHROPIC_API_KEY") + if api_key == "" { + log.Print("ANTHROPIC_API_KEY not set") + os.Exit(0) + } + + // Create client + var err error + client, err = anthropic.New(api_key, opts.OptTrace(os.Stderr, verbose)) + if err != nil { + log.Println(err) + os.Exit(-1) + } + os.Exit(m.Run()) } /////////////////////////////////////////////////////////////////////////////// -// ENVIRONMENT +// TESTS -func GetApiKey(t *testing.T) string { - key := os.Getenv("ANTHROPIC_API_KEY") - if key == "" { - t.Skip("ANTHROPIC_API_KEY not set, skipping tests") - t.SkipNow() - } - return key +func Test_client_001(t *testing.T) { + assert := assert.New(t) + assert.NotNil(client) + t.Log(client) } diff --git a/pkg/anthropic/completion.go b/pkg/anthropic/completion.go new file mode 100644 index 0000000..6c25a11 --- /dev/null +++ b/pkg/anthropic/completion.go @@ -0,0 +1,243 @@ +package anthropic + +import ( + "context" + "encoding/json" + "fmt" + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Chat Completion Response +type Response struct { + Id string `json:"id"` + Type string `json:"type"` + Model string `json:"model"` + Reason string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` + Message + Metrics `json:"usage,omitempty"` +} + +// Metrics +type Metrics struct { + CacheCreationInputTokens uint `json:"cache_creation_input_tokens,omitempty"` + CacheReadInputTokens uint `json:"cache_read_input_tokens,omitempty"` + InputTokens uint `json:"input_tokens,omitempty"` + OutputTokens uint `json:"output_tokens,omitempty"` +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (r Response) String() string { + data, err := json.MarshalIndent(r, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +type reqMessages struct { + Model string `json:"model"` + MaxTokens uint64 `json:"max_tokens,omitempty"` + Metadata *optmetadata `json:"metadata,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK uint64 `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Messages []*Message `json:"messages"` + Tools []llm.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` +} + +func (anthropic *Client) Messages(ctx context.Context, context llm.Context, opts ...llm.Opt) (*Response, error) { + // Apply options + opt, err := llm.ApplyOpts(opts...) + if err != nil { + return nil, err + } + + // Request + req, err := client.NewJSONRequest(reqMessages{ + Model: context.(*session).model.Name(), + MaxTokens: optMaxTokens(context.(*session).model, opt), + Metadata: optMetadata(opt), + StopSequences: optStopSequences(opt), + Stream: optStream(opt), + System: optSystemPrompt(opt), + Temperature: optTemperature(opt), + TopK: optTopK(opt), + TopP: optTopP(opt), + Messages: context.(*session).seq, + Tools: optTools(anthropic, opt), + ToolChoice: optToolChoice(opt), + }) + if err != nil { + return nil, err + } + + // Stream + var response Response + reqopts := []client.RequestOpt{ + client.OptPath("messages"), + } + if optStream(opt) { + reqopts = append(reqopts, client.OptTextStreamCallback(func(evt client.TextStreamEvent) error { + if err := streamEvent(&response, evt); err != nil { + return err + } + if fn := opt.StreamFn(); fn != nil { + fn(&response) + } + return nil + })) + } + + // Response + if err := anthropic.DoWithContext(ctx, req, &response, reqopts...); err != nil { + return nil, err + } + + // Return success + return &response, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +// Handle streaming events +func streamEvent(response *Response, evt client.TextStreamEvent) error { + switch evt.Event { + case "message_start": + // Start of a message + var r struct { + Type string `json:"type"` + Response Response `json:"message"` + } + if err := evt.Json(&r); err != nil { + return err + } else { + response.Id = r.Response.Id + response.Type = r.Response.Type + response.Model = r.Response.Model + response.Message = r.Response.Message + response.Metrics = r.Response.Metrics + response.Reason = r.Response.Reason + response.StopSequence = r.Response.StopSequence + } + case "content_block_start": + // Start of a content block, append to response + var r struct { + Type string `json:"type"` + Index uint `json:"index"` + Content Content `json:"content_block"` + } + if err := evt.Json(&r); err != nil { + return err + } else if int(r.Index) != len(response.Message.Content) { + return fmt.Errorf("%s: unexpected index %d", r.Type, r.Index) + } else { + response.Message.Content = append(response.Message.Content, &r.Content) + } + case "content_block_delta": + // Continuation of a content block, append to content + var r struct { + Type string `json:"type"` + Index uint `json:"index"` + Content Content `json:"delta"` + } + if err := evt.Json(&r); err != nil { + return err + } else if int(r.Index) != len(response.Message.Content)-1 { + return fmt.Errorf("%s: unexpected index %d", r.Type, r.Index) + } else if content, err := appendDelta(response.Message.Content, &r.Content); err != nil { + return err + } else { + response.Message.Content = content + } + case "content_block_stop": + // End of a content block + var r struct { + Type string `json:"type"` + Index uint `json:"index"` + } + if err := evt.Json(&r); err != nil { + return err + } else if int(r.Index) != len(response.Message.Content)-1 { + return fmt.Errorf("%s: unexpected index %d", r.Type, r.Index) + } + // We need to convert the partial_json response into a full json object + content := response.Message.Content[r.Index] + if content.Type == "tool_use" && content.InputJson != "" { + if err := json.Unmarshal([]byte(content.InputJson), &content.Input); err != nil { + return err + } + } + case "message_delta": + // Message update + var r struct { + Type string `json:"type"` + Delta Response `json:"delta"` + Usage Metrics `json:"usage"` + } + if err := evt.Json(&r); err != nil { + return err + } + + // Update stop reason + response.Reason = r.Delta.Reason + response.StopSequence = r.Delta.StopSequence + + // Update metrics + response.Metrics.InputTokens += r.Usage.InputTokens + response.Metrics.OutputTokens += r.Usage.OutputTokens + response.Metrics.CacheCreationInputTokens += r.Usage.CacheCreationInputTokens + response.Metrics.CacheReadInputTokens += r.Usage.CacheReadInputTokens + case "message_stop": + // NO-OP + return nil + case "ping": + // NO-OP + return nil + default: + // NO-OP + return nil + } + + // Return success + return nil +} + +// Append delta to content +func appendDelta(content []*Content, delta *Content) ([]*Content, error) { + if len(content) == 0 { + return nil, fmt.Errorf("unexpected delta") + } + + // Get the content block we want to append to + last := content[len(content)-1] + + // Append text_delta + switch { + case last.Type == "text" && delta.Type == "text_delta": + last.Text += delta.Text + case last.Type == "tool_use" && delta.Type == "input_json_delta": + last.InputJson += delta.InputJson + default: + return nil, fmt.Errorf("unexpected delta %s for %s", delta.Type, last.Type) + } + + // Return the content + return content, nil +} diff --git a/pkg/anthropic/completion_test.go b/pkg/anthropic/completion_test.go new file mode 100644 index 0000000..2726c7e --- /dev/null +++ b/pkg/anthropic/completion_test.go @@ -0,0 +1,222 @@ +package anthropic_test + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + + // Packages + + llm "github.com/mutablelogic/go-llm" + anthropic "github.com/mutablelogic/go-llm/pkg/anthropic" + "github.com/mutablelogic/go-llm/pkg/tool" + assert "github.com/stretchr/testify/assert" +) + +func Test_chat_001(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "claude-3-5-haiku-20241022") + + if assert.NotNil(model) { + response, err := client.Messages(context.TODO(), model.UserPrompt("Hello, how are you?")) + assert.NoError(err) + assert.NotEmpty(response) + t.Log(response) + } +} + +func Test_chat_002(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "claude-3-5-haiku-20241022") + if !assert.NotNil(model) { + t.FailNow() + } + + t.Run("Temperature", func(t *testing.T) { + r, err := client.Messages(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithTemperature(0.5)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("TopP", func(t *testing.T) { + r, err := client.Messages(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithTopP(0.5)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("TopK", func(t *testing.T) { + r, err := client.Messages(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithTopK(90)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + + t.Run("MaxTokens", func(t *testing.T) { + r, err := client.Messages(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithMaxTokens(10)) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("Stream", func(t *testing.T) { + r, err := client.Messages(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithStream(func(r llm.Completion) { + t.Log(r.Role(), "=>", r.Text(0)) + })) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("Stop", func(t *testing.T) { + r, err := client.Messages(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithStopSequence("weather")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("System", func(t *testing.T) { + r, err := client.Messages(context.TODO(), model.UserPrompt("What is the temperature in London?"), llm.WithSystemPrompt("You reply in shakespearian language, in one sentence")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) + t.Run("User", func(t *testing.T) { + r, err := client.Messages(context.TODO(), model.UserPrompt("What is the temperature in London?"), anthropic.WithUser("username")) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r) + } + }) +} + +func Test_chat_003(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "claude-3-5-sonnet-20241022") + if !assert.NotNil(model) { + t.FailNow() + } + + t.Run("ImageCaption", func(t *testing.T) { + f, err := os.Open("testdata/guggenheim.jpg") + if !assert.NoError(err) { + t.FailNow() + } + defer f.Close() + + // Describe an image + r, err := client.Messages(context.TODO(), model.UserPrompt("Provide a short caption for this image", llm.WithAttachment(f))) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r.Text(0)) + } + }) + + t.Run("DocSummarize", func(t *testing.T) { + f, err := os.Open("testdata/LICENSE") + if !assert.NoError(err) { + t.FailNow() + } + defer f.Close() + + // Summarize a document + r, err := client.Messages(context.TODO(), model.UserPrompt("Summarize this document", llm.WithAttachment(f))) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + assert.NotEmpty(r.Text(0)) + t.Log(r.Text(0)) + } + }) +} + +func Test_chat_004(t *testing.T) { + assert := assert.New(t) + model := client.Model(context.TODO(), "claude-3-5-haiku-20241022") + if !assert.NotNil(model) { + t.FailNow() + } + + toolkit := tool.NewToolKit() + toolkit.Register(&weather{}) + + t.Run("ToolChoiceAuto", func(t *testing.T) { + // Get the weather for a city + r, err := client.Messages( + context.TODO(), + model.UserPrompt("What is the weather in the capital city of germany?"), + llm.WithToolKit(toolkit), + llm.WithToolChoice("auto"), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + + calls := r.ToolCalls(0) + assert.NotEmpty(calls) + + var w weather + assert.NoError(calls[0].Decode(&w)) + assert.Equal("berlin", strings.ToLower(w.City)) + } + }) + t.Run("ToolChoiceFunc", func(t *testing.T) { + // Get the weather for a city + r, err := client.Messages( + context.TODO(), + model.UserPrompt("What is the weather in the capital city of germany?"), + llm.WithToolKit(toolkit), + llm.WithToolChoice("weather_in_city"), + ) + if assert.NoError(err) { + assert.Equal("assistant", r.Role()) + assert.Equal(1, r.Num()) + + calls := r.ToolCalls(0) + assert.NotEmpty(calls) + + var w weather + assert.NoError(calls[0].Decode(&w)) + assert.Equal("berlin", strings.ToLower(w.City)) + } + }) +} + +type weather struct { + City string `json:"city" help:"The city to get the weather for"` +} + +func (weather) Name() string { + return "weather_in_city" +} + +func (weather) Description() string { + return "Get the weather for a city" +} + +func (w weather) Run(ctx context.Context) (any, error) { + return fmt.Sprintf("The weather in %q is sunny and warm", w.City), nil +} diff --git a/pkg/anthropic/message.go b/pkg/anthropic/message.go index 3a8acdd..33aecb3 100644 --- a/pkg/anthropic/message.go +++ b/pkg/anthropic/message.go @@ -2,22 +2,28 @@ package anthropic import ( "encoding/json" - "net/http" "strings" // Packages llm "github.com/mutablelogic/go-llm" + "github.com/mutablelogic/go-llm/pkg/tool" ) /////////////////////////////////////////////////////////////////////////////// // TYPES // Message with text or object content -type MessageMeta struct { +type Message struct { + RoleContent +} + +type RoleContent struct { Role string `json:"role"` Content []*Content `json:"content,omitempty"` } +var _ llm.Completion = (*Message)(nil) + type Content struct { Type string `json:"type"` // image, document, text, tool_use ContentText @@ -64,18 +70,34 @@ type contentcitation struct { Enabled bool `json:"enabled"` // true } +/////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +var ( + supportedAttachments = map[string]string{ + "image/jpeg": "image", + "image/png": "image", + "image/gif": "image", + "image/webp": "image", + "application/pdf": "document", + "text/plain": "text", + } +) + /////////////////////////////////////////////////////////////////////////////// // LIFECYCLE -// Return a Content object with text content +// Return a content object with text content func NewTextContent(v string) *Content { - content := new(Content) - content.Type = "text" - content.ContentText.Text = v - return content + return &Content{ + Type: "text", + ContentText: ContentText{ + Text: v, + }, + } } -// Return a Content object with tool result +// Return a content object with tool result func NewToolResultContent(v llm.ToolResult) *Content { content := new(Content) content.Type = "tool_result" @@ -93,51 +115,10 @@ func NewToolResultContent(v llm.ToolResult) *Content { return content } -/////////////////////////////////////////////////////////////////////////////// -// STRINGIFY - -func (m MessageMeta) String() string { - data, err := json.MarshalIndent(m, "", " ") - if err != nil { - return err.Error() - } - return string(data) -} - -/////////////////////////////////////////////////////////////////////////////// -// PRIVATE METHODS - -func (m MessageMeta) Text() string { - if len(m.Content) == 0 { - return "" - } - var text []string - for _, content := range m.Content { - if content.Type == "text" { - text = append(text, content.ContentText.Text) - } - } - return strings.Join(text, "\n") -} - -/////////////////////////////////////////////////////////////////////////////// -// PRIVATE METHODS - -var ( - supportedAttachments = map[string]string{ - "image/jpeg": "image", - "image/png": "image", - "image/gif": "image", - "image/webp": "image", - "application/pdf": "document", - "text/plain": "text", - } -) - -// Read content from an io.Reader -func attachmentContent(attachment *llm.Attachment, ephemeral, citations bool) (*Content, error) { +// Make attachment content +func NewAttachment(attachment *llm.Attachment, ephemeral, citations bool) (*Content, error) { // Detect mimetype - mimetype := http.DetectContentType(attachment.Data()) + mimetype := attachment.Type() if strings.HasPrefix(mimetype, "text/") { // Switch to text/plain - TODO: charsets? mimetype = "text/plain" @@ -191,3 +172,53 @@ func attachmentContent(attachment *llm.Attachment, ephemeral, citations bool) (* // Return success return content, nil } + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (m Message) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - MESSAGE + +func (m Message) Num() int { + return 1 +} + +func (m Message) Role() string { + return m.RoleContent.Role +} + +func (m Message) Text(index int) string { + if index != 0 { + return "" + } + var text []string + for _, content := range m.RoleContent.Content { + if content.Type == "text" { + text = append(text, content.ContentText.Text) + } + } + return strings.Join(text, "\n") +} + +func (m Message) ToolCalls(index int) []llm.ToolCall { + if index != 0 { + return nil + } + + // Gather tool calls + var result []llm.ToolCall + for _, content := range m.Content { + if content.Type == "tool_use" { + result = append(result, tool.NewCall(content.ContentTool.Id, content.ContentTool.Name, content.ContentTool.Input)) + } + } + return result +} diff --git a/pkg/anthropic/messages.go b/pkg/anthropic/messages.go deleted file mode 100644 index d8becd3..0000000 --- a/pkg/anthropic/messages.go +++ /dev/null @@ -1,239 +0,0 @@ -package anthropic - -import ( - "context" - "encoding/json" - "fmt" - - // Packages - client "github.com/mutablelogic/go-client" - llm "github.com/mutablelogic/go-llm" -) - -/////////////////////////////////////////////////////////////////////////////// -// TYPES - -// Messages Response -type Response struct { - Type string `json:"type"` - Model string `json:"model"` - Id string `json:"id"` - MessageMeta - Reason string `json:"stop_reason,omitempty"` - StopSequence *string `json:"stop_sequence,omitempty"` - Metrics `json:"usage,omitempty"` -} - -// Metrics -type Metrics struct { - CacheCreationInputTokens uint `json:"cache_creation_input_tokens,omitempty"` - CacheReadInputTokens uint `json:"cache_read_input_tokens,omitempty"` - InputTokens uint `json:"input_tokens,omitempty"` - OutputTokens uint `json:"output_tokens,omitempty"` -} - -/////////////////////////////////////////////////////////////////////////////// -// STRINGIFY - -func (r Response) String() string { - data, err := json.MarshalIndent(r, "", " ") - if err != nil { - return err.Error() - } - return string(data) -} - -/////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - -type reqMessages struct { - Model string `json:"model"` - MaxTokens uint `json:"max_tokens,omitempty"` - Metadata *optmetadata `json:"metadata,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Stream bool `json:"stream,omitempty"` - System string `json:"system,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK uint64 `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Messages []*MessageMeta `json:"messages"` - Tools []llm.Tool `json:"tools,omitempty"` -} - -func (anthropic *Client) Messages(ctx context.Context, context llm.Context, opts ...llm.Opt) (*Response, error) { - // Apply options - opt, err := llm.ApplyOpts(opts...) - if err != nil { - return nil, err - } - - // Request - req, err := client.NewJSONRequest(reqMessages{ - Model: context.(*session).model.Name(), - Messages: context.(*session).seq, - Tools: optTools(anthropic, opt), - MaxTokens: optMaxTokens(context.(*session).model, opt), - Metadata: optMetadata(opt), - StopSequences: optStopSequences(opt), - Stream: optStream(opt), - System: optSystemPrompt(opt), - Temperature: optTemperature(opt), - TopK: optTopK(opt), - TopP: optTopP(opt), - }) - if err != nil { - return nil, err - } - - // Stream - var response Response - reqopts := []client.RequestOpt{ - client.OptPath("messages"), - } - if optStream(opt) { - // Append delta to content - appendDelta := func(content []*Content, delta *Content) ([]*Content, error) { - if len(content) == 0 { - return nil, fmt.Errorf("unexpected delta") - } - - // Get the content block we want to append to - last := content[len(content)-1] - - // Append text_delta - switch { - case last.Type == "text" && delta.Type == "text_delta": - last.Text += delta.Text - case last.Type == "tool_use" && delta.Type == "input_json_delta": - last.InputJson += delta.InputJson - default: - return nil, fmt.Errorf("unexpected delta %s for %s", delta.Type, last.Type) - } - - // Return the content - return content, nil - } - reqopts = append(reqopts, client.OptTextStreamCallback(func(evt client.TextStreamEvent) error { - switch evt.Event { - case "message_start": - // Start of a message - var r struct { - Type string `json:"type"` - Response Response `json:"message"` - } - if err := evt.Json(&r); err != nil { - return err - } else { - response = r.Response - } - case "content_block_start": - // Start of a content block, append to response - var r struct { - Type string `json:"type"` - Index uint `json:"index"` - Content Content `json:"content_block"` - } - if err := evt.Json(&r); err != nil { - return err - } else if int(r.Index) != len(response.MessageMeta.Content) { - return fmt.Errorf("%s: unexpected index %d", r.Type, r.Index) - } else { - response.MessageMeta.Content = append(response.MessageMeta.Content, &r.Content) - } - case "content_block_delta": - // Continuation of a content block, append to content - var r struct { - Type string `json:"type"` - Index uint `json:"index"` - Content Content `json:"delta"` - } - if err := evt.Json(&r); err != nil { - return err - } else if int(r.Index) != len(response.MessageMeta.Content)-1 { - return fmt.Errorf("%s: unexpected index %d", r.Type, r.Index) - } else if content, err := appendDelta(response.MessageMeta.Content, &r.Content); err != nil { - return err - } else { - response.MessageMeta.Content = content - } - case "content_block_stop": - // End of a content block - var r struct { - Type string `json:"type"` - Index uint `json:"index"` - } - if err := evt.Json(&r); err != nil { - return err - } else if int(r.Index) != len(response.MessageMeta.Content)-1 { - return fmt.Errorf("%s: unexpected index %d", r.Type, r.Index) - } - // We need to convert the partial_json response into a full json object - content := response.MessageMeta.Content[r.Index] - if content.Type == "tool_use" && content.InputJson != "" { - if err := json.Unmarshal([]byte(content.InputJson), &content.Input); err != nil { - return err - } - } - case "message_delta": - // Message update - var r struct { - Type string `json:"type"` - Delta Response `json:"delta"` - Usage Metrics `json:"usage"` - } - if err := evt.Json(&r); err != nil { - return err - } - - // Update stop reason - response.Reason = r.Delta.Reason - response.StopSequence = r.Delta.StopSequence - - // Update metrics - response.Metrics.InputTokens += r.Usage.InputTokens - response.Metrics.OutputTokens += r.Usage.OutputTokens - response.Metrics.CacheCreationInputTokens += r.Usage.CacheCreationInputTokens - response.Metrics.CacheReadInputTokens += r.Usage.CacheReadInputTokens - case "message_stop": - // NO-OP - return nil - case "ping": - // NO-OP - return nil - default: - // NO-OP - return nil - } - - if fn := opt.StreamFn(); fn != nil { - fn(&response) - } - - // Return success - return nil - })) - } - - // Response - if err := anthropic.DoWithContext(ctx, req, &response, reqopts...); err != nil { - return nil, err - } - - // Return success - return &response, nil -} - -/////////////////////////////////////////////////////////////////////////////// -// INTERFACE - CONTEXT CONTENT - -func (response Response) Role() string { - return response.MessageMeta.Role -} - -func (response Response) Text() string { - return response.MessageMeta.Text() -} - -func (response Response) ToolCalls() []llm.ToolCall { - return nil -} diff --git a/pkg/anthropic/messages_test.go b/pkg/anthropic/messages_test.go deleted file mode 100644 index 41e9152..0000000 --- a/pkg/anthropic/messages_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package anthropic_test - -import ( - "context" - "encoding/json" - "log" - "os" - "testing" - - // Packages - opts "github.com/mutablelogic/go-client" - llm "github.com/mutablelogic/go-llm" - anthropic "github.com/mutablelogic/go-llm/pkg/anthropic" - tool "github.com/mutablelogic/go-llm/pkg/tool" - assert "github.com/stretchr/testify/assert" -) - -func Test_messages_001(t *testing.T) { - assert := assert.New(t) - client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) - if assert.NoError(err) { - assert.NotNil(client) - t.Log(client) - } - - model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") - if assert.NoError(err) { - assert.NotNil(client) - t.Log(client) - } else { - t.FailNow() - } - - f, err := os.Open("testdata/guggenheim.jpg") - if !assert.NoError(err) { - t.FailNow() - } - defer f.Close() - - response, err := client.Messages(context.TODO(), model.UserPrompt("what is this image?", llm.WithAttachment(f))) - if assert.NoError(err) { - t.Log(response) - } -} - -func Test_messages_002(t *testing.T) { - assert := assert.New(t) - client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) - if assert.NoError(err) { - assert.NotNil(client) - t.Log(client) - } - - model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") - if assert.NoError(err) { - assert.NotNil(client) - t.Log(client) - } else { - t.FailNow() - } - - f, err := os.Open("testdata/LICENSE") - if !assert.NoError(err) { - t.FailNow() - } - defer f.Close() - - response, err := client.Messages(context.TODO(), model.UserPrompt("summarize this document for me", llm.WithAttachment(f))) - if assert.NoError(err) { - t.Log(response) - } -} - -func Test_messages_003(t *testing.T) { - assert := assert.New(t) - client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) - if assert.NoError(err) { - assert.NotNil(client) - t.Log(client) - } - - model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") - if assert.NoError(err) { - assert.NotNil(client) - t.Log(client) - } else { - t.FailNow() - } - - response, err := client.Messages(context.TODO(), model.UserPrompt("why is the sky blue"), llm.WithStream(func(r llm.ContextContent) { - t.Log(r) - })) - if assert.NoError(err) { - t.Log(response) - } -} - -func Test_messages_004(t *testing.T) { - assert := assert.New(t) - client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) - if assert.NoError(err) { - assert.NotNil(client) - t.Log(client) - } - - model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") - if assert.NoError(err) { - assert.NotNil(client) - t.Log(client) - } else { - t.FailNow() - } - - toolkit := tool.NewToolKit() - if err := toolkit.Register(new(weather)); !assert.NoError(err) { - t.FailNow() - } - - response, err := client.Messages(context.TODO(), model.UserPrompt("why is the sky blue"), llm.WithToolKit(toolkit)) - if assert.NoError(err) { - t.Log(response) - } -} - -func Test_messages_005(t *testing.T) { - assert := assert.New(t) - client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) - if assert.NoError(err) { - assert.NotNil(client) - t.Log(client) - } - - model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") - if assert.NoError(err) { - assert.NotNil(client) - t.Log(client) - } else { - t.FailNow() - } - - toolkit := tool.NewToolKit() - if err := toolkit.Register(new(weather)); !assert.NoError(err) { - t.FailNow() - } - - response, err := client.Messages(context.TODO(), model.UserPrompt("why is the sky blue"), llm.WithStream(func(r llm.ContextContent) { - t.Log(r) - }), llm.WithToolKit(toolkit)) - if assert.NoError(err) { - t.Log(response) - } -} - -//////////////////////////////////////////////////////////////////////////////// -// TOOLS - -type weather struct { - Location string `json:"location" name:"location" help:"The location to get the weather for" required:"true"` -} - -func (*weather) Name() string { - return "weather_in_location" -} - -func (*weather) Description() string { - return "Get the weather in a location" -} - -func (weather *weather) String() string { - data, err := json.MarshalIndent(weather, "", " ") - if err != nil { - return err.Error() - } - return string(data) -} - -func (weather *weather) Run(ctx context.Context) (any, error) { - log.Println("weather_in_location", "=>", weather) - return "very sunny today", nil -} diff --git a/pkg/anthropic/model.go b/pkg/anthropic/model.go index 04baa49..2b50ee6 100644 --- a/pkg/anthropic/model.go +++ b/pkg/anthropic/model.go @@ -2,6 +2,7 @@ package anthropic import ( "context" + "encoding/json" "net/url" "time" @@ -13,16 +14,14 @@ import ( /////////////////////////////////////////////////////////////////////////////// // TYPES -// model is the implementation of the llm.Model interface type model struct { - client *Client - ModelMeta + *Client `json:"-"` + meta Model } var _ llm.Model = (*model)(nil) -// ModelMeta is the metadata for an anthropic model -type ModelMeta struct { +type Model struct { Name string `json:"id"` Description string `json:"display_name,omitempty"` Type string `json:"type,omitempty"` @@ -30,65 +29,44 @@ type ModelMeta struct { } /////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - -// Agent interface -func (anthropic *Client) Models(ctx context.Context) ([]llm.Model, error) { - // Cache models - if len(anthropic.cache) == 0 { - models, err := anthropic.ListModels(ctx) - if err != nil { - return nil, err - } - for _, model := range models { - name := model.Name() - anthropic.cache[name] = model - } - } +// STRINGIFY - // Return models - result := make([]llm.Model, 0, len(anthropic.cache)) - for _, model := range anthropic.cache { - result = append(result, model) - } - return result, nil +func (m model) MarshalJSON() ([]byte, error) { + return json.Marshal(m.meta) } -// Agent interface -func (anthropic *Client) Model(ctx context.Context, model string) llm.Model { - // Cache models - if len(anthropic.cache) == 0 { - _, err := anthropic.Models(ctx) - if err != nil { - panic(err) - } +func (m model) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() } - - // Return model - return anthropic.cache[model] + return string(data) } +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - API + // Get a model by name func (anthropic *Client) GetModel(ctx context.Context, name string) (llm.Model, error) { - var response ModelMeta + var response Model if err := anthropic.DoWithContext(ctx, nil, &response, client.OptPath("models", name)); err != nil { return nil, err } // Return success - return &model{client: anthropic, ModelMeta: response}, nil + return &model{anthropic, response}, nil } // List models func (anthropic *Client) ListModels(ctx context.Context) ([]llm.Model, error) { - // Send the request var response struct { - Body []ModelMeta `json:"data"` - HasMore bool `json:"has_more"` - FirstId string `json:"first_id"` - LastId string `json:"last_id"` + Body []Model `json:"data"` + HasMore bool `json:"has_more"` + FirstId string `json:"first_id"` + LastId string `json:"last_id"` } + // Request request := url.Values{} result := make([]llm.Model, 0, 100) for { @@ -98,10 +76,7 @@ func (anthropic *Client) ListModels(ctx context.Context) ([]llm.Model, error) { // Convert to llm.Model for _, meta := range response.Body { - result = append(result, &model{ - client: anthropic, - ModelMeta: meta, - }) + result = append(result, &model{anthropic, meta}) } // If there are no more models, return @@ -118,7 +93,7 @@ func (anthropic *Client) ListModels(ctx context.Context) ([]llm.Model, error) { // Return the name of a model func (model *model) Name() string { - return model.ModelMeta.Name + return model.meta.Name } // Embedding vector generation - not supported on Anthropic diff --git a/pkg/anthropic/model_test.go b/pkg/anthropic/model_test.go new file mode 100644 index 0000000..54aca5d --- /dev/null +++ b/pkg/anthropic/model_test.go @@ -0,0 +1,22 @@ +package anthropic_test + +import ( + "context" + "encoding/json" + "testing" + + // Packages + assert "github.com/stretchr/testify/assert" +) + +func Test_models_001(t *testing.T) { + assert := assert.New(t) + + response, err := client.ListModels(context.TODO()) + assert.NoError(err) + assert.NotEmpty(response) + + data, err := json.MarshalIndent(response, "", " ") + assert.NoError(err) + t.Log(string(data)) +} diff --git a/pkg/anthropic/opt.go b/pkg/anthropic/opt.go index 5461b59..862721b 100644 --- a/pkg/anthropic/opt.go +++ b/pkg/anthropic/opt.go @@ -15,7 +15,7 @@ type optmetadata struct { } //////////////////////////////////////////////////////////////////////////////// -// OPTIONS +// PUBLIC METHODS func WithUser(v string) llm.Opt { return func(o *llm.Opts) error { @@ -49,6 +49,13 @@ func optEphemeral(opt *llm.Opts) bool { return opt.GetBool("ephemeral") } +func optMetadata(opt *llm.Opts) *optmetadata { + if user, ok := opt.Get("user").(string); ok { + return &optmetadata{User: user} + } + return nil +} + func optTools(agent *Client, opts *llm.Opts) []llm.Tool { toolkit := opts.ToolKit() if toolkit == nil { @@ -57,25 +64,44 @@ func optTools(agent *Client, opts *llm.Opts) []llm.Tool { return toolkit.Tools(agent) } -func optMaxTokens(model llm.Model, opt *llm.Opts) uint { +func optToolChoice(opts *llm.Opts) any { + choices, ok := opts.Get("tool_choice").([]string) + if !ok || len(choices) == 0 { + return nil + } + + // We only support one choice + var result struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` + DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"` + } + choice := strings.TrimSpace(strings.ToLower(choices[0])) + switch choice { + case "": + return nil + case "auto", "any": + result.Type = choice + default: + result.Type = "tool" + result.Name = choice + } + return result +} + +func optMaxTokens(model llm.Model, opt *llm.Opts) uint64 { + if opt.Has("max_tokens") { + return opt.GetUint64("max_tokens") + } // https://docs.anthropic.com/en/docs/about-claude/models switch { - case strings.Contains(model.Name(), "claude-3-5-haiku"): - return 8192 - case strings.Contains(model.Name(), "claude-3-5-sonnet"): + case strings.HasPrefix(model.Name(), "claude-3-5"): return 8192 default: return 4096 } } -func optMetadata(opt *llm.Opts) *optmetadata { - if user, ok := opt.Get("user").(string); ok { - return &optmetadata{User: user} - } - return nil -} - func optStopSequences(opt *llm.Opts) []string { if opt.Has("stop") { if stop, ok := opt.Get("stop").([]string); ok { diff --git a/pkg/anthropic/session.go b/pkg/anthropic/session.go index d545582..bc7d062 100644 --- a/pkg/anthropic/session.go +++ b/pkg/anthropic/session.go @@ -6,7 +6,6 @@ import ( // Packages llm "github.com/mutablelogic/go-llm" - tool "github.com/mutablelogic/go-llm/pkg/tool" ) ////////////////////////////////////////////////////////////////// @@ -15,7 +14,7 @@ import ( type session struct { model *model opts []llm.Opt - seq []*MessageMeta + seq []*Message } var _ llm.Context = (*session)(nil) @@ -28,6 +27,7 @@ func (model *model) Context(opts ...llm.Opt) llm.Context { return &session{ model: model, opts: opts, + seq: make([]*Message, 0, 10), } } @@ -36,13 +36,13 @@ func (model *model) Context(opts ...llm.Opt) llm.Context { func (model *model) UserPrompt(prompt string, opts ...llm.Opt) llm.Context { context := model.Context(opts...) - meta, err := userPrompt(prompt, opts...) + message, err := userPrompt(prompt, opts...) if err != nil { panic(err) } // Add to the sequence - context.(*session).seq = append(context.(*session).seq, meta) + context.(*session).seq = append(context.(*session).seq, message) // Return success return context @@ -68,69 +68,64 @@ func (session session) String() string { ////////////////////////////////////////////////////////////////// // PUBLIC METHODS +// Return the number of completions +func (session *session) Num() int { + if len(session.seq) == 0 { + return 0 + } + return 1 +} + // Return the role of the last message func (session *session) Role() string { if len(session.seq) == 0 { return "" } - return session.seq[len(session.seq)-1].Role + return session.seq[len(session.seq)-1].Role() } // Return the text of the last message -func (session *session) Text() string { +func (session *session) Text(index int) string { if len(session.seq) == 0 { return "" } - meta := session.seq[len(session.seq)-1] - return meta.Text() + return session.seq[len(session.seq)-1].Text(index) } // Return the current session tool calls, or empty if no tool calls were made -func (session *session) ToolCalls() []llm.ToolCall { +func (session *session) ToolCalls(index int) []llm.ToolCall { // Sanity check for tool call if len(session.seq) == 0 { return nil } - meta := session.seq[len(session.seq)-1] - if meta.Role != "assistant" { - return nil - } - - // Gather tool calls - var result []llm.ToolCall - for _, content := range meta.Content { - if content.Type == "tool_use" { - result = append(result, tool.NewCall(content.ContentTool.Id, content.ContentTool.Name, content.ContentTool.Input)) - } - } - return result + return session.seq[len(session.seq)-1].ToolCalls(index) } // Generate a response from a user prompt (with attachments) and // other empheral options func (session *session) FromUser(ctx context.Context, prompt string, opts ...llm.Opt) error { - // Append the user prompt to the sequence - meta, err := userPrompt(prompt, opts...) + message, err := userPrompt(prompt, opts...) if err != nil { return err - } else { - session.seq = append(session.seq, meta) } + // Append the user prompt to the sequence + session.seq = append(session.seq, message) + // The options come from the session options and the user options chatopts := make([]llm.Opt, 0, len(session.opts)+len(opts)) chatopts = append(chatopts, session.opts...) chatopts = append(chatopts, opts...) // Call the 'chat' method - client := session.model.client - r, err := client.Messages(ctx, session, chatopts...) + r, err := session.model.Messages(ctx, session, chatopts...) if err != nil { return err - } else { - session.seq = append(session.seq, &r.MessageMeta) } + // Append the first message from the set of completions + session.seq = append(session.seq, &r.Message) + // Return success return nil } @@ -138,22 +133,23 @@ func (session *session) FromUser(ctx context.Context, prompt string, opts ...llm // Generate a response from a tool, passing the call identifier or // function name, and the result func (session *session) FromTool(ctx context.Context, results ...llm.ToolResult) error { - meta, err := toolResults(results...) + message, err := toolResults(results...) if err != nil { return err - } else { - session.seq = append(session.seq, meta) } + // Append the tool results to the sequence + session.seq = append(session.seq, message) + // Call the 'chat' method - client := session.model.client - r, err := client.Messages(ctx, session, session.opts...) + r, err := session.model.Messages(ctx, session, session.opts...) if err != nil { return err - } else { - session.seq = append(session.seq, &r.MessageMeta) } + // Append the first message from the set of completions + session.seq = append(session.seq, &r.Message) + // Return success return nil } @@ -161,53 +157,53 @@ func (session *session) FromTool(ctx context.Context, results ...llm.ToolResult) /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS -func userPrompt(prompt string, opts ...llm.Opt) (*MessageMeta, error) { - // Apply attachments - opt, err := llm.ApplyOpts(opts...) +func userPrompt(prompt string, opts ...llm.Opt) (*Message, error) { + // Get attachments + opt, err := llm.ApplyPromptOpts(opts...) if err != nil { return nil, err } - // Get attachments + // Get attachments, allocate content attachments := opt.Attachments() + content := make([]*Content, 1, len(attachments)+1) - // Create user message - meta := MessageMeta{ - Role: "user", - Content: make([]*Content, 1, len(attachments)+1), - } - - // Append the text - meta.Content[0] = NewTextContent(prompt) - - // Append any additional data + // Append the text and the attachments + content[0] = NewTextContent(prompt) for _, attachment := range attachments { - content, err := attachmentContent(attachment, optEphemeral(opt), optCitations(opt)) + contentData, err := NewAttachment(attachment, optEphemeral(opt), optCitations(opt)) if err != nil { return nil, err } - meta.Content = append(meta.Content, content) + content = append(content, contentData) } // Return success - return &meta, nil + return &Message{ + RoleContent: RoleContent{ + Role: "user", + Content: content, + }, + }, nil } -func toolResults(results ...llm.ToolResult) (*MessageMeta, error) { +func toolResults(results ...llm.ToolResult) (*Message, error) { // Check for no results if len(results) == 0 { return nil, llm.ErrBadParameter.Withf("No tool results") } // Create user message - meta := MessageMeta{ - Role: "user", - Content: make([]*Content, 0, len(results)), + message := Message{ + RoleContent{ + Role: "user", + Content: make([]*Content, 0, len(results)), + }, } for _, result := range results { - meta.Content = append(meta.Content, NewToolResultContent(result)) + message.RoleContent.Content = append(message.RoleContent.Content, NewToolResultContent(result)) } // Return success - return &meta, nil + return &message, nil } diff --git a/pkg/anthropic/session_test.go b/pkg/anthropic/session_test.go index e27c078..0835e16 100644 --- a/pkg/anthropic/session_test.go +++ b/pkg/anthropic/session_test.go @@ -2,91 +2,55 @@ package anthropic_test import ( "context" - "os" "testing" // Packages - opts "github.com/mutablelogic/go-client" llm "github.com/mutablelogic/go-llm" - anthropic "github.com/mutablelogic/go-llm/pkg/anthropic" tool "github.com/mutablelogic/go-llm/pkg/tool" assert "github.com/stretchr/testify/assert" ) func Test_session_001(t *testing.T) { - client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) - if err != nil { + assert := assert.New(t) + model := client.Model(context.TODO(), "claude-3-5-haiku-20241022") + if !assert.NotNil(model) { t.FailNow() } - model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") - if err != nil { - t.FailNow() + session := model.Context() + if assert.NotNil(session) { + err := session.FromUser(context.TODO(), "Hello, how are you?") + assert.NoError(err) + t.Log(session) } - - // Session with a single user prompt - streaming - t.Run("stream", func(t *testing.T) { - assert := assert.New(t) - session := model.Context(llm.WithStream(func(stream llm.ContextContent) { - t.Log("SESSION DELTA", stream) - })) - assert.NotNil(session) - - err := session.FromUser(context.TODO(), "Why is the grass green?") - if !assert.NoError(err) { - t.FailNow() - } - assert.Equal("assistant", session.Role()) - assert.NotEmpty(session.Text()) - }) - - // Session with a single user prompt - not streaming - t.Run("nostream", func(t *testing.T) { - assert := assert.New(t) - session := model.Context() - assert.NotNil(session) - - err := session.FromUser(context.TODO(), "Why is the sky blue?") - if !assert.NoError(err) { - t.FailNow() - } - assert.Equal("assistant", session.Role()) - assert.NotEmpty(session.Text()) - }) } func Test_session_002(t *testing.T) { - client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) - if err != nil { + assert := assert.New(t) + model := client.Model(context.TODO(), "claude-3-5-haiku-20241022") + if !assert.NotNil(model) { t.FailNow() } - model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") - if err != nil { + toolkit := tool.NewToolKit() + toolkit.Register(&weather{}) + + session := model.Context(llm.WithToolKit(toolkit)) + if !assert.NotNil(session) { t.FailNow() } - // Session with a tool call - t.Run("toolcall", func(t *testing.T) { - assert := assert.New(t) + assert.NoError(session.FromUser(context.TODO(), "What is the weather like in London today?")) + calls := session.ToolCalls(0) + if assert.Len(calls, 1) { + assert.Equal("weather_in_city", calls[0].Name()) - toolkit := tool.NewToolKit() - if err := toolkit.Register(new(weather)); !assert.NoError(err) { - t.FailNow() - } + result, err := toolkit.Run(context.TODO(), calls...) + assert.NoError(err) + assert.Len(result, 1) - session := model.Context(llm.WithToolKit(toolkit)) - assert.NotNil(session) - - err = session.FromUser(context.TODO(), "What is today's weather, in Berlin?") - if !assert.NoError(err) { - t.FailNow() - } + assert.NoError(session.FromTool(context.TODO(), result...)) + } - result, err := toolkit.Run(context.TODO(), session.ToolCalls()...) - if !assert.NoError(err) { - t.FailNow() - } - assert.NotEmpty(result) - }) + t.Log(session) } diff --git a/pkg/mistral/client.go b/pkg/mistral/client.go index 7e643dc..f70f8bf 100644 --- a/pkg/mistral/client.go +++ b/pkg/mistral/client.go @@ -4,11 +4,11 @@ mistral implements an API client for mistral (https://docs.mistral.ai/api/) package mistral import ( - // Packages "context" - "github.com/mutablelogic/go-client" - "github.com/mutablelogic/go-llm" + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" ) /////////////////////////////////////////////////////////////////////////////// diff --git a/pkg/mistral/chat_completion.go b/pkg/mistral/completion.go similarity index 100% rename from pkg/mistral/chat_completion.go rename to pkg/mistral/completion.go diff --git a/pkg/mistral/chat_completion_test.go b/pkg/mistral/completion_test.go similarity index 98% rename from pkg/mistral/chat_completion_test.go rename to pkg/mistral/completion_test.go index b44d5ff..899d364 100644 --- a/pkg/mistral/chat_completion_test.go +++ b/pkg/mistral/completion_test.go @@ -8,10 +8,9 @@ import ( "testing" // Packages - - "github.com/mutablelogic/go-llm" + llm "github.com/mutablelogic/go-llm" mistral "github.com/mutablelogic/go-llm/pkg/mistral" - "github.com/mutablelogic/go-llm/pkg/tool" + tool "github.com/mutablelogic/go-llm/pkg/tool" assert "github.com/stretchr/testify/assert" ) diff --git a/pkg/mistral/model.go b/pkg/mistral/model.go index 24420b6..ab7374b 100644 --- a/pkg/mistral/model.go +++ b/pkg/mistral/model.go @@ -16,6 +16,8 @@ type model struct { meta Model } +var _ llm.Model = (*model)(nil) + type Model struct { Name string `json:"id"` Description string `json:"description,omitempty"` diff --git a/pkg/mistral/model_test.go b/pkg/mistral/model_test.go index 812be6e..65269e3 100644 --- a/pkg/mistral/model_test.go +++ b/pkg/mistral/model_test.go @@ -6,7 +6,6 @@ import ( "testing" // Packages - assert "github.com/stretchr/testify/assert" )