From 94d0ea53b2245c07e736db03b89d2de08505a91e Mon Sep 17 00:00:00 2001 From: tom-fitz <16616192+tom-fitz@users.noreply.github.com> Date: Tue, 11 Jun 2024 12:31:06 -0600 Subject: [PATCH 1/9] 265: unifying roles, setting default as user --- pkg/api/schemas/chat.go | 12 ++++++++++-- pkg/api/schemas/chat_stream.go | 2 +- pkg/api/schemas/chat_test.go | 18 +++++++++--------- pkg/providers/anthropic/chat.go | 2 +- pkg/providers/azureopenai/chat_stream_test.go | 4 ++-- pkg/providers/azureopenai/client_test.go | 4 ++-- pkg/providers/bedrock/client_test.go | 2 +- pkg/providers/bedrock/testdata/chat.req.json | 2 +- pkg/providers/octoml/client_test.go | 2 +- pkg/providers/ollama/chat.go | 2 +- pkg/providers/ollama/client_test.go | 6 +++--- pkg/providers/openai/chat_stream_test.go | 4 ++-- pkg/providers/openai/chat_test.go | 2 +- 13 files changed, 35 insertions(+), 27 deletions(-) diff --git a/pkg/api/schemas/chat.go b/pkg/api/schemas/chat.go index bb846043..2ca0af06 100644 --- a/pkg/api/schemas/chat.go +++ b/pkg/api/schemas/chat.go @@ -62,7 +62,7 @@ func (r *ChatRequest) Params(modelID string, modelName string) *ChatParams { func NewChatFromStr(message string) *ChatRequest { return &ChatRequest{ Message: ChatMessage{ - "user", + RoleUser, message, }, } @@ -93,10 +93,18 @@ type TokenUsage struct { TotalTokens int `json:"total_tokens"` } +type Role string + +const ( + RoleSystem Role = "system" + RoleUser Role = "user" + RoleAssistant Role = "assistant" +) + // ChatMessage is a message in a chat request. type ChatMessage struct { // The role of the author of this message. One of system, user, or assistant. - Role string `json:"role" validate:"required"` + Role Role `json:"role" validate:"required"` // The content of the message. Content string `json:"content" validate:"required"` } diff --git a/pkg/api/schemas/chat_stream.go b/pkg/api/schemas/chat_stream.go index f7cf8b27..bdcf8fcd 100644 --- a/pkg/api/schemas/chat_stream.go +++ b/pkg/api/schemas/chat_stream.go @@ -30,7 +30,7 @@ func NewChatStreamFromStr(message string) *ChatStreamRequest { return &ChatStreamRequest{ ChatRequest: &ChatRequest{ Message: ChatMessage{ - "user", + RoleUser, message, }, }, diff --git a/pkg/api/schemas/chat_test.go b/pkg/api/schemas/chat_test.go index 9b5ce407..f4cfe9fa 100644 --- a/pkg/api/schemas/chat_test.go +++ b/pkg/api/schemas/chat_test.go @@ -30,7 +30,7 @@ func TestChatRequest_DefaultParams(t *testing.T) { chatReq := ChatRequest{ Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: defaultMessage, }, MessageHistory: []ChatMessage{ @@ -42,7 +42,7 @@ func TestChatRequest_DefaultParams(t *testing.T) { OverrideParams: &map[string]ModelParamsOverride{ modelID: { Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: myModelMessage, }, }, @@ -66,7 +66,7 @@ func TestChatRequest_ModelIDOverride(t *testing.T) { chatReq := ChatRequest{ Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: defaultMessage, }, MessageHistory: []ChatMessage{ @@ -78,7 +78,7 @@ func TestChatRequest_ModelIDOverride(t *testing.T) { OverrideParams: &map[string]ModelParamsOverride{ modelID: { Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: myModelMessage, }, }, @@ -102,7 +102,7 @@ func TestChatRequest_ModelNameOverride(t *testing.T) { chatReq := ChatRequest{ Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: defaultMessage, }, MessageHistory: []ChatMessage{ @@ -114,7 +114,7 @@ func TestChatRequest_ModelNameOverride(t *testing.T) { OverrideParams: &map[string]ModelParamsOverride{ modelName: { Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: myModelMessage, }, }, @@ -139,7 +139,7 @@ func TestChatRequest_ModelNameIDOverride(t *testing.T) { chatReq := ChatRequest{ Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: defaultMessage, }, MessageHistory: []ChatMessage{ @@ -151,13 +151,13 @@ func TestChatRequest_ModelNameIDOverride(t *testing.T) { OverrideParams: &map[string]ModelParamsOverride{ modelName: { Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: myModelNameMessage, }, }, modelID: { Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: myModelIDMessage, }, }, diff --git a/pkg/providers/anthropic/chat.go b/pkg/providers/anthropic/chat.go index 80b45f2b..f0e14fb2 100644 --- a/pkg/providers/anthropic/chat.go +++ b/pkg/providers/anthropic/chat.go @@ -139,7 +139,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche ModelResponse: schemas.ModelResponse{ Metadata: map[string]string{}, Message: schemas.ChatMessage{ - Role: completion.Type, + Role: schemas.Role(completion.Type), Content: completion.Text, }, TokenUsage: schemas.TokenUsage{ diff --git a/pkg/providers/azureopenai/chat_stream_test.go b/pkg/providers/azureopenai/chat_stream_test.go index 5aade1f5..efb70a0c 100644 --- a/pkg/providers/azureopenai/chat_stream_test.go +++ b/pkg/providers/azureopenai/chat_stream_test.go @@ -72,7 +72,7 @@ func TestAzureOpenAIClient_ChatStreamRequest(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} @@ -140,7 +140,7 @@ func TestAzureOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the biggest animal?", }}} diff --git a/pkg/providers/azureopenai/client_test.go b/pkg/providers/azureopenai/client_test.go index 1700bca0..3529413c 100644 --- a/pkg/providers/azureopenai/client_test.go +++ b/pkg/providers/azureopenai/client_test.go @@ -56,7 +56,7 @@ func TestAzureOpenAIClient_ChatRequest(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} @@ -116,7 +116,7 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the dealio?", }}} diff --git a/pkg/providers/bedrock/client_test.go b/pkg/providers/bedrock/client_test.go index cdae1f68..57056150 100644 --- a/pkg/providers/bedrock/client_test.go +++ b/pkg/providers/bedrock/client_test.go @@ -62,7 +62,7 @@ func TestBedrockClient_ChatRequest(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the biggest animal?", }}} diff --git a/pkg/providers/bedrock/testdata/chat.req.json b/pkg/providers/bedrock/testdata/chat.req.json index c2e941d2..9466eda7 100644 --- a/pkg/providers/bedrock/testdata/chat.req.json +++ b/pkg/providers/bedrock/testdata/chat.req.json @@ -2,7 +2,7 @@ "model": "amazon.titan-text-express-v1", "messages": [ { - "role": "user", + "role": schemas.RoleUser, "content": "What's the biggest animal?" } ], diff --git a/pkg/providers/octoml/client_test.go b/pkg/providers/octoml/client_test.go index f35de1f7..353dbe9d 100644 --- a/pkg/providers/octoml/client_test.go +++ b/pkg/providers/octoml/client_test.go @@ -121,7 +121,7 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) { // Create a chat request payload chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the dealeo?", }}} diff --git a/pkg/providers/ollama/chat.go b/pkg/providers/ollama/chat.go index b93f5b10..40356d7b 100644 --- a/pkg/providers/ollama/chat.go +++ b/pkg/providers/ollama/chat.go @@ -172,7 +172,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche Cached: false, ModelResponse: schemas.ModelResponse{ Message: schemas.ChatMessage{ - Role: ollamaCompletion.Message.Role, + Role: schemas.Role(ollamaCompletion.Message.Role), Content: ollamaCompletion.Message.Content, }, TokenUsage: schemas.TokenUsage{ diff --git a/pkg/providers/ollama/client_test.go b/pkg/providers/ollama/client_test.go index e6c584cf..091a8804 100644 --- a/pkg/providers/ollama/client_test.go +++ b/pkg/providers/ollama/client_test.go @@ -57,7 +57,7 @@ func TestOllamaClient_ChatRequest(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the biggest animal?", }}} @@ -85,7 +85,7 @@ func TestOllamaClient_ChatRequest_Non200Response(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} @@ -122,7 +122,7 @@ func TestOllamaClient_ChatRequest_SuccessfulResponse(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} diff --git a/pkg/providers/openai/chat_stream_test.go b/pkg/providers/openai/chat_stream_test.go index 1ab8483b..236b9d6f 100644 --- a/pkg/providers/openai/chat_stream_test.go +++ b/pkg/providers/openai/chat_stream_test.go @@ -72,7 +72,7 @@ func TestOpenAIClient_ChatStreamRequest(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} @@ -140,7 +140,7 @@ func TestOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} diff --git a/pkg/providers/openai/chat_test.go b/pkg/providers/openai/chat_test.go index 3109f150..209507ad 100644 --- a/pkg/providers/openai/chat_test.go +++ b/pkg/providers/openai/chat_test.go @@ -57,7 +57,7 @@ func TestOpenAIClient_ChatRequest(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} From 5f6c31fd6608e9d57f365905995dd05e3d62b764 Mon Sep 17 00:00:00 2001 From: tom-fitz <16616192+tom-fitz@users.noreply.github.com> Date: Tue, 11 Jun 2024 12:45:59 -0600 Subject: [PATCH 2/9] 265: fix missing assistant role updates --- pkg/providers/bedrock/chat.go | 2 +- pkg/providers/cohere/chat.go | 2 +- pkg/providers/ollama/client_test.go | 2 +- pkg/providers/openai/chat_stream.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/providers/bedrock/chat.go b/pkg/providers/bedrock/chat.go index 658c1769..dda0637e 100644 --- a/pkg/providers/bedrock/chat.go +++ b/pkg/providers/bedrock/chat.go @@ -104,7 +104,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche Cached: false, ModelResponse: schemas.ModelResponse{ Message: schemas.ChatMessage{ - Role: "assistant", + Role: schemas.RoleAssistant, Content: modelResult.OutputText, }, TokenUsage: schemas.TokenUsage{ diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index ddf75680..d5ef55d1 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -127,7 +127,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche "responseId": cohereCompletion.ResponseID, }, Message: schemas.ChatMessage{ - Role: "assistant", + Role: schemas.RoleAssistant, Content: cohereCompletion.Text, }, TokenUsage: schemas.TokenUsage{ diff --git a/pkg/providers/ollama/client_test.go b/pkg/providers/ollama/client_test.go index 091a8804..e7af71d9 100644 --- a/pkg/providers/ollama/client_test.go +++ b/pkg/providers/ollama/client_test.go @@ -130,6 +130,6 @@ func TestOllamaClient_ChatRequest_SuccessfulResponse(t *testing.T) { require.NoError(t, err) require.NotNil(t, response) - require.Equal(t, "assistant", response.ModelResponse.Message.Role) + require.Equal(t, schemas.RoleAssistant, response.ModelResponse.Message.Role) require.Equal(t, "London", response.ModelResponse.Message.Content) } diff --git a/pkg/providers/openai/chat_stream.go b/pkg/providers/openai/chat_stream.go index 08ca2b21..659d8b8d 100644 --- a/pkg/providers/openai/chat_stream.go +++ b/pkg/providers/openai/chat_stream.go @@ -120,7 +120,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { "generated_at": completionChunk.Created, }, Message: schemas.ChatMessage{ - Role: "assistant", // doesn't present in all chunks + Role: schemas.RoleAssistant, // doesn't present in all chunks Content: responseChunk.Delta.Content, }, }, From c6b465088225bd8d0a5d48c20958d0b0cf39979d Mon Sep 17 00:00:00 2001 From: tom-fitz <16616192+tom-fitz@users.noreply.github.com> Date: Tue, 11 Jun 2024 12:51:10 -0600 Subject: [PATCH 3/9] 265: updating swagger docs --- docs/docs.go | 19 ++++++++++++++++++- docs/swagger.json | 19 ++++++++++++++++++- docs/swagger.yaml | 13 ++++++++++++- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/docs/docs.go b/docs/docs.go index 51a45f21..67918956 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -203,7 +203,11 @@ const docTemplate = `{ }, "role": { "description": "The role of the author of this message. One of system, user, or assistant.", - "type": "string" + "allOf": [ + { + "$ref": "#/definitions/schemas.Role" + } + ] } } }, @@ -308,6 +312,19 @@ const docTemplate = `{ } } }, + "schemas.Role": { + "type": "string", + "enum": [ + "system", + "user", + "assistant" + ], + "x-enum-varnames": [ + "RoleSystem", + "RoleUser", + "RoleAssistant" + ] + }, "schemas.RouterListSchema": { "type": "object", "properties": { diff --git a/docs/swagger.json b/docs/swagger.json index aa0d3b25..1912039e 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -200,7 +200,11 @@ }, "role": { "description": "The role of the author of this message. One of system, user, or assistant.", - "type": "string" + "allOf": [ + { + "$ref": "#/definitions/schemas.Role" + } + ] } } }, @@ -305,6 +309,19 @@ } } }, + "schemas.Role": { + "type": "string", + "enum": [ + "system", + "user", + "assistant" + ], + "x-enum-varnames": [ + "RoleSystem", + "RoleUser", + "RoleAssistant" + ] + }, "schemas.RouterListSchema": { "type": "object", "properties": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 82dcc4a6..4a683f2c 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -6,9 +6,10 @@ definitions: description: The content of the message. type: string role: + allOf: + - $ref: '#/definitions/schemas.Role' description: The role of the author of this message. One of system, user, or assistant. - type: string required: - content - role @@ -77,6 +78,16 @@ definitions: token_usage: $ref: '#/definitions/schemas.TokenUsage' type: object + schemas.Role: + enum: + - system + - user + - assistant + type: string + x-enum-varnames: + - RoleSystem + - RoleUser + - RoleAssistant schemas.RouterListSchema: properties: routers: From a491e62149759597915b5e8a124e574789d47655 Mon Sep 17 00:00:00 2001 From: tom-fitz <16616192+tom-fitz@users.noreply.github.com> Date: Thu, 13 Jun 2024 16:05:12 -0600 Subject: [PATCH 4/9] 265: creating mapper for interal role to provider role --- pkg/api/schemas/chat.go | 17 +++++++++++++++++ pkg/providers/cohere/chat.go | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pkg/api/schemas/chat.go b/pkg/api/schemas/chat.go index 2ca0af06..c9f8505e 100644 --- a/pkg/api/schemas/chat.go +++ b/pkg/api/schemas/chat.go @@ -108,3 +108,20 @@ type ChatMessage struct { // The content of the message. Content string `json:"content" validate:"required"` } + +func MapToProviderRole(provider string, role Role) Role { + switch provider { + case "cohere": + switch role { + case RoleAssistant: + return "CHATBOT" + case RoleSystem: + return "SYSTEM" + case RoleUser: + return "USER" + } + case "openai": + return role + } + return role +} diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index d5ef55d1..7e64573d 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -127,7 +127,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche "responseId": cohereCompletion.ResponseID, }, Message: schemas.ChatMessage{ - Role: schemas.RoleAssistant, + Role: schemas.MapToProviderRole(providerName, schemas.RoleAssistant), Content: cohereCompletion.Text, }, TokenUsage: schemas.TokenUsage{ From 3d1bf874925c953f65e4df48e1fd7663c8ef8fb4 Mon Sep 17 00:00:00 2001 From: tom-fitz <16616192+tom-fitz@users.noreply.github.com> Date: Thu, 13 Jun 2024 16:54:25 -0600 Subject: [PATCH 5/9] 265: adding tests for role mapping function --- pkg/api/schemas/chat.go | 6 +++++- pkg/api/schemas/chat_test.go | 41 ++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/pkg/api/schemas/chat.go b/pkg/api/schemas/chat.go index c9f8505e..30fff5e8 100644 --- a/pkg/api/schemas/chat.go +++ b/pkg/api/schemas/chat.go @@ -109,7 +109,9 @@ type ChatMessage struct { Content string `json:"content" validate:"required"` } +// MapToProviderRole maps the internal role to the role the specific provider is expecting func MapToProviderRole(provider string, role Role) Role { + // TODO: possibly return errors here if inputs are empty? switch provider { case "cohere": switch role { @@ -120,8 +122,10 @@ func MapToProviderRole(provider string, role Role) Role { case RoleUser: return "USER" } + case "openai": return role } - return role + + return "" } diff --git a/pkg/api/schemas/chat_test.go b/pkg/api/schemas/chat_test.go index f4cfe9fa..34cf4dfd 100644 --- a/pkg/api/schemas/chat_test.go +++ b/pkg/api/schemas/chat_test.go @@ -168,3 +168,44 @@ func TestChatRequest_ModelNameIDOverride(t *testing.T) { require.Equal(t, []string{backstory, myModelIDMessage}, ToSlice(params.Messages)) } + +func TestMapToProviderRole(t *testing.T) { + tests := []struct { + name string + provider string + role Role + expected Role + }{ + { + name: "should return CHATBOT if provider is Cohere and role is assistant", + provider: "cohere", + role: RoleAssistant, + expected: "CHATBOT", + }, + { + name: "should return SYSTEM if provider is Cohere and role is system", + provider: "cohere", + role: RoleSystem, + expected: "SYSTEM", + }, + { + name: "should return USER if provider is Cohere and role is user", + provider: "cohere", + role: RoleUser, + expected: "USER", + }, + { + name: "should return the role parameter if provider is openai", + provider: "openai", + role: RoleUser, + expected: RoleUser, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mappedRole := MapToProviderRole(tt.provider, tt.role) + require.Equal(t, tt.expected, mappedRole) + }) + } +} From 0b2c8c6620a0f7b844f8d72a789ffb17182c47f4 Mon Sep 17 00:00:00 2001 From: tom-fitz <16616192+tom-fitz@users.noreply.github.com> Date: Fri, 14 Jun 2024 14:42:48 -0600 Subject: [PATCH 6/9] 265: handling role mapping for cohere provider. adding in tests --- pkg/api/schemas/chat_test.go | 41 -------------------- pkg/providers/cohere/schemas.go | 21 ++++++++++- pkg/providers/cohere/schemas_test.go | 56 ++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 43 deletions(-) create mode 100644 pkg/providers/cohere/schemas_test.go diff --git a/pkg/api/schemas/chat_test.go b/pkg/api/schemas/chat_test.go index 34cf4dfd..f4cfe9fa 100644 --- a/pkg/api/schemas/chat_test.go +++ b/pkg/api/schemas/chat_test.go @@ -168,44 +168,3 @@ func TestChatRequest_ModelNameIDOverride(t *testing.T) { require.Equal(t, []string{backstory, myModelIDMessage}, ToSlice(params.Messages)) } - -func TestMapToProviderRole(t *testing.T) { - tests := []struct { - name string - provider string - role Role - expected Role - }{ - { - name: "should return CHATBOT if provider is Cohere and role is assistant", - provider: "cohere", - role: RoleAssistant, - expected: "CHATBOT", - }, - { - name: "should return SYSTEM if provider is Cohere and role is system", - provider: "cohere", - role: RoleSystem, - expected: "SYSTEM", - }, - { - name: "should return USER if provider is Cohere and role is user", - provider: "cohere", - role: RoleUser, - expected: "USER", - }, - { - name: "should return the role parameter if provider is openai", - provider: "openai", - role: RoleUser, - expected: RoleUser, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mappedRole := MapToProviderRole(tt.provider, tt.role) - require.Equal(t, tt.expected, mappedRole) - }) - } -} diff --git a/pkg/providers/cohere/schemas.go b/pkg/providers/cohere/schemas.go index 9dc9bb09..a048a30f 100644 --- a/pkg/providers/cohere/schemas.go +++ b/pkg/providers/cohere/schemas.go @@ -2,7 +2,7 @@ package cohere import "github.com/EinStack/glide/pkg/api/schemas" -// Cohere Chat Response +// ChatCompletion Cohere Chat Response type ChatCompletion struct { Text string `json:"text"` GenerationID string `json:"generation_id"` @@ -112,7 +112,24 @@ func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { message := params.Messages[len(params.Messages)-1] messageHistory := params.Messages[:len(params.Messages)-1] - // TODO: Map chat message roles to Cohere roles: CHATBOT, SYSTEM, USER + mapRole := func(role schemas.Role) string { + switch role { + case schemas.RoleSystem: + return "SYSTEM" + case schemas.RoleUser: + return "USER" + case schemas.RoleAssistant: + return "CHATBOT" + default: + return "USER" + } + } + + for i := range messageHistory { + messageHistory[i].Role = schemas.Role(mapRole(messageHistory[i].Role)) + } + + message.Role = schemas.Role(mapRole(message.Role)) r.Message = message.Content r.ChatHistory = messageHistory diff --git a/pkg/providers/cohere/schemas_test.go b/pkg/providers/cohere/schemas_test.go new file mode 100644 index 00000000..5396a290 --- /dev/null +++ b/pkg/providers/cohere/schemas_test.go @@ -0,0 +1,56 @@ +package cohere + +import ( + "github.com/EinStack/glide/pkg/api/schemas" + "github.com/stretchr/testify/require" + "testing" +) + +func TestChatRequest_ApplyParams(t *testing.T) { + tests := []struct { + name string + chatReq ChatRequest + params *schemas.ChatParams + expected ChatRequest + }{ + { + name: "should set role to default USER when role is empty string", + chatReq: ChatRequest{}, + params: &schemas.ChatParams{ + Messages: []schemas.ChatMessage{ + {Role: "", Content: "Hello"}, + {Role: schemas.RoleAssistant, Content: "Hi there!"}, + }, + }, + expected: ChatRequest{ + Message: "Hi there!", + ChatHistory: []schemas.ChatMessage{ + {Role: "USER", Content: "Hello"}, + }, + }, + }, + { + name: "should set role to default USER when role is RoleUser", + chatReq: ChatRequest{}, + params: &schemas.ChatParams{ + Messages: []schemas.ChatMessage{ + {Role: schemas.RoleUser, Content: "Hello"}, + {Role: schemas.RoleAssistant, Content: "Hi there!"}, + }, + }, + expected: ChatRequest{ + Message: "Hi there!", + ChatHistory: []schemas.ChatMessage{ + {Role: "USER", Content: "Hello"}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.chatReq.ApplyParams(tt.params) + require.Equal(t, tt.expected, tt.chatReq) + }) + } +} From ed57b3e0468410920ffe819c28a8ac99c76d47be Mon Sep 17 00:00:00 2001 From: tom-fitz <16616192+tom-fitz@users.noreply.github.com> Date: Fri, 14 Jun 2024 14:48:06 -0600 Subject: [PATCH 7/9] 256: fixing tests --- pkg/providers/cohere/schemas_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/providers/cohere/schemas_test.go b/pkg/providers/cohere/schemas_test.go index 5396a290..a71042b8 100644 --- a/pkg/providers/cohere/schemas_test.go +++ b/pkg/providers/cohere/schemas_test.go @@ -1,9 +1,10 @@ package cohere import ( + "testing" + "github.com/EinStack/glide/pkg/api/schemas" "github.com/stretchr/testify/require" - "testing" ) func TestChatRequest_ApplyParams(t *testing.T) { From 516457bf935c59d8a9e68e6352e0888c9e5f45d6 Mon Sep 17 00:00:00 2001 From: tom-fitz <16616192+tom-fitz@users.noreply.github.com> Date: Sat, 15 Jun 2024 14:24:44 -0600 Subject: [PATCH 8/9] 265: removing unused role mapping func. handling role set from payload in cohere --- pkg/api/schemas/chat.go | 21 --------------------- pkg/providers/cohere/chat.go | 2 +- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/pkg/api/schemas/chat.go b/pkg/api/schemas/chat.go index 30fff5e8..2ca0af06 100644 --- a/pkg/api/schemas/chat.go +++ b/pkg/api/schemas/chat.go @@ -108,24 +108,3 @@ type ChatMessage struct { // The content of the message. Content string `json:"content" validate:"required"` } - -// MapToProviderRole maps the internal role to the role the specific provider is expecting -func MapToProviderRole(provider string, role Role) Role { - // TODO: possibly return errors here if inputs are empty? - switch provider { - case "cohere": - switch role { - case RoleAssistant: - return "CHATBOT" - case RoleSystem: - return "SYSTEM" - case RoleUser: - return "USER" - } - - case "openai": - return role - } - - return "" -} diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index 7e64573d..3011a4b4 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -127,7 +127,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche "responseId": cohereCompletion.ResponseID, }, Message: schemas.ChatMessage{ - Role: schemas.MapToProviderRole(providerName, schemas.RoleAssistant), + Role: payload.ChatHistory[len(payload.ChatHistory)-1].Role, Content: cohereCompletion.Text, }, TokenUsage: schemas.TokenUsage{ From 201ef76a252b6902e948c7126de6000b10feba24 Mon Sep 17 00:00:00 2001 From: tom-fitz <16616192+tom-fitz@users.noreply.github.com> Date: Mon, 17 Jun 2024 11:51:21 -0600 Subject: [PATCH 9/9] 265: pushing wip. Need clarification on role response logic --- pkg/providers/cohere/chat.go | 2 +- pkg/providers/cohere/schemas.go | 2 + pkg/providers/cohere/schemas_test.go | 57 ---------------------------- 3 files changed, 3 insertions(+), 58 deletions(-) delete mode 100644 pkg/providers/cohere/schemas_test.go diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index 3011a4b4..969d7da9 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -127,7 +127,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche "responseId": cohereCompletion.ResponseID, }, Message: schemas.ChatMessage{ - Role: payload.ChatHistory[len(payload.ChatHistory)-1].Role, + Role: payload.Role, Content: cohereCompletion.Text, }, TokenUsage: schemas.TokenUsage{ diff --git a/pkg/providers/cohere/schemas.go b/pkg/providers/cohere/schemas.go index a048a30f..2c31ce13 100644 --- a/pkg/providers/cohere/schemas.go +++ b/pkg/providers/cohere/schemas.go @@ -92,6 +92,7 @@ type FinalResponse struct { type ChatRequest struct { Model string `json:"model"` Message string `json:"message"` + Role schemas.Role `json:"role"` ChatHistory []schemas.ChatMessage `json:"chat_history"` Temperature float64 `json:"temperature,omitempty"` Preamble string `json:"preamble,omitempty"` @@ -131,6 +132,7 @@ func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { message.Role = schemas.Role(mapRole(message.Role)) + r.Role = message.Role r.Message = message.Content r.ChatHistory = messageHistory } diff --git a/pkg/providers/cohere/schemas_test.go b/pkg/providers/cohere/schemas_test.go deleted file mode 100644 index a71042b8..00000000 --- a/pkg/providers/cohere/schemas_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package cohere - -import ( - "testing" - - "github.com/EinStack/glide/pkg/api/schemas" - "github.com/stretchr/testify/require" -) - -func TestChatRequest_ApplyParams(t *testing.T) { - tests := []struct { - name string - chatReq ChatRequest - params *schemas.ChatParams - expected ChatRequest - }{ - { - name: "should set role to default USER when role is empty string", - chatReq: ChatRequest{}, - params: &schemas.ChatParams{ - Messages: []schemas.ChatMessage{ - {Role: "", Content: "Hello"}, - {Role: schemas.RoleAssistant, Content: "Hi there!"}, - }, - }, - expected: ChatRequest{ - Message: "Hi there!", - ChatHistory: []schemas.ChatMessage{ - {Role: "USER", Content: "Hello"}, - }, - }, - }, - { - name: "should set role to default USER when role is RoleUser", - chatReq: ChatRequest{}, - params: &schemas.ChatParams{ - Messages: []schemas.ChatMessage{ - {Role: schemas.RoleUser, Content: "Hello"}, - {Role: schemas.RoleAssistant, Content: "Hi there!"}, - }, - }, - expected: ChatRequest{ - Message: "Hi there!", - ChatHistory: []schemas.ChatMessage{ - {Role: "USER", Content: "Hello"}, - }, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tt.chatReq.ApplyParams(tt.params) - require.Equal(t, tt.expected, tt.chatReq) - }) - } -}