From 8819af87299f72e6c5d33593e775499a2ae01da6 Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Tue, 18 Jun 2024 22:52:20 +0300 Subject: [PATCH 01/18] #67: Inited a new embedding router --- pkg/routers/config.go | 178 +----------------- pkg/routers/embed/config.go | 15 ++ pkg/routers/embed/router.go | 22 +++ pkg/routers/{router.go => lang.go} | 0 pkg/routers/lang/config.go | 179 +++++++++++++++++++ pkg/routers/{ => lang}/config_test.go | 17 +- pkg/routers/{router_test.go => lang_test.go} | 0 7 files changed, 227 insertions(+), 184 deletions(-) create mode 100644 pkg/routers/embed/config.go create mode 100644 pkg/routers/embed/router.go rename pkg/routers/{router.go => lang.go} (100%) create mode 100644 pkg/routers/lang/config.go rename pkg/routers/{ => lang}/config_test.go (97%) rename pkg/routers/{router_test.go => lang_test.go} (100%) diff --git a/pkg/routers/config.go b/pkg/routers/config.go index c7651f6d..3ae65a3a 100644 --- a/pkg/routers/config.go +++ b/pkg/routers/config.go @@ -2,21 +2,16 @@ package routers import ( "fmt" - "time" - + "github.com/EinStack/glide/pkg/routers/lang" "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/routers/routing" - - "github.com/EinStack/glide/pkg/routers/retry" - - "github.com/EinStack/glide/pkg/providers" "go.uber.org/multierr" "go.uber.org/zap" ) type Config struct { - LanguageRouters []LangRouterConfig `yaml:"language" validate:"required,gte=1,dive"` // the list of language routers + LanguageRouters []lang.LangRouterConfig `yaml:"language" validate:"required,dive"` // the list of language routers + EmbeddingRouters []EmbeddingRouterConfig `yaml:"embedding" validate:"required,dive"` } func (c *Config) BuildLangRouters(tel *telemetry.Telemetry) ([]*LangRouter, error) { @@ -54,170 +49,3 @@ func (c *Config) BuildLangRouters(tel *telemetry.Telemetry) ([]*LangRouter, erro return routers, nil } - -// TODO: how to specify other backoff strategies? -// TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738 -// LangRouterConfig -type LangRouterConfig struct { - ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID - Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled? - Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router - RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request - Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests -} - -// BuildModels creates LanguageModel slice out of the given config -func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*providers.LanguageModel, []*providers.LanguageModel, error) { //nolint: cyclop - var errs error - - seenIDs := make(map[string]bool, len(c.Models)) - chatModels := make([]*providers.LanguageModel, 0, len(c.Models)) - chatStreamModels := make([]*providers.LanguageModel, 0, len(c.Models)) - - for _, modelConfig := range c.Models { - if _, ok := seenIDs[modelConfig.ID]; ok { - return nil, nil, fmt.Errorf( - "ID \"%v\" is specified for more than one model in router \"%v\", while it should be unique in scope of that pool", - modelConfig.ID, - c.ID, - ) - } - - seenIDs[modelConfig.ID] = true - - if !modelConfig.Enabled { - tel.L().Info( - "ModelName is disabled, skipping", - zap.String("router", c.ID), - zap.String("model", modelConfig.ID), - ) - - continue - } - - tel.L().Debug( - "Init lang model", - zap.String("router", c.ID), - zap.String("model", modelConfig.ID), - ) - - model, err := modelConfig.ToModel(tel) - if err != nil { - errs = multierr.Append(errs, err) - continue - } - - chatModels = append(chatModels, model) - - if !model.SupportChatStream() { - tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn( - "Provider doesn't support or have not been yet integrated with streaming chat, it won't serve streaming chat requests", - zap.String("routerID", c.ID), - zap.String("modelID", model.ID()), - zap.String("provider", model.Provider()), - ) - - continue - } - - chatStreamModels = append(chatStreamModels, model) - } - - if errs != nil { - return nil, nil, errs - } - - if len(chatModels) == 0 { - return nil, nil, fmt.Errorf("router \"%v\" must have at least one active model, zero defined", c.ID) - } - - if len(chatModels) == 1 { - tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn( - fmt.Sprintf("Router \"%v\" has only one active model defined. "+ - "This is not recommended for production setups. "+ - "Define at least a few models to leverage resiliency logic Glide provides", - c.ID, - ), - ) - } - - if len(chatStreamModels) == 1 { - tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn( - fmt.Sprintf("Router \"%v\" has only one active model defined with streaming chat support. "+ - "This is not recommended for production setups. "+ - "Define at least a few models to leverage resiliency logic Glide provides", - c.ID, - ), - ) - } - - if len(chatStreamModels) == 0 { - tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn( - fmt.Sprintf("Router \"%v\" has only no model with streaming chat support. "+ - "The streaming chat workflow won't work until you define any", - c.ID, - ), - ) - } - - return chatModels, chatStreamModels, nil -} - -func (c *LangRouterConfig) BuildRetry() *retry.ExpRetry { - retryConfig := c.Retry - maxDelay := time.Duration(*retryConfig.MaxDelay) - - return retry.NewExpRetry( - retryConfig.MaxRetries, - retryConfig.BaseMultiplier, - time.Duration(retryConfig.MinDelay), - &maxDelay, - ) -} - -func (c *LangRouterConfig) BuildRouting( - chatModels []*providers.LanguageModel, - chatStreamModels []*providers.LanguageModel, -) (routing.LangModelRouting, routing.LangModelRouting, error) { - chatModelPool := make([]providers.Model, 0, len(chatModels)) - chatStreamModelPool := make([]providers.Model, 0, len(chatStreamModels)) - - for _, model := range chatModels { - chatModelPool = append(chatModelPool, model) - } - - for _, model := range chatStreamModels { - chatStreamModelPool = append(chatStreamModelPool, model) - } - - switch c.RoutingStrategy { - case routing.Priority: - return routing.NewPriority(chatModelPool), routing.NewPriority(chatStreamModelPool), nil - case routing.RoundRobin: - return routing.NewRoundRobinRouting(chatModelPool), routing.NewRoundRobinRouting(chatStreamModelPool), nil - case routing.WeightedRoundRobin: - return routing.NewWeightedRoundRobin(chatModelPool), routing.NewWeightedRoundRobin(chatStreamModelPool), nil - case routing.LeastLatency: - return routing.NewLeastLatencyRouting(providers.ChatLatency, chatModelPool), - routing.NewLeastLatencyRouting(providers.ChatStreamLatency, chatStreamModelPool), - nil - } - - return nil, nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy) -} - -func DefaultLangRouterConfig() LangRouterConfig { - return LangRouterConfig{ - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - } -} - -func (c *LangRouterConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - *c = DefaultLangRouterConfig() - - type plain LangRouterConfig // to avoid recursion - - return unmarshal((*plain)(c)) -} diff --git a/pkg/routers/embed/config.go b/pkg/routers/embed/config.go new file mode 100644 index 00000000..93346b43 --- /dev/null +++ b/pkg/routers/embed/config.go @@ -0,0 +1,15 @@ +package embed + +import ( + "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/routers/retry" + "github.com/EinStack/glide/pkg/routers/routing" +) + +type EmbeddingRouterConfig struct { + ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID + Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled? + Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router + RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request + Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests +} diff --git a/pkg/routers/embed/router.go b/pkg/routers/embed/router.go new file mode 100644 index 00000000..dd3542c3 --- /dev/null +++ b/pkg/routers/embed/router.go @@ -0,0 +1,22 @@ +package embed + +import ( + "context" + "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/routers" + "github.com/EinStack/glide/pkg/routers/retry" + "github.com/EinStack/glide/pkg/telemetry" + "go.uber.org/zap" +) + +type EmbeddingRouter struct { + routerID routers.RouterID + Config *LangRouterConfig + retry *retry.ExpRetry + tel *telemetry.Telemetry + logger *zap.Logger +} + +func (r *routers.LangRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { + +} diff --git a/pkg/routers/router.go b/pkg/routers/lang.go similarity index 100% rename from pkg/routers/router.go rename to pkg/routers/lang.go diff --git a/pkg/routers/lang/config.go b/pkg/routers/lang/config.go new file mode 100644 index 00000000..f35d1109 --- /dev/null +++ b/pkg/routers/lang/config.go @@ -0,0 +1,179 @@ +package lang + +import ( + "fmt" + "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/routers/retry" + "github.com/EinStack/glide/pkg/routers/routing" + "github.com/EinStack/glide/pkg/telemetry" + "go.uber.org/multierr" + "go.uber.org/zap" + "time" +) + +// TODO: how to specify other backoff strategies? +// TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738 +// LangRouterConfig +type LangRouterConfig struct { + ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID + Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled? + Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router + RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request + Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests +} + +// BuildModels creates LanguageModel slice out of the given config +func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*providers.LanguageModel, []*providers.LanguageModel, error) { //nolint: cyclop + var errs error + + seenIDs := make(map[string]bool, len(c.Models)) + chatModels := make([]*providers.LanguageModel, 0, len(c.Models)) + chatStreamModels := make([]*providers.LanguageModel, 0, len(c.Models)) + + for _, modelConfig := range c.Models { + if _, ok := seenIDs[modelConfig.ID]; ok { + return nil, nil, fmt.Errorf( + "ID \"%v\" is specified for more than one model in router \"%v\", while it should be unique in scope of that pool", + modelConfig.ID, + c.ID, + ) + } + + seenIDs[modelConfig.ID] = true + + if !modelConfig.Enabled { + tel.L().Info( + "ModelName is disabled, skipping", + zap.String("router", c.ID), + zap.String("model", modelConfig.ID), + ) + + continue + } + + tel.L().Debug( + "Init lang model", + zap.String("router", c.ID), + zap.String("model", modelConfig.ID), + ) + + model, err := modelConfig.ToModel(tel) + if err != nil { + errs = multierr.Append(errs, err) + continue + } + + chatModels = append(chatModels, model) + + if !model.SupportChatStream() { + tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn( + "Provider doesn't support or have not been yet integrated with streaming chat, it won't serve streaming chat requests", + zap.String("routerID", c.ID), + zap.String("modelID", model.ID()), + zap.String("provider", model.Provider()), + ) + + continue + } + + chatStreamModels = append(chatStreamModels, model) + } + + if errs != nil { + return nil, nil, errs + } + + if len(chatModels) == 0 { + return nil, nil, fmt.Errorf("router \"%v\" must have at least one active model, zero defined", c.ID) + } + + if len(chatModels) == 1 { + tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn( + fmt.Sprintf("Router \"%v\" has only one active model defined. "+ + "This is not recommended for production setups. "+ + "Define at least a few models to leverage resiliency logic Glide provides", + c.ID, + ), + ) + } + + if len(chatStreamModels) == 1 { + tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn( + fmt.Sprintf("Router \"%v\" has only one active model defined with streaming chat support. "+ + "This is not recommended for production setups. "+ + "Define at least a few models to leverage resiliency logic Glide provides", + c.ID, + ), + ) + } + + if len(chatStreamModels) == 0 { + tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn( + fmt.Sprintf("Router \"%v\" has only no model with streaming chat support. "+ + "The streaming chat workflow won't work until you define any", + c.ID, + ), + ) + } + + return chatModels, chatStreamModels, nil +} + +func (c *LangRouterConfig) BuildRetry() *retry.ExpRetry { + retryConfig := c.Retry + maxDelay := time.Duration(*retryConfig.MaxDelay) + + return retry.NewExpRetry( + retryConfig.MaxRetries, + retryConfig.BaseMultiplier, + time.Duration(retryConfig.MinDelay), + &maxDelay, + ) +} + +func (c *LangRouterConfig) BuildRouting( + chatModels []*providers.LanguageModel, + chatStreamModels []*providers.LanguageModel, +) (routing.LangModelRouting, routing.LangModelRouting, error) { + chatModelPool := make([]providers.Model, 0, len(chatModels)) + chatStreamModelPool := make([]providers.Model, 0, len(chatStreamModels)) + + for _, model := range chatModels { + chatModelPool = append(chatModelPool, model) + } + + for _, model := range chatStreamModels { + chatStreamModelPool = append(chatStreamModelPool, model) + } + + switch c.RoutingStrategy { + case routing.Priority: + return routing.NewPriority(chatModelPool), routing.NewPriority(chatStreamModelPool), nil + case routing.RoundRobin: + return routing.NewRoundRobinRouting(chatModelPool), routing.NewRoundRobinRouting(chatStreamModelPool), nil + case routing.WeightedRoundRobin: + return routing.NewWeightedRoundRobin(chatModelPool), routing.NewWeightedRoundRobin(chatStreamModelPool), nil + case routing.LeastLatency: + return routing.NewLeastLatencyRouting(providers.ChatLatency, chatModelPool), + routing.NewLeastLatencyRouting(providers.ChatStreamLatency, chatStreamModelPool), + nil + } + + return nil, nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy) +} + +func DefaultLangRouterConfig() LangRouterConfig { + return LangRouterConfig{ + Enabled: true, + RoutingStrategy: routing.Priority, + Retry: retry.DefaultExpRetryConfig(), + } +} + +func (c *LangRouterConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + *c = DefaultLangRouterConfig() + + type plain LangRouterConfig // to avoid recursion + + return unmarshal((*plain)(c)) +} diff --git a/pkg/routers/config_test.go b/pkg/routers/lang/config_test.go similarity index 97% rename from pkg/routers/config_test.go rename to pkg/routers/lang/config_test.go index d740df2c..abbb5bcd 100644 --- a/pkg/routers/config_test.go +++ b/pkg/routers/lang/config_test.go @@ -1,11 +1,10 @@ -package routers +package lang import ( - "testing" - "github.com/EinStack/glide/pkg/providers/cohere" - + routers2 "github.com/EinStack/glide/pkg/routers" "github.com/EinStack/glide/pkg/telemetry" + "testing" "github.com/EinStack/glide/pkg/routers/routing" @@ -27,7 +26,7 @@ import ( func TestRouterConfig_BuildModels(t *testing.T) { defaultParams := openai.DefaultParams() - cfg := Config{ + cfg := routers2.Config{ LanguageRouters: []LangRouterConfig{ { ID: "first_router", @@ -128,11 +127,11 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { tests := []struct { name string - config Config + config routers2.Config }{ { "duplicated router IDs", - Config{ + routers2.Config{ LanguageRouters: []LangRouterConfig{ { ID: "first_router", @@ -177,7 +176,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { }, { "duplicated model IDs", - Config{ + routers2.Config{ LanguageRouters: []LangRouterConfig{ { ID: "first_router", @@ -214,7 +213,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { }, { "no models", - Config{ + routers2.Config{ LanguageRouters: []LangRouterConfig{ { ID: "first_router", diff --git a/pkg/routers/router_test.go b/pkg/routers/lang_test.go similarity index 100% rename from pkg/routers/router_test.go rename to pkg/routers/lang_test.go From 27347ecf29d513e1977b9c30cbdf57a6559b9a0d Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 24 Jun 2024 12:44:32 +0300 Subject: [PATCH 02/18] #67: Moved resiliency and client packages on higher level out of providers & routers --- pkg/{providers => }/clients/config.go | 0 pkg/{providers => }/clients/config_test.go | 0 pkg/{providers => }/clients/errors.go | 0 pkg/{providers => }/clients/errors_test.go | 0 pkg/{providers => }/clients/sse.go | 0 pkg/{providers => }/clients/sse_test.go | 0 pkg/{providers => }/clients/stream.go | 0 pkg/providers/anthropic/chat.go | 3 +-- pkg/providers/anthropic/chat_stream.go | 7 +++---- pkg/providers/anthropic/client.go | 3 +-- pkg/providers/anthropic/client_test.go | 3 +-- pkg/providers/anthropic/errors.go | 2 +- pkg/providers/azureopenai/chat.go | 3 +-- pkg/providers/azureopenai/chat_stream.go | 8 ++++---- pkg/providers/azureopenai/chat_stream_test.go | 11 +++++------ pkg/providers/azureopenai/client.go | 3 +-- pkg/providers/azureopenai/client_test.go | 3 +-- pkg/providers/azureopenai/errors.go | 2 +- pkg/providers/bedrock/chat_stream.go | 7 +++---- pkg/providers/bedrock/client.go | 3 +-- pkg/providers/bedrock/client_test.go | 3 +-- pkg/providers/cohere/chat.go | 3 +-- pkg/providers/cohere/chat_stream.go | 7 +++---- pkg/providers/cohere/chat_stream_test.go | 3 +-- pkg/providers/cohere/client.go | 3 +-- pkg/providers/cohere/client_test.go | 3 +-- pkg/providers/cohere/errors.go | 2 +- pkg/providers/config.go | 6 ++---- pkg/providers/lang.go | 12 +++++------- pkg/providers/octoml/chat_stream.go | 7 +++---- pkg/providers/octoml/client.go | 3 +-- pkg/providers/octoml/client_test.go | 3 +-- pkg/providers/octoml/errors.go | 2 +- pkg/providers/ollama/chat.go | 3 +-- pkg/providers/ollama/chat_stream.go | 7 +++---- pkg/providers/ollama/client.go | 3 +-- pkg/providers/ollama/client_test.go | 3 +-- pkg/providers/openai/chat.go | 3 +-- pkg/providers/openai/chat_stream.go | 8 ++++---- pkg/providers/openai/chat_stream_test.go | 11 +++++------ pkg/providers/openai/chat_test.go | 9 ++++----- pkg/providers/openai/client.go | 3 +-- pkg/providers/openai/errors.go | 2 +- pkg/providers/testing/lang.go | 9 ++++----- pkg/{routers => resiliency}/health/buckets.go | 0 pkg/{routers => resiliency}/health/buckets_test.go | 0 pkg/{routers => resiliency}/health/error_budget.go | 0 .../health/error_budget_test.go | 0 pkg/{routers => resiliency}/health/ratelimit.go | 0 pkg/{routers => resiliency}/health/ratelimit_test.go | 0 pkg/{routers => resiliency}/health/tracker.go | 3 +-- pkg/{routers => resiliency}/health/tracker_test.go | 2 +- pkg/{routers => resiliency}/retry/config.go | 0 pkg/{routers => resiliency}/retry/config_test.go | 0 pkg/{routers => resiliency}/retry/exp.go | 0 pkg/{routers => resiliency}/retry/exp_test.go | 0 pkg/routers/config.go | 6 +++--- pkg/routers/embed/config.go | 2 +- pkg/routers/embed/router.go | 8 ++++---- pkg/routers/lang/config.go | 10 +++++----- pkg/routers/lang/config_test.go | 9 +++------ pkg/routers/{lang.go => lang/router.go} | 8 ++++++-- pkg/routers/{lang_test.go => lang/router_test.go} | 8 ++++---- pkg/routers/manager.go | 11 ++++++----- 64 files changed, 106 insertions(+), 137 deletions(-) rename pkg/{providers => }/clients/config.go (100%) rename pkg/{providers => }/clients/config_test.go (100%) rename pkg/{providers => }/clients/errors.go (100%) rename pkg/{providers => }/clients/errors_test.go (100%) rename pkg/{providers => }/clients/sse.go (100%) rename pkg/{providers => }/clients/sse_test.go (100%) rename pkg/{providers => }/clients/stream.go (100%) rename pkg/{routers => resiliency}/health/buckets.go (100%) rename pkg/{routers => resiliency}/health/buckets_test.go (100%) rename pkg/{routers => resiliency}/health/error_budget.go (100%) rename pkg/{routers => resiliency}/health/error_budget_test.go (100%) rename pkg/{routers => resiliency}/health/ratelimit.go (100%) rename pkg/{routers => resiliency}/health/ratelimit_test.go (100%) rename pkg/{routers => resiliency}/health/tracker.go (94%) rename pkg/{routers => resiliency}/health/tracker_test.go (93%) rename pkg/{routers => resiliency}/retry/config.go (100%) rename pkg/{routers => resiliency}/retry/config_test.go (100%) rename pkg/{routers => resiliency}/retry/exp.go (100%) rename pkg/{routers => resiliency}/retry/exp_test.go (100%) rename pkg/routers/{lang.go => lang/router.go} (97%) rename pkg/routers/{lang_test.go => lang/router_test.go} (98%) diff --git a/pkg/providers/clients/config.go b/pkg/clients/config.go similarity index 100% rename from pkg/providers/clients/config.go rename to pkg/clients/config.go diff --git a/pkg/providers/clients/config_test.go b/pkg/clients/config_test.go similarity index 100% rename from pkg/providers/clients/config_test.go rename to pkg/clients/config_test.go diff --git a/pkg/providers/clients/errors.go b/pkg/clients/errors.go similarity index 100% rename from pkg/providers/clients/errors.go rename to pkg/clients/errors.go diff --git a/pkg/providers/clients/errors_test.go b/pkg/clients/errors_test.go similarity index 100% rename from pkg/providers/clients/errors_test.go rename to pkg/clients/errors_test.go diff --git a/pkg/providers/clients/sse.go b/pkg/clients/sse.go similarity index 100% rename from pkg/providers/clients/sse.go rename to pkg/clients/sse.go diff --git a/pkg/providers/clients/sse_test.go b/pkg/clients/sse_test.go similarity index 100% rename from pkg/providers/clients/sse_test.go rename to pkg/clients/sse_test.go diff --git a/pkg/providers/clients/stream.go b/pkg/clients/stream.go similarity index 100% rename from pkg/providers/clients/stream.go rename to pkg/clients/stream.go diff --git a/pkg/providers/anthropic/chat.go b/pkg/providers/anthropic/chat.go index 80b45f2b..03f7591f 100644 --- a/pkg/providers/anthropic/chat.go +++ b/pkg/providers/anthropic/chat.go @@ -5,12 +5,11 @@ import ( "context" "encoding/json" "fmt" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) diff --git a/pkg/providers/anthropic/chat_stream.go b/pkg/providers/anthropic/chat_stream.go index 5a6f2112..6040d2c1 100644 --- a/pkg/providers/anthropic/chat_stream.go +++ b/pkg/providers/anthropic/chat_stream.go @@ -2,8 +2,7 @@ package anthropic import ( "context" - - "github.com/EinStack/glide/pkg/providers/clients" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/api/schemas" ) @@ -12,6 +11,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { - return nil, clients.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { + return nil, clients2.ErrChatStreamNotImplemented } diff --git a/pkg/providers/anthropic/client.go b/pkg/providers/anthropic/client.go index bb34fe07..11da9173 100644 --- a/pkg/providers/anthropic/client.go +++ b/pkg/providers/anthropic/client.go @@ -1,13 +1,12 @@ package anthropic import ( + "github.com/EinStack/glide/pkg/clients" "net/http" "net/url" "time" "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/providers/clients" ) const ( diff --git a/pkg/providers/anthropic/client_test.go b/pkg/providers/anthropic/client_test.go index b0c11f36..75be00f1 100644 --- a/pkg/providers/anthropic/client_test.go +++ b/pkg/providers/anthropic/client_test.go @@ -3,6 +3,7 @@ package anthropic import ( "context" "encoding/json" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -10,8 +11,6 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/anthropic/errors.go b/pkg/providers/anthropic/errors.go index 126de68d..222b1921 100644 --- a/pkg/providers/anthropic/errors.go +++ b/pkg/providers/anthropic/errors.go @@ -2,13 +2,13 @@ package anthropic import ( "fmt" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" "go.uber.org/zap" ) diff --git a/pkg/providers/azureopenai/chat.go b/pkg/providers/azureopenai/chat.go index 22005fa3..c3c73656 100644 --- a/pkg/providers/azureopenai/chat.go +++ b/pkg/providers/azureopenai/chat.go @@ -5,11 +5,10 @@ import ( "context" "encoding/json" "fmt" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/providers/openai" "github.com/EinStack/glide/pkg/api/schemas" diff --git a/pkg/providers/azureopenai/chat_stream.go b/pkg/providers/azureopenai/chat_stream.go index 8e73a556..bf9bd215 100644 --- a/pkg/providers/azureopenai/chat_stream.go +++ b/pkg/providers/azureopenai/chat_stream.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + clients2 "github.com/EinStack/glide/pkg/clients" "io" "net/http" @@ -12,7 +13,6 @@ import ( "github.com/EinStack/glide/pkg/providers/openai" - "github.com/EinStack/glide/pkg/providers/clients" "github.com/r3labs/sse/v2" "go.uber.org/zap" @@ -82,7 +82,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // if err is io.EOF, this still means that the stream is interrupted unexpectedly // because the normal stream termination is done via finding out streamDoneMarker - return nil, clients.ErrProviderUnavailable + return nil, clients2.ErrProviderUnavailable } s.tel.L().Debug( @@ -91,7 +91,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { zap.ByteString("rawChunk", rawEvent), ) - event, err := clients.ParseSSEvent(rawEvent) + event, err := clients2.ParseSSEvent(rawEvent) if bytes.Equal(event.Data, openai.StreamDoneMarker) { s.tel.L().Info( @@ -155,7 +155,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients2.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { diff --git a/pkg/providers/azureopenai/chat_stream_test.go b/pkg/providers/azureopenai/chat_stream_test.go index 5aade1f5..49d792d8 100644 --- a/pkg/providers/azureopenai/chat_stream_test.go +++ b/pkg/providers/azureopenai/chat_stream_test.go @@ -3,6 +3,7 @@ package azureopenai import ( "context" "encoding/json" + clients2 "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -14,14 +15,12 @@ import ( "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/stretchr/testify/require" ) func TestAzureOpenAIClient_ChatStreamSupported(t *testing.T) { providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) @@ -64,7 +63,7 @@ func TestAzureOpenAIClient_ChatStreamRequest(t *testing.T) { ctx := context.Background() providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() providerCfg.BaseURL = AzureopenAIServer.URL @@ -132,7 +131,7 @@ func TestAzureOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { ctx := context.Background() providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() providerCfg.BaseURL = openAIServer.URL @@ -153,7 +152,7 @@ func TestAzureOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { for { chunk, err := stream.Recv() if err != nil { - require.ErrorIs(t, err, clients.ErrProviderUnavailable) + require.ErrorIs(t, err, clients2.ErrProviderUnavailable) return } diff --git a/pkg/providers/azureopenai/client.go b/pkg/providers/azureopenai/client.go index 0f594805..c1399307 100644 --- a/pkg/providers/azureopenai/client.go +++ b/pkg/providers/azureopenai/client.go @@ -2,14 +2,13 @@ package azureopenai import ( "fmt" + "github.com/EinStack/glide/pkg/clients" "net/http" "time" "github.com/EinStack/glide/pkg/providers/openai" "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/providers/clients" ) const ( diff --git a/pkg/providers/azureopenai/client_test.go b/pkg/providers/azureopenai/client_test.go index 1700bca0..b92c9142 100644 --- a/pkg/providers/azureopenai/client_test.go +++ b/pkg/providers/azureopenai/client_test.go @@ -3,6 +3,7 @@ package azureopenai import ( "context" "encoding/json" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -10,8 +11,6 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/azureopenai/errors.go b/pkg/providers/azureopenai/errors.go index 6a30e989..b7bb4e14 100644 --- a/pkg/providers/azureopenai/errors.go +++ b/pkg/providers/azureopenai/errors.go @@ -2,13 +2,13 @@ package azureopenai import ( "fmt" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" "go.uber.org/zap" ) diff --git a/pkg/providers/bedrock/chat_stream.go b/pkg/providers/bedrock/chat_stream.go index bb07da7d..99f3c8d1 100644 --- a/pkg/providers/bedrock/chat_stream.go +++ b/pkg/providers/bedrock/chat_stream.go @@ -2,8 +2,7 @@ package bedrock import ( "context" - - "github.com/EinStack/glide/pkg/providers/clients" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/api/schemas" ) @@ -12,6 +11,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { - return nil, clients.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { + return nil, clients2.ErrChatStreamNotImplemented } diff --git a/pkg/providers/bedrock/client.go b/pkg/providers/bedrock/client.go index 0567b9fc..5385691d 100644 --- a/pkg/providers/bedrock/client.go +++ b/pkg/providers/bedrock/client.go @@ -3,14 +3,13 @@ package bedrock import ( "context" "errors" + "github.com/EinStack/glide/pkg/clients" "net/http" "net/url" "time" "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" diff --git a/pkg/providers/bedrock/client_test.go b/pkg/providers/bedrock/client_test.go index cdae1f68..957a754d 100644 --- a/pkg/providers/bedrock/client_test.go +++ b/pkg/providers/bedrock/client_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -11,8 +12,6 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index ddf75680..12ec6206 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -5,12 +5,11 @@ import ( "context" "encoding/json" "fmt" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" diff --git a/pkg/providers/cohere/chat_stream.go b/pkg/providers/cohere/chat_stream.go index 1d8ed243..8fb670a3 100644 --- a/pkg/providers/cohere/chat_stream.go +++ b/pkg/providers/cohere/chat_stream.go @@ -5,13 +5,12 @@ import ( "context" "encoding/json" "fmt" + clients2 "github.com/EinStack/glide/pkg/clients" "io" "net/http" "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" - "go.uber.org/zap" "github.com/EinStack/glide/pkg/api/schemas" @@ -96,7 +95,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // if io.EOF occurred in the middle of the stream, then the stream was interrupted - return nil, clients.ErrProviderUnavailable + return nil, clients2.ErrProviderUnavailable } s.tel.L().Debug( @@ -178,7 +177,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients2.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { diff --git a/pkg/providers/cohere/chat_stream_test.go b/pkg/providers/cohere/chat_stream_test.go index 7deb5b88..3d9410be 100644 --- a/pkg/providers/cohere/chat_stream_test.go +++ b/pkg/providers/cohere/chat_stream_test.go @@ -3,6 +3,7 @@ package cohere import ( "context" "encoding/json" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -14,8 +15,6 @@ import ( "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/stretchr/testify/require" ) diff --git a/pkg/providers/cohere/client.go b/pkg/providers/cohere/client.go index c13ff64b..c8a00b7f 100644 --- a/pkg/providers/cohere/client.go +++ b/pkg/providers/cohere/client.go @@ -1,13 +1,12 @@ package cohere import ( + "github.com/EinStack/glide/pkg/clients" "net/http" "net/url" "time" "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/providers/clients" ) const ( diff --git a/pkg/providers/cohere/client_test.go b/pkg/providers/cohere/client_test.go index 2e5ab487..959de556 100644 --- a/pkg/providers/cohere/client_test.go +++ b/pkg/providers/cohere/client_test.go @@ -4,6 +4,7 @@ package cohere import ( "context" "encoding/json" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -15,8 +16,6 @@ import ( "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/stretchr/testify/require" ) diff --git a/pkg/providers/cohere/errors.go b/pkg/providers/cohere/errors.go index 118ef719..bac434ff 100644 --- a/pkg/providers/cohere/errors.go +++ b/pkg/providers/cohere/errors.go @@ -2,13 +2,13 @@ package cohere import ( "fmt" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" "go.uber.org/zap" ) diff --git a/pkg/providers/config.go b/pkg/providers/config.go index 206be273..e656822c 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -3,17 +3,15 @@ package providers import ( "errors" "fmt" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/resiliency/health" "github.com/EinStack/glide/pkg/routers/latency" "github.com/EinStack/glide/pkg/providers/ollama" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/providers/bedrock" - "github.com/EinStack/glide/pkg/routers/health" - "github.com/EinStack/glide/pkg/providers/openai" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/lang.go b/pkg/providers/lang.go index d2a6aa06..4e16b979 100644 --- a/pkg/providers/lang.go +++ b/pkg/providers/lang.go @@ -2,17 +2,15 @@ package providers import ( "context" + "github.com/EinStack/glide/pkg/clients" + health2 "github.com/EinStack/glide/pkg/resiliency/health" "io" "time" "github.com/EinStack/glide/pkg/config/fields" - "github.com/EinStack/glide/pkg/routers/health" - "github.com/EinStack/glide/pkg/routers/latency" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/api/schemas" ) @@ -42,17 +40,17 @@ type LanguageModel struct { modelID string weight int client LangProvider - healthTracker *health.Tracker + healthTracker *health2.Tracker chatLatency *latency.MovingAverage chatStreamLatency *latency.MovingAverage latencyUpdateInterval *fields.Duration } -func NewLangModel(modelID string, client LangProvider, budget *health.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { +func NewLangModel(modelID string, client LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { return &LanguageModel{ modelID: modelID, client: client, - healthTracker: health.NewTracker(budget), + healthTracker: health2.NewTracker(budget), chatLatency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples), chatStreamLatency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples), latencyUpdateInterval: latencyConfig.UpdateInterval, diff --git a/pkg/providers/octoml/chat_stream.go b/pkg/providers/octoml/chat_stream.go index d0e33420..999612bc 100644 --- a/pkg/providers/octoml/chat_stream.go +++ b/pkg/providers/octoml/chat_stream.go @@ -2,8 +2,7 @@ package octoml import ( "context" - - "github.com/EinStack/glide/pkg/providers/clients" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/api/schemas" ) @@ -12,6 +11,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { - return nil, clients.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { + return nil, clients2.ErrChatStreamNotImplemented } diff --git a/pkg/providers/octoml/client.go b/pkg/providers/octoml/client.go index 07f889bb..11e3b269 100644 --- a/pkg/providers/octoml/client.go +++ b/pkg/providers/octoml/client.go @@ -2,13 +2,12 @@ package octoml import ( "errors" + "github.com/EinStack/glide/pkg/clients" "net/http" "net/url" "time" "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/providers/clients" ) const ( diff --git a/pkg/providers/octoml/client_test.go b/pkg/providers/octoml/client_test.go index f35de1f7..485d0474 100644 --- a/pkg/providers/octoml/client_test.go +++ b/pkg/providers/octoml/client_test.go @@ -3,6 +3,7 @@ package octoml import ( "context" "encoding/json" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -12,8 +13,6 @@ import ( "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/telemetry" "github.com/stretchr/testify/require" diff --git a/pkg/providers/octoml/errors.go b/pkg/providers/octoml/errors.go index 97f16840..9f446f67 100644 --- a/pkg/providers/octoml/errors.go +++ b/pkg/providers/octoml/errors.go @@ -2,13 +2,13 @@ package octoml import ( "fmt" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" "go.uber.org/zap" ) diff --git a/pkg/providers/ollama/chat.go b/pkg/providers/ollama/chat.go index b93f5b10..87acef9b 100644 --- a/pkg/providers/ollama/chat.go +++ b/pkg/providers/ollama/chat.go @@ -5,12 +5,11 @@ import ( "context" "encoding/json" "fmt" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/google/uuid" "github.com/EinStack/glide/pkg/api/schemas" diff --git a/pkg/providers/ollama/chat_stream.go b/pkg/providers/ollama/chat_stream.go index 31075ca1..a5a265d4 100644 --- a/pkg/providers/ollama/chat_stream.go +++ b/pkg/providers/ollama/chat_stream.go @@ -2,8 +2,7 @@ package ollama import ( "context" - - "github.com/EinStack/glide/pkg/providers/clients" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/api/schemas" ) @@ -12,6 +11,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { - return nil, clients.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { + return nil, clients2.ErrChatStreamNotImplemented } diff --git a/pkg/providers/ollama/client.go b/pkg/providers/ollama/client.go index 5a61898e..d54e43ed 100644 --- a/pkg/providers/ollama/client.go +++ b/pkg/providers/ollama/client.go @@ -1,13 +1,12 @@ package ollama import ( + "github.com/EinStack/glide/pkg/clients" "net/http" "net/url" "time" "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/providers/clients" ) const ( diff --git a/pkg/providers/ollama/client_test.go b/pkg/providers/ollama/client_test.go index e6c584cf..61958fa9 100644 --- a/pkg/providers/ollama/client_test.go +++ b/pkg/providers/ollama/client_test.go @@ -3,6 +3,7 @@ package ollama import ( "context" "encoding/json" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -10,8 +11,6 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 519d7d43..efc5edcf 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -5,11 +5,10 @@ import ( "context" "encoding/json" "fmt" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) diff --git a/pkg/providers/openai/chat_stream.go b/pkg/providers/openai/chat_stream.go index 08ca2b21..5b8f8b41 100644 --- a/pkg/providers/openai/chat_stream.go +++ b/pkg/providers/openai/chat_stream.go @@ -5,10 +5,10 @@ import ( "context" "encoding/json" "fmt" + clients2 "github.com/EinStack/glide/pkg/clients" "io" "net/http" - "github.com/EinStack/glide/pkg/providers/clients" "github.com/r3labs/sse/v2" "go.uber.org/zap" @@ -74,7 +74,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // if err is io.EOF, this still means that the stream is interrupted unexpectedly // because the normal stream termination is done via finding out streamDoneMarker - return nil, clients.ErrProviderUnavailable + return nil, clients2.ErrProviderUnavailable } s.logger.Debug( @@ -82,7 +82,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { zap.ByteString("rawChunk", rawEvent), ) - event, err := clients.ParseSSEvent(rawEvent) + event, err := clients2.ParseSSEvent(rawEvent) if bytes.Equal(event.Data, StreamDoneMarker) { return nil, io.EOF @@ -141,7 +141,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients2.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { diff --git a/pkg/providers/openai/chat_stream_test.go b/pkg/providers/openai/chat_stream_test.go index 1ab8483b..459192b7 100644 --- a/pkg/providers/openai/chat_stream_test.go +++ b/pkg/providers/openai/chat_stream_test.go @@ -3,6 +3,7 @@ package openai import ( "context" "encoding/json" + clients2 "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -14,14 +15,12 @@ import ( "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/stretchr/testify/require" ) func TestOpenAIClient_ChatStreamSupported(t *testing.T) { providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) @@ -64,7 +63,7 @@ func TestOpenAIClient_ChatStreamRequest(t *testing.T) { ctx := context.Background() providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() providerCfg.BaseURL = openAIServer.URL @@ -132,7 +131,7 @@ func TestOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { ctx := context.Background() providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() providerCfg.BaseURL = openAIServer.URL @@ -153,7 +152,7 @@ func TestOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { for { chunk, err := stream.Recv() if err != nil { - require.ErrorIs(t, err, clients.ErrProviderUnavailable) + require.ErrorIs(t, err, clients2.ErrProviderUnavailable) return } diff --git a/pkg/providers/openai/chat_test.go b/pkg/providers/openai/chat_test.go index 3109f150..4d626e81 100644 --- a/pkg/providers/openai/chat_test.go +++ b/pkg/providers/openai/chat_test.go @@ -3,6 +3,7 @@ package openai import ( "context" "encoding/json" + clients2 "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -10,8 +11,6 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" @@ -49,7 +48,7 @@ func TestOpenAIClient_ChatRequest(t *testing.T) { ctx := context.Background() providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() providerCfg.BaseURL = openAIServer.URL @@ -78,7 +77,7 @@ func TestOpenAIClient_RateLimit(t *testing.T) { ctx := context.Background() providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() providerCfg.BaseURL = openAIServer.URL @@ -93,5 +92,5 @@ func TestOpenAIClient_RateLimit(t *testing.T) { _, err = client.Chat(ctx, &chatParams) require.Error(t, err) - require.IsType(t, &clients.RateLimitError{}, err) + require.IsType(t, &clients2.RateLimitError{}, err) } diff --git a/pkg/providers/openai/client.go b/pkg/providers/openai/client.go index 832ade57..ec20b3ca 100644 --- a/pkg/providers/openai/client.go +++ b/pkg/providers/openai/client.go @@ -1,6 +1,7 @@ package openai import ( + "github.com/EinStack/glide/pkg/clients" "net/http" "net/url" "time" @@ -8,8 +9,6 @@ import ( "go.uber.org/zap" "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/providers/clients" ) const ( diff --git a/pkg/providers/openai/errors.go b/pkg/providers/openai/errors.go index 14978f8c..0cf2a418 100644 --- a/pkg/providers/openai/errors.go +++ b/pkg/providers/openai/errors.go @@ -2,13 +2,13 @@ package openai import ( "fmt" + "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" "go.uber.org/zap" ) diff --git a/pkg/providers/testing/lang.go b/pkg/providers/testing/lang.go index 0f7f1f4e..39389cc8 100644 --- a/pkg/providers/testing/lang.go +++ b/pkg/providers/testing/lang.go @@ -2,10 +2,9 @@ package testing import ( "context" + clients2 "github.com/EinStack/glide/pkg/clients" "io" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/api/schemas" ) @@ -124,7 +123,7 @@ func (c *ProviderMock) SupportChatStream() bool { func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) { if c.chatResps == nil { - return nil, clients.ErrProviderUnavailable + return nil, clients2.ErrProviderUnavailable } responses := *c.chatResps @@ -139,9 +138,9 @@ func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas. return response.Resp(), nil } -func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { +func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { if c.chatStreams == nil || c.idx >= len(*c.chatStreams) { - return nil, clients.ErrProviderUnavailable + return nil, clients2.ErrProviderUnavailable } streams := *c.chatStreams diff --git a/pkg/routers/health/buckets.go b/pkg/resiliency/health/buckets.go similarity index 100% rename from pkg/routers/health/buckets.go rename to pkg/resiliency/health/buckets.go diff --git a/pkg/routers/health/buckets_test.go b/pkg/resiliency/health/buckets_test.go similarity index 100% rename from pkg/routers/health/buckets_test.go rename to pkg/resiliency/health/buckets_test.go diff --git a/pkg/routers/health/error_budget.go b/pkg/resiliency/health/error_budget.go similarity index 100% rename from pkg/routers/health/error_budget.go rename to pkg/resiliency/health/error_budget.go diff --git a/pkg/routers/health/error_budget_test.go b/pkg/resiliency/health/error_budget_test.go similarity index 100% rename from pkg/routers/health/error_budget_test.go rename to pkg/resiliency/health/error_budget_test.go diff --git a/pkg/routers/health/ratelimit.go b/pkg/resiliency/health/ratelimit.go similarity index 100% rename from pkg/routers/health/ratelimit.go rename to pkg/resiliency/health/ratelimit.go diff --git a/pkg/routers/health/ratelimit_test.go b/pkg/resiliency/health/ratelimit_test.go similarity index 100% rename from pkg/routers/health/ratelimit_test.go rename to pkg/resiliency/health/ratelimit_test.go diff --git a/pkg/routers/health/tracker.go b/pkg/resiliency/health/tracker.go similarity index 94% rename from pkg/routers/health/tracker.go rename to pkg/resiliency/health/tracker.go index 8cba6e65..13e89b54 100644 --- a/pkg/routers/health/tracker.go +++ b/pkg/resiliency/health/tracker.go @@ -2,8 +2,7 @@ package health import ( "errors" - - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/clients" ) // Tracker tracks errors and general health of model provider diff --git a/pkg/routers/health/tracker_test.go b/pkg/resiliency/health/tracker_test.go similarity index 93% rename from pkg/routers/health/tracker_test.go rename to pkg/resiliency/health/tracker_test.go index 8927a041..032da2ed 100644 --- a/pkg/routers/health/tracker_test.go +++ b/pkg/resiliency/health/tracker_test.go @@ -1,10 +1,10 @@ package health import ( + "github.com/EinStack/glide/pkg/clients" "testing" "time" - "github.com/EinStack/glide/pkg/providers/clients" "github.com/stretchr/testify/require" ) diff --git a/pkg/routers/retry/config.go b/pkg/resiliency/retry/config.go similarity index 100% rename from pkg/routers/retry/config.go rename to pkg/resiliency/retry/config.go diff --git a/pkg/routers/retry/config_test.go b/pkg/resiliency/retry/config_test.go similarity index 100% rename from pkg/routers/retry/config_test.go rename to pkg/resiliency/retry/config_test.go diff --git a/pkg/routers/retry/exp.go b/pkg/resiliency/retry/exp.go similarity index 100% rename from pkg/routers/retry/exp.go rename to pkg/resiliency/retry/exp.go diff --git a/pkg/routers/retry/exp_test.go b/pkg/resiliency/retry/exp_test.go similarity index 100% rename from pkg/routers/retry/exp_test.go rename to pkg/resiliency/retry/exp_test.go diff --git a/pkg/routers/config.go b/pkg/routers/config.go index 3ae65a3a..557ecd91 100644 --- a/pkg/routers/config.go +++ b/pkg/routers/config.go @@ -14,9 +14,9 @@ type Config struct { EmbeddingRouters []EmbeddingRouterConfig `yaml:"embedding" validate:"required,dive"` } -func (c *Config) BuildLangRouters(tel *telemetry.Telemetry) ([]*LangRouter, error) { +func (c *Config) BuildLangRouters(tel *telemetry.Telemetry) ([]*lang.LangRouter, error) { seenIDs := make(map[string]bool, len(c.LanguageRouters)) - routers := make([]*LangRouter, 0, len(c.LanguageRouters)) + routers := make([]*lang.LangRouter, 0, len(c.LanguageRouters)) var errs error @@ -34,7 +34,7 @@ func (c *Config) BuildLangRouters(tel *telemetry.Telemetry) ([]*LangRouter, erro tel.L().Debug("Init router", zap.String("routerID", routerConfig.ID)) - router, err := NewLangRouter(&c.LanguageRouters[idx], tel) + router, err := lang.NewLangRouter(&c.LanguageRouters[idx], tel) if err != nil { errs = multierr.Append(errs, err) continue diff --git a/pkg/routers/embed/config.go b/pkg/routers/embed/config.go index 93346b43..52d77eef 100644 --- a/pkg/routers/embed/config.go +++ b/pkg/routers/embed/config.go @@ -2,7 +2,7 @@ package embed import ( "github.com/EinStack/glide/pkg/providers" - "github.com/EinStack/glide/pkg/routers/retry" + "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/routers/routing" ) diff --git a/pkg/routers/embed/router.go b/pkg/routers/embed/router.go index dd3542c3..94a87fbb 100644 --- a/pkg/routers/embed/router.go +++ b/pkg/routers/embed/router.go @@ -3,20 +3,20 @@ package embed import ( "context" "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/routers" - "github.com/EinStack/glide/pkg/routers/retry" + "github.com/EinStack/glide/pkg/resiliency/retry" + "github.com/EinStack/glide/pkg/routers/lang" "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/zap" ) type EmbeddingRouter struct { - routerID routers.RouterID + routerID lang.RouterID Config *LangRouterConfig retry *retry.ExpRetry tel *telemetry.Telemetry logger *zap.Logger } -func (r *routers.LangRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { +func (r *lang.LangRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { } diff --git a/pkg/routers/lang/config.go b/pkg/routers/lang/config.go index f35d1109..f4245d34 100644 --- a/pkg/routers/lang/config.go +++ b/pkg/routers/lang/config.go @@ -3,7 +3,7 @@ package lang import ( "fmt" "github.com/EinStack/glide/pkg/providers" - "github.com/EinStack/glide/pkg/routers/retry" + retry2 "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/routers/routing" "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/multierr" @@ -17,7 +17,7 @@ import ( type LangRouterConfig struct { ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled? - Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router + Retry *retry2.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests } @@ -119,11 +119,11 @@ func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*providers.L return chatModels, chatStreamModels, nil } -func (c *LangRouterConfig) BuildRetry() *retry.ExpRetry { +func (c *LangRouterConfig) BuildRetry() *retry2.ExpRetry { retryConfig := c.Retry maxDelay := time.Duration(*retryConfig.MaxDelay) - return retry.NewExpRetry( + return retry2.NewExpRetry( retryConfig.MaxRetries, retryConfig.BaseMultiplier, time.Duration(retryConfig.MinDelay), @@ -166,7 +166,7 @@ func DefaultLangRouterConfig() LangRouterConfig { return LangRouterConfig{ Enabled: true, RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), + Retry: retry2.DefaultExpRetryConfig(), } } diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang/config_test.go index abbb5bcd..fde65344 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang/config_test.go @@ -1,23 +1,20 @@ package lang import ( + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/providers/cohere" + "github.com/EinStack/glide/pkg/resiliency/health" + "github.com/EinStack/glide/pkg/resiliency/retry" routers2 "github.com/EinStack/glide/pkg/routers" "github.com/EinStack/glide/pkg/telemetry" "testing" "github.com/EinStack/glide/pkg/routers/routing" - "github.com/EinStack/glide/pkg/routers/retry" - "github.com/EinStack/glide/pkg/routers/latency" - "github.com/EinStack/glide/pkg/routers/health" - "github.com/EinStack/glide/pkg/providers/openai" - "github.com/EinStack/glide/pkg/providers/clients" - "github.com/EinStack/glide/pkg/providers" "github.com/stretchr/testify/require" diff --git a/pkg/routers/lang.go b/pkg/routers/lang/router.go similarity index 97% rename from pkg/routers/lang.go rename to pkg/routers/lang/router.go index 4a7d0d0f..368ae260 100644 --- a/pkg/routers/lang.go +++ b/pkg/routers/lang/router.go @@ -1,10 +1,10 @@ -package routers +package lang import ( "context" "errors" + "github.com/EinStack/glide/pkg/resiliency/retry" - "github.com/EinStack/glide/pkg/routers/retry" "go.uber.org/zap" "github.com/EinStack/glide/pkg/providers" @@ -238,3 +238,7 @@ func (r *LangRouter) ChatStream( &schemas.ReasonError, ) } + +func (r *LangRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { + +} diff --git a/pkg/routers/lang_test.go b/pkg/routers/lang/router_test.go similarity index 98% rename from pkg/routers/lang_test.go rename to pkg/routers/lang/router_test.go index f56216e3..8641cb0a 100644 --- a/pkg/routers/lang_test.go +++ b/pkg/routers/lang/router_test.go @@ -1,17 +1,17 @@ -package routers +package lang import ( "context" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/resiliency/health" + "github.com/EinStack/glide/pkg/resiliency/retry" "testing" "time" "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/providers" - "github.com/EinStack/glide/pkg/providers/clients" ptesting "github.com/EinStack/glide/pkg/providers/testing" - "github.com/EinStack/glide/pkg/routers/health" "github.com/EinStack/glide/pkg/routers/latency" - "github.com/EinStack/glide/pkg/routers/retry" "github.com/EinStack/glide/pkg/routers/routing" "github.com/EinStack/glide/pkg/telemetry" "github.com/stretchr/testify/require" diff --git a/pkg/routers/manager.go b/pkg/routers/manager.go index 123ea09e..7516e03f 100644 --- a/pkg/routers/manager.go +++ b/pkg/routers/manager.go @@ -2,14 +2,15 @@ package routers import ( "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/routers/lang" "github.com/EinStack/glide/pkg/telemetry" ) type RouterManager struct { Config *Config tel *telemetry.Telemetry - langRouterMap *map[string]*LangRouter - langRouters []*LangRouter + langRouterMap *map[string]*lang.LangRouter + langRouters []*lang.LangRouter } // NewManager creates a new instance of Router Manager that creates, holds and returns all routers @@ -19,7 +20,7 @@ func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { return nil, err } - langRouterMap := make(map[string]*LangRouter, len(langRouters)) + langRouterMap := make(map[string]*lang.LangRouter, len(langRouters)) for _, router := range langRouters { langRouterMap[router.ID()] = router @@ -35,12 +36,12 @@ func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { return &manager, err } -func (r *RouterManager) GetLangRouters() []*LangRouter { +func (r *RouterManager) GetLangRouters() []*lang.LangRouter { return r.langRouters } // GetLangRouter returns a router by type and ID -func (r *RouterManager) GetLangRouter(routerID string) (*LangRouter, error) { +func (r *RouterManager) GetLangRouter(routerID string) (*lang.LangRouter, error) { if router, found := (*r.langRouterMap)[routerID]; found { return router, nil } From d842fa062e8e958a424a158d9aedbe0bb06a148b Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Fri, 5 Jul 2024 13:22:29 +0300 Subject: [PATCH 03/18] #67: Restructure router config, model & provider interfaces to incorporate the new embedding router --- .gitignore | 2 + pkg/api/http/handlers.go | 11 ++- pkg/api/http/server.go | 8 +- pkg/api/servers.go | 6 +- pkg/config/config.go | 12 +-- pkg/gateway.go | 5 +- pkg/models/config.go | 40 ++++++++ pkg/{providers => models}/lang.go | 23 ++--- pkg/models/model.go | 11 +++ pkg/providers/anthropic/chat.go | 3 +- pkg/providers/anthropic/chat_stream.go | 1 + pkg/providers/anthropic/client.go | 3 +- pkg/providers/anthropic/client_test.go | 3 +- pkg/providers/anthropic/errors.go | 3 +- pkg/providers/azureopenai/chat.go | 3 +- pkg/providers/azureopenai/chat_stream.go | 3 +- pkg/providers/azureopenai/chat_stream_test.go | 3 +- pkg/providers/azureopenai/client.go | 3 +- pkg/providers/azureopenai/client_test.go | 3 +- pkg/providers/azureopenai/errors.go | 3 +- pkg/providers/bedrock/chat_stream.go | 1 + pkg/providers/bedrock/client.go | 3 +- pkg/providers/bedrock/client_test.go | 3 +- pkg/providers/cohere/chat.go | 3 +- pkg/providers/cohere/chat_stream.go | 3 +- pkg/providers/cohere/chat_stream_test.go | 3 +- pkg/providers/cohere/client.go | 3 +- pkg/providers/cohere/client_test.go | 3 +- pkg/providers/cohere/errors.go | 3 +- pkg/providers/config.go | 73 ++++---------- pkg/providers/octoml/chat_stream.go | 1 + pkg/providers/octoml/client.go | 3 +- pkg/providers/octoml/client_test.go | 3 +- pkg/providers/octoml/errors.go | 3 +- pkg/providers/ollama/chat.go | 3 +- pkg/providers/ollama/chat_stream.go | 1 + pkg/providers/ollama/client.go | 3 +- pkg/providers/ollama/client_test.go | 3 +- pkg/providers/openai/chat.go | 3 +- pkg/providers/openai/chat_stream.go | 3 +- pkg/providers/openai/chat_stream_test.go | 3 +- pkg/providers/openai/chat_test.go | 3 +- pkg/providers/openai/client.go | 3 +- pkg/providers/openai/embed.go | 13 +++ pkg/providers/openai/errors.go | 3 +- pkg/providers/provider.go | 19 ++-- pkg/providers/testing/lang.go | 3 +- pkg/resiliency/health/tracker.go | 1 + pkg/resiliency/health/tracker_test.go | 3 +- pkg/routers/config.go | 52 ++-------- pkg/routers/embed/config.go | 10 +- pkg/routers/embed/router.go | 4 +- pkg/routers/lang/config.go | 99 +++++++++++++------ pkg/routers/lang/config_test.go | 5 +- pkg/routers/lang/router.go | 40 +++----- pkg/routers/lang/router_test.go | 57 ++++++----- pkg/routers/manager/config.go | 9 ++ pkg/routers/{ => manager}/manager.go | 14 +-- pkg/routers/routing/least_latency_test.go | 2 +- 59 files changed, 343 insertions(+), 273 deletions(-) create mode 100644 pkg/models/config.go rename pkg/{providers => models}/lang.go (89%) create mode 100644 pkg/models/model.go create mode 100644 pkg/providers/openai/embed.go create mode 100644 pkg/routers/manager/config.go rename pkg/routers/{ => manager}/manager.go (71%) diff --git a/.gitignore b/.gitignore index 066b8f56..18d81aa2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,11 @@ .idea dist .env +.env.bak config.yaml bin glide +glide.exe tmp coverage.txt precommit.txt diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index 98e9f3a3..cc2ac3d3 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -4,8 +4,9 @@ import ( "context" "sync" + "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/routers" "github.com/EinStack/glide/pkg/telemetry" "github.com/gofiber/contrib/websocket" "github.com/gofiber/fiber/v2" @@ -31,7 +32,7 @@ type Handler = func(c *fiber.Ctx) error // @Failure 400 {object} schemas.Error // @Failure 404 {object} schemas.Error // @Router /v1/language/{router}/chat [POST] -func LangChatHandler(routerManager *routers.RouterManager) Handler { +func LangChatHandler(routerManager *manager.RouterManager) Handler { return func(c *fiber.Ctx) error { if !c.Is("json") { return c.Status(fiber.StatusBadRequest).JSON(schemas.ErrUnsupportedMediaType) @@ -72,7 +73,7 @@ func LangChatHandler(routerManager *routers.RouterManager) Handler { } } -func LangStreamRouterValidator(routerManager *routers.RouterManager) Handler { +func LangStreamRouterValidator(routerManager *manager.RouterManager) Handler { return func(c *fiber.Ctx) error { if websocket.IsWebSocketUpgrade(c) { routerID := c.Params("router") @@ -107,7 +108,7 @@ func LangStreamRouterValidator(routerManager *routers.RouterManager) Handler { // @Failure 426 // @Failure 404 {object} schemas.Error // @Router /v1/language/{router}/chatStream [GET] -func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.RouterManager) Handler { +func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *manager.RouterManager) Handler { // TODO: expose websocket connection configs https://github.com/gofiber/contrib/tree/main/websocket return websocket.New(func(c *websocket.Conn) { routerID := c.Params("router") @@ -175,7 +176,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout // @Produce json // @Success 200 {object} schemas.RouterListSchema // @Router /v1/language/ [GET] -func LangRoutersHandler(routerManager *routers.RouterManager) Handler { +func LangRoutersHandler(routerManager *manager.RouterManager) Handler { return func(c *fiber.Ctx) error { configuredRouters := routerManager.GetLangRouters() cfgs := make([]interface{}, 0, len(configuredRouters)) // opaque by design diff --git a/pkg/api/http/server.go b/pkg/api/http/server.go index 9b70a05f..35899963 100644 --- a/pkg/api/http/server.go +++ b/pkg/api/http/server.go @@ -6,6 +6,8 @@ import ( "fmt" "time" + "github.com/EinStack/glide/pkg/routers/manager" + "github.com/gofiber/contrib/otelfiber" "github.com/gofiber/swagger" @@ -17,19 +19,17 @@ import ( "github.com/gofiber/fiber/v2" - "github.com/EinStack/glide/pkg/routers" - "github.com/EinStack/glide/pkg/telemetry" ) type Server struct { config *ServerConfig telemetry *telemetry.Telemetry - routerManager *routers.RouterManager + routerManager *manager.RouterManager server *fiber.App } -func NewServer(config *ServerConfig, tel *telemetry.Telemetry, routerManager *routers.RouterManager) (*Server, error) { +func NewServer(config *ServerConfig, tel *telemetry.Telemetry, routerManager *manager.RouterManager) (*Server, error) { srv := config.ToServer() return &Server{ diff --git a/pkg/api/servers.go b/pkg/api/servers.go index 3588e257..da2d130a 100644 --- a/pkg/api/servers.go +++ b/pkg/api/servers.go @@ -4,9 +4,9 @@ import ( "context" "sync" - "go.uber.org/zap" + "github.com/EinStack/glide/pkg/routers/manager" - "github.com/EinStack/glide/pkg/routers" + "go.uber.org/zap" "github.com/EinStack/glide/pkg/telemetry" @@ -19,7 +19,7 @@ type ServerManager struct { telemetry *telemetry.Telemetry } -func NewServerManager(cfg *Config, tel *telemetry.Telemetry, router *routers.RouterManager) (*ServerManager, error) { +func NewServerManager(cfg *Config, tel *telemetry.Telemetry, router *manager.RouterManager) (*ServerManager, error) { httpServer, err := http.NewServer(cfg.HTTP, tel, router) if err != nil { return nil, err diff --git a/pkg/config/config.go b/pkg/config/config.go index dd520a9a..cacdc2a9 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,18 +1,16 @@ package config import ( - "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/routers" - "github.com/EinStack/glide/pkg/api" + routerconfig "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/telemetry" ) // Config is a general top-level Glide configuration type Config struct { - Telemetry *telemetry.Config `yaml:"telemetry" validate:"required"` - API *api.Config `yaml:"api" validate:"required"` - Routers routers.Config `yaml:"routers" validate:"required"` + Telemetry *telemetry.Config `yaml:"telemetry" validate:"required"` + API *api.Config `yaml:"api" validate:"required"` + Routers routerconfig.Config `yaml:"routers" validate:"required"` } func DefaultConfig() *Config { diff --git a/pkg/gateway.go b/pkg/gateway.go index 950ec26d..a3c8969d 100644 --- a/pkg/gateway.go +++ b/pkg/gateway.go @@ -7,7 +7,8 @@ import ( "os/signal" "syscall" - "github.com/EinStack/glide/pkg/routers" + "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/version" "go.opentelemetry.io/contrib/instrumentation/host" "go.opentelemetry.io/contrib/instrumentation/runtime" @@ -49,7 +50,7 @@ func NewGateway(configProvider *config.Provider) (*Gateway, error) { tel.L().Info("🐦Glide is starting up", zap.String("version", version.FullVersion)) tel.L().Debug("✅ Config loaded successfully:\n" + configProvider.GetStr()) - routerManager, err := routers.NewManager(&cfg.Routers, tel) + routerManager, err := manager.NewManager(&cfg.Routers, tel) if err != nil { return nil, err } diff --git a/pkg/models/config.go b/pkg/models/config.go new file mode 100644 index 00000000..5289ea03 --- /dev/null +++ b/pkg/models/config.go @@ -0,0 +1,40 @@ +package models + +import ( + "fmt" + + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/resiliency/health" + "github.com/EinStack/glide/pkg/routers/latency" + "github.com/EinStack/glide/pkg/telemetry" +) + +type Config[P any] struct { + ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router) + Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is the model enabled? + ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"` + Latency *latency.Config `yaml:"latency" json:"latency"` + Weight int `yaml:"weight" json:"weight"` + Client *clients.ClientConfig `yaml:"client" json:"client"` + + Provider P `yaml:"provider" json:"provider"` +} + +func DefaultConfig[P any]() Config[P] { + return Config[P]{ + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Weight: 1, + } +} + +func (c *Config) ToModel(tel *telemetry.Telemetry) (*LanguageModel, error) { + client, err := c.Provider.ToClient(tel, c.Client) + if err != nil { + return nil, fmt.Errorf("error initializing client: %w", err) + } + + return NewLangModel(c.ID, client, c.ErrorBudget, *c.Latency, c.Weight), nil +} diff --git a/pkg/providers/lang.go b/pkg/models/lang.go similarity index 89% rename from pkg/providers/lang.go rename to pkg/models/lang.go index 4e16b979..299111b6 100644 --- a/pkg/providers/lang.go +++ b/pkg/models/lang.go @@ -1,12 +1,15 @@ -package providers +package models import ( "context" - "github.com/EinStack/glide/pkg/clients" - health2 "github.com/EinStack/glide/pkg/resiliency/health" "io" "time" + "github.com/EinStack/glide/pkg/providers" + + "github.com/EinStack/glide/pkg/clients" + health2 "github.com/EinStack/glide/pkg/resiliency/health" + "github.com/EinStack/glide/pkg/config/fields" "github.com/EinStack/glide/pkg/routers/latency" @@ -14,16 +17,6 @@ import ( "github.com/EinStack/glide/pkg/api/schemas" ) -// LangProvider defines an interface a provider should fulfill to be able to serve language chat requests -type LangProvider interface { - ModelProvider - - SupportChatStream() bool - - Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) - ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) -} - type LangModel interface { Model Provider() string @@ -39,14 +32,14 @@ type LangModel interface { type LanguageModel struct { modelID string weight int - client LangProvider + client providers.LangProvider healthTracker *health2.Tracker chatLatency *latency.MovingAverage chatStreamLatency *latency.MovingAverage latencyUpdateInterval *fields.Duration } -func NewLangModel(modelID string, client LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { +func NewLangModel(modelID string, client providers.LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { return &LanguageModel{ modelID: modelID, client: client, diff --git a/pkg/models/model.go b/pkg/models/model.go new file mode 100644 index 00000000..707efee3 --- /dev/null +++ b/pkg/models/model.go @@ -0,0 +1,11 @@ +package models + +import "github.com/EinStack/glide/pkg/config/fields" + +// Model represent a configured external modality-agnostic model with its routing properties and status +type Model interface { + ID() string + Healthy() bool + LatencyUpdateInterval() *fields.Duration + Weight() int +} diff --git a/pkg/providers/anthropic/chat.go b/pkg/providers/anthropic/chat.go index 03f7591f..a89515a8 100644 --- a/pkg/providers/anthropic/chat.go +++ b/pkg/providers/anthropic/chat.go @@ -5,11 +5,12 @@ import ( "context" "encoding/json" "fmt" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) diff --git a/pkg/providers/anthropic/chat_stream.go b/pkg/providers/anthropic/chat_stream.go index 6040d2c1..dbb0b8ff 100644 --- a/pkg/providers/anthropic/chat_stream.go +++ b/pkg/providers/anthropic/chat_stream.go @@ -2,6 +2,7 @@ package anthropic import ( "context" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/api/schemas" diff --git a/pkg/providers/anthropic/client.go b/pkg/providers/anthropic/client.go index 11da9173..e42ccc31 100644 --- a/pkg/providers/anthropic/client.go +++ b/pkg/providers/anthropic/client.go @@ -1,11 +1,12 @@ package anthropic import ( - "github.com/EinStack/glide/pkg/clients" "net/http" "net/url" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" ) diff --git a/pkg/providers/anthropic/client_test.go b/pkg/providers/anthropic/client_test.go index 75be00f1..70977bb0 100644 --- a/pkg/providers/anthropic/client_test.go +++ b/pkg/providers/anthropic/client_test.go @@ -3,7 +3,6 @@ package anthropic import ( "context" "encoding/json" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -11,6 +10,8 @@ import ( "path/filepath" "testing" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/anthropic/errors.go b/pkg/providers/anthropic/errors.go index 222b1921..5c7a1370 100644 --- a/pkg/providers/anthropic/errors.go +++ b/pkg/providers/anthropic/errors.go @@ -2,11 +2,12 @@ package anthropic import ( "fmt" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/zap" diff --git a/pkg/providers/azureopenai/chat.go b/pkg/providers/azureopenai/chat.go index c3c73656..2c62dc0f 100644 --- a/pkg/providers/azureopenai/chat.go +++ b/pkg/providers/azureopenai/chat.go @@ -5,10 +5,11 @@ import ( "context" "encoding/json" "fmt" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/providers/openai" "github.com/EinStack/glide/pkg/api/schemas" diff --git a/pkg/providers/azureopenai/chat_stream.go b/pkg/providers/azureopenai/chat_stream.go index bf9bd215..dfa787c4 100644 --- a/pkg/providers/azureopenai/chat_stream.go +++ b/pkg/providers/azureopenai/chat_stream.go @@ -5,10 +5,11 @@ import ( "context" "encoding/json" "fmt" - clients2 "github.com/EinStack/glide/pkg/clients" "io" "net/http" + clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" "github.com/EinStack/glide/pkg/providers/openai" diff --git a/pkg/providers/azureopenai/chat_stream_test.go b/pkg/providers/azureopenai/chat_stream_test.go index 49d792d8..39a5b93e 100644 --- a/pkg/providers/azureopenai/chat_stream_test.go +++ b/pkg/providers/azureopenai/chat_stream_test.go @@ -3,7 +3,6 @@ package azureopenai import ( "context" "encoding/json" - clients2 "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -11,6 +10,8 @@ import ( "path/filepath" "testing" + clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/azureopenai/client.go b/pkg/providers/azureopenai/client.go index c1399307..88c5b64d 100644 --- a/pkg/providers/azureopenai/client.go +++ b/pkg/providers/azureopenai/client.go @@ -2,10 +2,11 @@ package azureopenai import ( "fmt" - "github.com/EinStack/glide/pkg/clients" "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/providers/openai" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/azureopenai/client_test.go b/pkg/providers/azureopenai/client_test.go index b92c9142..5c390114 100644 --- a/pkg/providers/azureopenai/client_test.go +++ b/pkg/providers/azureopenai/client_test.go @@ -3,7 +3,6 @@ package azureopenai import ( "context" "encoding/json" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -11,6 +10,8 @@ import ( "path/filepath" "testing" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/azureopenai/errors.go b/pkg/providers/azureopenai/errors.go index b7bb4e14..d659c027 100644 --- a/pkg/providers/azureopenai/errors.go +++ b/pkg/providers/azureopenai/errors.go @@ -2,11 +2,12 @@ package azureopenai import ( "fmt" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/zap" diff --git a/pkg/providers/bedrock/chat_stream.go b/pkg/providers/bedrock/chat_stream.go index 99f3c8d1..57413043 100644 --- a/pkg/providers/bedrock/chat_stream.go +++ b/pkg/providers/bedrock/chat_stream.go @@ -2,6 +2,7 @@ package bedrock import ( "context" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/api/schemas" diff --git a/pkg/providers/bedrock/client.go b/pkg/providers/bedrock/client.go index 5385691d..673cb49f 100644 --- a/pkg/providers/bedrock/client.go +++ b/pkg/providers/bedrock/client.go @@ -3,11 +3,12 @@ package bedrock import ( "context" "errors" - "github.com/EinStack/glide/pkg/clients" "net/http" "net/url" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" "github.com/aws/aws-sdk-go-v2/config" diff --git a/pkg/providers/bedrock/client_test.go b/pkg/providers/bedrock/client_test.go index 957a754d..e99f8d9c 100644 --- a/pkg/providers/bedrock/client_test.go +++ b/pkg/providers/bedrock/client_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -12,6 +11,8 @@ import ( "path/filepath" "testing" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index 12ec6206..4729d55f 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -5,11 +5,12 @@ import ( "context" "encoding/json" "fmt" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" diff --git a/pkg/providers/cohere/chat_stream.go b/pkg/providers/cohere/chat_stream.go index 8fb670a3..6f194945 100644 --- a/pkg/providers/cohere/chat_stream.go +++ b/pkg/providers/cohere/chat_stream.go @@ -5,10 +5,11 @@ import ( "context" "encoding/json" "fmt" - clients2 "github.com/EinStack/glide/pkg/clients" "io" "net/http" + clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/zap" diff --git a/pkg/providers/cohere/chat_stream_test.go b/pkg/providers/cohere/chat_stream_test.go index 3d9410be..82060f84 100644 --- a/pkg/providers/cohere/chat_stream_test.go +++ b/pkg/providers/cohere/chat_stream_test.go @@ -3,7 +3,6 @@ package cohere import ( "context" "encoding/json" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -11,6 +10,8 @@ import ( "path/filepath" "testing" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/cohere/client.go b/pkg/providers/cohere/client.go index c8a00b7f..a8426598 100644 --- a/pkg/providers/cohere/client.go +++ b/pkg/providers/cohere/client.go @@ -1,11 +1,12 @@ package cohere import ( - "github.com/EinStack/glide/pkg/clients" "net/http" "net/url" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" ) diff --git a/pkg/providers/cohere/client_test.go b/pkg/providers/cohere/client_test.go index 959de556..bb4f99e4 100644 --- a/pkg/providers/cohere/client_test.go +++ b/pkg/providers/cohere/client_test.go @@ -4,7 +4,6 @@ package cohere import ( "context" "encoding/json" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -12,6 +11,8 @@ import ( "path/filepath" "testing" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/cohere/errors.go b/pkg/providers/cohere/errors.go index bac434ff..5b5548c1 100644 --- a/pkg/providers/cohere/errors.go +++ b/pkg/providers/cohere/errors.go @@ -2,11 +2,12 @@ package cohere import ( "fmt" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/zap" diff --git a/pkg/providers/config.go b/pkg/providers/config.go index e656822c..8e1c66c2 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -3,37 +3,22 @@ package providers import ( "errors" "fmt" - "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/resiliency/health" - - "github.com/EinStack/glide/pkg/routers/latency" "github.com/EinStack/glide/pkg/providers/ollama" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/providers/anthropic" + "github.com/EinStack/glide/pkg/providers/azureopenai" "github.com/EinStack/glide/pkg/providers/bedrock" - + "github.com/EinStack/glide/pkg/providers/cohere" + "github.com/EinStack/glide/pkg/providers/octoml" "github.com/EinStack/glide/pkg/providers/openai" - "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/providers/octoml" - - "github.com/EinStack/glide/pkg/providers/cohere" - - "github.com/EinStack/glide/pkg/providers/azureopenai" - - "github.com/EinStack/glide/pkg/providers/anthropic" ) var ErrProviderNotFound = errors.New("provider not found") -type LangModelConfig struct { - ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router) - Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is the model enabled? - ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"` - Latency *latency.Config `yaml:"latency" json:"latency"` - Weight int `yaml:"weight" json:"weight"` - Client *clients.ClientConfig `yaml:"client" json:"client"` +type LangProviders struct { // Add other providers like OpenAI *openai.Config `yaml:"openai,omitempty" json:"openai,omitempty"` AzureOpenAI *azureopenai.Config `yaml:"azureopenai,omitempty" json:"azureopenai,omitempty"` @@ -44,47 +29,28 @@ type LangModelConfig struct { Ollama *ollama.Config `yaml:"ollama,omitempty" json:"ollama,omitempty"` } -func DefaultLangModelConfig() *LangModelConfig { - return &LangModelConfig{ - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), - Weight: 1, - } -} - -func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (*LanguageModel, error) { - client, err := c.initClient(tel) - if err != nil { - return nil, fmt.Errorf("error initializing client: %v", err) - } - - return NewLangModel(c.ID, client, c.ErrorBudget, *c.Latency, c.Weight), nil -} - -// initClient initializes the language model client based on the provided configuration. +// ToClient initializes the language model client based on the provided configuration. // It takes a telemetry object as input and returns a LangModelProvider and an error. -func (c *LangModelConfig) initClient(tel *telemetry.Telemetry) (LangProvider, error) { +func (c *LangProviders) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) { switch { case c.OpenAI != nil: - return openai.NewClient(c.OpenAI, c.Client, tel) + return openai.NewClient(c.OpenAI, clientConfig, tel) case c.AzureOpenAI != nil: - return azureopenai.NewClient(c.AzureOpenAI, c.Client, tel) + return azureopenai.NewClient(c.AzureOpenAI, clientConfig, tel) case c.Cohere != nil: - return cohere.NewClient(c.Cohere, c.Client, tel) + return cohere.NewClient(c.Cohere, clientConfig, tel) case c.OctoML != nil: - return octoml.NewClient(c.OctoML, c.Client, tel) + return octoml.NewClient(c.OctoML, clientConfig, tel) case c.Anthropic != nil: - return anthropic.NewClient(c.Anthropic, c.Client, tel) + return anthropic.NewClient(c.Anthropic, clientConfig, tel) case c.Bedrock != nil: - return bedrock.NewClient(c.Bedrock, c.Client, tel) + return bedrock.NewClient(c.Bedrock, clientConfig, tel) default: return nil, ErrProviderNotFound } } -func (c *LangModelConfig) validateOneProvider() error { +func (c *LangProviders) validateOneProvider() error { providersConfigured := 0 if c.OpenAI != nil { @@ -117,13 +83,12 @@ func (c *LangModelConfig) validateOneProvider() error { // check other providers here if providersConfigured == 0 { - return fmt.Errorf("exactly one provider must be configured for model \"%v\", none is configured", c.ID) + return fmt.Errorf("exactly one provider must be configured, none is configured") } if providersConfigured > 1 { return fmt.Errorf( - "exactly one provider must be configured for model \"%v\", %v are configured", - c.ID, + "exactly one provider must be configured, but %v are configured", providersConfigured, ) } @@ -131,8 +96,8 @@ func (c *LangModelConfig) validateOneProvider() error { return nil } -func (c *LangModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - *c = *DefaultLangModelConfig() +func (c *LangProviders) UnmarshalYAML(unmarshal func(interface{}) error) error { + *c = DefaultConfig() type plain LangModelConfig // to avoid recursion diff --git a/pkg/providers/octoml/chat_stream.go b/pkg/providers/octoml/chat_stream.go index 999612bc..7b8a1766 100644 --- a/pkg/providers/octoml/chat_stream.go +++ b/pkg/providers/octoml/chat_stream.go @@ -2,6 +2,7 @@ package octoml import ( "context" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/api/schemas" diff --git a/pkg/providers/octoml/client.go b/pkg/providers/octoml/client.go index 11e3b269..420a991a 100644 --- a/pkg/providers/octoml/client.go +++ b/pkg/providers/octoml/client.go @@ -2,11 +2,12 @@ package octoml import ( "errors" - "github.com/EinStack/glide/pkg/clients" "net/http" "net/url" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" ) diff --git a/pkg/providers/octoml/client_test.go b/pkg/providers/octoml/client_test.go index 485d0474..128fd1f0 100644 --- a/pkg/providers/octoml/client_test.go +++ b/pkg/providers/octoml/client_test.go @@ -3,7 +3,6 @@ package octoml import ( "context" "encoding/json" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -11,6 +10,8 @@ import ( "path/filepath" "testing" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/octoml/errors.go b/pkg/providers/octoml/errors.go index 9f446f67..fe1d1198 100644 --- a/pkg/providers/octoml/errors.go +++ b/pkg/providers/octoml/errors.go @@ -2,11 +2,12 @@ package octoml import ( "fmt" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/zap" diff --git a/pkg/providers/ollama/chat.go b/pkg/providers/ollama/chat.go index 87acef9b..42ee1f99 100644 --- a/pkg/providers/ollama/chat.go +++ b/pkg/providers/ollama/chat.go @@ -5,11 +5,12 @@ import ( "context" "encoding/json" "fmt" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/google/uuid" "github.com/EinStack/glide/pkg/api/schemas" diff --git a/pkg/providers/ollama/chat_stream.go b/pkg/providers/ollama/chat_stream.go index a5a265d4..15d220e9 100644 --- a/pkg/providers/ollama/chat_stream.go +++ b/pkg/providers/ollama/chat_stream.go @@ -2,6 +2,7 @@ package ollama import ( "context" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/api/schemas" diff --git a/pkg/providers/ollama/client.go b/pkg/providers/ollama/client.go index d54e43ed..85192b6b 100644 --- a/pkg/providers/ollama/client.go +++ b/pkg/providers/ollama/client.go @@ -1,11 +1,12 @@ package ollama import ( - "github.com/EinStack/glide/pkg/clients" "net/http" "net/url" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" ) diff --git a/pkg/providers/ollama/client_test.go b/pkg/providers/ollama/client_test.go index 61958fa9..1c9dad49 100644 --- a/pkg/providers/ollama/client_test.go +++ b/pkg/providers/ollama/client_test.go @@ -3,7 +3,6 @@ package ollama import ( "context" "encoding/json" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -11,6 +10,8 @@ import ( "path/filepath" "testing" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index efc5edcf..4c4b21c8 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -5,10 +5,11 @@ import ( "context" "encoding/json" "fmt" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) diff --git a/pkg/providers/openai/chat_stream.go b/pkg/providers/openai/chat_stream.go index 5b8f8b41..d362cb67 100644 --- a/pkg/providers/openai/chat_stream.go +++ b/pkg/providers/openai/chat_stream.go @@ -5,10 +5,11 @@ import ( "context" "encoding/json" "fmt" - clients2 "github.com/EinStack/glide/pkg/clients" "io" "net/http" + clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/r3labs/sse/v2" "go.uber.org/zap" diff --git a/pkg/providers/openai/chat_stream_test.go b/pkg/providers/openai/chat_stream_test.go index 459192b7..6928e6f0 100644 --- a/pkg/providers/openai/chat_stream_test.go +++ b/pkg/providers/openai/chat_stream_test.go @@ -3,7 +3,6 @@ package openai import ( "context" "encoding/json" - clients2 "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -11,6 +10,8 @@ import ( "path/filepath" "testing" + clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/openai/chat_test.go b/pkg/providers/openai/chat_test.go index 4d626e81..0aae4d0e 100644 --- a/pkg/providers/openai/chat_test.go +++ b/pkg/providers/openai/chat_test.go @@ -3,7 +3,6 @@ package openai import ( "context" "encoding/json" - clients2 "github.com/EinStack/glide/pkg/clients" "io" "net/http" "net/http/httptest" @@ -11,6 +10,8 @@ import ( "path/filepath" "testing" + clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/openai/client.go b/pkg/providers/openai/client.go index ec20b3ca..bb49dab3 100644 --- a/pkg/providers/openai/client.go +++ b/pkg/providers/openai/client.go @@ -1,11 +1,12 @@ package openai import ( - "github.com/EinStack/glide/pkg/clients" "net/http" "net/url" "time" + "github.com/EinStack/glide/pkg/clients" + "go.uber.org/zap" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/providers/openai/embed.go b/pkg/providers/openai/embed.go new file mode 100644 index 00000000..48e69328 --- /dev/null +++ b/pkg/providers/openai/embed.go @@ -0,0 +1,13 @@ +package openai + +import ( + "context" + + "github.com/EinStack/glide/pkg/api/schemas" +) + +// Embed sends an embedding request to the specified OpenAI model. +func (c *Client) Embed(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { + // TODO: implement + return nil, nil +} diff --git a/pkg/providers/openai/errors.go b/pkg/providers/openai/errors.go index 0cf2a418..58e37292 100644 --- a/pkg/providers/openai/errors.go +++ b/pkg/providers/openai/errors.go @@ -2,11 +2,12 @@ package openai import ( "fmt" - "github.com/EinStack/glide/pkg/clients" "io" "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/zap" diff --git a/pkg/providers/provider.go b/pkg/providers/provider.go index 91aded44..2341ddc9 100644 --- a/pkg/providers/provider.go +++ b/pkg/providers/provider.go @@ -1,7 +1,10 @@ package providers import ( - "github.com/EinStack/glide/pkg/config/fields" + "context" + + "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" ) // ModelProvider exposes provider context @@ -10,10 +13,12 @@ type ModelProvider interface { ModelName() string } -// Model represent a configured external modality-agnostic model with its routing properties and status -type Model interface { - ID() string - Healthy() bool - LatencyUpdateInterval() *fields.Duration - Weight() int +// LangProvider defines an interface a provider should fulfill to be able to serve language chat requests +type LangProvider interface { + ModelProvider + + SupportChatStream() bool + + Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) + ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) } diff --git a/pkg/providers/testing/lang.go b/pkg/providers/testing/lang.go index 39389cc8..3c27792a 100644 --- a/pkg/providers/testing/lang.go +++ b/pkg/providers/testing/lang.go @@ -2,9 +2,10 @@ package testing import ( "context" - clients2 "github.com/EinStack/glide/pkg/clients" "io" + clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schemas" ) diff --git a/pkg/resiliency/health/tracker.go b/pkg/resiliency/health/tracker.go index 13e89b54..3d4a313b 100644 --- a/pkg/resiliency/health/tracker.go +++ b/pkg/resiliency/health/tracker.go @@ -2,6 +2,7 @@ package health import ( "errors" + "github.com/EinStack/glide/pkg/clients" ) diff --git a/pkg/resiliency/health/tracker_test.go b/pkg/resiliency/health/tracker_test.go index 032da2ed..279bd378 100644 --- a/pkg/resiliency/health/tracker_test.go +++ b/pkg/resiliency/health/tracker_test.go @@ -1,10 +1,11 @@ package health import ( - "github.com/EinStack/glide/pkg/clients" "testing" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/stretchr/testify/require" ) diff --git a/pkg/routers/config.go b/pkg/routers/config.go index 557ecd91..99cef09f 100644 --- a/pkg/routers/config.go +++ b/pkg/routers/config.go @@ -1,51 +1,13 @@ package routers import ( - "fmt" - "github.com/EinStack/glide/pkg/routers/lang" - "github.com/EinStack/glide/pkg/telemetry" - - "go.uber.org/multierr" - "go.uber.org/zap" + "github.com/EinStack/glide/pkg/resiliency/retry" + "github.com/EinStack/glide/pkg/routers/routing" ) -type Config struct { - LanguageRouters []lang.LangRouterConfig `yaml:"language" validate:"required,dive"` // the list of language routers - EmbeddingRouters []EmbeddingRouterConfig `yaml:"embedding" validate:"required,dive"` -} - -func (c *Config) BuildLangRouters(tel *telemetry.Telemetry) ([]*lang.LangRouter, error) { - seenIDs := make(map[string]bool, len(c.LanguageRouters)) - routers := make([]*lang.LangRouter, 0, len(c.LanguageRouters)) - - var errs error - - for idx, routerConfig := range c.LanguageRouters { - if _, ok := seenIDs[routerConfig.ID]; ok { - return nil, fmt.Errorf("ID \"%v\" is specified for more than one router while each ID should be unique", routerConfig.ID) - } - - seenIDs[routerConfig.ID] = true - - if !routerConfig.Enabled { - tel.L().Info(fmt.Sprintf("Router \"%v\" is disabled, skipping", routerConfig.ID)) - continue - } - - tel.L().Debug("Init router", zap.String("routerID", routerConfig.ID)) - - router, err := lang.NewLangRouter(&c.LanguageRouters[idx], tel) - if err != nil { - errs = multierr.Append(errs, err) - continue - } - - routers = append(routers, router) - } - - if errs != nil { - return nil, errs - } - - return routers, nil +type RouterConfig struct { + ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID + Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled? + Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router + RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request } diff --git a/pkg/routers/embed/config.go b/pkg/routers/embed/config.go index 52d77eef..63894e82 100644 --- a/pkg/routers/embed/config.go +++ b/pkg/routers/embed/config.go @@ -2,14 +2,10 @@ package embed import ( "github.com/EinStack/glide/pkg/providers" - "github.com/EinStack/glide/pkg/resiliency/retry" - "github.com/EinStack/glide/pkg/routers/routing" + "github.com/EinStack/glide/pkg/routers" ) type EmbeddingRouterConfig struct { - ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID - Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled? - Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router - RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request - Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests + routers.RouterConfig + Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests } diff --git a/pkg/routers/embed/router.go b/pkg/routers/embed/router.go index 94a87fbb..501f63e6 100644 --- a/pkg/routers/embed/router.go +++ b/pkg/routers/embed/router.go @@ -2,6 +2,7 @@ package embed import ( "context" + "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/routers/lang" @@ -17,6 +18,5 @@ type EmbeddingRouter struct { logger *zap.Logger } -func (r *lang.LangRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { - +func (r *EmbeddingRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { } diff --git a/pkg/routers/lang/config.go b/pkg/routers/lang/config.go index f4245d34..f68d1c8b 100644 --- a/pkg/routers/lang/config.go +++ b/pkg/routers/lang/config.go @@ -2,33 +2,34 @@ package lang import ( "fmt" + "time" + + "github.com/EinStack/glide/pkg/routers" + + "github.com/EinStack/glide/pkg/models" "github.com/EinStack/glide/pkg/providers" - retry2 "github.com/EinStack/glide/pkg/resiliency/retry" + "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/routers/routing" "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/multierr" "go.uber.org/zap" - "time" ) // TODO: how to specify other backoff strategies? // TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738 -// LangRouterConfig -type LangRouterConfig struct { - ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID - Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled? - Retry *retry2.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router - RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request - Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests +// RouterConfig +type RouterConfig struct { + routers.RouterConfig + Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests } // BuildModels creates LanguageModel slice out of the given config -func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*providers.LanguageModel, []*providers.LanguageModel, error) { //nolint: cyclop +func (c *RouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*models.LanguageModel, []*models.LanguageModel, error) { //nolint: cyclop var errs error seenIDs := make(map[string]bool, len(c.Models)) - chatModels := make([]*providers.LanguageModel, 0, len(c.Models)) - chatStreamModels := make([]*providers.LanguageModel, 0, len(c.Models)) + chatModels := make([]*models.LanguageModel, 0, len(c.Models)) + chatStreamModels := make([]*models.LanguageModel, 0, len(c.Models)) for _, modelConfig := range c.Models { if _, ok := seenIDs[modelConfig.ID]; ok { @@ -119,11 +120,11 @@ func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*providers.L return chatModels, chatStreamModels, nil } -func (c *LangRouterConfig) BuildRetry() *retry2.ExpRetry { +func (c *RouterConfig) BuildRetry() *retry.ExpRetry { retryConfig := c.Retry maxDelay := time.Duration(*retryConfig.MaxDelay) - return retry2.NewExpRetry( + return retry.NewExpRetry( retryConfig.MaxRetries, retryConfig.BaseMultiplier, time.Duration(retryConfig.MinDelay), @@ -131,12 +132,12 @@ func (c *LangRouterConfig) BuildRetry() *retry2.ExpRetry { ) } -func (c *LangRouterConfig) BuildRouting( - chatModels []*providers.LanguageModel, - chatStreamModels []*providers.LanguageModel, +func (c *RouterConfig) BuildRouting( + chatModels []*models.LanguageModel, + chatStreamModels []*models.LanguageModel, ) (routing.LangModelRouting, routing.LangModelRouting, error) { - chatModelPool := make([]providers.Model, 0, len(chatModels)) - chatStreamModelPool := make([]providers.Model, 0, len(chatStreamModels)) + chatModelPool := make([]models.Model, 0, len(chatModels)) + chatStreamModelPool := make([]models.Model, 0, len(chatStreamModels)) for _, model := range chatModels { chatModelPool = append(chatModelPool, model) @@ -154,26 +155,66 @@ func (c *LangRouterConfig) BuildRouting( case routing.WeightedRoundRobin: return routing.NewWeightedRoundRobin(chatModelPool), routing.NewWeightedRoundRobin(chatStreamModelPool), nil case routing.LeastLatency: - return routing.NewLeastLatencyRouting(providers.ChatLatency, chatModelPool), - routing.NewLeastLatencyRouting(providers.ChatStreamLatency, chatStreamModelPool), + return routing.NewLeastLatencyRouting(models.ChatLatency, chatModelPool), + routing.NewLeastLatencyRouting(models.ChatStreamLatency, chatStreamModelPool), nil } return nil, nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy) } -func DefaultLangRouterConfig() LangRouterConfig { - return LangRouterConfig{ - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry2.DefaultExpRetryConfig(), +func DefaultRouterConfig() *RouterConfig { + return &RouterConfig{ + RouterConfig: routers.RouterConfig{ + Enabled: true, + RoutingStrategy: routing.Priority, + Retry: retry.DefaultExpRetryConfig(), + }, } } -func (c *LangRouterConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - *c = DefaultLangRouterConfig() +func (c *RouterConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + *c = *DefaultRouterConfig() - type plain LangRouterConfig // to avoid recursion + type plain RouterConfig // to avoid recursion return unmarshal((*plain)(c)) } + +type RoutersConfig []RouterConfig + +func (c RoutersConfig) Build(tel *telemetry.Telemetry) ([]*Router, error) { + seenIDs := make(map[string]bool, len(c)) + langRouters := make([]*Router, 0, len(c)) + + var errs error + + for idx, routerConfig := range c { + if _, ok := seenIDs[routerConfig.ID]; ok { + return nil, fmt.Errorf("ID \"%v\" is specified for more than one router while each ID should be unique", routerConfig.ID) + } + + seenIDs[routerConfig.ID] = true + + if !routerConfig.Enabled { + tel.L().Info(fmt.Sprintf("Router \"%v\" is disabled, skipping", routerConfig.ID)) + continue + } + + tel.L().Debug("Init router", zap.String("routerID", routerConfig.ID)) + + router, err := NewLangRouter(&c[idx], tel) + if err != nil { + errs = multierr.Append(errs, err) + continue + } + + langRouters = append(langRouters, router) + } + + if errs != nil { + return nil, errs + } + + return langRouters, nil +} diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang/config_test.go index fde65344..79cbb210 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang/config_test.go @@ -1,13 +1,14 @@ package lang import ( + "testing" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/providers/cohere" "github.com/EinStack/glide/pkg/resiliency/health" "github.com/EinStack/glide/pkg/resiliency/retry" routers2 "github.com/EinStack/glide/pkg/routers" "github.com/EinStack/glide/pkg/telemetry" - "testing" "github.com/EinStack/glide/pkg/routers/routing" @@ -24,7 +25,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { defaultParams := openai.DefaultParams() cfg := routers2.Config{ - LanguageRouters: []LangRouterConfig{ + LanguageRouters: []RouterConfig{ { ID: "first_router", Enabled: true, diff --git a/pkg/routers/lang/router.go b/pkg/routers/lang/router.go index 368ae260..caf98a4c 100644 --- a/pkg/routers/lang/router.go +++ b/pkg/routers/lang/router.go @@ -3,28 +3,24 @@ package lang import ( "context" "errors" - "github.com/EinStack/glide/pkg/resiliency/retry" - - "go.uber.org/zap" - - "github.com/EinStack/glide/pkg/providers" - - "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/routers/routing" "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/resiliency/retry" + "github.com/EinStack/glide/pkg/routers/routing" + "github.com/EinStack/glide/pkg/telemetry" + "go.uber.org/zap" ) var ErrNoModels = errors.New("no models configured for router") type RouterID = string -type LangRouter struct { +type Router struct { routerID RouterID - Config *LangRouterConfig - chatModels []*providers.LanguageModel - chatStreamModels []*providers.LanguageModel + Config *RouterConfig + chatModels []*models.LanguageModel + chatStreamModels []*models.LanguageModel chatRouting routing.LangModelRouting chatStreamRouting routing.LangModelRouting retry *retry.ExpRetry @@ -32,7 +28,7 @@ type LangRouter struct { logger *zap.Logger } -func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter, error) { +func NewLangRouter(cfg *RouterConfig, tel *telemetry.Telemetry) (*Router, error) { chatModels, chatStreamModels, err := cfg.BuildModels(tel) if err != nil { return nil, err @@ -43,7 +39,7 @@ func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter return nil, err } - router := &LangRouter{ + router := &Router{ routerID: cfg.ID, Config: cfg, chatModels: chatModels, @@ -58,11 +54,11 @@ func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter return router, err } -func (r *LangRouter) ID() RouterID { +func (r *Router) ID() RouterID { return r.routerID } -func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) { +func (r *Router) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) { if len(r.chatModels) == 0 { return nil, ErrNoModels } @@ -80,7 +76,7 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem break } - langModel := model.(providers.LangModel) + langModel := model.(models.LangModel) chatParams := req.Params(langModel.ID(), langModel.ModelName()) @@ -118,7 +114,7 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem return nil, &schemas.ErrNoModelAvailable } -func (r *LangRouter) ChatStream( +func (r *Router) ChatStream( ctx context.Context, req *schemas.ChatStreamRequest, respC chan<- *schemas.ChatStreamMessage, @@ -150,7 +146,7 @@ func (r *LangRouter) ChatStream( break } - langModel := model.(providers.LangModel) + langModel := model.(models.LangModel) chatParams := req.Params(langModel.ID(), langModel.ModelName()) modelRespC, err := langModel.ChatStream(ctx, chatParams) @@ -238,7 +234,3 @@ func (r *LangRouter) ChatStream( &schemas.ReasonError, ) } - -func (r *LangRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { - -} diff --git a/pkg/routers/lang/router_test.go b/pkg/routers/lang/router_test.go index 8641cb0a..41515958 100644 --- a/pkg/routers/lang/router_test.go +++ b/pkg/routers/lang/router_test.go @@ -2,11 +2,14 @@ package lang import ( "context" + "testing" + "time" + + "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/resiliency/health" "github.com/EinStack/glide/pkg/resiliency/retry" - "testing" - "time" "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/providers" @@ -21,15 +24,15 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*models.LanguageModel{ + models.NewLangModel( "first", ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, ), - providers.NewLangModel( + models.NewLangModel( "second", ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}}), budget, @@ -68,22 +71,22 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*models.LanguageModel{ + models.NewLangModel( "first", ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "3"}}), budget, *latConfig, 1, ), - providers.NewLangModel( + models.NewLangModel( "second", ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "4"}}), budget, *latConfig, 1, ), - providers.NewLangModel( + models.NewLangModel( "third", ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, @@ -126,15 +129,15 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { budget := health.NewErrorBudget(1, health.MILLI) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*models.LanguageModel{ + models.NewLangModel( "first", ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "2"}}), budget, *latConfig, 1, ), - providers.NewLangModel( + models.NewLangModel( "second", ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "1"}}), budget, @@ -170,15 +173,15 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { budget := health.NewErrorBudget(1, health.MIN) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*models.LanguageModel{ + models.NewLangModel( "first", ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: clients.ErrProviderUnavailable}, {Msg: "3"}}), budget, *latConfig, 1, ), - providers.NewLangModel( + models.NewLangModel( "second", ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, @@ -216,15 +219,15 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*models.LanguageModel{ + models.NewLangModel( "first", ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), budget, *latConfig, 1, ), - providers.NewLangModel( + models.NewLangModel( "second", ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), budget, @@ -259,8 +262,8 @@ func TestLangRouter_ChatStream(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*models.LanguageModel{ + models.NewLangModel( "first", ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ ptesting.NewRespStreamMock(&[]ptesting.RespMock{ @@ -275,7 +278,7 @@ func TestLangRouter_ChatStream(t *testing.T) { *latConfig, 1, ), - providers.NewLangModel( + models.NewLangModel( "second", ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ ptesting.NewRespStreamMock(&[]ptesting.RespMock{ @@ -335,15 +338,15 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*models.LanguageModel{ + models.NewLangModel( "first", ptesting.NewStreamProviderMock(nil, nil), budget, *latConfig, 1, ), - providers.NewLangModel( + models.NewLangModel( "second", ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ ptesting.NewRespStreamMock( @@ -405,8 +408,8 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*models.LanguageModel{ + models.NewLangModel( "first", ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ ptesting.NewRespStreamMock(&[]ptesting.RespMock{ @@ -417,7 +420,7 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { *latConfig, 1, ), - providers.NewLangModel( + models.NewLangModel( "second", ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ ptesting.NewRespStreamMock(&[]ptesting.RespMock{ diff --git a/pkg/routers/manager/config.go b/pkg/routers/manager/config.go new file mode 100644 index 00000000..aaaeac09 --- /dev/null +++ b/pkg/routers/manager/config.go @@ -0,0 +1,9 @@ +package manager + +import "github.com/EinStack/glide/pkg/routers/lang" + +// Config defines a config for a set of supported router types +type Config struct { + LanguageRouters lang.RoutersConfig `yaml:"language" validate:"required,dive"` // the list of language routers + // EmbeddingRouters []EmbeddingRouterConfig `yaml:"embedding" validate:"required,dive"` +} diff --git a/pkg/routers/manager.go b/pkg/routers/manager/manager.go similarity index 71% rename from pkg/routers/manager.go rename to pkg/routers/manager/manager.go index 7516e03f..add72012 100644 --- a/pkg/routers/manager.go +++ b/pkg/routers/manager/manager.go @@ -1,4 +1,4 @@ -package routers +package manager import ( "github.com/EinStack/glide/pkg/api/schemas" @@ -9,18 +9,18 @@ import ( type RouterManager struct { Config *Config tel *telemetry.Telemetry - langRouterMap *map[string]*lang.LangRouter - langRouters []*lang.LangRouter + langRouterMap *map[string]*lang.Router + langRouters []*lang.Router } // NewManager creates a new instance of Router Manager that creates, holds and returns all routers func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { - langRouters, err := cfg.BuildLangRouters(tel) + langRouters, err := cfg.LanguageRouters.Build(tel) if err != nil { return nil, err } - langRouterMap := make(map[string]*lang.LangRouter, len(langRouters)) + langRouterMap := make(map[string]*lang.Router, len(langRouters)) for _, router := range langRouters { langRouterMap[router.ID()] = router @@ -36,12 +36,12 @@ func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { return &manager, err } -func (r *RouterManager) GetLangRouters() []*lang.LangRouter { +func (r *RouterManager) GetLangRouters() []*lang.Router { return r.langRouters } // GetLangRouter returns a router by type and ID -func (r *RouterManager) GetLangRouter(routerID string) (*lang.LangRouter, error) { +func (r *RouterManager) GetLangRouter(routerID string) (*lang.Router, error) { if router, found := (*r.langRouterMap)[routerID]; found { return router, nil } diff --git a/pkg/routers/routing/least_latency_test.go b/pkg/routers/routing/least_latency_test.go index 0ed9c51b..d65ee0d2 100644 --- a/pkg/routers/routing/least_latency_test.go +++ b/pkg/routers/routing/least_latency_test.go @@ -150,7 +150,7 @@ func TestLeastLatencyRouting_NoHealthyModels(t *testing.T) { models = append(models, ptesting.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) } - routing := NewLeastLatencyRouting(providers.ChatLatency, models) + routing := NewLeastLatencyRouting(models.ChatLatency, models) iterator := routing.Iterator() _, err := iterator.Next() From c924140ec6b1a4169104fd1396e695ded2370cc2 Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Fri, 5 Jul 2024 15:16:43 +0300 Subject: [PATCH 04/18] #67: Fixed broken code after the first wave of refactoring/restructuring --- config.dev.yaml | 5 +- config.sample.yaml | 9 + pkg/models/config.go | 17 +- pkg/providers/config.go | 21 +- pkg/providers/openai/embed.go | 2 +- pkg/providers/provider.go | 3 + pkg/providers/testing/models.go | 6 +- pkg/routers/config.go | 11 + pkg/routers/embed/config.go | 3 +- pkg/routers/embed/router.go | 24 +- pkg/routers/lang/config.go | 37 ++- pkg/routers/lang/config_test.go | 216 ++++++++---------- pkg/routers/lang/router_test.go | 87 ++++--- pkg/routers/routing/least_latency.go | 14 +- pkg/routers/routing/least_latency_test.go | 16 +- pkg/routers/routing/priority.go | 16 +- pkg/routers/routing/priority_test.go | 14 +- pkg/routers/routing/round_robin.go | 8 +- pkg/routers/routing/round_robin_test.go | 14 +- pkg/routers/routing/strategies.go | 4 +- pkg/routers/routing/weighted_round_robin.go | 8 +- .../routing/weighted_round_robin_test.go | 14 +- 22 files changed, 286 insertions(+), 263 deletions(-) diff --git a/config.dev.yaml b/config.dev.yaml index 80c77a5b..8bd2af4a 100644 --- a/config.dev.yaml +++ b/config.dev.yaml @@ -8,5 +8,6 @@ routers: - id: default models: - id: openai - openai: - api_key: "${env:OPENAI_API_KEY}" + provider: + openai: + api_key: "${env:OPENAI_API_KEY}" diff --git a/config.sample.yaml b/config.sample.yaml index 3ce72055..7118a6f5 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -6,3 +6,12 @@ telemetry: #api: # http: # ... + +routers: + language: + - id: default + models: + - id: openai + provider: + openai: + api_key: "${env:OPENAI_API_KEY}" diff --git a/pkg/models/config.go b/pkg/models/config.go index 5289ea03..97c93443 100644 --- a/pkg/models/config.go +++ b/pkg/models/config.go @@ -3,13 +3,16 @@ package models import ( "fmt" + "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/resiliency/health" "github.com/EinStack/glide/pkg/routers/latency" "github.com/EinStack/glide/pkg/telemetry" ) -type Config[P any] struct { +// Config defines an extra configuration for a model wrapper around a provider +type Config[P providers.ProviderFactory] struct { ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router) Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is the model enabled? ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"` @@ -20,7 +23,15 @@ type Config[P any] struct { Provider P `yaml:"provider" json:"provider"` } -func DefaultConfig[P any]() Config[P] { +func NewConfig[P providers.ProviderFactory](ID string) *Config[P] { + config := DefaultConfig[P]() + + config.ID = ID + + return &config +} + +func DefaultConfig[P providers.ProviderFactory]() Config[P] { return Config[P]{ Enabled: true, Client: clients.DefaultClientConfig(), @@ -30,7 +41,7 @@ func DefaultConfig[P any]() Config[P] { } } -func (c *Config) ToModel(tel *telemetry.Telemetry) (*LanguageModel, error) { +func (c *Config[P]) ToModel(tel *telemetry.Telemetry) (*LanguageModel, error) { client, err := c.Provider.ToClient(tel, c.Client) if err != nil { return nil, fmt.Errorf("error initializing client: %w", err) diff --git a/pkg/providers/config.go b/pkg/providers/config.go index 8e1c66c2..d4145611 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -16,7 +16,16 @@ import ( "github.com/EinStack/glide/pkg/telemetry" ) -var ErrProviderNotFound = errors.New("provider not found") +// TODO: ProviderFactory should be more generic, not tied to LangProviders + +var ErrNoProviderConfigured = errors.New("exactly one provider must be configured, none is configured") + +type ProviderFactory interface { + ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) +} + +// TODO: LangProviders should be decoupled and +// represented as a registry where providers can add their factories dynamically type LangProviders struct { // Add other providers like @@ -29,9 +38,11 @@ type LangProviders struct { Ollama *ollama.Config `yaml:"ollama,omitempty" json:"ollama,omitempty"` } +var _ ProviderFactory = (*LangProviders)(nil) + // ToClient initializes the language model client based on the provided configuration. // It takes a telemetry object as input and returns a LangModelProvider and an error. -func (c *LangProviders) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) { +func (c LangProviders) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) { switch { case c.OpenAI != nil: return openai.NewClient(c.OpenAI, clientConfig, tel) @@ -83,7 +94,7 @@ func (c *LangProviders) validateOneProvider() error { // check other providers here if providersConfigured == 0 { - return fmt.Errorf("exactly one provider must be configured, none is configured") + return ErrNoProviderConfigured } if providersConfigured > 1 { @@ -97,9 +108,7 @@ func (c *LangProviders) validateOneProvider() error { } func (c *LangProviders) UnmarshalYAML(unmarshal func(interface{}) error) error { - *c = DefaultConfig() - - type plain LangModelConfig // to avoid recursion + type plain LangProviders // to avoid recursion if err := unmarshal((*plain)(c)); err != nil { return err diff --git a/pkg/providers/openai/embed.go b/pkg/providers/openai/embed.go index 48e69328..69f9aa27 100644 --- a/pkg/providers/openai/embed.go +++ b/pkg/providers/openai/embed.go @@ -7,7 +7,7 @@ import ( ) // Embed sends an embedding request to the specified OpenAI model. -func (c *Client) Embed(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Embed(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) { // TODO: implement return nil, nil } diff --git a/pkg/providers/provider.go b/pkg/providers/provider.go index 2341ddc9..7fdefc25 100644 --- a/pkg/providers/provider.go +++ b/pkg/providers/provider.go @@ -2,11 +2,14 @@ package providers import ( "context" + "errors" "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/clients" ) +var ErrProviderNotFound = errors.New("provider not found") + // ModelProvider exposes provider context type ModelProvider interface { Provider() string diff --git a/pkg/providers/testing/models.go b/pkg/providers/testing/models.go index d4ac3840..57500d21 100644 --- a/pkg/providers/testing/models.go +++ b/pkg/providers/testing/models.go @@ -4,10 +4,8 @@ import ( "time" "github.com/EinStack/glide/pkg/config/fields" - + "github.com/EinStack/glide/pkg/models" "github.com/EinStack/glide/pkg/routers/latency" - - "github.com/EinStack/glide/pkg/providers" ) // LangModelMock @@ -55,6 +53,6 @@ func (m LangModelMock) Weight() int { return m.weight } -func ChatMockLatency(model providers.Model) *latency.MovingAverage { +func ChatMockLatency(model models.Model) *latency.MovingAverage { return model.(LangModelMock).chatLatency } diff --git a/pkg/routers/config.go b/pkg/routers/config.go index 99cef09f..a3c8f69a 100644 --- a/pkg/routers/config.go +++ b/pkg/routers/config.go @@ -5,9 +5,20 @@ import ( "github.com/EinStack/glide/pkg/routers/routing" ) +// TODO: how to specify other backoff strategies? +// TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738 + type RouterConfig struct { ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled? Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request } + +func DefaultConfig() RouterConfig { + return RouterConfig{ + Enabled: true, + RoutingStrategy: routing.Priority, + Retry: retry.DefaultExpRetryConfig(), + } +} diff --git a/pkg/routers/embed/config.go b/pkg/routers/embed/config.go index 63894e82..49f4821b 100644 --- a/pkg/routers/embed/config.go +++ b/pkg/routers/embed/config.go @@ -1,11 +1,10 @@ package embed import ( - "github.com/EinStack/glide/pkg/providers" "github.com/EinStack/glide/pkg/routers" ) type EmbeddingRouterConfig struct { routers.RouterConfig - Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests + // Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests } diff --git a/pkg/routers/embed/router.go b/pkg/routers/embed/router.go index 501f63e6..9068537b 100644 --- a/pkg/routers/embed/router.go +++ b/pkg/routers/embed/router.go @@ -1,22 +1,12 @@ package embed -import ( - "context" - - "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/resiliency/retry" - "github.com/EinStack/glide/pkg/routers/lang" - "github.com/EinStack/glide/pkg/telemetry" - "go.uber.org/zap" -) - type EmbeddingRouter struct { - routerID lang.RouterID - Config *LangRouterConfig - retry *retry.ExpRetry - tel *telemetry.Telemetry - logger *zap.Logger + // routerID lang.RouterID + // Config *LangRouterConfig + // retry *retry.ExpRetry + // tel *telemetry.Telemetry + // logger *zap.Logger } -func (r *EmbeddingRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { -} +//func (r *EmbeddingRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { +//} diff --git a/pkg/routers/lang/config.go b/pkg/routers/lang/config.go index f68d1c8b..0be9ef06 100644 --- a/pkg/routers/lang/config.go +++ b/pkg/routers/lang/config.go @@ -15,12 +15,37 @@ import ( "go.uber.org/zap" ) -// TODO: how to specify other backoff strategies? -// TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738 +type ( + ModelConfig = models.Config[providers.LangProviders] + ModelPoolConfig = []ModelConfig +) + // RouterConfig type RouterConfig struct { routers.RouterConfig - Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests + Models ModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests +} + +type RouterConfigOption = func(*RouterConfig) + +func WithModels(models ModelPoolConfig) RouterConfigOption { + return func(c *RouterConfig) { + c.Models = models + } +} + +func NewRouterConfig(RouterID string, opt ...RouterConfigOption) *RouterConfig { + config := &RouterConfig{ + RouterConfig: routers.DefaultConfig(), + } + + config.ID = RouterID + + for _, o := range opt { + o(config) + } + + return config } // BuildModels creates LanguageModel slice out of the given config @@ -165,11 +190,7 @@ func (c *RouterConfig) BuildRouting( func DefaultRouterConfig() *RouterConfig { return &RouterConfig{ - RouterConfig: routers.RouterConfig{ - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - }, + RouterConfig: routers.DefaultConfig(), } } diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang/config_test.go index 79cbb210..975cdcb6 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang/config_test.go @@ -4,70 +4,59 @@ import ( "testing" "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/providers" "github.com/EinStack/glide/pkg/providers/cohere" + "github.com/EinStack/glide/pkg/providers/openai" "github.com/EinStack/glide/pkg/resiliency/health" - "github.com/EinStack/glide/pkg/resiliency/retry" - routers2 "github.com/EinStack/glide/pkg/routers" - "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/routers/routing" - "github.com/EinStack/glide/pkg/routers/latency" - - "github.com/EinStack/glide/pkg/providers/openai" - - "github.com/EinStack/glide/pkg/providers" - + "github.com/EinStack/glide/pkg/routers/routing" + "github.com/EinStack/glide/pkg/telemetry" "github.com/stretchr/testify/require" ) func TestRouterConfig_BuildModels(t *testing.T) { defaultParams := openai.DefaultParams() - cfg := routers2.Config{ - LanguageRouters: []RouterConfig{ - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), + cfg := RoutersConfig{ + *NewRouterConfig( + "first_router", + WithModels(ModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: providers.LangProviders{ OpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, }, }, - }, - { - ID: "second_router", - Enabled: true, - RoutingStrategy: routing.LeastLatency, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), + }), + ), + *NewRouterConfig( + "second_router", + WithModels(ModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: providers.LangProviders{ OpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, }, }, - }, - }, + }), + ), } - routers, err := cfg.BuildLangRouters(telemetry.NewTelemetryMock()) + routers, err := cfg.Build(telemetry.NewTelemetryMock()) require.NoError(t, err) require.Len(t, routers, 2) @@ -82,21 +71,20 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { openAIParams := openai.DefaultParams() cohereParams := cohere.DefaultParams() - cfg := LangRouterConfig{ - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ + cfg := NewRouterConfig( + "first_router", + WithModels(ModelPoolConfig{ { ID: "first_model", Enabled: true, Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - OpenAI: &openai.Config{ - APIKey: "ABC", - DefaultParams: &openAIParams, + Provider: providers.LangProviders{ + OpenAI: &openai.Config{ + APIKey: "ABC", + DefaultParams: &openAIParams, + }, }, }, { @@ -105,13 +93,15 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Cohere: &cohere.Config{ - APIKey: "ABC", - DefaultParams: &cohereParams, + Provider: providers.LangProviders{ + Cohere: &cohere.Config{ + APIKey: "ABC", + DefaultParams: &cohereParams, + }, }, }, - }, - } + }), + ) chatModels, streamChatModels, err := cfg.BuildModels(tel) @@ -125,108 +115,98 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { tests := []struct { name string - config routers2.Config + config RoutersConfig }{ { "duplicated router IDs", - routers2.Config{ - LanguageRouters: []LangRouterConfig{ - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), + RoutersConfig{ + *NewRouterConfig( + "first_router", + WithModels(ModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: providers.LangProviders{ OpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, }, }, - }, - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.LeastLatency, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), + }), + ), + *NewRouterConfig( + "first_router", + WithModels(ModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: providers.LangProviders{ OpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, }, }, - }, - }, + }), + ), }, }, { "duplicated model IDs", - routers2.Config{ - LanguageRouters: []LangRouterConfig{ - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), + RoutersConfig{ + *NewRouterConfig( + "first_router", + WithModels(ModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: providers.LangProviders{ OpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, }, - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), + }, + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: providers.LangProviders{ OpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, }, }, - }, - }, + }), + ), }, }, { "no models", - routers2.Config{ - LanguageRouters: []LangRouterConfig{ - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{}, - }, - }, + RoutersConfig{ + *NewRouterConfig( + "first_router", + WithModels(ModelPoolConfig{}), + ), }, }, } for _, test := range tests { - _, err := test.config.BuildLangRouters(telemetry.NewTelemetryMock()) + _, err := test.config.Build(telemetry.NewTelemetryMock()) require.Error(t, err) } diff --git a/pkg/routers/lang/router_test.go b/pkg/routers/lang/router_test.go index 41515958..087e2a71 100644 --- a/pkg/routers/lang/router_test.go +++ b/pkg/routers/lang/router_test.go @@ -12,7 +12,6 @@ import ( "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/providers" ptesting "github.com/EinStack/glide/pkg/providers/testing" "github.com/EinStack/glide/pkg/routers/latency" "github.com/EinStack/glide/pkg/routers/routing" @@ -41,16 +40,15 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), @@ -95,19 +93,18 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } expectedModels := []string{"third", "third"} - router := LangRouter{ + router := Router{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - chatRouting: routing.NewPriority(models), - chatStreamRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), + chatStreamRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), @@ -146,17 +143,16 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), - chatRouting: routing.NewPriority(models), - chatStreamRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), + chatStreamRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), @@ -190,19 +186,18 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), } @@ -236,19 +231,18 @@ func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), } @@ -293,18 +287,17 @@ func TestLangRouter_ChatStream(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_stream_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), @@ -363,18 +356,17 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_stream_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), @@ -433,19 +425,18 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), } diff --git a/pkg/routers/routing/least_latency.go b/pkg/routers/routing/least_latency.go index 015c044e..e6c56a6f 100644 --- a/pkg/routers/routing/least_latency.go +++ b/pkg/routers/routing/least_latency.go @@ -5,9 +5,9 @@ import ( "sync/atomic" "time" - "github.com/EinStack/glide/pkg/routers/latency" + "github.com/EinStack/glide/pkg/models" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/routers/latency" ) const ( @@ -15,16 +15,16 @@ const ( ) // LatencyGetter defines where to find latency for the specific model action -type LatencyGetter = func(model providers.Model) *latency.MovingAverage +type LatencyGetter = func(model models.Model) *latency.MovingAverage // ModelSchedule defines latency update schedule for models type ModelSchedule struct { mu sync.RWMutex - model providers.Model + model models.Model expireAt time.Time } -func NewSchedule(model providers.Model) *ModelSchedule { +func NewSchedule(model models.Model) *ModelSchedule { schedule := &ModelSchedule{ model: model, } @@ -67,7 +67,7 @@ type LeastLatencyRouting struct { schedules []*ModelSchedule } -func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []providers.Model) *LeastLatencyRouting { +func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []models.Model) *LeastLatencyRouting { schedules := make([]*ModelSchedule, 0, len(models)) for _, model := range models { @@ -95,7 +95,7 @@ func (r *LeastLatencyRouting) Iterator() LangModelIterator { // other model latencies that might have improved over time). // For that, we introduced expiration time after which the model receives a request // even if it was not the fastest to respond -func (r *LeastLatencyRouting) Next() (providers.Model, error) { //nolint:cyclop +func (r *LeastLatencyRouting) Next() (models.Model, error) { //nolint:cyclop coldSchedules := r.getColdModelSchedules() if len(coldSchedules) > 0 { diff --git a/pkg/routers/routing/least_latency_test.go b/pkg/routers/routing/least_latency_test.go index d65ee0d2..523b0790 100644 --- a/pkg/routers/routing/least_latency_test.go +++ b/pkg/routers/routing/least_latency_test.go @@ -5,9 +5,9 @@ import ( "testing" "time" - ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" - "github.com/EinStack/glide/pkg/providers" + ptesting "github.com/EinStack/glide/pkg/providers/testing" "github.com/stretchr/testify/require" ) @@ -33,13 +33,13 @@ func TestLeastLatencyRouting_Warmup(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(tc.models)) + modelPool := make([]models.Model, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) + modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) } - routing := NewLeastLatencyRouting(ptesting.ChatMockLatency, models) + routing := NewLeastLatencyRouting(ptesting.ChatMockLatency, modelPool) iterator := routing.Iterator() // loop three times over the whole pool to check if we return back to the begging of the list @@ -144,13 +144,13 @@ func TestLeastLatencyRouting_NoHealthyModels(t *testing.T) { for name, latencies := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(latencies)) + modelPool := make([]models.Model, 0, len(latencies)) for idx, latency := range latencies { - models = append(models, ptesting.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) + modelPool = append(modelPool, ptesting.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) } - routing := NewLeastLatencyRouting(models.ChatLatency, models) + routing := NewLeastLatencyRouting(models.ChatLatency, modelPool) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/routing/priority.go b/pkg/routers/routing/priority.go index f895458c..04d4d94e 100644 --- a/pkg/routers/routing/priority.go +++ b/pkg/routers/routing/priority.go @@ -3,7 +3,7 @@ package routing import ( "sync/atomic" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/models" ) const ( @@ -15,10 +15,10 @@ const ( // Priority of models are defined as position of the model on the list // (e.g. the first model definition has the highest priority, then the second model definition and so on) type PriorityRouting struct { - models []providers.Model + models []models.Model } -func NewPriority(models []providers.Model) *PriorityRouting { +func NewPriority(models []models.Model) *PriorityRouting { return &PriorityRouting{ models: models, } @@ -35,14 +35,14 @@ func (r *PriorityRouting) Iterator() LangModelIterator { type PriorityIterator struct { idx *atomic.Uint64 - models []providers.Model + models []models.Model } -func (r PriorityIterator) Next() (providers.Model, error) { - models := r.models +func (r PriorityIterator) Next() (models.Model, error) { + modelPool := r.models - for idx := int(r.idx.Load()); idx < len(models); idx = int(r.idx.Add(1)) { - model := models[idx] + for idx := int(r.idx.Load()); idx < len(modelPool); idx = int(r.idx.Add(1)) { + model := modelPool[idx] if !model.Healthy() { continue diff --git a/pkg/routers/routing/priority_test.go b/pkg/routers/routing/priority_test.go index cee98c60..eb090c76 100644 --- a/pkg/routers/routing/priority_test.go +++ b/pkg/routers/routing/priority_test.go @@ -3,9 +3,9 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" - "github.com/EinStack/glide/pkg/providers" + ptesting "github.com/EinStack/glide/pkg/providers/testing" "github.com/stretchr/testify/require" ) @@ -29,13 +29,13 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(tc.models)) + modelPool := make([]models.Model, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) } - routing := NewPriority(models) + routing := NewPriority(modelPool) iterator := routing.Iterator() // loop three times over the whole pool to check if we return back to the begging of the list @@ -49,13 +49,13 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) { } func TestPriorityRouting_NoHealthyModels(t *testing.T) { - models := []providers.Model{ + modelPool := []models.Model{ ptesting.NewLangModelMock("first", false, 0, 1), ptesting.NewLangModelMock("second", false, 0, 1), ptesting.NewLangModelMock("third", false, 0, 1), } - routing := NewPriority(models) + routing := NewPriority(modelPool) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/routing/round_robin.go b/pkg/routers/routing/round_robin.go index e5a0f927..abd2ff96 100644 --- a/pkg/routers/routing/round_robin.go +++ b/pkg/routers/routing/round_robin.go @@ -3,7 +3,7 @@ package routing import ( "sync/atomic" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/models" ) const ( @@ -13,10 +13,10 @@ const ( // RoundRobinRouting routes request to the next model in the list in cycle type RoundRobinRouting struct { idx atomic.Uint64 - models []providers.Model + models []models.Model } -func NewRoundRobinRouting(models []providers.Model) *RoundRobinRouting { +func NewRoundRobinRouting(models []models.Model) *RoundRobinRouting { return &RoundRobinRouting{ models: models, } @@ -26,7 +26,7 @@ func (r *RoundRobinRouting) Iterator() LangModelIterator { return r } -func (r *RoundRobinRouting) Next() (providers.Model, error) { +func (r *RoundRobinRouting) Next() (models.Model, error) { modelLen := len(r.models) // in order to avoid infinite loop in case of no healthy model is available, diff --git a/pkg/routers/routing/round_robin_test.go b/pkg/routers/routing/round_robin_test.go index fc34ec48..2a6e579b 100644 --- a/pkg/routers/routing/round_robin_test.go +++ b/pkg/routers/routing/round_robin_test.go @@ -3,9 +3,9 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" - "github.com/EinStack/glide/pkg/providers" + ptesting "github.com/EinStack/glide/pkg/providers/testing" "github.com/stretchr/testify/require" ) @@ -30,13 +30,13 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(tc.models)) + modelPool := make([]models.Model, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) } - routing := NewRoundRobinRouting(models) + routing := NewRoundRobinRouting(modelPool) iterator := routing.Iterator() for i := 0; i < 3; i++ { @@ -52,13 +52,13 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { } func TestRoundRobinRouting_NoHealthyModels(t *testing.T) { - models := []providers.Model{ + modelPool := []models.Model{ ptesting.NewLangModelMock("first", false, 0, 1), ptesting.NewLangModelMock("second", false, 0, 1), ptesting.NewLangModelMock("third", false, 0, 1), } - routing := NewRoundRobinRouting(models) + routing := NewRoundRobinRouting(modelPool) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/routing/strategies.go b/pkg/routers/routing/strategies.go index 56f03676..960702a4 100644 --- a/pkg/routers/routing/strategies.go +++ b/pkg/routers/routing/strategies.go @@ -3,7 +3,7 @@ package routing import ( "errors" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/models" ) var ErrNoHealthyModels = errors.New("no healthy models found") @@ -16,5 +16,5 @@ type LangModelRouting interface { } type LangModelIterator interface { - Next() (providers.Model, error) + Next() (models.Model, error) } diff --git a/pkg/routers/routing/weighted_round_robin.go b/pkg/routers/routing/weighted_round_robin.go index 2e028408..dfbee414 100644 --- a/pkg/routers/routing/weighted_round_robin.go +++ b/pkg/routers/routing/weighted_round_robin.go @@ -3,7 +3,7 @@ package routing import ( "sync" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/models" ) const ( @@ -11,7 +11,7 @@ const ( ) type Weighter struct { - model providers.Model + model models.Model currentWeight int } @@ -36,7 +36,7 @@ type WRoundRobinRouting struct { weights []*Weighter } -func NewWeightedRoundRobin(models []providers.Model) *WRoundRobinRouting { +func NewWeightedRoundRobin(models []models.Model) *WRoundRobinRouting { weights := make([]*Weighter, 0, len(models)) for _, model := range models { @@ -55,7 +55,7 @@ func (r *WRoundRobinRouting) Iterator() LangModelIterator { return r } -func (r *WRoundRobinRouting) Next() (providers.Model, error) { +func (r *WRoundRobinRouting) Next() (models.Model, error) { r.mu.Lock() defer r.mu.Unlock() diff --git a/pkg/routers/routing/weighted_round_robin_test.go b/pkg/routers/routing/weighted_round_robin_test.go index f4b59bb3..8e4a9ee2 100644 --- a/pkg/routers/routing/weighted_round_robin_test.go +++ b/pkg/routers/routing/weighted_round_robin_test.go @@ -3,9 +3,9 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" - "github.com/EinStack/glide/pkg/providers" + ptesting "github.com/EinStack/glide/pkg/providers/testing" "github.com/stretchr/testify/require" ) @@ -116,13 +116,13 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(tc.models)) + modelPool := make([]models.Model, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) + modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) } - routing := NewWeightedRoundRobin(models) + routing := NewWeightedRoundRobin(modelPool) iterator := routing.Iterator() actualDistribution := make(map[string]int, len(tc.models)) @@ -142,13 +142,13 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { } func TestWRoundRobinRouting_NoHealthyModels(t *testing.T) { - models := []providers.Model{ + modelPool := []models.Model{ ptesting.NewLangModelMock("first", false, 0, 1), ptesting.NewLangModelMock("second", false, 0, 2), ptesting.NewLangModelMock("third", false, 0, 3), } - routing := NewWeightedRoundRobin(models) + routing := NewWeightedRoundRobin(modelPool) iterator := routing.Iterator() _, err := iterator.Next() From 003691cd371cc3c2ae47152ade8f2b3ce2ebcb3e Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Wed, 31 Jul 2024 21:15:07 +0300 Subject: [PATCH 05/18] #67: Experimenting with dynamic lang provider config loading & validation --- pkg/models/config.go | 10 +- pkg/models/lang.go | 7 +- pkg/provider/config.go | 12 + pkg/{providers => provider}/provider.go | 15 +- pkg/providers/config.go | 256 ++++++++++++------ pkg/providers/config_test.go | 29 ++ pkg/providers/config_test.yaml | 9 + pkg/providers/openai/chat.go | 2 +- pkg/providers/openai/chat_stream.go | 2 +- pkg/providers/openai/client.go | 6 +- pkg/providers/openai/config.go | 10 + pkg/providers/openai/errors.go | 4 +- pkg/providers/openai/register.go | 9 + pkg/providers/registry.go | 42 +++ pkg/providers/testing/config.go | 29 ++ pkg/routers/lang/config.go | 4 +- pkg/routers/lang/config_test.go | 12 +- pkg/routers/lang/router_test.go | 2 +- pkg/routers/routing/least_latency_test.go | 3 +- pkg/routers/routing/priority_test.go | 3 +- pkg/routers/routing/round_robin_test.go | 3 +- .../routing/weighted_round_robin_test.go | 3 +- 22 files changed, 360 insertions(+), 112 deletions(-) create mode 100644 pkg/provider/config.go rename pkg/{providers => provider}/provider.go (65%) create mode 100644 pkg/providers/config_test.go create mode 100644 pkg/providers/config_test.yaml create mode 100644 pkg/providers/openai/register.go create mode 100644 pkg/providers/registry.go create mode 100644 pkg/providers/testing/config.go diff --git a/pkg/models/config.go b/pkg/models/config.go index 97c93443..5a28d50a 100644 --- a/pkg/models/config.go +++ b/pkg/models/config.go @@ -2,17 +2,15 @@ package models import ( "fmt" - - "github.com/EinStack/glide/pkg/providers" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/provider" "github.com/EinStack/glide/pkg/resiliency/health" "github.com/EinStack/glide/pkg/routers/latency" "github.com/EinStack/glide/pkg/telemetry" ) // Config defines an extra configuration for a model wrapper around a provider -type Config[P providers.ProviderFactory] struct { +type Config[P provider.ProviderConfig] struct { ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router) Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is the model enabled? ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"` @@ -23,7 +21,7 @@ type Config[P providers.ProviderFactory] struct { Provider P `yaml:"provider" json:"provider"` } -func NewConfig[P providers.ProviderFactory](ID string) *Config[P] { +func NewConfig[P provider.ProviderConfig](ID string) *Config[P] { config := DefaultConfig[P]() config.ID = ID @@ -31,7 +29,7 @@ func NewConfig[P providers.ProviderFactory](ID string) *Config[P] { return &config } -func DefaultConfig[P providers.ProviderFactory]() Config[P] { +func DefaultConfig[P provider.ProviderConfig]() Config[P] { return Config[P]{ Enabled: true, Client: clients.DefaultClientConfig(), diff --git a/pkg/models/lang.go b/pkg/models/lang.go index 299111b6..f16e6051 100644 --- a/pkg/models/lang.go +++ b/pkg/models/lang.go @@ -2,11 +2,10 @@ package models import ( "context" + "github.com/EinStack/glide/pkg/provider" "io" "time" - "github.com/EinStack/glide/pkg/providers" - "github.com/EinStack/glide/pkg/clients" health2 "github.com/EinStack/glide/pkg/resiliency/health" @@ -32,14 +31,14 @@ type LangModel interface { type LanguageModel struct { modelID string weight int - client providers.LangProvider + client provider.LangProvider healthTracker *health2.Tracker chatLatency *latency.MovingAverage chatStreamLatency *latency.MovingAverage latencyUpdateInterval *fields.Duration } -func NewLangModel(modelID string, client providers.LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { +func NewLangModel(modelID string, client provider.LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { return &LanguageModel{ modelID: modelID, client: client, diff --git a/pkg/provider/config.go b/pkg/provider/config.go new file mode 100644 index 00000000..0424e839 --- /dev/null +++ b/pkg/provider/config.go @@ -0,0 +1,12 @@ +package provider + +import ( + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" +) + +// TODO: ProviderConfig should be more generic, not tied to LangProviders +type ProviderConfig interface { + UnmarshalYAML(unmarshal func(interface{}) error) error + ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) +} diff --git a/pkg/providers/provider.go b/pkg/provider/provider.go similarity index 65% rename from pkg/providers/provider.go rename to pkg/provider/provider.go index 7fdefc25..d2b76419 100644 --- a/pkg/providers/provider.go +++ b/pkg/provider/provider.go @@ -1,4 +1,4 @@ -package providers +package provider import ( "context" @@ -10,9 +10,11 @@ import ( var ErrProviderNotFound = errors.New("provider not found") +type ProviderID = string + // ModelProvider exposes provider context type ModelProvider interface { - Provider() string + Provider() ProviderID ModelName() string } @@ -25,3 +27,12 @@ type LangProvider interface { Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) } + +// EmbeddingProvider defines an interface a provider should fulfill to be able to generate embeddings +type EmbeddingProvider interface { + ModelProvider + + SupportEmbedding() bool + + Embed(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) +} diff --git a/pkg/providers/config.go b/pkg/providers/config.go index d4145611..ca22cb7b 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -3,116 +3,220 @@ package providers import ( "errors" "fmt" + "github.com/EinStack/glide/pkg/provider" + "github.com/go-playground/validator/v10" + "strings" - "github.com/EinStack/glide/pkg/providers/ollama" + "gopkg.in/yaml.v3" "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers/anthropic" - "github.com/EinStack/glide/pkg/providers/azureopenai" - "github.com/EinStack/glide/pkg/providers/bedrock" - "github.com/EinStack/glide/pkg/providers/cohere" - "github.com/EinStack/glide/pkg/providers/octoml" - "github.com/EinStack/glide/pkg/providers/openai" "github.com/EinStack/glide/pkg/telemetry" ) -// TODO: ProviderFactory should be more generic, not tied to LangProviders - var ErrNoProviderConfigured = errors.New("exactly one provider must be configured, none is configured") -type ProviderFactory interface { - ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) -} +var validate *validator.Validate -// TODO: LangProviders should be decoupled and -// represented as a registry where providers can add their factories dynamically - -type LangProviders struct { - // Add other providers like - OpenAI *openai.Config `yaml:"openai,omitempty" json:"openai,omitempty"` - AzureOpenAI *azureopenai.Config `yaml:"azureopenai,omitempty" json:"azureopenai,omitempty"` - Cohere *cohere.Config `yaml:"cohere,omitempty" json:"cohere,omitempty"` - OctoML *octoml.Config `yaml:"octoml,omitempty" json:"octoml,omitempty"` - Anthropic *anthropic.Config `yaml:"anthropic,omitempty" json:"anthropic,omitempty"` - Bedrock *bedrock.Config `yaml:"bedrock,omitempty" json:"bedrock,omitempty"` - Ollama *ollama.Config `yaml:"ollama,omitempty" json:"ollama,omitempty"` +func init() { + validate = validator.New() } -var _ ProviderFactory = (*LangProviders)(nil) +// TODO: rename DynLangProvider to DynLangProviderConfig +type DynLangProvider map[provider.ProviderID]interface{} -// ToClient initializes the language model client based on the provided configuration. -// It takes a telemetry object as input and returns a LangModelProvider and an error. -func (c LangProviders) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) { - switch { - case c.OpenAI != nil: - return openai.NewClient(c.OpenAI, clientConfig, tel) - case c.AzureOpenAI != nil: - return azureopenai.NewClient(c.AzureOpenAI, clientConfig, tel) - case c.Cohere != nil: - return cohere.NewClient(c.Cohere, clientConfig, tel) - case c.OctoML != nil: - return octoml.NewClient(c.OctoML, clientConfig, tel) - case c.Anthropic != nil: - return anthropic.NewClient(c.Anthropic, clientConfig, tel) - case c.Bedrock != nil: - return bedrock.NewClient(c.Bedrock, clientConfig, tel) - default: - return nil, ErrProviderNotFound - } -} +var _ provider.ProviderConfig = (*DynLangProvider)(nil) -func (c *LangProviders) validateOneProvider() error { - providersConfigured := 0 +func (p DynLangProvider) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { + for providerID, configValue := range p { + if configValue == nil { + continue + } - if c.OpenAI != nil { - providersConfigured++ - } + providerConfig, found := LangRegistry.Get(providerID) - if c.AzureOpenAI != nil { - providersConfigured++ - } + if !found { + return nil, fmt.Errorf( + "provider %s is not supported (available providers: %v)", + providerID, + strings.Join(LangRegistry.Available(), ", "), + ) + } - if c.Cohere != nil { - providersConfigured++ - } + providerConfigUnmarshaller := func(providerConfig interface{}) error { + providerConfigBytes, err := yaml.Marshal(configValue) - if c.OctoML != nil { - providersConfigured++ - } + if err != nil { + return err + } - if c.Anthropic != nil { - providersConfigured++ - } + return yaml.Unmarshal(providerConfigBytes, providerConfig) + } + + err := providerConfig.UnmarshalYAML(providerConfigUnmarshaller) - if c.Bedrock != nil { - providersConfigured++ + if err != nil { + return nil, err + } + + return providerConfig.ToClient(tel, clientConfig) } - if c.Ollama != nil { - providersConfigured++ + return nil, provider.ErrProviderNotFound +} + +// validate ensure there is only one provider configured and it's supported by Glide +func (p DynLangProvider) validate() error { + configuredProviders := make([]provider.ProviderID, 0, len(p)) + + for providerID, config := range p { + if config != nil { + configuredProviders = append(configuredProviders, providerID) + } } - // check other providers here - if providersConfigured == 0 { + if len(configuredProviders) == 0 { return ErrNoProviderConfigured } - if providersConfigured > 1 { + if len(configuredProviders) > 1 { return fmt.Errorf( - "exactly one provider must be configured, but %v are configured", - providersConfigured, + "exactly one provider must be configured, but %v are configured (%v)", + len(configuredProviders), + strings.Join(configuredProviders, ", "), ) } - return nil + providerID := configuredProviders[0] + providerConfig, found := LangRegistry.Get(providerID) + + if !found { + return fmt.Errorf( + "provider %s is not supported (available providers: %v)", + providerID, + strings.Join(LangRegistry.Available(), ", "), + ) + } + + providerConfigUnmarshaller := func(providerConfig interface{}) error { + configValue := p[providerID] + providerConfigBytes, err := yaml.Marshal(configValue) + if err != nil { + return err + } + + err = yaml.Unmarshal(providerConfigBytes, providerConfig) + + if err != nil { + return err + } + + return validate.Struct(providerConfig) + } + + return providerConfig.UnmarshalYAML(providerConfigUnmarshaller) } -func (c *LangProviders) UnmarshalYAML(unmarshal func(interface{}) error) error { - type plain LangProviders // to avoid recursion +func (p *DynLangProvider) UnmarshalYAML(unmarshal func(interface{}) error) error { + type plain DynLangProvider // to avoid recursion + temp := plain{} - if err := unmarshal((*plain)(c)); err != nil { + if err := unmarshal(&temp); err != nil { return err } - return c.validateOneProvider() + *p = DynLangProvider(temp) + + return p.validate() } + +// TODO: Remove this old LangProviders struct + +//type LangProviders struct { +// // Add other providers like +// OpenAI *openai.Config `yaml:"openai,omitempty" json:"openai,omitempty"` +// AzureOpenAI *azureopenai.Config `yaml:"azureopenai,omitempty" json:"azureopenai,omitempty"` +// Cohere *cohere.Config `yaml:"cohere,omitempty" json:"cohere,omitempty"` +// OctoML *octoml.Config `yaml:"octoml,omitempty" json:"octoml,omitempty"` +// Anthropic *anthropic.Config `yaml:"anthropic,omitempty" json:"anthropic,omitempty"` +// Bedrock *bedrock.Config `yaml:"bedrock,omitempty" json:"bedrock,omitempty"` +// Ollama *ollama.Config `yaml:"ollama,omitempty" json:"ollama,omitempty"` +//} +// +//var _ ProviderConfig = (*LangProviders)(nil) + +// ToClient initializes the language model client based on the provided configuration. +// It takes a telemetry object as input and returns a LangModelProvider and an error. +//func (c LangProviders) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) { +// switch { +// case c.OpenAI != nil: +// return openai.NewClient(c.OpenAI, clientConfig, tel) +// case c.AzureOpenAI != nil: +// return azureopenai.NewClient(c.AzureOpenAI, clientConfig, tel) +// case c.Cohere != nil: +// return cohere.NewClient(c.Cohere, clientConfig, tel) +// case c.OctoML != nil: +// return octoml.NewClient(c.OctoML, clientConfig, tel) +// case c.Anthropic != nil: +// return anthropic.NewClient(c.Anthropic, clientConfig, tel) +// case c.Bedrock != nil: +// return bedrock.NewClient(c.Bedrock, clientConfig, tel) +// default: +// return nil, ErrProviderNotFound +// } +//} + +//func (c *LangProviders) validateOneProvider() error { +// providersConfigured := 0 +// +// if c.OpenAI != nil { +// providersConfigured++ +// } +// +// if c.AzureOpenAI != nil { +// providersConfigured++ +// } +// +// if c.Cohere != nil { +// providersConfigured++ +// } +// +// if c.OctoML != nil { +// providersConfigured++ +// } +// +// if c.Anthropic != nil { +// providersConfigured++ +// } +// +// if c.Bedrock != nil { +// providersConfigured++ +// } +// +// if c.Ollama != nil { +// providersConfigured++ +// } +// +// // check other providers here +// if providersConfigured == 0 { +// return ErrNoProviderConfigured +// } +// +// if providersConfigured > 1 { +// return fmt.Errorf( +// "exactly one provider must be configured, but %v are configured", +// providersConfigured, +// ) +// } +// +// return nil +//} + +//func (c *LangProviders) UnmarshalYAML(unmarshal func(interface{}) error) error { +// type plain LangProviders // to avoid recursion +// +// if err := unmarshal((*plain)(c)); err != nil { +// return err +// } +// +// return c.validateOneProvider() +//} diff --git a/pkg/providers/config_test.go b/pkg/providers/config_test.go new file mode 100644 index 00000000..16a7d00d --- /dev/null +++ b/pkg/providers/config_test.go @@ -0,0 +1,29 @@ +package providers + +import ( + testprovider "github.com/EinStack/glide/pkg/providers/testing" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" + "os" + "path/filepath" + "testing" +) + +func TestDynLangProvider(t *testing.T) { + LangRegistry.Register(testprovider.ProviderTest, &testprovider.Config{}) + + type ProviderConfig struct { + Provider *DynLangProvider `yaml:"provider"` + } + + prConfig := make(DynLangProvider) + providerConfig := ProviderConfig{ + Provider: &prConfig, + } + + config, err := os.ReadFile(filepath.Clean("./config_test.yaml")) + require.NoError(t, err) + + err = yaml.Unmarshal(config, &providerConfig) + require.NoError(t, err) +} diff --git a/pkg/providers/config_test.yaml b/pkg/providers/config_test.yaml new file mode 100644 index 00000000..5bc1fb1c --- /dev/null +++ b/pkg/providers/config_test.yaml @@ -0,0 +1,9 @@ +provider: + testprovider: + base_url: "https://api.example.com" + chat_endpoint: "/chat/completions" + model: "example-model" + api_key: "example_api_key" + default_params: + param1: "value1" + param2: "value2" diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 4c4b21c8..06698295 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -126,7 +126,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche response := schemas.ChatResponse{ ID: chatCompletion.ID, Created: chatCompletion.Created, - Provider: providerName, + Provider: ProviderOpenAI, ModelName: chatCompletion.ModelName, Cached: false, ModelResponse: schemas.ModelResponse{ diff --git a/pkg/providers/openai/chat_stream.go b/pkg/providers/openai/chat_stream.go index d362cb67..8fd0a617 100644 --- a/pkg/providers/openai/chat_stream.go +++ b/pkg/providers/openai/chat_stream.go @@ -112,7 +112,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // TODO: use objectpool here return &schemas.ChatStreamChunk{ Cached: false, - Provider: providerName, + Provider: ProviderOpenAI, ModelName: completionChunk.ModelName, ModelResponse: schemas.ModelChunkResponse{ Metadata: &schemas.Metadata{ diff --git a/pkg/providers/openai/client.go b/pkg/providers/openai/client.go index bb49dab3..795d94f6 100644 --- a/pkg/providers/openai/client.go +++ b/pkg/providers/openai/client.go @@ -13,7 +13,7 @@ import ( ) const ( - providerName = "openai" + ProviderOpenAI = "openai" ) // Client is a client for accessing OpenAI API @@ -37,7 +37,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } logger := tel.L().With( - zap.String("provider", providerName), + zap.String("provider", ProviderOpenAI), ) c := &Client{ @@ -62,7 +62,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } func (c *Client) Provider() string { - return providerName + return ProviderOpenAI } func (c *Client) ModelName() string { diff --git a/pkg/providers/openai/config.go b/pkg/providers/openai/config.go index 8342db41..02af8ce4 100644 --- a/pkg/providers/openai/config.go +++ b/pkg/providers/openai/config.go @@ -1,7 +1,11 @@ package openai import ( + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/config/fields" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/telemetry" ) // Params defines OpenAI-specific model params with the specific validation of values @@ -49,6 +53,8 @@ type Config struct { DefaultParams *Params `yaml:"default_params,omitempty" json:"default_params"` } +var _ providers.ProviderConfig = (*Config)(nil) + // DefaultConfig for OpenAI models func DefaultConfig() *Config { defaultParams := DefaultParams() @@ -61,6 +67,10 @@ func DefaultConfig() *Config { } } +func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { + return NewClient(c, clientConfig, tel) +} + func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { *c = *DefaultConfig() diff --git a/pkg/providers/openai/errors.go b/pkg/providers/openai/errors.go index 58e37292..640962c4 100644 --- a/pkg/providers/openai/errors.go +++ b/pkg/providers/openai/errors.go @@ -28,7 +28,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { if err != nil { m.tel.Logger.Error( "Failed to unmarshal chat response error", - zap.String("provider", providerName), + zap.String("provider", ProviderOpenAI), zap.Error(err), zap.ByteString("rawResponse", bodyBytes), ) @@ -38,7 +38,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { m.tel.Logger.Error( "Chat request failed", - zap.String("provider", providerName), + zap.String("provider", ProviderOpenAI), zap.Int("statusCode", resp.StatusCode), zap.String("response", string(bodyBytes)), zap.Any("headers", resp.Header), diff --git a/pkg/providers/openai/register.go b/pkg/providers/openai/register.go new file mode 100644 index 00000000..baf37ac1 --- /dev/null +++ b/pkg/providers/openai/register.go @@ -0,0 +1,9 @@ +package openai + +import ( + "github.com/EinStack/glide/pkg/providers" +) + +func init() { + providers.LangRegistry.Register(ProviderOpenAI, &Config{}) +} diff --git a/pkg/providers/registry.go b/pkg/providers/registry.go new file mode 100644 index 00000000..b98f8030 --- /dev/null +++ b/pkg/providers/registry.go @@ -0,0 +1,42 @@ +package providers + +import ( + "fmt" + "github.com/EinStack/glide/pkg/provider" +) + +var LangRegistry = NewProviderRegistry() + +type ProviderRegistry struct { + providers map[provider.ProviderID]provider.ProviderConfig +} + +func NewProviderRegistry() *ProviderRegistry { + return &ProviderRegistry{ + providers: make(map[provider.ProviderID]provider.ProviderConfig), + } +} + +func (r *ProviderRegistry) Register(name provider.ProviderID, config provider.ProviderConfig) { + if _, ok := r.Get(name); ok { + panic(fmt.Sprintf("provider %s is already registered", name)) + } + + r.providers[name] = config +} + +func (r *ProviderRegistry) Get(name provider.ProviderID) (provider.ProviderConfig, bool) { + config, ok := r.providers[name] + + return config, ok +} + +func (r *ProviderRegistry) Available() []provider.ProviderID { + available := make([]provider.ProviderID, 0, len(r.providers)) + + for providerID, _ := range r.providers { + available = append(available, providerID) + } + + return available +} diff --git a/pkg/providers/testing/config.go b/pkg/providers/testing/config.go new file mode 100644 index 00000000..11237fc2 --- /dev/null +++ b/pkg/providers/testing/config.go @@ -0,0 +1,29 @@ +package testing + +import ( + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/config/fields" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/telemetry" +) + +const ( + ProviderTest = "testprovider" +) + +type Config struct { + BaseURL string `yaml:"base_url" json:"base_url" validate:"required"` + ChatEndpoint string `yaml:"chat_endpoint" json:"chat_endpoint" validate:"required"` + ModelName string `yaml:"model" json:"model" validate:"required"` + APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` +} + +func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { + return NewProviderMock(nil, []RespMock{}), nil +} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + type plain Config // to avoid recursion + + return unmarshal((*plain)(c)) +} diff --git a/pkg/routers/lang/config.go b/pkg/routers/lang/config.go index 0be9ef06..14b6ab14 100644 --- a/pkg/routers/lang/config.go +++ b/pkg/routers/lang/config.go @@ -2,12 +2,12 @@ package lang import ( "fmt" + "github.com/EinStack/glide/pkg/providers" "time" "github.com/EinStack/glide/pkg/routers" "github.com/EinStack/glide/pkg/models" - "github.com/EinStack/glide/pkg/providers" "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/routers/routing" "github.com/EinStack/glide/pkg/telemetry" @@ -16,7 +16,7 @@ import ( ) type ( - ModelConfig = models.Config[providers.LangProviders] + ModelConfig = models.Config[providers.DynLangProvider] ModelPoolConfig = []ModelConfig ) diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang/config_test.go index 975cdcb6..141c329e 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang/config_test.go @@ -27,8 +27,8 @@ func TestRouterConfig_BuildModels(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - OpenAI: &openai.Config{ + Provider: providers.DynLangProvider{ + openai.ProviderOpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, @@ -45,8 +45,8 @@ func TestRouterConfig_BuildModels(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - OpenAI: &openai.Config{ + Provider: providers.DynLangProvider{ + openai.ProviderOpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, @@ -80,8 +80,8 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - OpenAI: &openai.Config{ + Provider: providers.DynLangProvider{ + openai.ProviderOpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &openAIParams, }, diff --git a/pkg/routers/lang/router_test.go b/pkg/routers/lang/router_test.go index 087e2a71..ab6ab24a 100644 --- a/pkg/routers/lang/router_test.go +++ b/pkg/routers/lang/router_test.go @@ -2,6 +2,7 @@ package lang import ( "context" + ptesting "github.com/EinStack/glide/pkg/providers/testing" "testing" "time" @@ -12,7 +13,6 @@ import ( "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/api/schemas" - ptesting "github.com/EinStack/glide/pkg/providers/testing" "github.com/EinStack/glide/pkg/routers/latency" "github.com/EinStack/glide/pkg/routers/routing" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/routers/routing/least_latency_test.go b/pkg/routers/routing/least_latency_test.go index 523b0790..e9b73565 100644 --- a/pkg/routers/routing/least_latency_test.go +++ b/pkg/routers/routing/least_latency_test.go @@ -1,14 +1,13 @@ package routing import ( + ptesting "github.com/EinStack/glide/pkg/providers/testing" "strconv" "testing" "time" "github.com/EinStack/glide/pkg/models" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - "github.com/stretchr/testify/require" ) diff --git a/pkg/routers/routing/priority_test.go b/pkg/routers/routing/priority_test.go index eb090c76..9176ee52 100644 --- a/pkg/routers/routing/priority_test.go +++ b/pkg/routers/routing/priority_test.go @@ -1,12 +1,11 @@ package routing import ( + ptesting "github.com/EinStack/glide/pkg/providers/testing" "testing" "github.com/EinStack/glide/pkg/models" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - "github.com/stretchr/testify/require" ) diff --git a/pkg/routers/routing/round_robin_test.go b/pkg/routers/routing/round_robin_test.go index 2a6e579b..dbba794e 100644 --- a/pkg/routers/routing/round_robin_test.go +++ b/pkg/routers/routing/round_robin_test.go @@ -1,12 +1,11 @@ package routing import ( + ptesting "github.com/EinStack/glide/pkg/providers/testing" "testing" "github.com/EinStack/glide/pkg/models" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - "github.com/stretchr/testify/require" ) diff --git a/pkg/routers/routing/weighted_round_robin_test.go b/pkg/routers/routing/weighted_round_robin_test.go index 8e4a9ee2..e8a17491 100644 --- a/pkg/routers/routing/weighted_round_robin_test.go +++ b/pkg/routers/routing/weighted_round_robin_test.go @@ -1,12 +1,11 @@ package routing import ( + ptesting "github.com/EinStack/glide/pkg/providers/testing" "testing" "github.com/EinStack/glide/pkg/models" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - "github.com/stretchr/testify/require" ) From e043e4dd9ffe76877b9c0b722ed012f54f03fb30 Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Thu, 8 Aug 2024 16:30:54 +0300 Subject: [PATCH 06/18] #67: Fixed a part of linting errors --- pkg/models/config.go | 1 + pkg/models/lang.go | 3 ++- pkg/providers/config.go | 6 ++---- pkg/providers/config_test.go | 7 ++++--- pkg/providers/openai/config.go | 3 +-- pkg/providers/registry.go | 3 ++- pkg/routers/lang/config.go | 5 +++-- pkg/routers/lang/config_test.go | 6 +++--- pkg/routers/lang/router_test.go | 3 ++- pkg/routers/routing/least_latency_test.go | 3 ++- pkg/routers/routing/priority_test.go | 3 ++- pkg/routers/routing/round_robin_test.go | 3 ++- pkg/routers/routing/weighted_round_robin_test.go | 3 ++- 13 files changed, 28 insertions(+), 21 deletions(-) diff --git a/pkg/models/config.go b/pkg/models/config.go index 5a28d50a..edc67cae 100644 --- a/pkg/models/config.go +++ b/pkg/models/config.go @@ -2,6 +2,7 @@ package models import ( "fmt" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/provider" "github.com/EinStack/glide/pkg/resiliency/health" diff --git a/pkg/models/lang.go b/pkg/models/lang.go index f16e6051..2d50bd35 100644 --- a/pkg/models/lang.go +++ b/pkg/models/lang.go @@ -2,10 +2,11 @@ package models import ( "context" - "github.com/EinStack/glide/pkg/provider" "io" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" health2 "github.com/EinStack/glide/pkg/resiliency/health" diff --git a/pkg/providers/config.go b/pkg/providers/config.go index ca22cb7b..f58ec046 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -3,9 +3,10 @@ package providers import ( "errors" "fmt" + "strings" + "github.com/EinStack/glide/pkg/provider" "github.com/go-playground/validator/v10" - "strings" "gopkg.in/yaml.v3" @@ -44,7 +45,6 @@ func (p DynLangProvider) ToClient(tel *telemetry.Telemetry, clientConfig *client providerConfigUnmarshaller := func(providerConfig interface{}) error { providerConfigBytes, err := yaml.Marshal(configValue) - if err != nil { return err } @@ -53,7 +53,6 @@ func (p DynLangProvider) ToClient(tel *telemetry.Telemetry, clientConfig *client } err := providerConfig.UnmarshalYAML(providerConfigUnmarshaller) - if err != nil { return nil, err } @@ -105,7 +104,6 @@ func (p DynLangProvider) validate() error { } err = yaml.Unmarshal(providerConfigBytes, providerConfig) - if err != nil { return err } diff --git a/pkg/providers/config_test.go b/pkg/providers/config_test.go index 16a7d00d..7bb7c40f 100644 --- a/pkg/providers/config_test.go +++ b/pkg/providers/config_test.go @@ -1,12 +1,13 @@ package providers import ( - testprovider "github.com/EinStack/glide/pkg/providers/testing" - "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" "os" "path/filepath" "testing" + + testprovider "github.com/EinStack/glide/pkg/providers/testing" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) func TestDynLangProvider(t *testing.T) { diff --git a/pkg/providers/openai/config.go b/pkg/providers/openai/config.go index 02af8ce4..4dcf9ff4 100644 --- a/pkg/providers/openai/config.go +++ b/pkg/providers/openai/config.go @@ -4,7 +4,6 @@ import ( "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/config/fields" "github.com/EinStack/glide/pkg/provider" - "github.com/EinStack/glide/pkg/providers" "github.com/EinStack/glide/pkg/telemetry" ) @@ -53,7 +52,7 @@ type Config struct { DefaultParams *Params `yaml:"default_params,omitempty" json:"default_params"` } -var _ providers.ProviderConfig = (*Config)(nil) +var _ provider.ProviderConfig = (*Config)(nil) // DefaultConfig for OpenAI models func DefaultConfig() *Config { diff --git a/pkg/providers/registry.go b/pkg/providers/registry.go index b98f8030..3f626f6e 100644 --- a/pkg/providers/registry.go +++ b/pkg/providers/registry.go @@ -2,6 +2,7 @@ package providers import ( "fmt" + "github.com/EinStack/glide/pkg/provider" ) @@ -34,7 +35,7 @@ func (r *ProviderRegistry) Get(name provider.ProviderID) (provider.ProviderConfi func (r *ProviderRegistry) Available() []provider.ProviderID { available := make([]provider.ProviderID, 0, len(r.providers)) - for providerID, _ := range r.providers { + for providerID := range r.providers { available = append(available, providerID) } diff --git a/pkg/routers/lang/config.go b/pkg/routers/lang/config.go index 14b6ab14..a8f3f6bb 100644 --- a/pkg/routers/lang/config.go +++ b/pkg/routers/lang/config.go @@ -2,9 +2,10 @@ package lang import ( "fmt" - "github.com/EinStack/glide/pkg/providers" "time" + "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/routers" "github.com/EinStack/glide/pkg/models" @@ -16,7 +17,7 @@ import ( ) type ( - ModelConfig = models.Config[providers.DynLangProvider] + ModelConfig = models.Config[*providers.DynLangProvider] ModelPoolConfig = []ModelConfig ) diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang/config_test.go index 141c329e..aa782f96 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang/config_test.go @@ -27,7 +27,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.DynLangProvider{ + Provider: &providers.DynLangProvider{ openai.ProviderOpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -45,7 +45,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.DynLangProvider{ + Provider: &providers.DynLangProvider{ openai.ProviderOpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -80,7 +80,7 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.DynLangProvider{ + Provider: &providers.DynLangProvider{ openai.ProviderOpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &openAIParams, diff --git a/pkg/routers/lang/router_test.go b/pkg/routers/lang/router_test.go index ab6ab24a..65c9b2fe 100644 --- a/pkg/routers/lang/router_test.go +++ b/pkg/routers/lang/router_test.go @@ -2,10 +2,11 @@ package lang import ( "context" - ptesting "github.com/EinStack/glide/pkg/providers/testing" "testing" "time" + ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" "github.com/EinStack/glide/pkg/clients" diff --git a/pkg/routers/routing/least_latency_test.go b/pkg/routers/routing/least_latency_test.go index e9b73565..2f18b322 100644 --- a/pkg/routers/routing/least_latency_test.go +++ b/pkg/routers/routing/least_latency_test.go @@ -1,11 +1,12 @@ package routing import ( - ptesting "github.com/EinStack/glide/pkg/providers/testing" "strconv" "testing" "time" + ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" "github.com/stretchr/testify/require" diff --git a/pkg/routers/routing/priority_test.go b/pkg/routers/routing/priority_test.go index 9176ee52..c0713f35 100644 --- a/pkg/routers/routing/priority_test.go +++ b/pkg/routers/routing/priority_test.go @@ -1,9 +1,10 @@ package routing import ( - ptesting "github.com/EinStack/glide/pkg/providers/testing" "testing" + ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" "github.com/stretchr/testify/require" diff --git a/pkg/routers/routing/round_robin_test.go b/pkg/routers/routing/round_robin_test.go index dbba794e..c2c6d307 100644 --- a/pkg/routers/routing/round_robin_test.go +++ b/pkg/routers/routing/round_robin_test.go @@ -1,9 +1,10 @@ package routing import ( - ptesting "github.com/EinStack/glide/pkg/providers/testing" "testing" + ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" "github.com/stretchr/testify/require" diff --git a/pkg/routers/routing/weighted_round_robin_test.go b/pkg/routers/routing/weighted_round_robin_test.go index e8a17491..e24d4e81 100644 --- a/pkg/routers/routing/weighted_round_robin_test.go +++ b/pkg/routers/routing/weighted_round_robin_test.go @@ -1,9 +1,10 @@ package routing import ( - ptesting "github.com/EinStack/glide/pkg/providers/testing" "testing" + ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" "github.com/stretchr/testify/require" From 339e3f9f2f68791eccf05babc570f0a88a55e2ab Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Thu, 8 Aug 2024 17:13:03 +0300 Subject: [PATCH 07/18] #67: Picked more idiomatic naming for the ProviderID --- pkg/providers/cohere/chat.go | 2 +- pkg/providers/cohere/chat_stream.go | 10 +-- pkg/providers/cohere/client.go | 4 +- pkg/providers/cohere/errors.go | 4 +- pkg/providers/config.go | 94 +---------------------------- pkg/providers/openai/chat.go | 2 +- pkg/providers/openai/chat_stream.go | 2 +- pkg/providers/openai/client.go | 6 +- pkg/providers/openai/errors.go | 4 +- pkg/providers/openai/register.go | 2 +- pkg/providers/testing/config.go | 2 +- pkg/routers/lang/config_test.go | 26 ++++---- 12 files changed, 34 insertions(+), 124 deletions(-) diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index 4729d55f..754d8537 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -118,7 +118,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche response := schemas.ChatResponse{ ID: cohereCompletion.ResponseID, Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this - Provider: providerName, + Provider: ProviderID, ModelName: c.config.ModelName, Cached: false, ModelResponse: schemas.ModelResponse{ diff --git a/pkg/providers/cohere/chat_stream.go b/pkg/providers/cohere/chat_stream.go index 6f194945..392f0a27 100644 --- a/pkg/providers/cohere/chat_stream.go +++ b/pkg/providers/cohere/chat_stream.go @@ -90,7 +90,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { if err != nil { s.tel.L().Warn( "Chat stream is unexpectedly disconnected", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.Error(err), ) @@ -101,7 +101,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { s.tel.L().Debug( "Raw chat stream chunk", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.ByteString("rawChunk", rawChunk), ) @@ -119,7 +119,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { if responseChunk.EventType != TextGenEvent && responseChunk.EventType != StreamEndEvent { s.tel.L().Debug( "Unsupported stream chunk type, skipping it", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.ByteString("chunk", rawChunk), ) @@ -132,7 +132,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // TODO: use objectpool here return &schemas.ChatStreamChunk{ Cached: false, - Provider: providerName, + Provider: ProviderID, ModelName: s.modelName, ModelResponse: schemas.ModelChunkResponse{ Metadata: &schemas.Metadata{ @@ -151,7 +151,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // TODO: use objectpool here return &schemas.ChatStreamChunk{ Cached: false, - Provider: providerName, + Provider: ProviderID, ModelName: s.modelName, ModelResponse: schemas.ModelChunkResponse{ Metadata: &schemas.Metadata{ diff --git a/pkg/providers/cohere/client.go b/pkg/providers/cohere/client.go index a8426598..3393e010 100644 --- a/pkg/providers/cohere/client.go +++ b/pkg/providers/cohere/client.go @@ -11,7 +11,7 @@ import ( ) const ( - providerName = "cohere" + ProviderID = "cohere" ) // Client is a client for accessing Cohere API @@ -54,7 +54,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } func (c *Client) Provider() string { - return providerName + return ProviderID } func (c *Client) ModelName() string { diff --git a/pkg/providers/cohere/errors.go b/pkg/providers/cohere/errors.go index 5b5548c1..5f8ea045 100644 --- a/pkg/providers/cohere/errors.go +++ b/pkg/providers/cohere/errors.go @@ -28,7 +28,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { if err != nil { m.tel.Logger.Error( "Failed to unmarshal chat response error", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.Error(err), zap.ByteString("rawResponse", bodyBytes), ) @@ -38,7 +38,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { m.tel.Logger.Error( "Chat request failed", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.Int("statusCode", resp.StatusCode), zap.String("response", string(bodyBytes)), zap.Any("headers", resp.Header), diff --git a/pkg/providers/config.go b/pkg/providers/config.go index f58ec046..466f07e3 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -98,6 +98,7 @@ func (p DynLangProvider) validate() error { providerConfigUnmarshaller := func(providerConfig interface{}) error { configValue := p[providerID] + providerConfigBytes, err := yaml.Marshal(configValue) if err != nil { return err @@ -116,6 +117,7 @@ func (p DynLangProvider) validate() error { func (p *DynLangProvider) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain DynLangProvider // to avoid recursion + temp := plain{} if err := unmarshal(&temp); err != nil { @@ -126,95 +128,3 @@ func (p *DynLangProvider) UnmarshalYAML(unmarshal func(interface{}) error) error return p.validate() } - -// TODO: Remove this old LangProviders struct - -//type LangProviders struct { -// // Add other providers like -// OpenAI *openai.Config `yaml:"openai,omitempty" json:"openai,omitempty"` -// AzureOpenAI *azureopenai.Config `yaml:"azureopenai,omitempty" json:"azureopenai,omitempty"` -// Cohere *cohere.Config `yaml:"cohere,omitempty" json:"cohere,omitempty"` -// OctoML *octoml.Config `yaml:"octoml,omitempty" json:"octoml,omitempty"` -// Anthropic *anthropic.Config `yaml:"anthropic,omitempty" json:"anthropic,omitempty"` -// Bedrock *bedrock.Config `yaml:"bedrock,omitempty" json:"bedrock,omitempty"` -// Ollama *ollama.Config `yaml:"ollama,omitempty" json:"ollama,omitempty"` -//} -// -//var _ ProviderConfig = (*LangProviders)(nil) - -// ToClient initializes the language model client based on the provided configuration. -// It takes a telemetry object as input and returns a LangModelProvider and an error. -//func (c LangProviders) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) { -// switch { -// case c.OpenAI != nil: -// return openai.NewClient(c.OpenAI, clientConfig, tel) -// case c.AzureOpenAI != nil: -// return azureopenai.NewClient(c.AzureOpenAI, clientConfig, tel) -// case c.Cohere != nil: -// return cohere.NewClient(c.Cohere, clientConfig, tel) -// case c.OctoML != nil: -// return octoml.NewClient(c.OctoML, clientConfig, tel) -// case c.Anthropic != nil: -// return anthropic.NewClient(c.Anthropic, clientConfig, tel) -// case c.Bedrock != nil: -// return bedrock.NewClient(c.Bedrock, clientConfig, tel) -// default: -// return nil, ErrProviderNotFound -// } -//} - -//func (c *LangProviders) validateOneProvider() error { -// providersConfigured := 0 -// -// if c.OpenAI != nil { -// providersConfigured++ -// } -// -// if c.AzureOpenAI != nil { -// providersConfigured++ -// } -// -// if c.Cohere != nil { -// providersConfigured++ -// } -// -// if c.OctoML != nil { -// providersConfigured++ -// } -// -// if c.Anthropic != nil { -// providersConfigured++ -// } -// -// if c.Bedrock != nil { -// providersConfigured++ -// } -// -// if c.Ollama != nil { -// providersConfigured++ -// } -// -// // check other providers here -// if providersConfigured == 0 { -// return ErrNoProviderConfigured -// } -// -// if providersConfigured > 1 { -// return fmt.Errorf( -// "exactly one provider must be configured, but %v are configured", -// providersConfigured, -// ) -// } -// -// return nil -//} - -//func (c *LangProviders) UnmarshalYAML(unmarshal func(interface{}) error) error { -// type plain LangProviders // to avoid recursion -// -// if err := unmarshal((*plain)(c)); err != nil { -// return err -// } -// -// return c.validateOneProvider() -//} diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 06698295..86bce6f1 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -126,7 +126,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche response := schemas.ChatResponse{ ID: chatCompletion.ID, Created: chatCompletion.Created, - Provider: ProviderOpenAI, + Provider: ProviderID, ModelName: chatCompletion.ModelName, Cached: false, ModelResponse: schemas.ModelResponse{ diff --git a/pkg/providers/openai/chat_stream.go b/pkg/providers/openai/chat_stream.go index 8fd0a617..ba219e30 100644 --- a/pkg/providers/openai/chat_stream.go +++ b/pkg/providers/openai/chat_stream.go @@ -112,7 +112,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // TODO: use objectpool here return &schemas.ChatStreamChunk{ Cached: false, - Provider: ProviderOpenAI, + Provider: ProviderID, ModelName: completionChunk.ModelName, ModelResponse: schemas.ModelChunkResponse{ Metadata: &schemas.Metadata{ diff --git a/pkg/providers/openai/client.go b/pkg/providers/openai/client.go index 795d94f6..30a04385 100644 --- a/pkg/providers/openai/client.go +++ b/pkg/providers/openai/client.go @@ -13,7 +13,7 @@ import ( ) const ( - ProviderOpenAI = "openai" + ProviderID = "openai" ) // Client is a client for accessing OpenAI API @@ -37,7 +37,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } logger := tel.L().With( - zap.String("provider", ProviderOpenAI), + zap.String("provider", ProviderID), ) c := &Client{ @@ -62,7 +62,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } func (c *Client) Provider() string { - return ProviderOpenAI + return ProviderID } func (c *Client) ModelName() string { diff --git a/pkg/providers/openai/errors.go b/pkg/providers/openai/errors.go index 640962c4..d0389cbe 100644 --- a/pkg/providers/openai/errors.go +++ b/pkg/providers/openai/errors.go @@ -28,7 +28,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { if err != nil { m.tel.Logger.Error( "Failed to unmarshal chat response error", - zap.String("provider", ProviderOpenAI), + zap.String("provider", ProviderID), zap.Error(err), zap.ByteString("rawResponse", bodyBytes), ) @@ -38,7 +38,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { m.tel.Logger.Error( "Chat request failed", - zap.String("provider", ProviderOpenAI), + zap.String("provider", ProviderID), zap.Int("statusCode", resp.StatusCode), zap.String("response", string(bodyBytes)), zap.Any("headers", resp.Header), diff --git a/pkg/providers/openai/register.go b/pkg/providers/openai/register.go index baf37ac1..4435ac8d 100644 --- a/pkg/providers/openai/register.go +++ b/pkg/providers/openai/register.go @@ -5,5 +5,5 @@ import ( ) func init() { - providers.LangRegistry.Register(ProviderOpenAI, &Config{}) + providers.LangRegistry.Register(ProviderID, &Config{}) } diff --git a/pkg/providers/testing/config.go b/pkg/providers/testing/config.go index 11237fc2..dd7d0853 100644 --- a/pkg/providers/testing/config.go +++ b/pkg/providers/testing/config.go @@ -18,7 +18,7 @@ type Config struct { APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` } -func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { +func (c *Config) ToClient(_ *telemetry.Telemetry, _ *clients.ClientConfig) (provider.LangProvider, error) { return NewProviderMock(nil, []RespMock{}), nil } diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang/config_test.go index aa782f96..1ed43363 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang/config_test.go @@ -28,7 +28,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), Provider: &providers.DynLangProvider{ - openai.ProviderOpenAI: &openai.Config{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, @@ -46,7 +46,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), Provider: &providers.DynLangProvider{ - openai.ProviderOpenAI: &openai.Config{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, @@ -81,7 +81,7 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), Provider: &providers.DynLangProvider{ - openai.ProviderOpenAI: &openai.Config{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &openAIParams, }, @@ -93,8 +93,8 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - Cohere: &cohere.Config{ + Provider: &providers.DynLangProvider{ + cohere.ProviderID: &cohere.Config{ APIKey: "ABC", DefaultParams: &cohereParams, }, @@ -129,8 +129,8 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - OpenAI: &openai.Config{ + Provider: &providers.DynLangProvider{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, @@ -147,8 +147,8 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - OpenAI: &openai.Config{ + Provider: &providers.DynLangProvider{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, @@ -170,8 +170,8 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - OpenAI: &openai.Config{ + Provider: &providers.DynLangProvider{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, @@ -183,8 +183,8 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - OpenAI: &openai.Config{ + Provider: &providers.DynLangProvider{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, From b662b1e0b3832f531f2e31d6f71e73b100711d06 Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Thu, 8 Aug 2024 18:30:20 +0300 Subject: [PATCH 08/18] #67: Got rid of provider package & did other layout restructuring to fix circular dependency issues --- pkg/{models => extmodel}/config.go | 11 +- pkg/{models => extmodel}/lang.go | 14 +-- pkg/extmodel/model.go | 11 ++ .../testing/models.go => extmodel/testing.go} | 5 +- pkg/models/model.go | 11 -- pkg/provider/config.go | 12 -- pkg/providers/config.go | 26 ++-- pkg/providers/config_test.go | 7 +- .../provider.go => providers/interface.go} | 6 +- pkg/providers/openai/config.go | 6 +- pkg/providers/registry.go | 14 +-- pkg/providers/{testing/lang.go => testing.go} | 34 ++++- pkg/providers/testing/config.go | 29 ----- pkg/routers/lang/config.go | 23 ++-- pkg/routers/lang/config_test.go | 16 +-- pkg/routers/lang/router.go | 11 +- pkg/routers/lang/router_test.go | 116 +++++++++--------- pkg/routers/routing/least_latency.go | 12 +- pkg/routers/routing/least_latency_test.go | 20 ++- pkg/routers/routing/priority.go | 10 +- pkg/routers/routing/priority_test.go | 16 ++- pkg/routers/routing/round_robin.go | 8 +- pkg/routers/routing/round_robin_test.go | 16 ++- pkg/routers/routing/strategies.go | 4 +- pkg/routers/routing/weighted_round_robin.go | 8 +- .../routing/weighted_round_robin_test.go | 16 ++- 26 files changed, 217 insertions(+), 245 deletions(-) rename pkg/{models => extmodel}/config.go (86%) rename pkg/{models => extmodel}/lang.go (91%) create mode 100644 pkg/extmodel/model.go rename pkg/{providers/testing/models.go => extmodel/testing.go} (89%) delete mode 100644 pkg/models/model.go delete mode 100644 pkg/provider/config.go rename pkg/{provider/provider.go => providers/interface.go} (97%) rename pkg/providers/{testing/lang.go => testing.go} (74%) delete mode 100644 pkg/providers/testing/config.go diff --git a/pkg/models/config.go b/pkg/extmodel/config.go similarity index 86% rename from pkg/models/config.go rename to pkg/extmodel/config.go index edc67cae..2266fcaa 100644 --- a/pkg/models/config.go +++ b/pkg/extmodel/config.go @@ -1,17 +1,18 @@ -package models +package extmodel import ( "fmt" + "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/provider" "github.com/EinStack/glide/pkg/resiliency/health" "github.com/EinStack/glide/pkg/routers/latency" "github.com/EinStack/glide/pkg/telemetry" ) // Config defines an extra configuration for a model wrapper around a provider -type Config[P provider.ProviderConfig] struct { +type Config[P providers.Configurer] struct { ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router) Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is the model enabled? ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"` @@ -22,7 +23,7 @@ type Config[P provider.ProviderConfig] struct { Provider P `yaml:"provider" json:"provider"` } -func NewConfig[P provider.ProviderConfig](ID string) *Config[P] { +func NewConfig[P providers.Configurer](ID string) *Config[P] { config := DefaultConfig[P]() config.ID = ID @@ -30,7 +31,7 @@ func NewConfig[P provider.ProviderConfig](ID string) *Config[P] { return &config } -func DefaultConfig[P provider.ProviderConfig]() Config[P] { +func DefaultConfig[P providers.Configurer]() Config[P] { return Config[P]{ Enabled: true, Client: clients.DefaultClientConfig(), diff --git a/pkg/models/lang.go b/pkg/extmodel/lang.go similarity index 91% rename from pkg/models/lang.go rename to pkg/extmodel/lang.go index 2d50bd35..e3243cb2 100644 --- a/pkg/models/lang.go +++ b/pkg/extmodel/lang.go @@ -1,11 +1,11 @@ -package models +package extmodel import ( "context" "io" "time" - "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/providers" "github.com/EinStack/glide/pkg/clients" health2 "github.com/EinStack/glide/pkg/resiliency/health" @@ -18,7 +18,7 @@ import ( ) type LangModel interface { - Model + Interface Provider() string ModelName() string Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) @@ -32,14 +32,14 @@ type LangModel interface { type LanguageModel struct { modelID string weight int - client provider.LangProvider + client providers.LangProvider healthTracker *health2.Tracker chatLatency *latency.MovingAverage chatStreamLatency *latency.MovingAverage latencyUpdateInterval *fields.Duration } -func NewLangModel(modelID string, client provider.LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { +func NewLangModel(modelID string, client providers.LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { return &LanguageModel{ modelID: modelID, client: client, @@ -170,10 +170,10 @@ func (m *LanguageModel) ModelName() string { return m.client.ModelName() } -func ChatLatency(model Model) *latency.MovingAverage { +func ChatLatency(model Interface) *latency.MovingAverage { return model.(LanguageModel).ChatLatency() } -func ChatStreamLatency(model Model) *latency.MovingAverage { +func ChatStreamLatency(model Interface) *latency.MovingAverage { return model.(LanguageModel).ChatStreamLatency() } diff --git a/pkg/extmodel/model.go b/pkg/extmodel/model.go new file mode 100644 index 00000000..b250c470 --- /dev/null +++ b/pkg/extmodel/model.go @@ -0,0 +1,11 @@ +package extmodel + +import "github.com/EinStack/glide/pkg/config/fields" + +// Interface represent a configured external modality-agnostic model with its routing properties and status +type Interface interface { + ID() string + Healthy() bool + LatencyUpdateInterval() *fields.Duration + Weight() int +} diff --git a/pkg/providers/testing/models.go b/pkg/extmodel/testing.go similarity index 89% rename from pkg/providers/testing/models.go rename to pkg/extmodel/testing.go index 57500d21..6d51ca79 100644 --- a/pkg/providers/testing/models.go +++ b/pkg/extmodel/testing.go @@ -1,10 +1,9 @@ -package testing +package extmodel import ( "time" "github.com/EinStack/glide/pkg/config/fields" - "github.com/EinStack/glide/pkg/models" "github.com/EinStack/glide/pkg/routers/latency" ) @@ -53,6 +52,6 @@ func (m LangModelMock) Weight() int { return m.weight } -func ChatMockLatency(model models.Model) *latency.MovingAverage { +func ChatMockLatency(model Interface) *latency.MovingAverage { return model.(LangModelMock).chatLatency } diff --git a/pkg/models/model.go b/pkg/models/model.go deleted file mode 100644 index 707efee3..00000000 --- a/pkg/models/model.go +++ /dev/null @@ -1,11 +0,0 @@ -package models - -import "github.com/EinStack/glide/pkg/config/fields" - -// Model represent a configured external modality-agnostic model with its routing properties and status -type Model interface { - ID() string - Healthy() bool - LatencyUpdateInterval() *fields.Duration - Weight() int -} diff --git a/pkg/provider/config.go b/pkg/provider/config.go deleted file mode 100644 index 0424e839..00000000 --- a/pkg/provider/config.go +++ /dev/null @@ -1,12 +0,0 @@ -package provider - -import ( - "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/telemetry" -) - -// TODO: ProviderConfig should be more generic, not tied to LangProviders -type ProviderConfig interface { - UnmarshalYAML(unmarshal func(interface{}) error) error - ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) -} diff --git a/pkg/providers/config.go b/pkg/providers/config.go index 466f07e3..6469710d 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -5,7 +5,6 @@ import ( "fmt" "strings" - "github.com/EinStack/glide/pkg/provider" "github.com/go-playground/validator/v10" "gopkg.in/yaml.v3" @@ -22,12 +21,17 @@ func init() { validate = validator.New() } -// TODO: rename DynLangProvider to DynLangProviderConfig -type DynLangProvider map[provider.ProviderID]interface{} +// TODO: Configurer should be more generic, not tied to LangProviders +type Configurer interface { + UnmarshalYAML(unmarshal func(interface{}) error) error + ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) +} + +type Config map[ProviderID]interface{} -var _ provider.ProviderConfig = (*DynLangProvider)(nil) +var _ Configurer = (*Config)(nil) -func (p DynLangProvider) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { +func (p Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) { for providerID, configValue := range p { if configValue == nil { continue @@ -60,12 +64,12 @@ func (p DynLangProvider) ToClient(tel *telemetry.Telemetry, clientConfig *client return providerConfig.ToClient(tel, clientConfig) } - return nil, provider.ErrProviderNotFound + return nil, ErrProviderNotFound } // validate ensure there is only one provider configured and it's supported by Glide -func (p DynLangProvider) validate() error { - configuredProviders := make([]provider.ProviderID, 0, len(p)) +func (p Config) validate() error { + configuredProviders := make([]ProviderID, 0, len(p)) for providerID, config := range p { if config != nil { @@ -115,8 +119,8 @@ func (p DynLangProvider) validate() error { return providerConfig.UnmarshalYAML(providerConfigUnmarshaller) } -func (p *DynLangProvider) UnmarshalYAML(unmarshal func(interface{}) error) error { - type plain DynLangProvider // to avoid recursion +func (p *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + type plain Config // to avoid recursion temp := plain{} @@ -124,7 +128,7 @@ func (p *DynLangProvider) UnmarshalYAML(unmarshal func(interface{}) error) error return err } - *p = DynLangProvider(temp) + *p = Config(temp) return p.validate() } diff --git a/pkg/providers/config_test.go b/pkg/providers/config_test.go index 7bb7c40f..7e1d18c8 100644 --- a/pkg/providers/config_test.go +++ b/pkg/providers/config_test.go @@ -5,19 +5,18 @@ import ( "path/filepath" "testing" - testprovider "github.com/EinStack/glide/pkg/providers/testing" "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" ) func TestDynLangProvider(t *testing.T) { - LangRegistry.Register(testprovider.ProviderTest, &testprovider.Config{}) + LangRegistry.Register(ProviderTest, &TestConfig{}) type ProviderConfig struct { - Provider *DynLangProvider `yaml:"provider"` + Provider *Config `yaml:"provider"` } - prConfig := make(DynLangProvider) + prConfig := make(Config) providerConfig := ProviderConfig{ Provider: &prConfig, } diff --git a/pkg/provider/provider.go b/pkg/providers/interface.go similarity index 97% rename from pkg/provider/provider.go rename to pkg/providers/interface.go index d2b76419..0b9fe45b 100644 --- a/pkg/provider/provider.go +++ b/pkg/providers/interface.go @@ -1,4 +1,4 @@ -package provider +package providers import ( "context" @@ -21,9 +21,7 @@ type ModelProvider interface { // LangProvider defines an interface a provider should fulfill to be able to serve language chat requests type LangProvider interface { ModelProvider - SupportChatStream() bool - Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) } @@ -31,8 +29,6 @@ type LangProvider interface { // EmbeddingProvider defines an interface a provider should fulfill to be able to generate embeddings type EmbeddingProvider interface { ModelProvider - SupportEmbedding() bool - Embed(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) } diff --git a/pkg/providers/openai/config.go b/pkg/providers/openai/config.go index 4dcf9ff4..fee9a589 100644 --- a/pkg/providers/openai/config.go +++ b/pkg/providers/openai/config.go @@ -3,7 +3,7 @@ package openai import ( "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/config/fields" - "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/providers" "github.com/EinStack/glide/pkg/telemetry" ) @@ -52,7 +52,7 @@ type Config struct { DefaultParams *Params `yaml:"default_params,omitempty" json:"default_params"` } -var _ provider.ProviderConfig = (*Config)(nil) +var _ providers.Configurer = (*Config)(nil) // DefaultConfig for OpenAI models func DefaultConfig() *Config { @@ -66,7 +66,7 @@ func DefaultConfig() *Config { } } -func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { +func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (providers.LangProvider, error) { return NewClient(c, clientConfig, tel) } diff --git a/pkg/providers/registry.go b/pkg/providers/registry.go index 3f626f6e..8298ebfa 100644 --- a/pkg/providers/registry.go +++ b/pkg/providers/registry.go @@ -2,23 +2,21 @@ package providers import ( "fmt" - - "github.com/EinStack/glide/pkg/provider" ) var LangRegistry = NewProviderRegistry() type ProviderRegistry struct { - providers map[provider.ProviderID]provider.ProviderConfig + providers map[ProviderID]Configurer } func NewProviderRegistry() *ProviderRegistry { return &ProviderRegistry{ - providers: make(map[provider.ProviderID]provider.ProviderConfig), + providers: make(map[ProviderID]Configurer), } } -func (r *ProviderRegistry) Register(name provider.ProviderID, config provider.ProviderConfig) { +func (r *ProviderRegistry) Register(name ProviderID, config Configurer) { if _, ok := r.Get(name); ok { panic(fmt.Sprintf("provider %s is already registered", name)) } @@ -26,14 +24,14 @@ func (r *ProviderRegistry) Register(name provider.ProviderID, config provider.Pr r.providers[name] = config } -func (r *ProviderRegistry) Get(name provider.ProviderID) (provider.ProviderConfig, bool) { +func (r *ProviderRegistry) Get(name ProviderID) (Configurer, bool) { config, ok := r.providers[name] return config, ok } -func (r *ProviderRegistry) Available() []provider.ProviderID { - available := make([]provider.ProviderID, 0, len(r.providers)) +func (r *ProviderRegistry) Available() []ProviderID { + available := make([]ProviderID, 0, len(r.providers)) for providerID := range r.providers { available = append(available, providerID) diff --git a/pkg/providers/testing/lang.go b/pkg/providers/testing.go similarity index 74% rename from pkg/providers/testing/lang.go rename to pkg/providers/testing.go index 3c27792a..f4cda83b 100644 --- a/pkg/providers/testing/lang.go +++ b/pkg/providers/testing.go @@ -1,14 +1,36 @@ -package testing +package providers import ( "context" "io" - clients2 "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/config/fields" + "github.com/EinStack/glide/pkg/telemetry" +) + +const ( + ProviderTest = "testprovider" ) +type TestConfig struct { + BaseURL string `yaml:"base_url" json:"base_url" validate:"required"` + ChatEndpoint string `yaml:"chat_endpoint" json:"chat_endpoint" validate:"required"` + ModelName string `yaml:"model" json:"model" validate:"required"` + APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` +} + +func (c *TestConfig) ToClient(_ *telemetry.Telemetry, _ *clients.ClientConfig) (LangProvider, error) { + return NewProviderMock(nil, []RespMock{}), nil +} + +func (c *TestConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + type plain TestConfig // to avoid recursion + + return unmarshal((*plain)(c)) +} + // RespMock mocks a chat response or a streaming chat chunk type RespMock struct { Msg string @@ -124,7 +146,7 @@ func (c *ProviderMock) SupportChatStream() bool { func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) { if c.chatResps == nil { - return nil, clients2.ErrProviderUnavailable + return nil, clients.ErrProviderUnavailable } responses := *c.chatResps @@ -139,9 +161,9 @@ func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas. return response.Resp(), nil } -func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { +func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { if c.chatStreams == nil || c.idx >= len(*c.chatStreams) { - return nil, clients2.ErrProviderUnavailable + return nil, clients.ErrProviderUnavailable } streams := *c.chatStreams diff --git a/pkg/providers/testing/config.go b/pkg/providers/testing/config.go deleted file mode 100644 index dd7d0853..00000000 --- a/pkg/providers/testing/config.go +++ /dev/null @@ -1,29 +0,0 @@ -package testing - -import ( - "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/config/fields" - "github.com/EinStack/glide/pkg/provider" - "github.com/EinStack/glide/pkg/telemetry" -) - -const ( - ProviderTest = "testprovider" -) - -type Config struct { - BaseURL string `yaml:"base_url" json:"base_url" validate:"required"` - ChatEndpoint string `yaml:"chat_endpoint" json:"chat_endpoint" validate:"required"` - ModelName string `yaml:"model" json:"model" validate:"required"` - APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` -} - -func (c *Config) ToClient(_ *telemetry.Telemetry, _ *clients.ClientConfig) (provider.LangProvider, error) { - return NewProviderMock(nil, []RespMock{}), nil -} - -func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { - type plain Config // to avoid recursion - - return unmarshal((*plain)(c)) -} diff --git a/pkg/routers/lang/config.go b/pkg/routers/lang/config.go index a8f3f6bb..8817ab5b 100644 --- a/pkg/routers/lang/config.go +++ b/pkg/routers/lang/config.go @@ -4,11 +4,12 @@ import ( "fmt" "time" + "github.com/EinStack/glide/pkg/extmodel" + "github.com/EinStack/glide/pkg/providers" "github.com/EinStack/glide/pkg/routers" - "github.com/EinStack/glide/pkg/models" "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/routers/routing" "github.com/EinStack/glide/pkg/telemetry" @@ -17,7 +18,7 @@ import ( ) type ( - ModelConfig = models.Config[*providers.DynLangProvider] + ModelConfig = extmodel.Config[*providers.Config] ModelPoolConfig = []ModelConfig ) @@ -50,12 +51,12 @@ func NewRouterConfig(RouterID string, opt ...RouterConfigOption) *RouterConfig { } // BuildModels creates LanguageModel slice out of the given config -func (c *RouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*models.LanguageModel, []*models.LanguageModel, error) { //nolint: cyclop +func (c *RouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*extmodel.LanguageModel, []*extmodel.LanguageModel, error) { //nolint: cyclop var errs error seenIDs := make(map[string]bool, len(c.Models)) - chatModels := make([]*models.LanguageModel, 0, len(c.Models)) - chatStreamModels := make([]*models.LanguageModel, 0, len(c.Models)) + chatModels := make([]*extmodel.LanguageModel, 0, len(c.Models)) + chatStreamModels := make([]*extmodel.LanguageModel, 0, len(c.Models)) for _, modelConfig := range c.Models { if _, ok := seenIDs[modelConfig.ID]; ok { @@ -159,11 +160,11 @@ func (c *RouterConfig) BuildRetry() *retry.ExpRetry { } func (c *RouterConfig) BuildRouting( - chatModels []*models.LanguageModel, - chatStreamModels []*models.LanguageModel, + chatModels []*extmodel.LanguageModel, + chatStreamModels []*extmodel.LanguageModel, ) (routing.LangModelRouting, routing.LangModelRouting, error) { - chatModelPool := make([]models.Model, 0, len(chatModels)) - chatStreamModelPool := make([]models.Model, 0, len(chatStreamModels)) + chatModelPool := make([]extmodel.Interface, 0, len(chatModels)) + chatStreamModelPool := make([]extmodel.Interface, 0, len(chatStreamModels)) for _, model := range chatModels { chatModelPool = append(chatModelPool, model) @@ -181,8 +182,8 @@ func (c *RouterConfig) BuildRouting( case routing.WeightedRoundRobin: return routing.NewWeightedRoundRobin(chatModelPool), routing.NewWeightedRoundRobin(chatStreamModelPool), nil case routing.LeastLatency: - return routing.NewLeastLatencyRouting(models.ChatLatency, chatModelPool), - routing.NewLeastLatencyRouting(models.ChatStreamLatency, chatStreamModelPool), + return routing.NewLeastLatencyRouting(extmodel.ChatLatency, chatModelPool), + routing.NewLeastLatencyRouting(extmodel.ChatStreamLatency, chatStreamModelPool), nil } diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang/config_test.go index 1ed43363..1b5975bc 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang/config_test.go @@ -27,7 +27,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -45,7 +45,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -80,7 +80,7 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &openAIParams, @@ -93,7 +93,7 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ cohere.ProviderID: &cohere.Config{ APIKey: "ABC", DefaultParams: &cohereParams, @@ -129,7 +129,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -147,7 +147,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -170,7 +170,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -183,7 +183,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, diff --git a/pkg/routers/lang/router.go b/pkg/routers/lang/router.go index caf98a4c..9dcc0c33 100644 --- a/pkg/routers/lang/router.go +++ b/pkg/routers/lang/router.go @@ -4,8 +4,9 @@ import ( "context" "errors" + "github.com/EinStack/glide/pkg/extmodel" + "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/models" "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/routers/routing" "github.com/EinStack/glide/pkg/telemetry" @@ -19,8 +20,8 @@ type RouterID = string type Router struct { routerID RouterID Config *RouterConfig - chatModels []*models.LanguageModel - chatStreamModels []*models.LanguageModel + chatModels []*extmodel.LanguageModel + chatStreamModels []*extmodel.LanguageModel chatRouting routing.LangModelRouting chatStreamRouting routing.LangModelRouting retry *retry.ExpRetry @@ -76,7 +77,7 @@ func (r *Router) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.C break } - langModel := model.(models.LangModel) + langModel := model.(extmodel.LangModel) chatParams := req.Params(langModel.ID(), langModel.ModelName()) @@ -146,7 +147,7 @@ func (r *Router) ChatStream( break } - langModel := model.(models.LangModel) + langModel := model.(extmodel.LangModel) chatParams := req.Params(langModel.ID(), langModel.ModelName()) modelRespC, err := langModel.ChatStream(ctx, chatParams) diff --git a/pkg/routers/lang/router_test.go b/pkg/routers/lang/router_test.go index 65c9b2fe..ce6c28dc 100644 --- a/pkg/routers/lang/router_test.go +++ b/pkg/routers/lang/router_test.go @@ -5,9 +5,9 @@ import ( "testing" "time" - ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/providers" - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/resiliency/health" @@ -24,24 +24,24 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Msg: "1"}}), budget, *latConfig, 1, ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -70,31 +70,31 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "3"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "3"}}), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "4"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "4"}}), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "third", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -127,24 +127,24 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { budget := health.NewErrorBudget(1, health.MILLI) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "2"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "2"}}), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "1"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "1"}}), budget, *latConfig, 1, ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -170,24 +170,24 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { budget := health.NewErrorBudget(1, health.MIN) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: clients.ErrProviderUnavailable}, {Msg: "3"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: clients.ErrProviderUnavailable}, {Msg: "3"}}), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -215,24 +215,24 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), budget, *latConfig, 1, ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -257,11 +257,11 @@ func TestLangRouter_ChatStream(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ + providers.NewRespStreamMock(&[]providers.RespMock{ {Msg: "Bill"}, {Msg: "Gates"}, {Msg: "entered"}, @@ -273,10 +273,10 @@ func TestLangRouter_ChatStream(t *testing.T) { *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ + providers.NewRespStreamMock(&[]providers.RespMock{ {Msg: "Knock"}, {Msg: "Knock"}, {Msg: "joke"}, @@ -288,7 +288,7 @@ func TestLangRouter_ChatStream(t *testing.T) { ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -332,19 +332,19 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewStreamProviderMock(nil, nil), + providers.NewStreamProviderMock(nil, nil), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock( - &[]ptesting.RespMock{ + providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ + providers.NewRespStreamMock( + &[]providers.RespMock{ {Msg: "Knock"}, {Msg: "knock"}, {Msg: "joke"}, @@ -357,7 +357,7 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -401,11 +401,11 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ + providers.NewRespStreamMock(&[]providers.RespMock{ {Err: clients.ErrProviderUnavailable}, }), }), @@ -413,10 +413,10 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ + providers.NewRespStreamMock(&[]providers.RespMock{ {Err: clients.ErrProviderUnavailable}, }), }), @@ -426,7 +426,7 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } diff --git a/pkg/routers/routing/least_latency.go b/pkg/routers/routing/least_latency.go index e6c56a6f..d34f45e2 100644 --- a/pkg/routers/routing/least_latency.go +++ b/pkg/routers/routing/least_latency.go @@ -5,7 +5,7 @@ import ( "sync/atomic" "time" - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" "github.com/EinStack/glide/pkg/routers/latency" ) @@ -15,16 +15,16 @@ const ( ) // LatencyGetter defines where to find latency for the specific model action -type LatencyGetter = func(model models.Model) *latency.MovingAverage +type LatencyGetter = func(model extmodel.Interface) *latency.MovingAverage // ModelSchedule defines latency update schedule for models type ModelSchedule struct { mu sync.RWMutex - model models.Model + model extmodel.Interface expireAt time.Time } -func NewSchedule(model models.Model) *ModelSchedule { +func NewSchedule(model extmodel.Interface) *ModelSchedule { schedule := &ModelSchedule{ model: model, } @@ -67,7 +67,7 @@ type LeastLatencyRouting struct { schedules []*ModelSchedule } -func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []models.Model) *LeastLatencyRouting { +func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []extmodel.Interface) *LeastLatencyRouting { schedules := make([]*ModelSchedule, 0, len(models)) for _, model := range models { @@ -95,7 +95,7 @@ func (r *LeastLatencyRouting) Iterator() LangModelIterator { // other model latencies that might have improved over time). // For that, we introduced expiration time after which the model receives a request // even if it was not the fastest to respond -func (r *LeastLatencyRouting) Next() (models.Model, error) { //nolint:cyclop +func (r *LeastLatencyRouting) Next() (extmodel.Interface, error) { //nolint:cyclop coldSchedules := r.getColdModelSchedules() if len(coldSchedules) > 0 { diff --git a/pkg/routers/routing/least_latency_test.go b/pkg/routers/routing/least_latency_test.go index 2f18b322..0e6618f2 100644 --- a/pkg/routers/routing/least_latency_test.go +++ b/pkg/routers/routing/least_latency_test.go @@ -5,9 +5,7 @@ import ( "testing" "time" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" "github.com/stretchr/testify/require" ) @@ -33,13 +31,13 @@ func TestLeastLatencyRouting_Warmup(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - modelPool := make([]models.Model, 0, len(tc.models)) + modelPool := make([]extmodel.Interface, 0, len(tc.models)) for _, model := range tc.models { - modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) + modelPool = append(modelPool, extmodel.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) } - routing := NewLeastLatencyRouting(ptesting.ChatMockLatency, modelPool) + routing := NewLeastLatencyRouting(extmodel.ChatMockLatency, modelPool) iterator := routing.Iterator() // loop three times over the whole pool to check if we return back to the begging of the list @@ -107,7 +105,7 @@ func TestLeastLatencyRouting_Routing(t *testing.T) { for _, model := range tc.models { schedules = append(schedules, &ModelSchedule{ - model: ptesting.NewLangModelMock( + model: extmodel.NewLangModelMock( model.modelID, model.healthy, model.latency, @@ -118,7 +116,7 @@ func TestLeastLatencyRouting_Routing(t *testing.T) { } routing := LeastLatencyRouting{ - latencyGetter: ptesting.ChatMockLatency, + latencyGetter: extmodel.ChatMockLatency, schedules: schedules, } @@ -144,13 +142,13 @@ func TestLeastLatencyRouting_NoHealthyModels(t *testing.T) { for name, latencies := range tests { t.Run(name, func(t *testing.T) { - modelPool := make([]models.Model, 0, len(latencies)) + modelPool := make([]extmodel.Interface, 0, len(latencies)) for idx, latency := range latencies { - modelPool = append(modelPool, ptesting.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) + modelPool = append(modelPool, extmodel.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) } - routing := NewLeastLatencyRouting(models.ChatLatency, modelPool) + routing := NewLeastLatencyRouting(extmodel.ChatLatency, modelPool) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/routing/priority.go b/pkg/routers/routing/priority.go index 04d4d94e..7cf5ceeb 100644 --- a/pkg/routers/routing/priority.go +++ b/pkg/routers/routing/priority.go @@ -3,7 +3,7 @@ package routing import ( "sync/atomic" - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" ) const ( @@ -15,10 +15,10 @@ const ( // Priority of models are defined as position of the model on the list // (e.g. the first model definition has the highest priority, then the second model definition and so on) type PriorityRouting struct { - models []models.Model + models []extmodel.Interface } -func NewPriority(models []models.Model) *PriorityRouting { +func NewPriority(models []extmodel.Interface) *PriorityRouting { return &PriorityRouting{ models: models, } @@ -35,10 +35,10 @@ func (r *PriorityRouting) Iterator() LangModelIterator { type PriorityIterator struct { idx *atomic.Uint64 - models []models.Model + models []extmodel.Interface } -func (r PriorityIterator) Next() (models.Model, error) { +func (r PriorityIterator) Next() (extmodel.Interface, error) { modelPool := r.models for idx := int(r.idx.Load()); idx < len(modelPool); idx = int(r.idx.Add(1)) { diff --git a/pkg/routers/routing/priority_test.go b/pkg/routers/routing/priority_test.go index c0713f35..98e27e7d 100644 --- a/pkg/routers/routing/priority_test.go +++ b/pkg/routers/routing/priority_test.go @@ -3,9 +3,7 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" "github.com/stretchr/testify/require" ) @@ -29,10 +27,10 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - modelPool := make([]models.Model, 0, len(tc.models)) + modelPool := make([]extmodel.Interface, 0, len(tc.models)) for _, model := range tc.models { - modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + modelPool = append(modelPool, extmodel.NewLangModelMock(model.modelID, model.healthy, 100, 1)) } routing := NewPriority(modelPool) @@ -49,10 +47,10 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) { } func TestPriorityRouting_NoHealthyModels(t *testing.T) { - modelPool := []models.Model{ - ptesting.NewLangModelMock("first", false, 0, 1), - ptesting.NewLangModelMock("second", false, 0, 1), - ptesting.NewLangModelMock("third", false, 0, 1), + modelPool := []extmodel.Interface{ + extmodel.NewLangModelMock("first", false, 0, 1), + extmodel.NewLangModelMock("second", false, 0, 1), + extmodel.NewLangModelMock("third", false, 0, 1), } routing := NewPriority(modelPool) diff --git a/pkg/routers/routing/round_robin.go b/pkg/routers/routing/round_robin.go index abd2ff96..7582cbcb 100644 --- a/pkg/routers/routing/round_robin.go +++ b/pkg/routers/routing/round_robin.go @@ -3,7 +3,7 @@ package routing import ( "sync/atomic" - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" ) const ( @@ -13,10 +13,10 @@ const ( // RoundRobinRouting routes request to the next model in the list in cycle type RoundRobinRouting struct { idx atomic.Uint64 - models []models.Model + models []extmodel.Interface } -func NewRoundRobinRouting(models []models.Model) *RoundRobinRouting { +func NewRoundRobinRouting(models []extmodel.Interface) *RoundRobinRouting { return &RoundRobinRouting{ models: models, } @@ -26,7 +26,7 @@ func (r *RoundRobinRouting) Iterator() LangModelIterator { return r } -func (r *RoundRobinRouting) Next() (models.Model, error) { +func (r *RoundRobinRouting) Next() (extmodel.Interface, error) { modelLen := len(r.models) // in order to avoid infinite loop in case of no healthy model is available, diff --git a/pkg/routers/routing/round_robin_test.go b/pkg/routers/routing/round_robin_test.go index c2c6d307..7287f468 100644 --- a/pkg/routers/routing/round_robin_test.go +++ b/pkg/routers/routing/round_robin_test.go @@ -3,9 +3,7 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" "github.com/stretchr/testify/require" ) @@ -30,10 +28,10 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - modelPool := make([]models.Model, 0, len(tc.models)) + modelPool := make([]extmodel.Interface, 0, len(tc.models)) for _, model := range tc.models { - modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + modelPool = append(modelPool, extmodel.NewLangModelMock(model.modelID, model.healthy, 100, 1)) } routing := NewRoundRobinRouting(modelPool) @@ -52,10 +50,10 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { } func TestRoundRobinRouting_NoHealthyModels(t *testing.T) { - modelPool := []models.Model{ - ptesting.NewLangModelMock("first", false, 0, 1), - ptesting.NewLangModelMock("second", false, 0, 1), - ptesting.NewLangModelMock("third", false, 0, 1), + modelPool := []extmodel.Interface{ + extmodel.NewLangModelMock("first", false, 0, 1), + extmodel.NewLangModelMock("second", false, 0, 1), + extmodel.NewLangModelMock("third", false, 0, 1), } routing := NewRoundRobinRouting(modelPool) diff --git a/pkg/routers/routing/strategies.go b/pkg/routers/routing/strategies.go index 960702a4..48d18ab6 100644 --- a/pkg/routers/routing/strategies.go +++ b/pkg/routers/routing/strategies.go @@ -3,7 +3,7 @@ package routing import ( "errors" - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" ) var ErrNoHealthyModels = errors.New("no healthy models found") @@ -16,5 +16,5 @@ type LangModelRouting interface { } type LangModelIterator interface { - Next() (models.Model, error) + Next() (extmodel.Interface, error) } diff --git a/pkg/routers/routing/weighted_round_robin.go b/pkg/routers/routing/weighted_round_robin.go index dfbee414..418add91 100644 --- a/pkg/routers/routing/weighted_round_robin.go +++ b/pkg/routers/routing/weighted_round_robin.go @@ -3,7 +3,7 @@ package routing import ( "sync" - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" ) const ( @@ -11,7 +11,7 @@ const ( ) type Weighter struct { - model models.Model + model extmodel.Interface currentWeight int } @@ -36,7 +36,7 @@ type WRoundRobinRouting struct { weights []*Weighter } -func NewWeightedRoundRobin(models []models.Model) *WRoundRobinRouting { +func NewWeightedRoundRobin(models []extmodel.Interface) *WRoundRobinRouting { weights := make([]*Weighter, 0, len(models)) for _, model := range models { @@ -55,7 +55,7 @@ func (r *WRoundRobinRouting) Iterator() LangModelIterator { return r } -func (r *WRoundRobinRouting) Next() (models.Model, error) { +func (r *WRoundRobinRouting) Next() (extmodel.Interface, error) { r.mu.Lock() defer r.mu.Unlock() diff --git a/pkg/routers/routing/weighted_round_robin_test.go b/pkg/routers/routing/weighted_round_robin_test.go index e24d4e81..7ec9b24c 100644 --- a/pkg/routers/routing/weighted_round_robin_test.go +++ b/pkg/routers/routing/weighted_round_robin_test.go @@ -3,9 +3,7 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" "github.com/stretchr/testify/require" ) @@ -116,10 +114,10 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - modelPool := make([]models.Model, 0, len(tc.models)) + modelPool := make([]extmodel.Interface, 0, len(tc.models)) for _, model := range tc.models { - modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) + modelPool = append(modelPool, extmodel.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) } routing := NewWeightedRoundRobin(modelPool) @@ -142,10 +140,10 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { } func TestWRoundRobinRouting_NoHealthyModels(t *testing.T) { - modelPool := []models.Model{ - ptesting.NewLangModelMock("first", false, 0, 1), - ptesting.NewLangModelMock("second", false, 0, 2), - ptesting.NewLangModelMock("third", false, 0, 3), + modelPool := []extmodel.Interface{ + extmodel.NewLangModelMock("first", false, 0, 1), + extmodel.NewLangModelMock("second", false, 0, 2), + extmodel.NewLangModelMock("third", false, 0, 3), } routing := NewWeightedRoundRobin(modelPool) From 5ad821c2c2c16927dd082aadd14b6711c097b31d Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Thu, 8 Aug 2024 18:38:09 +0300 Subject: [PATCH 09/18] #67: Renamed providers package to just provider --- pkg/extmodel/config.go | 8 ++-- pkg/extmodel/lang.go | 6 +-- pkg/{providers => provider}/anthropic/chat.go | 0 .../anthropic/chat_stream.go | 0 .../anthropic/client.go | 0 .../anthropic/client_test.go | 0 .../anthropic/config.go | 0 .../anthropic/errors.go | 0 .../anthropic/schamas.go | 0 .../anthropic/testdata/chat.req.json | 0 .../anthropic/testdata/chat.success.json | 0 .../azureopenai/chat.go | 2 +- .../azureopenai/chat_stream.go | 2 +- .../azureopenai/chat_stream_test.go | 0 .../azureopenai/client.go | 2 +- .../azureopenai/client_test.go | 0 .../azureopenai/config.go | 0 .../azureopenai/errors.go | 0 .../azureopenai/schemas.go | 0 .../azureopenai/testdata/chat.req.json | 0 .../azureopenai/testdata/chat.success.json | 0 .../testdata/chat_stream.empty.txt | 0 .../testdata/chat_stream.nodone.txt | 0 .../testdata/chat_stream.success.txt | 0 pkg/{providers => provider}/bedrock/chat.go | 0 .../bedrock/chat_stream.go | 0 pkg/{providers => provider}/bedrock/client.go | 0 .../bedrock/client_test.go | 0 pkg/{providers => provider}/bedrock/config.go | 0 .../bedrock/schemas.go | 0 .../bedrock/testdata/chat.req.json | 0 .../bedrock/testdata/chat.success.json | 0 pkg/{providers => provider}/cohere/chat.go | 0 .../cohere/chat_stream.go | 0 .../cohere/chat_stream_test.go | 0 pkg/{providers => provider}/cohere/client.go | 0 .../cohere/client_test.go | 0 pkg/{providers => provider}/cohere/config.go | 0 pkg/{providers => provider}/cohere/errors.go | 0 .../cohere/finish_reason.go | 0 pkg/{providers => provider}/cohere/schemas.go | 0 .../cohere/stream_reader.go | 0 .../cohere/testdata/chat.req.json | 0 .../cohere/testdata/chat.success.json | 0 .../testdata/chat_stream.interrupted.txt | 0 .../cohere/testdata/chat_stream.success.txt | 0 pkg/{providers => provider}/config.go | 6 +-- pkg/{providers => provider}/config_test.go | 2 +- pkg/{providers => provider}/config_test.yaml | 0 pkg/{providers => provider}/interface.go | 6 +-- pkg/{providers => provider}/octoml/chat.go | 2 +- .../octoml/chat_stream.go | 0 pkg/{providers => provider}/octoml/client.go | 0 .../octoml/client_test.go | 0 pkg/{providers => provider}/octoml/config.go | 0 pkg/{providers => provider}/octoml/errors.go | 0 .../octoml/testdata/chat.req.json | 0 .../octoml/testdata/chat.success.json | 0 pkg/{providers => provider}/ollama/chat.go | 0 .../ollama/chat_stream.go | 0 pkg/{providers => provider}/ollama/client.go | 0 .../ollama/client_test.go | 0 pkg/{providers => provider}/ollama/config.go | 0 pkg/{providers => provider}/ollama/schemas.go | 0 .../ollama/testdata/chat.req.json | 0 .../ollama/testdata/chat.success.json | 0 pkg/{providers => provider}/openai/chat.go | 0 .../openai/chat_stream.go | 0 .../openai/chat_stream_test.go | 0 .../openai/chat_test.go | 0 pkg/{providers => provider}/openai/client.go | 0 pkg/{providers => provider}/openai/config.go | 6 +-- pkg/{providers => provider}/openai/embed.go | 0 pkg/{providers => provider}/openai/errors.go | 0 .../openai/finish_reasons.go | 0 pkg/provider/openai/register.go | 7 +++ pkg/{providers => provider}/openai/schemas.go | 0 .../openai/testdata/chat.req.json | 0 .../openai/testdata/chat.success.json | 0 .../openai/testdata/chat_stream.empty.txt | 0 .../openai/testdata/chat_stream.nodone.txt | 0 .../openai/testdata/chat_stream.success.txt | 0 pkg/provider/registry.go | 41 ++++++++++++++++ pkg/{providers => provider}/testing.go | 26 +++++----- pkg/providers/openai/register.go | 9 ---- pkg/providers/registry.go | 41 ---------------- pkg/routers/lang/config.go | 6 +-- pkg/routers/lang/config_test.go | 23 ++++----- pkg/routers/lang/router_test.go | 48 +++++++++---------- 89 files changed, 121 insertions(+), 122 deletions(-) rename pkg/{providers => provider}/anthropic/chat.go (100%) rename pkg/{providers => provider}/anthropic/chat_stream.go (100%) rename pkg/{providers => provider}/anthropic/client.go (100%) rename pkg/{providers => provider}/anthropic/client_test.go (100%) rename pkg/{providers => provider}/anthropic/config.go (100%) rename pkg/{providers => provider}/anthropic/errors.go (100%) rename pkg/{providers => provider}/anthropic/schamas.go (100%) rename pkg/{providers => provider}/anthropic/testdata/chat.req.json (100%) rename pkg/{providers => provider}/anthropic/testdata/chat.success.json (100%) rename pkg/{providers => provider}/azureopenai/chat.go (98%) rename pkg/{providers => provider}/azureopenai/chat_stream.go (99%) rename pkg/{providers => provider}/azureopenai/chat_stream_test.go (100%) rename pkg/{providers => provider}/azureopenai/client.go (97%) rename pkg/{providers => provider}/azureopenai/client_test.go (100%) rename pkg/{providers => provider}/azureopenai/config.go (100%) rename pkg/{providers => provider}/azureopenai/errors.go (100%) rename pkg/{providers => provider}/azureopenai/schemas.go (100%) rename pkg/{providers => provider}/azureopenai/testdata/chat.req.json (100%) rename pkg/{providers => provider}/azureopenai/testdata/chat.success.json (100%) rename pkg/{providers => provider}/azureopenai/testdata/chat_stream.empty.txt (100%) rename pkg/{providers => provider}/azureopenai/testdata/chat_stream.nodone.txt (100%) rename pkg/{providers => provider}/azureopenai/testdata/chat_stream.success.txt (100%) rename pkg/{providers => provider}/bedrock/chat.go (100%) rename pkg/{providers => provider}/bedrock/chat_stream.go (100%) rename pkg/{providers => provider}/bedrock/client.go (100%) rename pkg/{providers => provider}/bedrock/client_test.go (100%) rename pkg/{providers => provider}/bedrock/config.go (100%) rename pkg/{providers => provider}/bedrock/schemas.go (100%) rename pkg/{providers => provider}/bedrock/testdata/chat.req.json (100%) rename pkg/{providers => provider}/bedrock/testdata/chat.success.json (100%) rename pkg/{providers => provider}/cohere/chat.go (100%) rename pkg/{providers => provider}/cohere/chat_stream.go (100%) rename pkg/{providers => provider}/cohere/chat_stream_test.go (100%) rename pkg/{providers => provider}/cohere/client.go (100%) rename pkg/{providers => provider}/cohere/client_test.go (100%) rename pkg/{providers => provider}/cohere/config.go (100%) rename pkg/{providers => provider}/cohere/errors.go (100%) rename pkg/{providers => provider}/cohere/finish_reason.go (100%) rename pkg/{providers => provider}/cohere/schemas.go (100%) rename pkg/{providers => provider}/cohere/stream_reader.go (100%) rename pkg/{providers => provider}/cohere/testdata/chat.req.json (100%) rename pkg/{providers => provider}/cohere/testdata/chat.success.json (100%) rename pkg/{providers => provider}/cohere/testdata/chat_stream.interrupted.txt (100%) rename pkg/{providers => provider}/cohere/testdata/chat_stream.success.txt (100%) rename pkg/{providers => provider}/config.go (96%) rename pkg/{providers => provider}/config_test.go (96%) rename pkg/{providers => provider}/config_test.yaml (100%) rename pkg/{providers => provider}/interface.go (93%) rename pkg/{providers => provider}/octoml/chat.go (98%) rename pkg/{providers => provider}/octoml/chat_stream.go (100%) rename pkg/{providers => provider}/octoml/client.go (100%) rename pkg/{providers => provider}/octoml/client_test.go (100%) rename pkg/{providers => provider}/octoml/config.go (100%) rename pkg/{providers => provider}/octoml/errors.go (100%) rename pkg/{providers => provider}/octoml/testdata/chat.req.json (100%) rename pkg/{providers => provider}/octoml/testdata/chat.success.json (100%) rename pkg/{providers => provider}/ollama/chat.go (100%) rename pkg/{providers => provider}/ollama/chat_stream.go (100%) rename pkg/{providers => provider}/ollama/client.go (100%) rename pkg/{providers => provider}/ollama/client_test.go (100%) rename pkg/{providers => provider}/ollama/config.go (100%) rename pkg/{providers => provider}/ollama/schemas.go (100%) rename pkg/{providers => provider}/ollama/testdata/chat.req.json (100%) rename pkg/{providers => provider}/ollama/testdata/chat.success.json (100%) rename pkg/{providers => provider}/openai/chat.go (100%) rename pkg/{providers => provider}/openai/chat_stream.go (100%) rename pkg/{providers => provider}/openai/chat_stream_test.go (100%) rename pkg/{providers => provider}/openai/chat_test.go (100%) rename pkg/{providers => provider}/openai/client.go (100%) rename pkg/{providers => provider}/openai/config.go (94%) rename pkg/{providers => provider}/openai/embed.go (100%) rename pkg/{providers => provider}/openai/errors.go (100%) rename pkg/{providers => provider}/openai/finish_reasons.go (100%) create mode 100644 pkg/provider/openai/register.go rename pkg/{providers => provider}/openai/schemas.go (100%) rename pkg/{providers => provider}/openai/testdata/chat.req.json (100%) rename pkg/{providers => provider}/openai/testdata/chat.success.json (100%) rename pkg/{providers => provider}/openai/testdata/chat_stream.empty.txt (100%) rename pkg/{providers => provider}/openai/testdata/chat_stream.nodone.txt (100%) rename pkg/{providers => provider}/openai/testdata/chat_stream.success.txt (100%) create mode 100644 pkg/provider/registry.go rename pkg/{providers => provider}/testing.go (83%) delete mode 100644 pkg/providers/openai/register.go delete mode 100644 pkg/providers/registry.go diff --git a/pkg/extmodel/config.go b/pkg/extmodel/config.go index 2266fcaa..5c1bfa91 100644 --- a/pkg/extmodel/config.go +++ b/pkg/extmodel/config.go @@ -3,7 +3,7 @@ package extmodel import ( "fmt" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/provider" "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/resiliency/health" @@ -12,7 +12,7 @@ import ( ) // Config defines an extra configuration for a model wrapper around a provider -type Config[P providers.Configurer] struct { +type Config[P provider.Configurer] struct { ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router) Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is the model enabled? ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"` @@ -23,7 +23,7 @@ type Config[P providers.Configurer] struct { Provider P `yaml:"provider" json:"provider"` } -func NewConfig[P providers.Configurer](ID string) *Config[P] { +func NewConfig[P provider.Configurer](ID string) *Config[P] { config := DefaultConfig[P]() config.ID = ID @@ -31,7 +31,7 @@ func NewConfig[P providers.Configurer](ID string) *Config[P] { return &config } -func DefaultConfig[P providers.Configurer]() Config[P] { +func DefaultConfig[P provider.Configurer]() Config[P] { return Config[P]{ Enabled: true, Client: clients.DefaultClientConfig(), diff --git a/pkg/extmodel/lang.go b/pkg/extmodel/lang.go index e3243cb2..6b345112 100644 --- a/pkg/extmodel/lang.go +++ b/pkg/extmodel/lang.go @@ -5,7 +5,7 @@ import ( "io" "time" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/provider" "github.com/EinStack/glide/pkg/clients" health2 "github.com/EinStack/glide/pkg/resiliency/health" @@ -32,14 +32,14 @@ type LangModel interface { type LanguageModel struct { modelID string weight int - client providers.LangProvider + client provider.LangProvider healthTracker *health2.Tracker chatLatency *latency.MovingAverage chatStreamLatency *latency.MovingAverage latencyUpdateInterval *fields.Duration } -func NewLangModel(modelID string, client providers.LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { +func NewLangModel(modelID string, client provider.LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { return &LanguageModel{ modelID: modelID, client: client, diff --git a/pkg/providers/anthropic/chat.go b/pkg/provider/anthropic/chat.go similarity index 100% rename from pkg/providers/anthropic/chat.go rename to pkg/provider/anthropic/chat.go diff --git a/pkg/providers/anthropic/chat_stream.go b/pkg/provider/anthropic/chat_stream.go similarity index 100% rename from pkg/providers/anthropic/chat_stream.go rename to pkg/provider/anthropic/chat_stream.go diff --git a/pkg/providers/anthropic/client.go b/pkg/provider/anthropic/client.go similarity index 100% rename from pkg/providers/anthropic/client.go rename to pkg/provider/anthropic/client.go diff --git a/pkg/providers/anthropic/client_test.go b/pkg/provider/anthropic/client_test.go similarity index 100% rename from pkg/providers/anthropic/client_test.go rename to pkg/provider/anthropic/client_test.go diff --git a/pkg/providers/anthropic/config.go b/pkg/provider/anthropic/config.go similarity index 100% rename from pkg/providers/anthropic/config.go rename to pkg/provider/anthropic/config.go diff --git a/pkg/providers/anthropic/errors.go b/pkg/provider/anthropic/errors.go similarity index 100% rename from pkg/providers/anthropic/errors.go rename to pkg/provider/anthropic/errors.go diff --git a/pkg/providers/anthropic/schamas.go b/pkg/provider/anthropic/schamas.go similarity index 100% rename from pkg/providers/anthropic/schamas.go rename to pkg/provider/anthropic/schamas.go diff --git a/pkg/providers/anthropic/testdata/chat.req.json b/pkg/provider/anthropic/testdata/chat.req.json similarity index 100% rename from pkg/providers/anthropic/testdata/chat.req.json rename to pkg/provider/anthropic/testdata/chat.req.json diff --git a/pkg/providers/anthropic/testdata/chat.success.json b/pkg/provider/anthropic/testdata/chat.success.json similarity index 100% rename from pkg/providers/anthropic/testdata/chat.success.json rename to pkg/provider/anthropic/testdata/chat.success.json diff --git a/pkg/providers/azureopenai/chat.go b/pkg/provider/azureopenai/chat.go similarity index 98% rename from pkg/providers/azureopenai/chat.go rename to pkg/provider/azureopenai/chat.go index 2c62dc0f..86aab1f2 100644 --- a/pkg/providers/azureopenai/chat.go +++ b/pkg/provider/azureopenai/chat.go @@ -10,7 +10,7 @@ import ( "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers/openai" + "github.com/EinStack/glide/pkg/provider/openai" "github.com/EinStack/glide/pkg/api/schemas" diff --git a/pkg/providers/azureopenai/chat_stream.go b/pkg/provider/azureopenai/chat_stream.go similarity index 99% rename from pkg/providers/azureopenai/chat_stream.go rename to pkg/provider/azureopenai/chat_stream.go index dfa787c4..7c0f5b2c 100644 --- a/pkg/providers/azureopenai/chat_stream.go +++ b/pkg/provider/azureopenai/chat_stream.go @@ -12,7 +12,7 @@ import ( "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/openai" + "github.com/EinStack/glide/pkg/provider/openai" "github.com/r3labs/sse/v2" diff --git a/pkg/providers/azureopenai/chat_stream_test.go b/pkg/provider/azureopenai/chat_stream_test.go similarity index 100% rename from pkg/providers/azureopenai/chat_stream_test.go rename to pkg/provider/azureopenai/chat_stream_test.go diff --git a/pkg/providers/azureopenai/client.go b/pkg/provider/azureopenai/client.go similarity index 97% rename from pkg/providers/azureopenai/client.go rename to pkg/provider/azureopenai/client.go index 88c5b64d..5c34a154 100644 --- a/pkg/providers/azureopenai/client.go +++ b/pkg/provider/azureopenai/client.go @@ -7,7 +7,7 @@ import ( "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers/openai" + "github.com/EinStack/glide/pkg/provider/openai" "github.com/EinStack/glide/pkg/telemetry" ) diff --git a/pkg/providers/azureopenai/client_test.go b/pkg/provider/azureopenai/client_test.go similarity index 100% rename from pkg/providers/azureopenai/client_test.go rename to pkg/provider/azureopenai/client_test.go diff --git a/pkg/providers/azureopenai/config.go b/pkg/provider/azureopenai/config.go similarity index 100% rename from pkg/providers/azureopenai/config.go rename to pkg/provider/azureopenai/config.go diff --git a/pkg/providers/azureopenai/errors.go b/pkg/provider/azureopenai/errors.go similarity index 100% rename from pkg/providers/azureopenai/errors.go rename to pkg/provider/azureopenai/errors.go diff --git a/pkg/providers/azureopenai/schemas.go b/pkg/provider/azureopenai/schemas.go similarity index 100% rename from pkg/providers/azureopenai/schemas.go rename to pkg/provider/azureopenai/schemas.go diff --git a/pkg/providers/azureopenai/testdata/chat.req.json b/pkg/provider/azureopenai/testdata/chat.req.json similarity index 100% rename from pkg/providers/azureopenai/testdata/chat.req.json rename to pkg/provider/azureopenai/testdata/chat.req.json diff --git a/pkg/providers/azureopenai/testdata/chat.success.json b/pkg/provider/azureopenai/testdata/chat.success.json similarity index 100% rename from pkg/providers/azureopenai/testdata/chat.success.json rename to pkg/provider/azureopenai/testdata/chat.success.json diff --git a/pkg/providers/azureopenai/testdata/chat_stream.empty.txt b/pkg/provider/azureopenai/testdata/chat_stream.empty.txt similarity index 100% rename from pkg/providers/azureopenai/testdata/chat_stream.empty.txt rename to pkg/provider/azureopenai/testdata/chat_stream.empty.txt diff --git a/pkg/providers/azureopenai/testdata/chat_stream.nodone.txt b/pkg/provider/azureopenai/testdata/chat_stream.nodone.txt similarity index 100% rename from pkg/providers/azureopenai/testdata/chat_stream.nodone.txt rename to pkg/provider/azureopenai/testdata/chat_stream.nodone.txt diff --git a/pkg/providers/azureopenai/testdata/chat_stream.success.txt b/pkg/provider/azureopenai/testdata/chat_stream.success.txt similarity index 100% rename from pkg/providers/azureopenai/testdata/chat_stream.success.txt rename to pkg/provider/azureopenai/testdata/chat_stream.success.txt diff --git a/pkg/providers/bedrock/chat.go b/pkg/provider/bedrock/chat.go similarity index 100% rename from pkg/providers/bedrock/chat.go rename to pkg/provider/bedrock/chat.go diff --git a/pkg/providers/bedrock/chat_stream.go b/pkg/provider/bedrock/chat_stream.go similarity index 100% rename from pkg/providers/bedrock/chat_stream.go rename to pkg/provider/bedrock/chat_stream.go diff --git a/pkg/providers/bedrock/client.go b/pkg/provider/bedrock/client.go similarity index 100% rename from pkg/providers/bedrock/client.go rename to pkg/provider/bedrock/client.go diff --git a/pkg/providers/bedrock/client_test.go b/pkg/provider/bedrock/client_test.go similarity index 100% rename from pkg/providers/bedrock/client_test.go rename to pkg/provider/bedrock/client_test.go diff --git a/pkg/providers/bedrock/config.go b/pkg/provider/bedrock/config.go similarity index 100% rename from pkg/providers/bedrock/config.go rename to pkg/provider/bedrock/config.go diff --git a/pkg/providers/bedrock/schemas.go b/pkg/provider/bedrock/schemas.go similarity index 100% rename from pkg/providers/bedrock/schemas.go rename to pkg/provider/bedrock/schemas.go diff --git a/pkg/providers/bedrock/testdata/chat.req.json b/pkg/provider/bedrock/testdata/chat.req.json similarity index 100% rename from pkg/providers/bedrock/testdata/chat.req.json rename to pkg/provider/bedrock/testdata/chat.req.json diff --git a/pkg/providers/bedrock/testdata/chat.success.json b/pkg/provider/bedrock/testdata/chat.success.json similarity index 100% rename from pkg/providers/bedrock/testdata/chat.success.json rename to pkg/provider/bedrock/testdata/chat.success.json diff --git a/pkg/providers/cohere/chat.go b/pkg/provider/cohere/chat.go similarity index 100% rename from pkg/providers/cohere/chat.go rename to pkg/provider/cohere/chat.go diff --git a/pkg/providers/cohere/chat_stream.go b/pkg/provider/cohere/chat_stream.go similarity index 100% rename from pkg/providers/cohere/chat_stream.go rename to pkg/provider/cohere/chat_stream.go diff --git a/pkg/providers/cohere/chat_stream_test.go b/pkg/provider/cohere/chat_stream_test.go similarity index 100% rename from pkg/providers/cohere/chat_stream_test.go rename to pkg/provider/cohere/chat_stream_test.go diff --git a/pkg/providers/cohere/client.go b/pkg/provider/cohere/client.go similarity index 100% rename from pkg/providers/cohere/client.go rename to pkg/provider/cohere/client.go diff --git a/pkg/providers/cohere/client_test.go b/pkg/provider/cohere/client_test.go similarity index 100% rename from pkg/providers/cohere/client_test.go rename to pkg/provider/cohere/client_test.go diff --git a/pkg/providers/cohere/config.go b/pkg/provider/cohere/config.go similarity index 100% rename from pkg/providers/cohere/config.go rename to pkg/provider/cohere/config.go diff --git a/pkg/providers/cohere/errors.go b/pkg/provider/cohere/errors.go similarity index 100% rename from pkg/providers/cohere/errors.go rename to pkg/provider/cohere/errors.go diff --git a/pkg/providers/cohere/finish_reason.go b/pkg/provider/cohere/finish_reason.go similarity index 100% rename from pkg/providers/cohere/finish_reason.go rename to pkg/provider/cohere/finish_reason.go diff --git a/pkg/providers/cohere/schemas.go b/pkg/provider/cohere/schemas.go similarity index 100% rename from pkg/providers/cohere/schemas.go rename to pkg/provider/cohere/schemas.go diff --git a/pkg/providers/cohere/stream_reader.go b/pkg/provider/cohere/stream_reader.go similarity index 100% rename from pkg/providers/cohere/stream_reader.go rename to pkg/provider/cohere/stream_reader.go diff --git a/pkg/providers/cohere/testdata/chat.req.json b/pkg/provider/cohere/testdata/chat.req.json similarity index 100% rename from pkg/providers/cohere/testdata/chat.req.json rename to pkg/provider/cohere/testdata/chat.req.json diff --git a/pkg/providers/cohere/testdata/chat.success.json b/pkg/provider/cohere/testdata/chat.success.json similarity index 100% rename from pkg/providers/cohere/testdata/chat.success.json rename to pkg/provider/cohere/testdata/chat.success.json diff --git a/pkg/providers/cohere/testdata/chat_stream.interrupted.txt b/pkg/provider/cohere/testdata/chat_stream.interrupted.txt similarity index 100% rename from pkg/providers/cohere/testdata/chat_stream.interrupted.txt rename to pkg/provider/cohere/testdata/chat_stream.interrupted.txt diff --git a/pkg/providers/cohere/testdata/chat_stream.success.txt b/pkg/provider/cohere/testdata/chat_stream.success.txt similarity index 100% rename from pkg/providers/cohere/testdata/chat_stream.success.txt rename to pkg/provider/cohere/testdata/chat_stream.success.txt diff --git a/pkg/providers/config.go b/pkg/provider/config.go similarity index 96% rename from pkg/providers/config.go rename to pkg/provider/config.go index 6469710d..23633e0b 100644 --- a/pkg/providers/config.go +++ b/pkg/provider/config.go @@ -1,4 +1,4 @@ -package providers +package provider import ( "errors" @@ -27,7 +27,7 @@ type Configurer interface { ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) } -type Config map[ProviderID]interface{} +type Config map[ID]interface{} var _ Configurer = (*Config)(nil) @@ -69,7 +69,7 @@ func (p Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientC // validate ensure there is only one provider configured and it's supported by Glide func (p Config) validate() error { - configuredProviders := make([]ProviderID, 0, len(p)) + configuredProviders := make([]ID, 0, len(p)) for providerID, config := range p { if config != nil { diff --git a/pkg/providers/config_test.go b/pkg/provider/config_test.go similarity index 96% rename from pkg/providers/config_test.go rename to pkg/provider/config_test.go index 7e1d18c8..451f92ea 100644 --- a/pkg/providers/config_test.go +++ b/pkg/provider/config_test.go @@ -1,4 +1,4 @@ -package providers +package provider import ( "os" diff --git a/pkg/providers/config_test.yaml b/pkg/provider/config_test.yaml similarity index 100% rename from pkg/providers/config_test.yaml rename to pkg/provider/config_test.yaml diff --git a/pkg/providers/interface.go b/pkg/provider/interface.go similarity index 93% rename from pkg/providers/interface.go rename to pkg/provider/interface.go index 0b9fe45b..b2e4ffbd 100644 --- a/pkg/providers/interface.go +++ b/pkg/provider/interface.go @@ -1,4 +1,4 @@ -package providers +package provider import ( "context" @@ -10,11 +10,11 @@ import ( var ErrProviderNotFound = errors.New("provider not found") -type ProviderID = string +type ID = string // ModelProvider exposes provider context type ModelProvider interface { - Provider() ProviderID + Provider() ID ModelName() string } diff --git a/pkg/providers/octoml/chat.go b/pkg/provider/octoml/chat.go similarity index 98% rename from pkg/providers/octoml/chat.go rename to pkg/provider/octoml/chat.go index 92f20fbf..9b2237f3 100644 --- a/pkg/providers/octoml/chat.go +++ b/pkg/provider/octoml/chat.go @@ -8,7 +8,7 @@ import ( "io" "net/http" - "github.com/EinStack/glide/pkg/providers/openai" + "github.com/EinStack/glide/pkg/provider/openai" "github.com/EinStack/glide/pkg/api/schemas" diff --git a/pkg/providers/octoml/chat_stream.go b/pkg/provider/octoml/chat_stream.go similarity index 100% rename from pkg/providers/octoml/chat_stream.go rename to pkg/provider/octoml/chat_stream.go diff --git a/pkg/providers/octoml/client.go b/pkg/provider/octoml/client.go similarity index 100% rename from pkg/providers/octoml/client.go rename to pkg/provider/octoml/client.go diff --git a/pkg/providers/octoml/client_test.go b/pkg/provider/octoml/client_test.go similarity index 100% rename from pkg/providers/octoml/client_test.go rename to pkg/provider/octoml/client_test.go diff --git a/pkg/providers/octoml/config.go b/pkg/provider/octoml/config.go similarity index 100% rename from pkg/providers/octoml/config.go rename to pkg/provider/octoml/config.go diff --git a/pkg/providers/octoml/errors.go b/pkg/provider/octoml/errors.go similarity index 100% rename from pkg/providers/octoml/errors.go rename to pkg/provider/octoml/errors.go diff --git a/pkg/providers/octoml/testdata/chat.req.json b/pkg/provider/octoml/testdata/chat.req.json similarity index 100% rename from pkg/providers/octoml/testdata/chat.req.json rename to pkg/provider/octoml/testdata/chat.req.json diff --git a/pkg/providers/octoml/testdata/chat.success.json b/pkg/provider/octoml/testdata/chat.success.json similarity index 100% rename from pkg/providers/octoml/testdata/chat.success.json rename to pkg/provider/octoml/testdata/chat.success.json diff --git a/pkg/providers/ollama/chat.go b/pkg/provider/ollama/chat.go similarity index 100% rename from pkg/providers/ollama/chat.go rename to pkg/provider/ollama/chat.go diff --git a/pkg/providers/ollama/chat_stream.go b/pkg/provider/ollama/chat_stream.go similarity index 100% rename from pkg/providers/ollama/chat_stream.go rename to pkg/provider/ollama/chat_stream.go diff --git a/pkg/providers/ollama/client.go b/pkg/provider/ollama/client.go similarity index 100% rename from pkg/providers/ollama/client.go rename to pkg/provider/ollama/client.go diff --git a/pkg/providers/ollama/client_test.go b/pkg/provider/ollama/client_test.go similarity index 100% rename from pkg/providers/ollama/client_test.go rename to pkg/provider/ollama/client_test.go diff --git a/pkg/providers/ollama/config.go b/pkg/provider/ollama/config.go similarity index 100% rename from pkg/providers/ollama/config.go rename to pkg/provider/ollama/config.go diff --git a/pkg/providers/ollama/schemas.go b/pkg/provider/ollama/schemas.go similarity index 100% rename from pkg/providers/ollama/schemas.go rename to pkg/provider/ollama/schemas.go diff --git a/pkg/providers/ollama/testdata/chat.req.json b/pkg/provider/ollama/testdata/chat.req.json similarity index 100% rename from pkg/providers/ollama/testdata/chat.req.json rename to pkg/provider/ollama/testdata/chat.req.json diff --git a/pkg/providers/ollama/testdata/chat.success.json b/pkg/provider/ollama/testdata/chat.success.json similarity index 100% rename from pkg/providers/ollama/testdata/chat.success.json rename to pkg/provider/ollama/testdata/chat.success.json diff --git a/pkg/providers/openai/chat.go b/pkg/provider/openai/chat.go similarity index 100% rename from pkg/providers/openai/chat.go rename to pkg/provider/openai/chat.go diff --git a/pkg/providers/openai/chat_stream.go b/pkg/provider/openai/chat_stream.go similarity index 100% rename from pkg/providers/openai/chat_stream.go rename to pkg/provider/openai/chat_stream.go diff --git a/pkg/providers/openai/chat_stream_test.go b/pkg/provider/openai/chat_stream_test.go similarity index 100% rename from pkg/providers/openai/chat_stream_test.go rename to pkg/provider/openai/chat_stream_test.go diff --git a/pkg/providers/openai/chat_test.go b/pkg/provider/openai/chat_test.go similarity index 100% rename from pkg/providers/openai/chat_test.go rename to pkg/provider/openai/chat_test.go diff --git a/pkg/providers/openai/client.go b/pkg/provider/openai/client.go similarity index 100% rename from pkg/providers/openai/client.go rename to pkg/provider/openai/client.go diff --git a/pkg/providers/openai/config.go b/pkg/provider/openai/config.go similarity index 94% rename from pkg/providers/openai/config.go rename to pkg/provider/openai/config.go index fee9a589..8bf383d8 100644 --- a/pkg/providers/openai/config.go +++ b/pkg/provider/openai/config.go @@ -3,7 +3,7 @@ package openai import ( "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/config/fields" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/provider" "github.com/EinStack/glide/pkg/telemetry" ) @@ -52,7 +52,7 @@ type Config struct { DefaultParams *Params `yaml:"default_params,omitempty" json:"default_params"` } -var _ providers.Configurer = (*Config)(nil) +var _ provider.Configurer = (*Config)(nil) // DefaultConfig for OpenAI models func DefaultConfig() *Config { @@ -66,7 +66,7 @@ func DefaultConfig() *Config { } } -func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (providers.LangProvider, error) { +func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { return NewClient(c, clientConfig, tel) } diff --git a/pkg/providers/openai/embed.go b/pkg/provider/openai/embed.go similarity index 100% rename from pkg/providers/openai/embed.go rename to pkg/provider/openai/embed.go diff --git a/pkg/providers/openai/errors.go b/pkg/provider/openai/errors.go similarity index 100% rename from pkg/providers/openai/errors.go rename to pkg/provider/openai/errors.go diff --git a/pkg/providers/openai/finish_reasons.go b/pkg/provider/openai/finish_reasons.go similarity index 100% rename from pkg/providers/openai/finish_reasons.go rename to pkg/provider/openai/finish_reasons.go diff --git a/pkg/provider/openai/register.go b/pkg/provider/openai/register.go new file mode 100644 index 00000000..b79b77b1 --- /dev/null +++ b/pkg/provider/openai/register.go @@ -0,0 +1,7 @@ +package openai + +import "github.com/EinStack/glide/pkg/provider" + +func init() { + provider.LangRegistry.Register(ProviderID, &Config{}) +} diff --git a/pkg/providers/openai/schemas.go b/pkg/provider/openai/schemas.go similarity index 100% rename from pkg/providers/openai/schemas.go rename to pkg/provider/openai/schemas.go diff --git a/pkg/providers/openai/testdata/chat.req.json b/pkg/provider/openai/testdata/chat.req.json similarity index 100% rename from pkg/providers/openai/testdata/chat.req.json rename to pkg/provider/openai/testdata/chat.req.json diff --git a/pkg/providers/openai/testdata/chat.success.json b/pkg/provider/openai/testdata/chat.success.json similarity index 100% rename from pkg/providers/openai/testdata/chat.success.json rename to pkg/provider/openai/testdata/chat.success.json diff --git a/pkg/providers/openai/testdata/chat_stream.empty.txt b/pkg/provider/openai/testdata/chat_stream.empty.txt similarity index 100% rename from pkg/providers/openai/testdata/chat_stream.empty.txt rename to pkg/provider/openai/testdata/chat_stream.empty.txt diff --git a/pkg/providers/openai/testdata/chat_stream.nodone.txt b/pkg/provider/openai/testdata/chat_stream.nodone.txt similarity index 100% rename from pkg/providers/openai/testdata/chat_stream.nodone.txt rename to pkg/provider/openai/testdata/chat_stream.nodone.txt diff --git a/pkg/providers/openai/testdata/chat_stream.success.txt b/pkg/provider/openai/testdata/chat_stream.success.txt similarity index 100% rename from pkg/providers/openai/testdata/chat_stream.success.txt rename to pkg/provider/openai/testdata/chat_stream.success.txt diff --git a/pkg/provider/registry.go b/pkg/provider/registry.go new file mode 100644 index 00000000..8862ba88 --- /dev/null +++ b/pkg/provider/registry.go @@ -0,0 +1,41 @@ +package provider + +import ( + "fmt" +) + +var LangRegistry = NewRegistry() + +type Registry struct { + providers map[ID]Configurer +} + +func NewRegistry() *Registry { + return &Registry{ + providers: make(map[ID]Configurer), + } +} + +func (r *Registry) Register(name ID, config Configurer) { + if _, ok := r.Get(name); ok { + panic(fmt.Sprintf("provider %s is already registered", name)) + } + + r.providers[name] = config +} + +func (r *Registry) Get(name ID) (Configurer, bool) { + config, ok := r.providers[name] + + return config, ok +} + +func (r *Registry) Available() []ID { + available := make([]ID, 0, len(r.providers)) + + for providerID := range r.providers { + available = append(available, providerID) + } + + return available +} diff --git a/pkg/providers/testing.go b/pkg/provider/testing.go similarity index 83% rename from pkg/providers/testing.go rename to pkg/provider/testing.go index f4cda83b..ca9a00d7 100644 --- a/pkg/providers/testing.go +++ b/pkg/provider/testing.go @@ -1,4 +1,4 @@ -package providers +package provider import ( "context" @@ -22,7 +22,7 @@ type TestConfig struct { } func (c *TestConfig) ToClient(_ *telemetry.Telemetry, _ *clients.ClientConfig) (LangProvider, error) { - return NewProviderMock(nil, []RespMock{}), nil + return NewMock(nil, []RespMock{}), nil } func (c *TestConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { @@ -113,8 +113,8 @@ func (m *RespStreamMock) Close() error { return nil } -// ProviderMock mocks a model provider -type ProviderMock struct { +// Mock mocks a model provider +type Mock struct { idx int chatResps *[]RespMock chatStreams *[]RespStreamMock @@ -122,8 +122,8 @@ type ProviderMock struct { modelName *string } -func NewProviderMock(modelName *string, responses []RespMock) *ProviderMock { - return &ProviderMock{ +func NewMock(modelName *string, responses []RespMock) *Mock { + return &Mock{ idx: 0, chatResps: &responses, supportStreaming: false, @@ -131,8 +131,8 @@ func NewProviderMock(modelName *string, responses []RespMock) *ProviderMock { } } -func NewStreamProviderMock(modelName *string, chatStreams []RespStreamMock) *ProviderMock { - return &ProviderMock{ +func NewStreamProviderMock(modelName *string, chatStreams []RespStreamMock) *Mock { + return &Mock{ idx: 0, modelName: modelName, chatStreams: &chatStreams, @@ -140,11 +140,11 @@ func NewStreamProviderMock(modelName *string, chatStreams []RespStreamMock) *Pro } } -func (c *ProviderMock) SupportChatStream() bool { +func (c *Mock) SupportChatStream() bool { return c.supportStreaming } -func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Mock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) { if c.chatResps == nil { return nil, clients.ErrProviderUnavailable } @@ -161,7 +161,7 @@ func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas. return response.Resp(), nil } -func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Mock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { if c.chatStreams == nil || c.idx >= len(*c.chatStreams) { return nil, clients.ErrProviderUnavailable } @@ -174,11 +174,11 @@ func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (cli return &stream, nil } -func (c *ProviderMock) Provider() string { +func (c *Mock) Provider() string { return "provider_mock" } -func (c *ProviderMock) ModelName() string { +func (c *Mock) ModelName() string { if c.modelName == nil { return "model_mock" } diff --git a/pkg/providers/openai/register.go b/pkg/providers/openai/register.go deleted file mode 100644 index 4435ac8d..00000000 --- a/pkg/providers/openai/register.go +++ /dev/null @@ -1,9 +0,0 @@ -package openai - -import ( - "github.com/EinStack/glide/pkg/providers" -) - -func init() { - providers.LangRegistry.Register(ProviderID, &Config{}) -} diff --git a/pkg/providers/registry.go b/pkg/providers/registry.go deleted file mode 100644 index 8298ebfa..00000000 --- a/pkg/providers/registry.go +++ /dev/null @@ -1,41 +0,0 @@ -package providers - -import ( - "fmt" -) - -var LangRegistry = NewProviderRegistry() - -type ProviderRegistry struct { - providers map[ProviderID]Configurer -} - -func NewProviderRegistry() *ProviderRegistry { - return &ProviderRegistry{ - providers: make(map[ProviderID]Configurer), - } -} - -func (r *ProviderRegistry) Register(name ProviderID, config Configurer) { - if _, ok := r.Get(name); ok { - panic(fmt.Sprintf("provider %s is already registered", name)) - } - - r.providers[name] = config -} - -func (r *ProviderRegistry) Get(name ProviderID) (Configurer, bool) { - config, ok := r.providers[name] - - return config, ok -} - -func (r *ProviderRegistry) Available() []ProviderID { - available := make([]ProviderID, 0, len(r.providers)) - - for providerID := range r.providers { - available = append(available, providerID) - } - - return available -} diff --git a/pkg/routers/lang/config.go b/pkg/routers/lang/config.go index 8817ab5b..eb4a56c9 100644 --- a/pkg/routers/lang/config.go +++ b/pkg/routers/lang/config.go @@ -4,9 +4,9 @@ import ( "fmt" "time" - "github.com/EinStack/glide/pkg/extmodel" + "github.com/EinStack/glide/pkg/provider" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/extmodel" "github.com/EinStack/glide/pkg/routers" @@ -18,7 +18,7 @@ import ( ) type ( - ModelConfig = extmodel.Config[*providers.Config] + ModelConfig = extmodel.Config[*provider.Config] ModelPoolConfig = []ModelConfig ) diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang/config_test.go index 1b5975bc..3f36c707 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang/config_test.go @@ -3,10 +3,11 @@ package lang import ( "testing" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers" - "github.com/EinStack/glide/pkg/providers/cohere" - "github.com/EinStack/glide/pkg/providers/openai" + "github.com/EinStack/glide/pkg/provider/cohere" + "github.com/EinStack/glide/pkg/provider/openai" "github.com/EinStack/glide/pkg/resiliency/health" "github.com/EinStack/glide/pkg/routers/latency" "github.com/EinStack/glide/pkg/routers/routing" @@ -27,7 +28,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.Config{ + Provider: &provider.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -45,7 +46,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.Config{ + Provider: &provider.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -80,7 +81,7 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.Config{ + Provider: &provider.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &openAIParams, @@ -93,7 +94,7 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.Config{ + Provider: &provider.Config{ cohere.ProviderID: &cohere.Config{ APIKey: "ABC", DefaultParams: &cohereParams, @@ -129,7 +130,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.Config{ + Provider: &provider.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -147,7 +148,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.Config{ + Provider: &provider.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -170,7 +171,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.Config{ + Provider: &provider.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -183,7 +184,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.Config{ + Provider: &provider.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, diff --git a/pkg/routers/lang/router_test.go b/pkg/routers/lang/router_test.go index ce6c28dc..2b71928a 100644 --- a/pkg/routers/lang/router_test.go +++ b/pkg/routers/lang/router_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/provider" "github.com/EinStack/glide/pkg/extmodel" @@ -27,14 +27,14 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - providers.NewProviderMock(nil, []providers.RespMock{{Msg: "1"}, {Msg: "2"}}), + provider.NewMock(nil, []provider.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, ), extmodel.NewLangModel( "second", - providers.NewProviderMock(nil, []providers.RespMock{{Msg: "1"}}), + provider.NewMock(nil, []provider.RespMock{{Msg: "1"}}), budget, *latConfig, 1, @@ -73,21 +73,21 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "3"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "3"}}), budget, *latConfig, 1, ), extmodel.NewLangModel( "second", - providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "4"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "4"}}), budget, *latConfig, 1, ), extmodel.NewLangModel( "third", - providers.NewProviderMock(nil, []providers.RespMock{{Msg: "1"}, {Msg: "2"}}), + provider.NewMock(nil, []provider.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, @@ -130,14 +130,14 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "2"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "2"}}), budget, *latConfig, 1, ), extmodel.NewLangModel( "second", - providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "1"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "1"}}), budget, *latConfig, 1, @@ -173,14 +173,14 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - providers.NewProviderMock(nil, []providers.RespMock{{Err: clients.ErrProviderUnavailable}, {Msg: "3"}}), + provider.NewMock(nil, []provider.RespMock{{Err: clients.ErrProviderUnavailable}, {Msg: "3"}}), budget, *latConfig, 1, ), extmodel.NewLangModel( "second", - providers.NewProviderMock(nil, []providers.RespMock{{Msg: "1"}, {Msg: "2"}}), + provider.NewMock(nil, []provider.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, @@ -218,14 +218,14 @@ func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), budget, *latConfig, 1, ), extmodel.NewLangModel( "second", - providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), budget, *latConfig, 1, @@ -260,8 +260,8 @@ func TestLangRouter_ChatStream(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ - providers.NewRespStreamMock(&[]providers.RespMock{ + provider.NewStreamProviderMock(nil, []provider.RespStreamMock{ + provider.NewRespStreamMock(&[]provider.RespMock{ {Msg: "Bill"}, {Msg: "Gates"}, {Msg: "entered"}, @@ -275,8 +275,8 @@ func TestLangRouter_ChatStream(t *testing.T) { ), extmodel.NewLangModel( "second", - providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ - providers.NewRespStreamMock(&[]providers.RespMock{ + provider.NewStreamProviderMock(nil, []provider.RespStreamMock{ + provider.NewRespStreamMock(&[]provider.RespMock{ {Msg: "Knock"}, {Msg: "Knock"}, {Msg: "joke"}, @@ -335,16 +335,16 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - providers.NewStreamProviderMock(nil, nil), + provider.NewStreamProviderMock(nil, nil), budget, *latConfig, 1, ), extmodel.NewLangModel( "second", - providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ - providers.NewRespStreamMock( - &[]providers.RespMock{ + provider.NewStreamProviderMock(nil, []provider.RespStreamMock{ + provider.NewRespStreamMock( + &[]provider.RespMock{ {Msg: "Knock"}, {Msg: "knock"}, {Msg: "joke"}, @@ -404,8 +404,8 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ - providers.NewRespStreamMock(&[]providers.RespMock{ + provider.NewStreamProviderMock(nil, []provider.RespStreamMock{ + provider.NewRespStreamMock(&[]provider.RespMock{ {Err: clients.ErrProviderUnavailable}, }), }), @@ -415,8 +415,8 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { ), extmodel.NewLangModel( "second", - providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ - providers.NewRespStreamMock(&[]providers.RespMock{ + provider.NewStreamProviderMock(nil, []provider.RespStreamMock{ + provider.NewRespStreamMock(&[]provider.RespMock{ {Err: clients.ErrProviderUnavailable}, }), }), From 4aa7b2bbe1cc2c4b04fc72ba26f01e8fcbbd9a52 Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Thu, 8 Aug 2024 18:50:06 +0300 Subject: [PATCH 10/18] #67: Onboarded cohere to the new Configurer interface --- pkg/provider/cohere/config.go | 7 +++++++ pkg/provider/cohere/register.go | 7 +++++++ pkg/provider/openai/config.go | 5 ++++- 3 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 pkg/provider/cohere/register.go diff --git a/pkg/provider/cohere/config.go b/pkg/provider/cohere/config.go index 8e7b8b1d..aedcbffa 100644 --- a/pkg/provider/cohere/config.go +++ b/pkg/provider/cohere/config.go @@ -1,7 +1,10 @@ package cohere import ( + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/config/fields" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/telemetry" ) // Params defines Cohere-specific model params with the specific validation of values @@ -58,6 +61,10 @@ func DefaultConfig() *Config { } } +func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { + return NewClient(c, clientConfig, tel) +} + func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { *c = *DefaultConfig() diff --git a/pkg/provider/cohere/register.go b/pkg/provider/cohere/register.go new file mode 100644 index 00000000..3845e24a --- /dev/null +++ b/pkg/provider/cohere/register.go @@ -0,0 +1,7 @@ +package cohere + +import "github.com/EinStack/glide/pkg/provider" + +func init() { + provider.LangRegistry.Register(ProviderID, &Config{}) +} diff --git a/pkg/provider/openai/config.go b/pkg/provider/openai/config.go index 8bf383d8..c6500c1f 100644 --- a/pkg/provider/openai/config.go +++ b/pkg/provider/openai/config.go @@ -52,7 +52,10 @@ type Config struct { DefaultParams *Params `yaml:"default_params,omitempty" json:"default_params"` } -var _ provider.Configurer = (*Config)(nil) +// ensure interfaces +var ( + _ provider.Configurer = (*Config)(nil) +) // DefaultConfig for OpenAI models func DefaultConfig() *Config { From d1cb962f4a6ec823936cd66b6d5cfcd756b2fc07 Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Thu, 8 Aug 2024 18:55:02 +0300 Subject: [PATCH 11/18] #67: Onboarded anthropic to the new Configurer interface --- pkg/provider/anthropic/chat.go | 2 +- pkg/provider/anthropic/client.go | 4 ++-- pkg/provider/anthropic/config.go | 7 +++++++ pkg/provider/anthropic/register.go | 7 +++++++ 4 files changed, 17 insertions(+), 3 deletions(-) create mode 100644 pkg/provider/anthropic/register.go diff --git a/pkg/provider/anthropic/chat.go b/pkg/provider/anthropic/chat.go index a89515a8..c45efb76 100644 --- a/pkg/provider/anthropic/chat.go +++ b/pkg/provider/anthropic/chat.go @@ -133,7 +133,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche response := schemas.ChatResponse{ ID: anthropicResponse.ID, Created: int(time.Now().UTC().Unix()), // not provided by anthropic - Provider: providerName, + Provider: ProviderID, ModelName: anthropicResponse.Model, Cached: false, ModelResponse: schemas.ModelResponse{ diff --git a/pkg/provider/anthropic/client.go b/pkg/provider/anthropic/client.go index e42ccc31..ce697dbf 100644 --- a/pkg/provider/anthropic/client.go +++ b/pkg/provider/anthropic/client.go @@ -11,7 +11,7 @@ import ( ) const ( - providerName = "anthropic" + ProviderID = "anthropic" ) // Client is a client for accessing OpenAI API @@ -54,7 +54,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } func (c *Client) Provider() string { - return providerName + return ProviderID } func (c *Client) ModelName() string { diff --git a/pkg/provider/anthropic/config.go b/pkg/provider/anthropic/config.go index abdb5b73..1d252811 100644 --- a/pkg/provider/anthropic/config.go +++ b/pkg/provider/anthropic/config.go @@ -1,7 +1,10 @@ package anthropic import ( + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/config/fields" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/telemetry" ) // Params defines OpenAI-specific model params with the specific validation of values @@ -57,6 +60,10 @@ func DefaultConfig() *Config { } } +func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { + return NewClient(c, clientConfig, tel) +} + func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { *c = *DefaultConfig() diff --git a/pkg/provider/anthropic/register.go b/pkg/provider/anthropic/register.go new file mode 100644 index 00000000..9b00ffc6 --- /dev/null +++ b/pkg/provider/anthropic/register.go @@ -0,0 +1,7 @@ +package anthropic + +import "github.com/EinStack/glide/pkg/provider" + +func init() { + provider.LangRegistry.Register(ProviderID, &Config{}) +} From f872d721a221df7891da1abb22169e621d5b49ba Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 12 Aug 2024 21:43:25 +0300 Subject: [PATCH 12/18] #67: Moved lang & embed router into the routers package & ensured LangProvider interface --- pkg/api/http/handlers.go | 10 ++-- pkg/api/http/server.go | 6 +- pkg/api/servers.go | 4 +- pkg/config/config.go | 8 +-- pkg/gateway.go | 4 +- pkg/provider/anthropic/client.go | 7 +++ pkg/provider/azureopenai/client.go | 7 +++ pkg/provider/bedrock/client.go | 7 +++ pkg/provider/cohere/client.go | 7 +++ pkg/provider/octoml/client.go | 7 +++ pkg/provider/ollama/client.go | 7 +++ pkg/provider/openai/client.go | 7 +++ pkg/provider/testing.go | 5 ++ pkg/routers/config.go | 7 +++ pkg/routers/embed/config.go | 10 ---- pkg/routers/embed_config.go | 16 +++++ .../{embed/router.go => embed_router.go} | 2 +- .../{lang/config.go => lang_config.go} | 60 +++++++++---------- .../config_test.go => lang_config_test.go} | 26 ++++---- .../{lang/router.go => lang_router.go} | 16 ++--- .../router_test.go => lang_router_test.go} | 18 +++--- pkg/routers/{manager => }/manager.go | 17 +++--- pkg/routers/manager/config.go | 9 --- 23 files changed, 161 insertions(+), 106 deletions(-) delete mode 100644 pkg/routers/embed/config.go create mode 100644 pkg/routers/embed_config.go rename pkg/routers/{embed/router.go => embed_router.go} (94%) rename pkg/routers/{lang/config.go => lang_config.go} (80%) rename pkg/routers/{lang/config_test.go => lang_config_test.go} (92%) rename pkg/routers/{lang/router.go => lang_router.go} (93%) rename pkg/routers/{lang/router_test.go => lang_router_test.go} (98%) rename pkg/routers/{manager => }/manager.go (64%) delete mode 100644 pkg/routers/manager/config.go diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index cc2ac3d3..58c727ce 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -4,7 +4,7 @@ import ( "context" "sync" - "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/routers" "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" @@ -32,7 +32,7 @@ type Handler = func(c *fiber.Ctx) error // @Failure 400 {object} schemas.Error // @Failure 404 {object} schemas.Error // @Router /v1/language/{router}/chat [POST] -func LangChatHandler(routerManager *manager.RouterManager) Handler { +func LangChatHandler(routerManager *routers.RouterManager) Handler { return func(c *fiber.Ctx) error { if !c.Is("json") { return c.Status(fiber.StatusBadRequest).JSON(schemas.ErrUnsupportedMediaType) @@ -73,7 +73,7 @@ func LangChatHandler(routerManager *manager.RouterManager) Handler { } } -func LangStreamRouterValidator(routerManager *manager.RouterManager) Handler { +func LangStreamRouterValidator(routerManager *routers.RouterManager) Handler { return func(c *fiber.Ctx) error { if websocket.IsWebSocketUpgrade(c) { routerID := c.Params("router") @@ -108,7 +108,7 @@ func LangStreamRouterValidator(routerManager *manager.RouterManager) Handler { // @Failure 426 // @Failure 404 {object} schemas.Error // @Router /v1/language/{router}/chatStream [GET] -func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *manager.RouterManager) Handler { +func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.RouterManager) Handler { // TODO: expose websocket connection configs https://github.com/gofiber/contrib/tree/main/websocket return websocket.New(func(c *websocket.Conn) { routerID := c.Params("router") @@ -176,7 +176,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *manager.Rout // @Produce json // @Success 200 {object} schemas.RouterListSchema // @Router /v1/language/ [GET] -func LangRoutersHandler(routerManager *manager.RouterManager) Handler { +func LangRoutersHandler(routerManager *routers.RouterManager) Handler { return func(c *fiber.Ctx) error { configuredRouters := routerManager.GetLangRouters() cfgs := make([]interface{}, 0, len(configuredRouters)) // opaque by design diff --git a/pkg/api/http/server.go b/pkg/api/http/server.go index 35899963..1242830a 100644 --- a/pkg/api/http/server.go +++ b/pkg/api/http/server.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/routers" "github.com/gofiber/contrib/otelfiber" @@ -25,11 +25,11 @@ import ( type Server struct { config *ServerConfig telemetry *telemetry.Telemetry - routerManager *manager.RouterManager + routerManager *routers.RouterManager server *fiber.App } -func NewServer(config *ServerConfig, tel *telemetry.Telemetry, routerManager *manager.RouterManager) (*Server, error) { +func NewServer(config *ServerConfig, tel *telemetry.Telemetry, routerManager *routers.RouterManager) (*Server, error) { srv := config.ToServer() return &Server{ diff --git a/pkg/api/servers.go b/pkg/api/servers.go index da2d130a..4ce8b37d 100644 --- a/pkg/api/servers.go +++ b/pkg/api/servers.go @@ -4,7 +4,7 @@ import ( "context" "sync" - "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/routers" "go.uber.org/zap" @@ -19,7 +19,7 @@ type ServerManager struct { telemetry *telemetry.Telemetry } -func NewServerManager(cfg *Config, tel *telemetry.Telemetry, router *manager.RouterManager) (*ServerManager, error) { +func NewServerManager(cfg *Config, tel *telemetry.Telemetry, router *routers.RouterManager) (*ServerManager, error) { httpServer, err := http.NewServer(cfg.HTTP, tel, router) if err != nil { return nil, err diff --git a/pkg/config/config.go b/pkg/config/config.go index cacdc2a9..cd99540e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -2,15 +2,15 @@ package config import ( "github.com/EinStack/glide/pkg/api" - routerconfig "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/routers" "github.com/EinStack/glide/pkg/telemetry" ) // Config is a general top-level Glide configuration type Config struct { - Telemetry *telemetry.Config `yaml:"telemetry" validate:"required"` - API *api.Config `yaml:"api" validate:"required"` - Routers routerconfig.Config `yaml:"routers" validate:"required"` + Telemetry *telemetry.Config `yaml:"telemetry" validate:"required"` + API *api.Config `yaml:"api" validate:"required"` + Routers routers.RoutersConfig `yaml:"routers" validate:"required"` } func DefaultConfig() *Config { diff --git a/pkg/gateway.go b/pkg/gateway.go index a3c8969d..b3ce904f 100644 --- a/pkg/gateway.go +++ b/pkg/gateway.go @@ -7,7 +7,7 @@ import ( "os/signal" "syscall" - "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/routers" "github.com/EinStack/glide/pkg/version" "go.opentelemetry.io/contrib/instrumentation/host" @@ -50,7 +50,7 @@ func NewGateway(configProvider *config.Provider) (*Gateway, error) { tel.L().Info("🐦Glide is starting up", zap.String("version", version.FullVersion)) tel.L().Debug("✅ Config loaded successfully:\n" + configProvider.GetStr()) - routerManager, err := manager.NewManager(&cfg.Routers, tel) + routerManager, err := routers.NewManager(&cfg.Routers, tel) if err != nil { return nil, err } diff --git a/pkg/provider/anthropic/client.go b/pkg/provider/anthropic/client.go index ce697dbf..2e08b2e2 100644 --- a/pkg/provider/anthropic/client.go +++ b/pkg/provider/anthropic/client.go @@ -5,6 +5,8 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -26,6 +28,11 @@ type Client struct { tel *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) diff --git a/pkg/provider/azureopenai/client.go b/pkg/provider/azureopenai/client.go index 5c34a154..6ec90469 100644 --- a/pkg/provider/azureopenai/client.go +++ b/pkg/provider/azureopenai/client.go @@ -5,6 +5,8 @@ import ( "net/http" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/provider/openai" @@ -28,6 +30,11 @@ type Client struct { tel *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new Azure OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL := fmt.Sprintf( diff --git a/pkg/provider/bedrock/client.go b/pkg/provider/bedrock/client.go index 673cb49f..aa3905fd 100644 --- a/pkg/provider/bedrock/client.go +++ b/pkg/provider/bedrock/client.go @@ -7,6 +7,8 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -36,6 +38,11 @@ type Client struct { telemetry *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint, providerConfig.ModelName, "/invoke") diff --git a/pkg/provider/cohere/client.go b/pkg/provider/cohere/client.go index 3393e010..c3e43eb0 100644 --- a/pkg/provider/cohere/client.go +++ b/pkg/provider/cohere/client.go @@ -5,6 +5,8 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -26,6 +28,11 @@ type Client struct { tel *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new Cohere client for the Cohere API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) diff --git a/pkg/provider/octoml/client.go b/pkg/provider/octoml/client.go index 420a991a..30ab7794 100644 --- a/pkg/provider/octoml/client.go +++ b/pkg/provider/octoml/client.go @@ -6,6 +6,8 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -31,6 +33,11 @@ type Client struct { telemetry *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OctoML client for the OctoML API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) diff --git a/pkg/provider/ollama/client.go b/pkg/provider/ollama/client.go index 85192b6b..df624cd5 100644 --- a/pkg/provider/ollama/client.go +++ b/pkg/provider/ollama/client.go @@ -5,6 +5,8 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -24,6 +26,11 @@ type Client struct { telemetry *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) diff --git a/pkg/provider/openai/client.go b/pkg/provider/openai/client.go index 30a04385..8567e26c 100644 --- a/pkg/provider/openai/client.go +++ b/pkg/provider/openai/client.go @@ -5,6 +5,8 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "go.uber.org/zap" @@ -29,6 +31,11 @@ type Client struct { logger *zap.Logger } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) diff --git a/pkg/provider/testing.go b/pkg/provider/testing.go index ca9a00d7..72133349 100644 --- a/pkg/provider/testing.go +++ b/pkg/provider/testing.go @@ -122,6 +122,11 @@ type Mock struct { modelName *string } +// ensure interfaces +var ( + _ LangProvider = (*Mock)(nil) +) + func NewMock(modelName *string, responses []RespMock) *Mock { return &Mock{ idx: 0, diff --git a/pkg/routers/config.go b/pkg/routers/config.go index a3c8f69a..6d4610ab 100644 --- a/pkg/routers/config.go +++ b/pkg/routers/config.go @@ -22,3 +22,10 @@ func DefaultConfig() RouterConfig { Retry: retry.DefaultExpRetryConfig(), } } + +// RoutersConfig defines a config for a set of supported router types +// TODO: remove nolint after renaming the package +type RoutersConfig struct { //nolint: revive + LanguageRouters LangRoutersConfig `yaml:"language" validate:"required,dive"` // the list of language routers + // EmbeddingRouters []EmbeddingRouterConfig `yaml:"embedding" validate:"required,dive"` +} diff --git a/pkg/routers/embed/config.go b/pkg/routers/embed/config.go deleted file mode 100644 index 49f4821b..00000000 --- a/pkg/routers/embed/config.go +++ /dev/null @@ -1,10 +0,0 @@ -package embed - -import ( - "github.com/EinStack/glide/pkg/routers" -) - -type EmbeddingRouterConfig struct { - routers.RouterConfig - // Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests -} diff --git a/pkg/routers/embed_config.go b/pkg/routers/embed_config.go new file mode 100644 index 00000000..a93e937b --- /dev/null +++ b/pkg/routers/embed_config.go @@ -0,0 +1,16 @@ +package routers + +import ( + "github.com/EinStack/glide/pkg/extmodel" + "github.com/EinStack/glide/pkg/provider" +) + +type ( + EmbedModelConfig = extmodel.Config[*provider.Config] + EmbedModelPoolConfig = []EmbedModelConfig +) + +type EmbeddingRouterConfig struct { + RouterConfig + Models EmbedModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests +} diff --git a/pkg/routers/embed/router.go b/pkg/routers/embed_router.go similarity index 94% rename from pkg/routers/embed/router.go rename to pkg/routers/embed_router.go index 9068537b..ef81d59a 100644 --- a/pkg/routers/embed/router.go +++ b/pkg/routers/embed_router.go @@ -1,4 +1,4 @@ -package embed +package routers type EmbeddingRouter struct { // routerID lang.RouterID diff --git a/pkg/routers/lang/config.go b/pkg/routers/lang_config.go similarity index 80% rename from pkg/routers/lang/config.go rename to pkg/routers/lang_config.go index eb4a56c9..9a7b685b 100644 --- a/pkg/routers/lang/config.go +++ b/pkg/routers/lang_config.go @@ -1,4 +1,4 @@ -package lang +package routers import ( "fmt" @@ -8,8 +8,6 @@ import ( "github.com/EinStack/glide/pkg/extmodel" - "github.com/EinStack/glide/pkg/routers" - "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/routers/routing" "github.com/EinStack/glide/pkg/telemetry" @@ -18,40 +16,40 @@ import ( ) type ( - ModelConfig = extmodel.Config[*provider.Config] - ModelPoolConfig = []ModelConfig + LangModelConfig = extmodel.Config[*provider.Config] + LangModelPoolConfig = []LangModelConfig ) -// RouterConfig -type RouterConfig struct { - routers.RouterConfig - Models ModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests +// LangRouterConfig +type LangRouterConfig struct { + RouterConfig + Models LangModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests } -type RouterConfigOption = func(*RouterConfig) +type RouterConfigOption = func(*LangRouterConfig) -func WithModels(models ModelPoolConfig) RouterConfigOption { - return func(c *RouterConfig) { +func WithModels(models LangModelPoolConfig) RouterConfigOption { + return func(c *LangRouterConfig) { c.Models = models } } -func NewRouterConfig(RouterID string, opt ...RouterConfigOption) *RouterConfig { - config := &RouterConfig{ - RouterConfig: routers.DefaultConfig(), +func NewRouterConfig(RouterID string, opt ...RouterConfigOption) *LangRouterConfig { + cfg := &LangRouterConfig{ + RouterConfig: DefaultConfig(), } - config.ID = RouterID + cfg.ID = RouterID for _, o := range opt { - o(config) + o(cfg) } - return config + return cfg } // BuildModels creates LanguageModel slice out of the given config -func (c *RouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*extmodel.LanguageModel, []*extmodel.LanguageModel, error) { //nolint: cyclop +func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*extmodel.LanguageModel, []*extmodel.LanguageModel, error) { //nolint: cyclop var errs error seenIDs := make(map[string]bool, len(c.Models)) @@ -147,7 +145,7 @@ func (c *RouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*extmodel.Langua return chatModels, chatStreamModels, nil } -func (c *RouterConfig) BuildRetry() *retry.ExpRetry { +func (c *LangRouterConfig) BuildRetry() *retry.ExpRetry { retryConfig := c.Retry maxDelay := time.Duration(*retryConfig.MaxDelay) @@ -159,7 +157,7 @@ func (c *RouterConfig) BuildRetry() *retry.ExpRetry { ) } -func (c *RouterConfig) BuildRouting( +func (c *LangRouterConfig) BuildRouting( chatModels []*extmodel.LanguageModel, chatStreamModels []*extmodel.LanguageModel, ) (routing.LangModelRouting, routing.LangModelRouting, error) { @@ -190,25 +188,25 @@ func (c *RouterConfig) BuildRouting( return nil, nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy) } -func DefaultRouterConfig() *RouterConfig { - return &RouterConfig{ - RouterConfig: routers.DefaultConfig(), +func DefaultRouterConfig() *LangRouterConfig { + return &LangRouterConfig{ + RouterConfig: DefaultConfig(), } } -func (c *RouterConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - *c = *DefaultRouterConfig() +func (c LangRouterConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + c = *DefaultRouterConfig() - type plain RouterConfig // to avoid recursion + type plain LangRouterConfig // to avoid recursion - return unmarshal((*plain)(c)) + return unmarshal((plain)(c)) } -type RoutersConfig []RouterConfig +type LangRoutersConfig []LangRouterConfig -func (c RoutersConfig) Build(tel *telemetry.Telemetry) ([]*Router, error) { +func (c LangRoutersConfig) Build(tel *telemetry.Telemetry) ([]*LangRouter, error) { seenIDs := make(map[string]bool, len(c)) - langRouters := make([]*Router, 0, len(c)) + langRouters := make([]*LangRouter, 0, len(c)) var errs error diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang_config_test.go similarity index 92% rename from pkg/routers/lang/config_test.go rename to pkg/routers/lang_config_test.go index 3f36c707..81998f53 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang_config_test.go @@ -1,4 +1,4 @@ -package lang +package routers import ( "testing" @@ -18,10 +18,10 @@ import ( func TestRouterConfig_BuildModels(t *testing.T) { defaultParams := openai.DefaultParams() - cfg := RoutersConfig{ + cfg := LangRoutersConfig{ *NewRouterConfig( "first_router", - WithModels(ModelPoolConfig{ + WithModels(LangModelPoolConfig{ { ID: "first_model", Enabled: true, @@ -39,7 +39,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { ), *NewRouterConfig( "second_router", - WithModels(ModelPoolConfig{ + WithModels(LangModelPoolConfig{ { ID: "first_model", Enabled: true, @@ -74,7 +74,7 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { cfg := NewRouterConfig( "first_router", - WithModels(ModelPoolConfig{ + WithModels(LangModelPoolConfig{ { ID: "first_model", Enabled: true, @@ -116,14 +116,14 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { tests := []struct { name string - config RoutersConfig + config LangRoutersConfig }{ { "duplicated router IDs", - RoutersConfig{ + LangRoutersConfig{ *NewRouterConfig( "first_router", - WithModels(ModelPoolConfig{ + WithModels(LangModelPoolConfig{ { ID: "first_model", Enabled: true, @@ -141,7 +141,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { ), *NewRouterConfig( "first_router", - WithModels(ModelPoolConfig{ + WithModels(LangModelPoolConfig{ { ID: "first_model", Enabled: true, @@ -161,10 +161,10 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { }, { "duplicated model IDs", - RoutersConfig{ + LangRoutersConfig{ *NewRouterConfig( "first_router", - WithModels(ModelPoolConfig{ + WithModels(LangModelPoolConfig{ { ID: "first_model", Enabled: true, @@ -197,10 +197,10 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { }, { "no models", - RoutersConfig{ + LangRoutersConfig{ *NewRouterConfig( "first_router", - WithModels(ModelPoolConfig{}), + WithModels(LangModelPoolConfig{}), ), }, }, diff --git a/pkg/routers/lang/router.go b/pkg/routers/lang_router.go similarity index 93% rename from pkg/routers/lang/router.go rename to pkg/routers/lang_router.go index 9dcc0c33..fbae5f2f 100644 --- a/pkg/routers/lang/router.go +++ b/pkg/routers/lang_router.go @@ -1,4 +1,4 @@ -package lang +package routers import ( "context" @@ -17,9 +17,9 @@ var ErrNoModels = errors.New("no models configured for router") type RouterID = string -type Router struct { +type LangRouter struct { routerID RouterID - Config *RouterConfig + Config *LangRouterConfig chatModels []*extmodel.LanguageModel chatStreamModels []*extmodel.LanguageModel chatRouting routing.LangModelRouting @@ -29,7 +29,7 @@ type Router struct { logger *zap.Logger } -func NewLangRouter(cfg *RouterConfig, tel *telemetry.Telemetry) (*Router, error) { +func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter, error) { chatModels, chatStreamModels, err := cfg.BuildModels(tel) if err != nil { return nil, err @@ -40,7 +40,7 @@ func NewLangRouter(cfg *RouterConfig, tel *telemetry.Telemetry) (*Router, error) return nil, err } - router := &Router{ + router := &LangRouter{ routerID: cfg.ID, Config: cfg, chatModels: chatModels, @@ -55,11 +55,11 @@ func NewLangRouter(cfg *RouterConfig, tel *telemetry.Telemetry) (*Router, error) return router, err } -func (r *Router) ID() RouterID { +func (r *LangRouter) ID() RouterID { return r.routerID } -func (r *Router) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) { +func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) { if len(r.chatModels) == 0 { return nil, ErrNoModels } @@ -115,7 +115,7 @@ func (r *Router) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.C return nil, &schemas.ErrNoModelAvailable } -func (r *Router) ChatStream( +func (r *LangRouter) ChatStream( ctx context.Context, req *schemas.ChatStreamRequest, respC chan<- *schemas.ChatStreamMessage, diff --git a/pkg/routers/lang/router_test.go b/pkg/routers/lang_router_test.go similarity index 98% rename from pkg/routers/lang/router_test.go rename to pkg/routers/lang_router_test.go index 2b71928a..671b0759 100644 --- a/pkg/routers/lang/router_test.go +++ b/pkg/routers/lang_router_test.go @@ -1,4 +1,4 @@ -package lang +package routers import ( "context" @@ -46,7 +46,7 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_router", retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), chatRouting: routing.NewPriority(modelPool), @@ -101,7 +101,7 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { expectedModels := []string{"third", "third"} - router := Router{ + router := LangRouter{ routerID: "test_router", retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), chatRouting: routing.NewPriority(modelPool), @@ -149,7 +149,7 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_router", retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), chatRouting: routing.NewPriority(modelPool), @@ -192,7 +192,7 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_router", retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), chatRouting: routing.NewPriority(modelPool), @@ -237,7 +237,7 @@ func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_router", retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), chatRouting: routing.NewPriority(modelPool), @@ -293,7 +293,7 @@ func TestLangRouter_ChatStream(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_stream_router", retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), chatRouting: routing.NewPriority(modelPool), @@ -362,7 +362,7 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_stream_router", retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), chatRouting: routing.NewPriority(modelPool), @@ -431,7 +431,7 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_router", retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), chatRouting: routing.NewPriority(modelPool), diff --git a/pkg/routers/manager/manager.go b/pkg/routers/manager.go similarity index 64% rename from pkg/routers/manager/manager.go rename to pkg/routers/manager.go index add72012..f719d091 100644 --- a/pkg/routers/manager/manager.go +++ b/pkg/routers/manager.go @@ -1,26 +1,25 @@ -package manager +package routers import ( "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/routers/lang" "github.com/EinStack/glide/pkg/telemetry" ) type RouterManager struct { - Config *Config + Config *RoutersConfig tel *telemetry.Telemetry - langRouterMap *map[string]*lang.Router - langRouters []*lang.Router + langRouterMap *map[string]*LangRouter + langRouters []*LangRouter } // NewManager creates a new instance of Router Manager that creates, holds and returns all routers -func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { +func NewManager(cfg *RoutersConfig, tel *telemetry.Telemetry) (*RouterManager, error) { langRouters, err := cfg.LanguageRouters.Build(tel) if err != nil { return nil, err } - langRouterMap := make(map[string]*lang.Router, len(langRouters)) + langRouterMap := make(map[string]*LangRouter, len(langRouters)) for _, router := range langRouters { langRouterMap[router.ID()] = router @@ -36,12 +35,12 @@ func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { return &manager, err } -func (r *RouterManager) GetLangRouters() []*lang.Router { +func (r *RouterManager) GetLangRouters() []*LangRouter { return r.langRouters } // GetLangRouter returns a router by type and ID -func (r *RouterManager) GetLangRouter(routerID string) (*lang.Router, error) { +func (r *RouterManager) GetLangRouter(routerID string) (*LangRouter, error) { if router, found := (*r.langRouterMap)[routerID]; found { return router, nil } diff --git a/pkg/routers/manager/config.go b/pkg/routers/manager/config.go deleted file mode 100644 index aaaeac09..00000000 --- a/pkg/routers/manager/config.go +++ /dev/null @@ -1,9 +0,0 @@ -package manager - -import "github.com/EinStack/glide/pkg/routers/lang" - -// Config defines a config for a set of supported router types -type Config struct { - LanguageRouters lang.RoutersConfig `yaml:"language" validate:"required,dive"` // the list of language routers - // EmbeddingRouters []EmbeddingRouterConfig `yaml:"embedding" validate:"required,dive"` -} From f55b24e1ee03b1082f76f9a113cae8d2d8317e3c Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 12 Aug 2024 21:53:25 +0300 Subject: [PATCH 13/18] #67: Renamed routers to router package --- pkg/api/http/handlers.go | 10 +++++----- pkg/api/http/server.go | 6 +++--- pkg/api/servers.go | 4 ++-- pkg/config/config.go | 8 ++++---- pkg/extmodel/config.go | 2 +- pkg/extmodel/lang.go | 2 +- pkg/extmodel/testing.go | 2 +- pkg/gateway.go | 4 ++-- pkg/{routers => router}/config.go | 13 ++++++------- pkg/{routers => router}/embed_config.go | 4 ++-- pkg/{routers => router}/embed_router.go | 2 +- pkg/{routers => router}/lang_config.go | 16 ++++++++-------- pkg/{routers => router}/lang_config_test.go | 6 +++--- pkg/{routers => router}/lang_router.go | 10 +++++----- pkg/{routers => router}/lang_router_test.go | 6 +++--- pkg/{routers => router}/latency/config.go | 0 pkg/{routers => router}/latency/config_test.go | 0 .../latency/moving_average.go | 0 .../latency/moving_average_test.go | 0 pkg/{routers => router}/manager.go | 12 ++++++------ pkg/{routers => router}/routing/least_latency.go | 2 +- .../routing/least_latency_test.go | 0 pkg/{routers => router}/routing/priority.go | 0 pkg/{routers => router}/routing/priority_test.go | 0 pkg/{routers => router}/routing/round_robin.go | 0 .../routing/round_robin_test.go | 0 pkg/{routers => router}/routing/strategies.go | 0 .../routing/weighted_round_robin.go | 0 .../routing/weighted_round_robin_test.go | 0 29 files changed, 54 insertions(+), 55 deletions(-) rename pkg/{routers => router}/config.go (83%) rename pkg/{routers => router}/embed_config.go (92%) rename pkg/{routers => router}/embed_router.go (94%) rename pkg/{routers => router}/lang_config.go (94%) rename pkg/{routers => router}/lang_config_test.go (97%) rename pkg/{routers => router}/lang_router.go (97%) rename pkg/{routers => router}/lang_router_test.go (99%) rename pkg/{routers => router}/latency/config.go (100%) rename pkg/{routers => router}/latency/config_test.go (100%) rename pkg/{routers => router}/latency/moving_average.go (100%) rename pkg/{routers => router}/latency/moving_average_test.go (100%) rename pkg/{routers => router}/manager.go (80%) rename pkg/{routers => router}/routing/least_latency.go (98%) rename pkg/{routers => router}/routing/least_latency_test.go (100%) rename pkg/{routers => router}/routing/priority.go (100%) rename pkg/{routers => router}/routing/priority_test.go (100%) rename pkg/{routers => router}/routing/round_robin.go (100%) rename pkg/{routers => router}/routing/round_robin_test.go (100%) rename pkg/{routers => router}/routing/strategies.go (100%) rename pkg/{routers => router}/routing/weighted_round_robin.go (100%) rename pkg/{routers => router}/routing/weighted_round_robin_test.go (100%) diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index 58c727ce..374516b6 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -4,7 +4,7 @@ import ( "context" "sync" - "github.com/EinStack/glide/pkg/routers" + "github.com/EinStack/glide/pkg/router" "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" @@ -32,7 +32,7 @@ type Handler = func(c *fiber.Ctx) error // @Failure 400 {object} schemas.Error // @Failure 404 {object} schemas.Error // @Router /v1/language/{router}/chat [POST] -func LangChatHandler(routerManager *routers.RouterManager) Handler { +func LangChatHandler(routerManager *router.Manager) Handler { return func(c *fiber.Ctx) error { if !c.Is("json") { return c.Status(fiber.StatusBadRequest).JSON(schemas.ErrUnsupportedMediaType) @@ -73,7 +73,7 @@ func LangChatHandler(routerManager *routers.RouterManager) Handler { } } -func LangStreamRouterValidator(routerManager *routers.RouterManager) Handler { +func LangStreamRouterValidator(routerManager *router.Manager) Handler { return func(c *fiber.Ctx) error { if websocket.IsWebSocketUpgrade(c) { routerID := c.Params("router") @@ -108,7 +108,7 @@ func LangStreamRouterValidator(routerManager *routers.RouterManager) Handler { // @Failure 426 // @Failure 404 {object} schemas.Error // @Router /v1/language/{router}/chatStream [GET] -func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.RouterManager) Handler { +func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *router.Manager) Handler { // TODO: expose websocket connection configs https://github.com/gofiber/contrib/tree/main/websocket return websocket.New(func(c *websocket.Conn) { routerID := c.Params("router") @@ -176,7 +176,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout // @Produce json // @Success 200 {object} schemas.RouterListSchema // @Router /v1/language/ [GET] -func LangRoutersHandler(routerManager *routers.RouterManager) Handler { +func LangRoutersHandler(routerManager *router.Manager) Handler { return func(c *fiber.Ctx) error { configuredRouters := routerManager.GetLangRouters() cfgs := make([]interface{}, 0, len(configuredRouters)) // opaque by design diff --git a/pkg/api/http/server.go b/pkg/api/http/server.go index 1242830a..6422623a 100644 --- a/pkg/api/http/server.go +++ b/pkg/api/http/server.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/EinStack/glide/pkg/routers" + "github.com/EinStack/glide/pkg/router" "github.com/gofiber/contrib/otelfiber" @@ -25,11 +25,11 @@ import ( type Server struct { config *ServerConfig telemetry *telemetry.Telemetry - routerManager *routers.RouterManager + routerManager *router.Manager server *fiber.App } -func NewServer(config *ServerConfig, tel *telemetry.Telemetry, routerManager *routers.RouterManager) (*Server, error) { +func NewServer(config *ServerConfig, tel *telemetry.Telemetry, routerManager *router.Manager) (*Server, error) { srv := config.ToServer() return &Server{ diff --git a/pkg/api/servers.go b/pkg/api/servers.go index 4ce8b37d..fd0a281e 100644 --- a/pkg/api/servers.go +++ b/pkg/api/servers.go @@ -4,7 +4,7 @@ import ( "context" "sync" - "github.com/EinStack/glide/pkg/routers" + "github.com/EinStack/glide/pkg/router" "go.uber.org/zap" @@ -19,7 +19,7 @@ type ServerManager struct { telemetry *telemetry.Telemetry } -func NewServerManager(cfg *Config, tel *telemetry.Telemetry, router *routers.RouterManager) (*ServerManager, error) { +func NewServerManager(cfg *Config, tel *telemetry.Telemetry, router *router.Manager) (*ServerManager, error) { httpServer, err := http.NewServer(cfg.HTTP, tel, router) if err != nil { return nil, err diff --git a/pkg/config/config.go b/pkg/config/config.go index cd99540e..9f390a45 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -2,15 +2,15 @@ package config import ( "github.com/EinStack/glide/pkg/api" - "github.com/EinStack/glide/pkg/routers" + "github.com/EinStack/glide/pkg/router" "github.com/EinStack/glide/pkg/telemetry" ) // Config is a general top-level Glide configuration type Config struct { - Telemetry *telemetry.Config `yaml:"telemetry" validate:"required"` - API *api.Config `yaml:"api" validate:"required"` - Routers routers.RoutersConfig `yaml:"routers" validate:"required"` + Telemetry *telemetry.Config `yaml:"telemetry" validate:"required"` + API *api.Config `yaml:"api" validate:"required"` + Routers router.RoutersConfig `yaml:"routers" validate:"required"` } func DefaultConfig() *Config { diff --git a/pkg/extmodel/config.go b/pkg/extmodel/config.go index 5c1bfa91..3edd45f0 100644 --- a/pkg/extmodel/config.go +++ b/pkg/extmodel/config.go @@ -7,7 +7,7 @@ import ( "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/resiliency/health" - "github.com/EinStack/glide/pkg/routers/latency" + "github.com/EinStack/glide/pkg/router/latency" "github.com/EinStack/glide/pkg/telemetry" ) diff --git a/pkg/extmodel/lang.go b/pkg/extmodel/lang.go index 6b345112..0c29870b 100644 --- a/pkg/extmodel/lang.go +++ b/pkg/extmodel/lang.go @@ -12,7 +12,7 @@ import ( "github.com/EinStack/glide/pkg/config/fields" - "github.com/EinStack/glide/pkg/routers/latency" + "github.com/EinStack/glide/pkg/router/latency" "github.com/EinStack/glide/pkg/api/schemas" ) diff --git a/pkg/extmodel/testing.go b/pkg/extmodel/testing.go index 6d51ca79..86829610 100644 --- a/pkg/extmodel/testing.go +++ b/pkg/extmodel/testing.go @@ -4,7 +4,7 @@ import ( "time" "github.com/EinStack/glide/pkg/config/fields" - "github.com/EinStack/glide/pkg/routers/latency" + "github.com/EinStack/glide/pkg/router/latency" ) // LangModelMock diff --git a/pkg/gateway.go b/pkg/gateway.go index b3ce904f..fd6c1878 100644 --- a/pkg/gateway.go +++ b/pkg/gateway.go @@ -7,7 +7,7 @@ import ( "os/signal" "syscall" - "github.com/EinStack/glide/pkg/routers" + "github.com/EinStack/glide/pkg/router" "github.com/EinStack/glide/pkg/version" "go.opentelemetry.io/contrib/instrumentation/host" @@ -50,7 +50,7 @@ func NewGateway(configProvider *config.Provider) (*Gateway, error) { tel.L().Info("🐦Glide is starting up", zap.String("version", version.FullVersion)) tel.L().Debug("✅ Config loaded successfully:\n" + configProvider.GetStr()) - routerManager, err := routers.NewManager(&cfg.Routers, tel) + routerManager, err := router.NewManager(&cfg.Routers, tel) if err != nil { return nil, err } diff --git a/pkg/routers/config.go b/pkg/router/config.go similarity index 83% rename from pkg/routers/config.go rename to pkg/router/config.go index 6d4610ab..57ebb9c4 100644 --- a/pkg/routers/config.go +++ b/pkg/router/config.go @@ -1,22 +1,22 @@ -package routers +package router import ( "github.com/EinStack/glide/pkg/resiliency/retry" - "github.com/EinStack/glide/pkg/routers/routing" + "github.com/EinStack/glide/pkg/router/routing" ) // TODO: how to specify other backoff strategies? // TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738 -type RouterConfig struct { +type Config struct { ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled? Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request } -func DefaultConfig() RouterConfig { - return RouterConfig{ +func DefaultConfig() Config { + return Config{ Enabled: true, RoutingStrategy: routing.Priority, Retry: retry.DefaultExpRetryConfig(), @@ -24,8 +24,7 @@ func DefaultConfig() RouterConfig { } // RoutersConfig defines a config for a set of supported router types -// TODO: remove nolint after renaming the package -type RoutersConfig struct { //nolint: revive +type RoutersConfig struct { LanguageRouters LangRoutersConfig `yaml:"language" validate:"required,dive"` // the list of language routers // EmbeddingRouters []EmbeddingRouterConfig `yaml:"embedding" validate:"required,dive"` } diff --git a/pkg/routers/embed_config.go b/pkg/router/embed_config.go similarity index 92% rename from pkg/routers/embed_config.go rename to pkg/router/embed_config.go index a93e937b..d0593e84 100644 --- a/pkg/routers/embed_config.go +++ b/pkg/router/embed_config.go @@ -1,4 +1,4 @@ -package routers +package router import ( "github.com/EinStack/glide/pkg/extmodel" @@ -11,6 +11,6 @@ type ( ) type EmbeddingRouterConfig struct { - RouterConfig + Config Models EmbedModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests } diff --git a/pkg/routers/embed_router.go b/pkg/router/embed_router.go similarity index 94% rename from pkg/routers/embed_router.go rename to pkg/router/embed_router.go index ef81d59a..4276f9a2 100644 --- a/pkg/routers/embed_router.go +++ b/pkg/router/embed_router.go @@ -1,4 +1,4 @@ -package routers +package router type EmbeddingRouter struct { // routerID lang.RouterID diff --git a/pkg/routers/lang_config.go b/pkg/router/lang_config.go similarity index 94% rename from pkg/routers/lang_config.go rename to pkg/router/lang_config.go index 9a7b685b..007a3769 100644 --- a/pkg/routers/lang_config.go +++ b/pkg/router/lang_config.go @@ -1,4 +1,4 @@ -package routers +package router import ( "fmt" @@ -9,7 +9,7 @@ import ( "github.com/EinStack/glide/pkg/extmodel" "github.com/EinStack/glide/pkg/resiliency/retry" - "github.com/EinStack/glide/pkg/routers/routing" + "github.com/EinStack/glide/pkg/router/routing" "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/multierr" "go.uber.org/zap" @@ -22,21 +22,21 @@ type ( // LangRouterConfig type LangRouterConfig struct { - RouterConfig + Config Models LangModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests } -type RouterConfigOption = func(*LangRouterConfig) +type ConfigOption = func(*LangRouterConfig) -func WithModels(models LangModelPoolConfig) RouterConfigOption { +func WithModels(models LangModelPoolConfig) ConfigOption { return func(c *LangRouterConfig) { c.Models = models } } -func NewRouterConfig(RouterID string, opt ...RouterConfigOption) *LangRouterConfig { +func NewRouterConfig(RouterID string, opt ...ConfigOption) *LangRouterConfig { cfg := &LangRouterConfig{ - RouterConfig: DefaultConfig(), + Config: DefaultConfig(), } cfg.ID = RouterID @@ -190,7 +190,7 @@ func (c *LangRouterConfig) BuildRouting( func DefaultRouterConfig() *LangRouterConfig { return &LangRouterConfig{ - RouterConfig: DefaultConfig(), + Config: DefaultConfig(), } } diff --git a/pkg/routers/lang_config_test.go b/pkg/router/lang_config_test.go similarity index 97% rename from pkg/routers/lang_config_test.go rename to pkg/router/lang_config_test.go index 81998f53..71d8d829 100644 --- a/pkg/routers/lang_config_test.go +++ b/pkg/router/lang_config_test.go @@ -1,4 +1,4 @@ -package routers +package router import ( "testing" @@ -9,8 +9,8 @@ import ( "github.com/EinStack/glide/pkg/provider/cohere" "github.com/EinStack/glide/pkg/provider/openai" "github.com/EinStack/glide/pkg/resiliency/health" - "github.com/EinStack/glide/pkg/routers/latency" - "github.com/EinStack/glide/pkg/routers/routing" + "github.com/EinStack/glide/pkg/router/latency" + "github.com/EinStack/glide/pkg/router/routing" "github.com/EinStack/glide/pkg/telemetry" "github.com/stretchr/testify/require" ) diff --git a/pkg/routers/lang_router.go b/pkg/router/lang_router.go similarity index 97% rename from pkg/routers/lang_router.go rename to pkg/router/lang_router.go index fbae5f2f..04c64e9e 100644 --- a/pkg/routers/lang_router.go +++ b/pkg/router/lang_router.go @@ -1,4 +1,4 @@ -package routers +package router import ( "context" @@ -8,17 +8,17 @@ import ( "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/resiliency/retry" - "github.com/EinStack/glide/pkg/routers/routing" + "github.com/EinStack/glide/pkg/router/routing" "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/zap" ) var ErrNoModels = errors.New("no models configured for router") -type RouterID = string +type ID = string type LangRouter struct { - routerID RouterID + routerID ID Config *LangRouterConfig chatModels []*extmodel.LanguageModel chatStreamModels []*extmodel.LanguageModel @@ -55,7 +55,7 @@ func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter return router, err } -func (r *LangRouter) ID() RouterID { +func (r *LangRouter) ID() ID { return r.routerID } diff --git a/pkg/routers/lang_router_test.go b/pkg/router/lang_router_test.go similarity index 99% rename from pkg/routers/lang_router_test.go rename to pkg/router/lang_router_test.go index 671b0759..92eb5d4b 100644 --- a/pkg/routers/lang_router_test.go +++ b/pkg/router/lang_router_test.go @@ -1,4 +1,4 @@ -package routers +package router import ( "context" @@ -14,8 +14,8 @@ import ( "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/routers/latency" - "github.com/EinStack/glide/pkg/routers/routing" + "github.com/EinStack/glide/pkg/router/latency" + "github.com/EinStack/glide/pkg/router/routing" "github.com/EinStack/glide/pkg/telemetry" "github.com/stretchr/testify/require" ) diff --git a/pkg/routers/latency/config.go b/pkg/router/latency/config.go similarity index 100% rename from pkg/routers/latency/config.go rename to pkg/router/latency/config.go diff --git a/pkg/routers/latency/config_test.go b/pkg/router/latency/config_test.go similarity index 100% rename from pkg/routers/latency/config_test.go rename to pkg/router/latency/config_test.go diff --git a/pkg/routers/latency/moving_average.go b/pkg/router/latency/moving_average.go similarity index 100% rename from pkg/routers/latency/moving_average.go rename to pkg/router/latency/moving_average.go diff --git a/pkg/routers/latency/moving_average_test.go b/pkg/router/latency/moving_average_test.go similarity index 100% rename from pkg/routers/latency/moving_average_test.go rename to pkg/router/latency/moving_average_test.go diff --git a/pkg/routers/manager.go b/pkg/router/manager.go similarity index 80% rename from pkg/routers/manager.go rename to pkg/router/manager.go index f719d091..b30afbcf 100644 --- a/pkg/routers/manager.go +++ b/pkg/router/manager.go @@ -1,11 +1,11 @@ -package routers +package router import ( "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" ) -type RouterManager struct { +type Manager struct { Config *RoutersConfig tel *telemetry.Telemetry langRouterMap *map[string]*LangRouter @@ -13,7 +13,7 @@ type RouterManager struct { } // NewManager creates a new instance of Router Manager that creates, holds and returns all routers -func NewManager(cfg *RoutersConfig, tel *telemetry.Telemetry) (*RouterManager, error) { +func NewManager(cfg *RoutersConfig, tel *telemetry.Telemetry) (*Manager, error) { langRouters, err := cfg.LanguageRouters.Build(tel) if err != nil { return nil, err @@ -25,7 +25,7 @@ func NewManager(cfg *RoutersConfig, tel *telemetry.Telemetry) (*RouterManager, e langRouterMap[router.ID()] = router } - manager := RouterManager{ + manager := Manager{ Config: cfg, tel: tel, langRouters: langRouters, @@ -35,12 +35,12 @@ func NewManager(cfg *RoutersConfig, tel *telemetry.Telemetry) (*RouterManager, e return &manager, err } -func (r *RouterManager) GetLangRouters() []*LangRouter { +func (r *Manager) GetLangRouters() []*LangRouter { return r.langRouters } // GetLangRouter returns a router by type and ID -func (r *RouterManager) GetLangRouter(routerID string) (*LangRouter, error) { +func (r *Manager) GetLangRouter(routerID string) (*LangRouter, error) { if router, found := (*r.langRouterMap)[routerID]; found { return router, nil } diff --git a/pkg/routers/routing/least_latency.go b/pkg/router/routing/least_latency.go similarity index 98% rename from pkg/routers/routing/least_latency.go rename to pkg/router/routing/least_latency.go index d34f45e2..e233c20e 100644 --- a/pkg/routers/routing/least_latency.go +++ b/pkg/router/routing/least_latency.go @@ -7,7 +7,7 @@ import ( "github.com/EinStack/glide/pkg/extmodel" - "github.com/EinStack/glide/pkg/routers/latency" + "github.com/EinStack/glide/pkg/router/latency" ) const ( diff --git a/pkg/routers/routing/least_latency_test.go b/pkg/router/routing/least_latency_test.go similarity index 100% rename from pkg/routers/routing/least_latency_test.go rename to pkg/router/routing/least_latency_test.go diff --git a/pkg/routers/routing/priority.go b/pkg/router/routing/priority.go similarity index 100% rename from pkg/routers/routing/priority.go rename to pkg/router/routing/priority.go diff --git a/pkg/routers/routing/priority_test.go b/pkg/router/routing/priority_test.go similarity index 100% rename from pkg/routers/routing/priority_test.go rename to pkg/router/routing/priority_test.go diff --git a/pkg/routers/routing/round_robin.go b/pkg/router/routing/round_robin.go similarity index 100% rename from pkg/routers/routing/round_robin.go rename to pkg/router/routing/round_robin.go diff --git a/pkg/routers/routing/round_robin_test.go b/pkg/router/routing/round_robin_test.go similarity index 100% rename from pkg/routers/routing/round_robin_test.go rename to pkg/router/routing/round_robin_test.go diff --git a/pkg/routers/routing/strategies.go b/pkg/router/routing/strategies.go similarity index 100% rename from pkg/routers/routing/strategies.go rename to pkg/router/routing/strategies.go diff --git a/pkg/routers/routing/weighted_round_robin.go b/pkg/router/routing/weighted_round_robin.go similarity index 100% rename from pkg/routers/routing/weighted_round_robin.go rename to pkg/router/routing/weighted_round_robin.go diff --git a/pkg/routers/routing/weighted_round_robin_test.go b/pkg/router/routing/weighted_round_robin_test.go similarity index 100% rename from pkg/routers/routing/weighted_round_robin_test.go rename to pkg/router/routing/weighted_round_robin_test.go From d1659d36c19e2262da56d54900f939af2330900d Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 12 Aug 2024 22:00:37 +0300 Subject: [PATCH 14/18] #67: Ensured the ChatStream interface --- pkg/provider/azureopenai/chat_stream.go | 13 +++++++++---- pkg/provider/cohere/chat_stream.go | 11 ++++++++--- pkg/provider/openai/chat_stream.go | 13 +++++++++---- pkg/provider/testing.go | 5 +++++ 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/pkg/provider/azureopenai/chat_stream.go b/pkg/provider/azureopenai/chat_stream.go index 7c0f5b2c..8e12c8e3 100644 --- a/pkg/provider/azureopenai/chat_stream.go +++ b/pkg/provider/azureopenai/chat_stream.go @@ -8,7 +8,7 @@ import ( "io" "net/http" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -34,6 +34,11 @@ type ChatStream struct { errMapper *ErrorMapper } +// ensure interface +var ( + _ clients.ChatStream = (*ChatStream)(nil) +) + func NewChatStream( tel *telemetry.Telemetry, client *http.Client, @@ -83,7 +88,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // if err is io.EOF, this still means that the stream is interrupted unexpectedly // because the normal stream termination is done via finding out streamDoneMarker - return nil, clients2.ErrProviderUnavailable + return nil, clients.ErrProviderUnavailable } s.tel.L().Debug( @@ -92,7 +97,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { zap.ByteString("rawChunk", rawEvent), ) - event, err := clients2.ParseSSEvent(rawEvent) + event, err := clients.ParseSSEvent(rawEvent) if bytes.Equal(event.Data, openai.StreamDoneMarker) { s.tel.L().Info( @@ -156,7 +161,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients2.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { diff --git a/pkg/provider/cohere/chat_stream.go b/pkg/provider/cohere/chat_stream.go index 392f0a27..51bd8045 100644 --- a/pkg/provider/cohere/chat_stream.go +++ b/pkg/provider/cohere/chat_stream.go @@ -8,7 +8,7 @@ import ( "io" "net/http" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -41,6 +41,11 @@ type ChatStream struct { tel *telemetry.Telemetry } +// ensure interface +var ( + _ clients.ChatStream = (*ChatStream)(nil) +) + func NewChatStream( tel *telemetry.Telemetry, client *http.Client, @@ -96,7 +101,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // if io.EOF occurred in the middle of the stream, then the stream was interrupted - return nil, clients2.ErrProviderUnavailable + return nil, clients.ErrProviderUnavailable } s.tel.L().Debug( @@ -178,7 +183,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients2.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { diff --git a/pkg/provider/openai/chat_stream.go b/pkg/provider/openai/chat_stream.go index ba219e30..0e4a341e 100644 --- a/pkg/provider/openai/chat_stream.go +++ b/pkg/provider/openai/chat_stream.go @@ -8,7 +8,7 @@ import ( "io" "net/http" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/clients" "github.com/r3labs/sse/v2" "go.uber.org/zap" @@ -29,6 +29,11 @@ type ChatStream struct { logger *zap.Logger } +// ensure interface +var ( + _ clients.ChatStream = (*ChatStream)(nil) +) + func NewChatStream( client *http.Client, req *http.Request, @@ -75,7 +80,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // if err is io.EOF, this still means that the stream is interrupted unexpectedly // because the normal stream termination is done via finding out streamDoneMarker - return nil, clients2.ErrProviderUnavailable + return nil, clients.ErrProviderUnavailable } s.logger.Debug( @@ -83,7 +88,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { zap.ByteString("rawChunk", rawEvent), ) - event, err := clients2.ParseSSEvent(rawEvent) + event, err := clients.ParseSSEvent(rawEvent) if bytes.Equal(event.Data, StreamDoneMarker) { return nil, io.EOF @@ -142,7 +147,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients2.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { diff --git a/pkg/provider/testing.go b/pkg/provider/testing.go index 72133349..cef49cdb 100644 --- a/pkg/provider/testing.go +++ b/pkg/provider/testing.go @@ -68,6 +68,11 @@ type RespStreamMock struct { Chunks *[]RespMock } +// ensure interface +var ( + _ clients.ChatStream = (*RespStreamMock)(nil) +) + func NewRespStreamMock(chunk *[]RespMock) RespStreamMock { return RespStreamMock{ idx: 0, From e5266113abcbbdb9afb0fee8814a1aae2943f3ed Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 12 Aug 2024 22:29:56 +0300 Subject: [PATCH 15/18] #67: Started to connect EmbedRouter into the system --- pkg/api/schemas/embed.go | 9 ++++++++ pkg/router/config.go | 4 ++-- pkg/router/embed_config.go | 45 +++++++++++++++++++++++++++++++++++++- pkg/router/embed_router.go | 19 +++++++++++++--- 4 files changed, 71 insertions(+), 6 deletions(-) create mode 100644 pkg/api/schemas/embed.go diff --git a/pkg/api/schemas/embed.go b/pkg/api/schemas/embed.go new file mode 100644 index 00000000..16fd70c7 --- /dev/null +++ b/pkg/api/schemas/embed.go @@ -0,0 +1,9 @@ +package schemas + +type EmbedRequest struct { + // TODO: implement +} + +type EmbedResponse struct { + // TODO: implement +} diff --git a/pkg/router/config.go b/pkg/router/config.go index 57ebb9c4..0641a369 100644 --- a/pkg/router/config.go +++ b/pkg/router/config.go @@ -25,6 +25,6 @@ func DefaultConfig() Config { // RoutersConfig defines a config for a set of supported router types type RoutersConfig struct { - LanguageRouters LangRoutersConfig `yaml:"language" validate:"required,dive"` // the list of language routers - // EmbeddingRouters []EmbeddingRouterConfig `yaml:"embedding" validate:"required,dive"` + LanguageRouters LangRoutersConfig `yaml:"language" validate:"required,dive"` // the list of language routers + EmbeddingRouters EmbedRoutersConfig `yaml:"embedding" validate:"required,dive"` } diff --git a/pkg/router/embed_config.go b/pkg/router/embed_config.go index d0593e84..7eb442a4 100644 --- a/pkg/router/embed_config.go +++ b/pkg/router/embed_config.go @@ -1,8 +1,13 @@ package router import ( + "fmt" + "github.com/EinStack/glide/pkg/extmodel" "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/telemetry" + "go.uber.org/multierr" + "go.uber.org/zap" ) type ( @@ -10,7 +15,45 @@ type ( EmbedModelPoolConfig = []EmbedModelConfig ) -type EmbeddingRouterConfig struct { +type EmbedRouterConfig struct { Config Models EmbedModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests } + +type EmbedRoutersConfig []EmbedRouterConfig + +func (c EmbedRoutersConfig) Build(tel *telemetry.Telemetry) ([]*EmbedRouter, error) { + seenIDs := make(map[string]bool, len(c)) + routers := make([]*EmbedRouter, 0, len(c)) + + var errs error + + for idx, routerConfig := range c { + if _, ok := seenIDs[routerConfig.ID]; ok { + return nil, fmt.Errorf("ID \"%v\" is specified for more than one router while each ID should be unique", routerConfig.ID) + } + + seenIDs[routerConfig.ID] = true + + if !routerConfig.Enabled { + tel.L().Info(fmt.Sprintf("Embed router \"%v\" is disabled, skipping", routerConfig.ID)) + continue + } + + tel.L().Debug("Init router", zap.String("routerID", routerConfig.ID)) + + r, err := NewEmbedRouter(&c[idx], tel) + if err != nil { + errs = multierr.Append(errs, err) + continue + } + + routers = append(routers, r) + } + + if errs != nil { + return nil, errs + } + + return routers, nil +} diff --git a/pkg/router/embed_router.go b/pkg/router/embed_router.go index 4276f9a2..5d62b522 100644 --- a/pkg/router/embed_router.go +++ b/pkg/router/embed_router.go @@ -1,6 +1,13 @@ package router -type EmbeddingRouter struct { +import ( + "context" + + "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/telemetry" +) + +type EmbedRouter struct { // routerID lang.RouterID // Config *LangRouterConfig // retry *retry.ExpRetry @@ -8,5 +15,11 @@ type EmbeddingRouter struct { // logger *zap.Logger } -//func (r *EmbeddingRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { -//} +func NewEmbedRouter(_ *EmbedRouterConfig, _ *telemetry.Telemetry) (*EmbedRouter, error) { + // TODO: implement + return &EmbedRouter{}, nil +} + +func (r *EmbedRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { + // TODO: implement +} From 24be91c6afbdd8e53bb22fb4d3c43ba3625edf6d Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 12 Aug 2024 22:37:57 +0300 Subject: [PATCH 16/18] #67: Moved the cmd out of pkgs --- {pkg/cmd => cmd}/cli.go | 0 main.go | 2 +- pkg/api/http/handlers.go | 12 ++++++------ 3 files changed, 7 insertions(+), 7 deletions(-) rename {pkg/cmd => cmd}/cli.go (100%) diff --git a/pkg/cmd/cli.go b/cmd/cli.go similarity index 100% rename from pkg/cmd/cli.go rename to cmd/cli.go diff --git a/main.go b/main.go index a6d84381..122d45ba 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,7 @@ package main import ( "log" - "github.com/EinStack/glide/pkg/cmd" + "github.com/EinStack/glide/cmd" ) // @title Glide diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index 374516b6..26b96535 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -50,7 +50,7 @@ func LangChatHandler(routerManager *router.Manager) Handler { // Get router ID from path routerID := c.Params("router") - router, err := routerManager.GetLangRouter(routerID) + r, err := routerManager.GetLangRouter(routerID) if err != nil { httpErr := schemas.FromErr(err) @@ -61,7 +61,7 @@ func LangChatHandler(routerManager *router.Manager) Handler { resp := schemas.GetChatResponse() defer schemas.ReleaseChatResponse(resp) - resp, err = router.Chat(c.Context(), req) + resp, err = r.Chat(c.Context(), req) if err != nil { httpErr := schemas.FromErr(err) @@ -121,7 +121,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *router.Manag chatStreamC := make(chan *schemas.ChatStreamMessage) - router, _ := routerManager.GetLangRouter(routerID) + r, _ := routerManager.GetLangRouter(routerID) defer close(chatStreamC) defer c.Conn.Close() @@ -158,7 +158,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *router.Manag go func(chatRequest schemas.ChatStreamRequest) { defer wg.Done() - router.ChatStream(context.Background(), &chatRequest, chatStreamC) + r.ChatStream(context.Background(), &chatRequest, chatStreamC) }(chatRequest) } @@ -181,8 +181,8 @@ func LangRoutersHandler(routerManager *router.Manager) Handler { configuredRouters := routerManager.GetLangRouters() cfgs := make([]interface{}, 0, len(configuredRouters)) // opaque by design - for _, router := range configuredRouters { - cfgs = append(cfgs, router.Config) + for _, r := range configuredRouters { + cfgs = append(cfgs, r.Config) } return c.Status(fiber.StatusOK).JSON(schemas.RouterListSchema{Routers: cfgs}) From 65d4fa09a9c8cb65d24ec3262920d3d0b730d5cb Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 12 Aug 2024 22:46:15 +0300 Subject: [PATCH 17/18] #67: Renamed schemas to schema --- pkg/api/http/handlers.go | 31 ++++++----- pkg/api/{schemas => schema}/chat.go | 2 +- pkg/api/{schemas => schema}/chat_stream.go | 2 +- pkg/api/{schemas => schema}/chat_test.go | 2 +- pkg/api/{schemas => schema}/embed.go | 2 +- pkg/api/{schemas => schema}/errors.go | 2 +- pkg/api/{schemas => schema}/health_checks.go | 2 +- pkg/api/{schemas => schema}/pool.go | 2 +- pkg/api/{schemas => schema}/routers.go | 2 +- pkg/clients/stream.go | 12 ++--- pkg/extmodel/lang.go | 12 ++--- pkg/provider/anthropic/chat.go | 37 +++++++------- pkg/provider/anthropic/chat_stream.go | 8 +-- pkg/provider/anthropic/client_test.go | 8 +-- pkg/provider/azureopenai/chat.go | 16 +++--- pkg/provider/azureopenai/chat_stream.go | 18 +++---- pkg/provider/azureopenai/chat_stream_test.go | 8 +-- pkg/provider/azureopenai/client_test.go | 10 ++-- pkg/provider/azureopenai/schemas.go | 48 ++++++++--------- pkg/provider/bedrock/chat.go | 16 +++--- pkg/provider/bedrock/chat_stream.go | 8 +-- pkg/provider/bedrock/client_test.go | 6 +-- pkg/provider/cohere/chat.go | 16 +++--- pkg/provider/cohere/chat_stream.go | 26 +++++----- pkg/provider/cohere/chat_stream_test.go | 8 +-- pkg/provider/cohere/client_test.go | 6 +-- pkg/provider/cohere/finish_reason.go | 15 +++--- pkg/provider/cohere/schemas.go | 36 ++++++------- pkg/provider/interface.go | 9 ++-- pkg/provider/octoml/chat.go | 36 ++++++------- pkg/provider/octoml/chat_stream.go | 8 +-- pkg/provider/octoml/client_test.go | 10 ++-- pkg/provider/ollama/chat.go | 54 ++++++++++---------- pkg/provider/ollama/chat_stream.go | 8 +-- pkg/provider/ollama/client_test.go | 10 ++-- pkg/provider/openai/chat.go | 15 +++--- pkg/provider/openai/chat_stream.go | 18 +++---- pkg/provider/openai/chat_stream_test.go | 8 +-- pkg/provider/openai/chat_test.go | 8 +-- pkg/provider/openai/embed.go | 4 +- pkg/provider/openai/finish_reasons.go | 14 ++--- pkg/provider/openai/schemas.go | 52 +++++++++---------- pkg/provider/testing.go | 25 ++++----- pkg/router/embed_router.go | 5 +- pkg/router/lang_router.go | 35 +++++++------ pkg/router/lang_router_test.go | 37 +++++++------- pkg/router/manager.go | 4 +- 47 files changed, 362 insertions(+), 359 deletions(-) rename pkg/api/{schemas => schema}/chat.go (99%) rename pkg/api/{schemas => schema}/chat_stream.go (99%) rename pkg/api/{schemas => schema}/chat_test.go (99%) rename pkg/api/{schemas => schema}/embed.go (86%) rename pkg/api/{schemas => schema}/errors.go (99%) rename pkg/api/{schemas => schema}/health_checks.go (79%) rename pkg/api/{schemas => schema}/pool.go (97%) rename pkg/api/{schemas => schema}/routers.go (95%) diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index 26b96535..b85e07ad 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -6,7 +6,6 @@ import ( "github.com/EinStack/glide/pkg/router" - "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" "github.com/gofiber/contrib/websocket" "github.com/gofiber/fiber/v2" @@ -35,16 +34,16 @@ type Handler = func(c *fiber.Ctx) error func LangChatHandler(routerManager *router.Manager) Handler { return func(c *fiber.Ctx) error { if !c.Is("json") { - return c.Status(fiber.StatusBadRequest).JSON(schemas.ErrUnsupportedMediaType) + return c.Status(fiber.StatusBadRequest).JSON(schema.ErrUnsupportedMediaType) } // Unmarshal request body - req := schemas.GetChatRequest() - defer schemas.ReleaseChatRequest(req) + req := schema.GetChatRequest() + defer schema.ReleaseChatRequest(req) err := c.BodyParser(&req) if err != nil { - return c.Status(fiber.StatusBadRequest).JSON(schemas.NewPayloadParseErr(err)) + return c.Status(fiber.StatusBadRequest).JSON(schema.NewPayloadParseErr(err)) } // Get router ID from path @@ -52,18 +51,18 @@ func LangChatHandler(routerManager *router.Manager) Handler { r, err := routerManager.GetLangRouter(routerID) if err != nil { - httpErr := schemas.FromErr(err) + httpErr := schema.FromErr(err) return c.Status(httpErr.Status).JSON(httpErr) } // Chat with router - resp := schemas.GetChatResponse() - defer schemas.ReleaseChatResponse(resp) + resp := schema.GetChatResponse() + defer schema.ReleaseChatResponse(resp) resp, err = r.Chat(c.Context(), req) if err != nil { - httpErr := schemas.FromErr(err) + httpErr := schema.FromErr(err) return c.Status(httpErr.Status).JSON(httpErr) } @@ -80,7 +79,7 @@ func LangStreamRouterValidator(routerManager *router.Manager) Handler { _, err := routerManager.GetLangRouter(routerID) if err != nil { - httpErr := schemas.FromErr(err) + httpErr := schema.FromErr(err) return c.Status(httpErr.Status).JSON(httpErr) } @@ -119,7 +118,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *router.Manag wg sync.WaitGroup ) - chatStreamC := make(chan *schemas.ChatStreamMessage) + chatStreamC := make(chan *schema.ChatStreamMessage) r, _ := routerManager.GetLangRouter(routerID) @@ -139,7 +138,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *router.Manag }() for { - var chatRequest schemas.ChatStreamRequest + var chatRequest schema.ChatStreamRequest if err = c.ReadJSON(&chatRequest); err != nil { // TODO: handle bad request schemas gracefully and return back validation errors @@ -155,7 +154,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *router.Manag // TODO: handle termination gracefully wg.Add(1) - go func(chatRequest schemas.ChatStreamRequest) { + go func(chatRequest schema.ChatStreamRequest) { defer wg.Done() r.ChatStream(context.Background(), &chatRequest, chatStreamC) @@ -185,7 +184,7 @@ func LangRoutersHandler(routerManager *router.Manager) Handler { cfgs = append(cfgs, r.Config) } - return c.Status(fiber.StatusOK).JSON(schemas.RouterListSchema{Routers: cfgs}) + return c.Status(fiber.StatusOK).JSON(schema.RouterListSchema{Routers: cfgs}) } } @@ -200,9 +199,9 @@ func LangRoutersHandler(routerManager *router.Manager) Handler { // @Success 200 {object} schemas.HealthSchema // @Router /v1/health/ [get] func HealthHandler(c *fiber.Ctx) error { - return c.Status(fiber.StatusOK).JSON(schemas.HealthSchema{Healthy: true}) + return c.Status(fiber.StatusOK).JSON(schema.HealthSchema{Healthy: true}) } func NotFoundHandler(c *fiber.Ctx) error { - return c.Status(fiber.StatusNotFound).JSON(schemas.ErrRouteNotFound) + return c.Status(fiber.StatusNotFound).JSON(schema.ErrRouteNotFound) } diff --git a/pkg/api/schemas/chat.go b/pkg/api/schema/chat.go similarity index 99% rename from pkg/api/schemas/chat.go rename to pkg/api/schema/chat.go index bb846043..b833b367 100644 --- a/pkg/api/schemas/chat.go +++ b/pkg/api/schema/chat.go @@ -1,4 +1,4 @@ -package schemas +package schema // ChatRequest defines Glide's Chat Request Schema unified across all language models type ChatRequest struct { diff --git a/pkg/api/schemas/chat_stream.go b/pkg/api/schema/chat_stream.go similarity index 99% rename from pkg/api/schemas/chat_stream.go rename to pkg/api/schema/chat_stream.go index f7cf8b27..ee1cd228 100644 --- a/pkg/api/schemas/chat_stream.go +++ b/pkg/api/schema/chat_stream.go @@ -1,4 +1,4 @@ -package schemas +package schema import "time" diff --git a/pkg/api/schemas/chat_test.go b/pkg/api/schema/chat_test.go similarity index 99% rename from pkg/api/schemas/chat_test.go rename to pkg/api/schema/chat_test.go index 9b5ce407..9d77da62 100644 --- a/pkg/api/schemas/chat_test.go +++ b/pkg/api/schema/chat_test.go @@ -1,4 +1,4 @@ -package schemas +package schema import ( "testing" diff --git a/pkg/api/schemas/embed.go b/pkg/api/schema/embed.go similarity index 86% rename from pkg/api/schemas/embed.go rename to pkg/api/schema/embed.go index 16fd70c7..5698d330 100644 --- a/pkg/api/schemas/embed.go +++ b/pkg/api/schema/embed.go @@ -1,4 +1,4 @@ -package schemas +package schema type EmbedRequest struct { // TODO: implement diff --git a/pkg/api/schemas/errors.go b/pkg/api/schema/errors.go similarity index 99% rename from pkg/api/schemas/errors.go rename to pkg/api/schema/errors.go index 2765f93e..0eecf0b5 100644 --- a/pkg/api/schemas/errors.go +++ b/pkg/api/schema/errors.go @@ -1,4 +1,4 @@ -package schemas +package schema import ( "fmt" diff --git a/pkg/api/schemas/health_checks.go b/pkg/api/schema/health_checks.go similarity index 79% rename from pkg/api/schemas/health_checks.go rename to pkg/api/schema/health_checks.go index 6078e769..896e00c5 100644 --- a/pkg/api/schemas/health_checks.go +++ b/pkg/api/schema/health_checks.go @@ -1,4 +1,4 @@ -package schemas +package schema type HealthSchema struct { Healthy bool `json:"healthy"` diff --git a/pkg/api/schemas/pool.go b/pkg/api/schema/pool.go similarity index 97% rename from pkg/api/schemas/pool.go rename to pkg/api/schema/pool.go index dcd9ccf8..4b5c38ba 100755 --- a/pkg/api/schemas/pool.go +++ b/pkg/api/schema/pool.go @@ -1,4 +1,4 @@ -package schemas +package schema import ( "sync" diff --git a/pkg/api/schemas/routers.go b/pkg/api/schema/routers.go similarity index 95% rename from pkg/api/schemas/routers.go rename to pkg/api/schema/routers.go index 9111a319..18dcee02 100644 --- a/pkg/api/schemas/routers.go +++ b/pkg/api/schema/routers.go @@ -1,4 +1,4 @@ -package schemas +package schema // RouterListSchema returns list of active configured routers. // diff --git a/pkg/clients/stream.go b/pkg/clients/stream.go index 913bbddc..4ab55fb0 100644 --- a/pkg/clients/stream.go +++ b/pkg/clients/stream.go @@ -1,21 +1,19 @@ package clients -import ( - "github.com/EinStack/glide/pkg/api/schemas" -) +import "github.com/EinStack/glide/pkg/api/schema" type ChatStream interface { Open() error - Recv() (*schemas.ChatStreamChunk, error) + Recv() (*schema.ChatStreamChunk, error) Close() error } type ChatStreamResult struct { - chunk *schemas.ChatStreamChunk + chunk *schema.ChatStreamChunk err error } -func (r *ChatStreamResult) Chunk() *schemas.ChatStreamChunk { +func (r *ChatStreamResult) Chunk() *schema.ChatStreamChunk { return r.chunk } @@ -23,7 +21,7 @@ func (r *ChatStreamResult) Error() error { return r.err } -func NewChatStreamResult(chunk *schemas.ChatStreamChunk, err error) *ChatStreamResult { +func NewChatStreamResult(chunk *schema.ChatStreamChunk, err error) *ChatStreamResult { return &ChatStreamResult{ chunk: chunk, err: err, diff --git a/pkg/extmodel/lang.go b/pkg/extmodel/lang.go index 0c29870b..7c95282a 100644 --- a/pkg/extmodel/lang.go +++ b/pkg/extmodel/lang.go @@ -5,6 +5,8 @@ import ( "io" "time" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/provider" "github.com/EinStack/glide/pkg/clients" @@ -13,16 +15,14 @@ import ( "github.com/EinStack/glide/pkg/config/fields" "github.com/EinStack/glide/pkg/router/latency" - - "github.com/EinStack/glide/pkg/api/schemas" ) type LangModel interface { Interface Provider() string ModelName() string - Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) - ChatStream(ctx context.Context, params *schemas.ChatParams) (<-chan *clients.ChatStreamResult, error) + Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) + ChatStream(ctx context.Context, params *schema.ChatParams) (<-chan *clients.ChatStreamResult, error) } // LanguageModel wraps provider client and expend it with health & latency tracking @@ -79,7 +79,7 @@ func (m LanguageModel) ChatStreamLatency() *latency.MovingAverage { return m.chatStreamLatency } -func (m *LanguageModel) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (m *LanguageModel) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { startedAt := time.Now() resp, err := m.client.Chat(ctx, params) @@ -98,7 +98,7 @@ func (m *LanguageModel) Chat(ctx context.Context, params *schemas.ChatParams) (* return resp, err } -func (m *LanguageModel) ChatStream(ctx context.Context, params *schemas.ChatParams) (<-chan *clients.ChatStreamResult, error) { +func (m *LanguageModel) ChatStream(ctx context.Context, params *schema.ChatParams) (<-chan *clients.ChatStreamResult, error) { stream, err := m.client.ChatStream(ctx, params) if err != nil { m.healthTracker.TrackErr(err) diff --git a/pkg/provider/anthropic/chat.go b/pkg/provider/anthropic/chat.go index c45efb76..bb0559ad 100644 --- a/pkg/provider/anthropic/chat.go +++ b/pkg/provider/anthropic/chat.go @@ -9,27 +9,28 @@ import ( "net/http" "time" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) // ChatRequest is an Anthropic-specific request schema type ChatRequest struct { - Model string `json:"model"` - Messages []schemas.ChatMessage `json:"messages"` - System string `json:"system,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Stream bool `json:"stream,omitempty"` - Metadata *string `json:"metadata,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` + Model string `json:"model"` + Messages []schema.ChatMessage `json:"messages"` + System string `json:"system,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + Metadata *string `json:"metadata,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { r.Messages = params.Messages } @@ -51,7 +52,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { // Chat sends a chat request to the specified anthropic model. // // Ref: https://docs.anthropic.com/claude/reference/messages_post -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate @@ -67,7 +68,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -130,19 +131,19 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche usage := anthropicResponse.Usage // Map response to ChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: anthropicResponse.ID, Created: int(time.Now().UTC().Unix()), // not provided by anthropic Provider: ProviderID, ModelName: anthropicResponse.Model, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{}, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: completion.Type, Content: completion.Text, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: usage.InputTokens, ResponseTokens: usage.OutputTokens, TotalTokens: usage.InputTokens + usage.OutputTokens, diff --git a/pkg/provider/anthropic/chat_stream.go b/pkg/provider/anthropic/chat_stream.go index dbb0b8ff..1a9f88a4 100644 --- a/pkg/provider/anthropic/chat_stream.go +++ b/pkg/provider/anthropic/chat_stream.go @@ -3,15 +3,15 @@ package anthropic import ( "context" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" ) func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { - return nil, clients2.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/provider/anthropic/client_test.go b/pkg/provider/anthropic/client_test.go index 70977bb0..2fe33334 100644 --- a/pkg/provider/anthropic/client_test.go +++ b/pkg/provider/anthropic/client_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -56,7 +56,7 @@ func TestAnthropicClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -86,7 +86,7 @@ func TestAnthropicClient_BadChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} diff --git a/pkg/provider/azureopenai/chat.go b/pkg/provider/azureopenai/chat.go index 86aab1f2..d2f1200e 100644 --- a/pkg/provider/azureopenai/chat.go +++ b/pkg/provider/azureopenai/chat.go @@ -8,12 +8,12 @@ import ( "io" "net/http" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/provider/openai" - "github.com/EinStack/glide/pkg/api/schemas" - "go.uber.org/zap" ) @@ -38,7 +38,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified azure openai model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -54,7 +54,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -110,19 +110,19 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to UnifiedChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: chatCompletion.ID, Created: chatCompletion.Created, Provider: providerName, ModelName: chatCompletion.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{}, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: modelChoice.Message.Role, Content: modelChoice.Message.Content, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: chatCompletion.Usage.PromptTokens, ResponseTokens: chatCompletion.Usage.CompletionTokens, TotalTokens: chatCompletion.Usage.TotalTokens, diff --git a/pkg/provider/azureopenai/chat_stream.go b/pkg/provider/azureopenai/chat_stream.go index 8e12c8e3..f75fae4c 100644 --- a/pkg/provider/azureopenai/chat_stream.go +++ b/pkg/provider/azureopenai/chat_stream.go @@ -8,6 +8,8 @@ import ( "io" "net/http" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -17,8 +19,6 @@ import ( "github.com/r3labs/sse/v2" "go.uber.org/zap" - - "github.com/EinStack/glide/pkg/api/schemas" ) // TODO: Think about reducing the number of copy-pasted code btw OpenAI and Azure providers @@ -73,7 +73,7 @@ func (s *ChatStream) Open() error { } // Recv receives a chat stream chunk from the ChatStream and returns a ChatStreamChunk object. -func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { +func (s *ChatStream) Recv() (*schema.ChatStreamChunk, error) { var completionChunk ChatCompletionChunk for { @@ -130,16 +130,16 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { responseChunk := completionChunk.Choices[0] // TODO: use objectpool here - return &schemas.ChatStreamChunk{ + return &schema.ChatStreamChunk{ Cached: false, Provider: providerName, ModelName: completionChunk.ModelName, - ModelResponse: schemas.ModelChunkResponse{ - Metadata: &schemas.Metadata{ + ModelResponse: schema.ModelChunkResponse{ + Metadata: &schema.Metadata{ "response_id": completionChunk.ID, "system_fingerprint": completionChunk.SystemFingerprint, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: responseChunk.Delta.Role, Content: responseChunk.Delta.Content, }, @@ -161,7 +161,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schema.ChatParams) (clients.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { @@ -177,7 +177,7 @@ func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (cl ), nil } -func (c *Client) makeStreamReq(ctx context.Context, params *schemas.ChatParams) (*http.Request, error) { +func (c *Client) makeStreamReq(ctx context.Context, params *schema.ChatParams) (*http.Request, error) { chatReq := *c.chatRequestTemplate chatReq.ApplyParams(params) diff --git a/pkg/provider/azureopenai/chat_stream_test.go b/pkg/provider/azureopenai/chat_stream_test.go index 39a5b93e..f056d599 100644 --- a/pkg/provider/azureopenai/chat_stream_test.go +++ b/pkg/provider/azureopenai/chat_stream_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -71,7 +71,7 @@ func TestAzureOpenAIClient_ChatStreamRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -139,7 +139,7 @@ func TestAzureOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the biggest animal?", }}} diff --git a/pkg/provider/azureopenai/client_test.go b/pkg/provider/azureopenai/client_test.go index 5c390114..accca38d 100644 --- a/pkg/provider/azureopenai/client_test.go +++ b/pkg/provider/azureopenai/client_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -55,7 +55,7 @@ func TestAzureOpenAIClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -88,7 +88,7 @@ func TestAzureOpenAIClient_ChatError(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -115,7 +115,7 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the dealio?", }}} diff --git a/pkg/provider/azureopenai/schemas.go b/pkg/provider/azureopenai/schemas.go index 5940648c..2ce12eb5 100644 --- a/pkg/provider/azureopenai/schemas.go +++ b/pkg/provider/azureopenai/schemas.go @@ -1,27 +1,27 @@ package azureopenai -import "github.com/EinStack/glide/pkg/api/schemas" +import "github.com/EinStack/glide/pkg/api/schema" // ChatRequest is an Azure openai-specific request schema type ChatRequest struct { - Messages []schemas.ChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` - LogitBias *map[int]float64 `json:"logit_bias,omitempty"` - User *string `json:"user,omitempty"` - Seed *int `json:"seed,omitempty"` - Tools []string `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` - ResponseFormat interface{} `json:"response_format,omitempty"` + Messages []schema.ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Seed *int `json:"seed,omitempty"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { r.Messages = params.Messages } @@ -38,10 +38,10 @@ type ChatCompletion struct { } type Choice struct { - Index int `json:"index"` - Message schemas.ChatMessage `json:"message"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message schema.ChatMessage `json:"message"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` } type Usage struct { @@ -62,7 +62,7 @@ type ChatCompletionChunk struct { } type StreamChoice struct { - Index int `json:"index"` - Delta schemas.ChatMessage `json:"delta"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Delta schema.ChatMessage `json:"delta"` + FinishReason string `json:"finish_reason"` } diff --git a/pkg/provider/bedrock/chat.go b/pkg/provider/bedrock/chat.go index 658c1769..cd51027b 100644 --- a/pkg/provider/bedrock/chat.go +++ b/pkg/provider/bedrock/chat.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" "go.uber.org/zap" @@ -22,7 +22,7 @@ type ChatRequest struct { TextGenerationConfig TextGenerationConfig `json:"textGenerationConfig"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { // message history not yet supported for AWS models // TODO: do something about lack of message history. Maybe just concatenate all messages? // in any case, this is not a way to go to ignore message history @@ -51,7 +51,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified bedrock model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -65,7 +65,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { rawPayload, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("unable to marshal chat request payload: %w", err) @@ -96,18 +96,18 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, ErrEmptyResponse } - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: uuid.NewString(), Created: int(time.Now().Unix()), Provider: providerName, ModelName: c.config.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ - Message: schemas.ChatMessage{ + ModelResponse: schema.ModelResponse{ + Message: schema.ChatMessage{ Role: "assistant", Content: modelResult.OutputText, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ // TODO: what would happen if there is a few responses? We need to sum that up PromptTokens: modelResult.TokenCount, ResponseTokens: -1, diff --git a/pkg/provider/bedrock/chat_stream.go b/pkg/provider/bedrock/chat_stream.go index 57413043..6bb87905 100644 --- a/pkg/provider/bedrock/chat_stream.go +++ b/pkg/provider/bedrock/chat_stream.go @@ -3,15 +3,15 @@ package bedrock import ( "context" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" ) func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { - return nil, clients2.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/provider/bedrock/client_test.go b/pkg/provider/bedrock/client_test.go index e99f8d9c..f6081966 100644 --- a/pkg/provider/bedrock/client_test.go +++ b/pkg/provider/bedrock/client_test.go @@ -11,9 +11,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -61,7 +61,7 @@ func TestBedrockClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the biggest animal?", }}} diff --git a/pkg/provider/cohere/chat.go b/pkg/provider/cohere/chat.go index 754d8537..7b2ebbb9 100644 --- a/pkg/provider/cohere/chat.go +++ b/pkg/provider/cohere/chat.go @@ -9,9 +9,9 @@ import ( "net/http" "time" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "go.uber.org/zap" ) @@ -30,7 +30,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified cohere model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate @@ -44,7 +44,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -115,22 +115,22 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to ChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: cohereCompletion.ResponseID, Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this Provider: ProviderID, ModelName: c.config.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{ "generationId": cohereCompletion.GenerationID, "responseId": cohereCompletion.ResponseID, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: "assistant", Content: cohereCompletion.Text, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: cohereCompletion.TokenCount.PromptTokens, ResponseTokens: cohereCompletion.TokenCount.ResponseTokens, TotalTokens: cohereCompletion.TokenCount.TotalTokens, diff --git a/pkg/provider/cohere/chat_stream.go b/pkg/provider/cohere/chat_stream.go index 51bd8045..46b07598 100644 --- a/pkg/provider/cohere/chat_stream.go +++ b/pkg/provider/cohere/chat_stream.go @@ -8,13 +8,13 @@ import ( "io" "net/http" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/zap" - - "github.com/EinStack/glide/pkg/api/schemas" ) // SupportedEventType Cohere has other types too: @@ -83,7 +83,7 @@ func (s *ChatStream) Open() error { return nil } -func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { +func (s *ChatStream) Recv() (*schema.ChatStreamChunk, error) { if s.streamFinished { return nil, io.EOF } @@ -135,16 +135,16 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { s.streamFinished = true // TODO: use objectpool here - return &schemas.ChatStreamChunk{ + return &schema.ChatStreamChunk{ Cached: false, Provider: ProviderID, ModelName: s.modelName, - ModelResponse: schemas.ModelChunkResponse{ - Metadata: &schemas.Metadata{ + ModelResponse: schema.ModelChunkResponse{ + Metadata: &schema.Metadata{ "generation_id": s.generationID, "response_id": responseChunk.Response.ResponseID, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: "model", Content: responseChunk.Text, }, @@ -154,15 +154,15 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { } // TODO: use objectpool here - return &schemas.ChatStreamChunk{ + return &schema.ChatStreamChunk{ Cached: false, Provider: ProviderID, ModelName: s.modelName, - ModelResponse: schemas.ModelChunkResponse{ - Metadata: &schemas.Metadata{ + ModelResponse: schema.ModelChunkResponse{ + Metadata: &schema.Metadata{ "generation_id": s.generationID, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: "model", Content: responseChunk.Text, }, @@ -183,7 +183,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schema.ChatParams) (clients.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { @@ -200,7 +200,7 @@ func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (cl ), nil } -func (c *Client) makeStreamReq(ctx context.Context, params *schemas.ChatParams) (*http.Request, error) { +func (c *Client) makeStreamReq(ctx context.Context, params *schema.ChatParams) (*http.Request, error) { // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate chatReq.ApplyParams(params) diff --git a/pkg/provider/cohere/chat_stream_test.go b/pkg/provider/cohere/chat_stream_test.go index 82060f84..e40eed14 100644 --- a/pkg/provider/cohere/chat_stream_test.go +++ b/pkg/provider/cohere/chat_stream_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -71,7 +71,7 @@ func TestCohere_ChatStreamRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -138,7 +138,7 @@ func TestCohere_ChatStreamRequestInterrupted(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} diff --git a/pkg/provider/cohere/client_test.go b/pkg/provider/cohere/client_test.go index bb4f99e4..721ceda7 100644 --- a/pkg/provider/cohere/client_test.go +++ b/pkg/provider/cohere/client_test.go @@ -11,9 +11,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -55,7 +55,7 @@ func TestCohereClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} diff --git a/pkg/provider/cohere/finish_reason.go b/pkg/provider/cohere/finish_reason.go index 139498e6..4d156875 100644 --- a/pkg/provider/cohere/finish_reason.go +++ b/pkg/provider/cohere/finish_reason.go @@ -3,9 +3,10 @@ package cohere import ( "strings" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -27,27 +28,27 @@ type FinishReasonMapper struct { tel *telemetry.Telemetry } -func (m *FinishReasonMapper) Map(finishReason *string) *schemas.FinishReason { +func (m *FinishReasonMapper) Map(finishReason *string) *schema.FinishReason { if finishReason == nil || len(*finishReason) == 0 { return nil } - var reason *schemas.FinishReason + var reason *schema.FinishReason switch strings.ToLower(*finishReason) { case CompleteReason: - reason = &schemas.ReasonComplete + reason = &schema.ReasonComplete case MaxTokensReason: - reason = &schemas.ReasonMaxTokens + reason = &schema.ReasonMaxTokens case FilteredReason: - reason = &schemas.ReasonContentFiltered + reason = &schema.ReasonContentFiltered default: m.tel.Logger.Warn( "Unknown finish reason, other is going to used", zap.String("unknown_reason", *finishReason), ) - reason = &schemas.ReasonOther + reason = &schema.ReasonOther } return reason diff --git a/pkg/provider/cohere/schemas.go b/pkg/provider/cohere/schemas.go index 9dc9bb09..c224ec0e 100644 --- a/pkg/provider/cohere/schemas.go +++ b/pkg/provider/cohere/schemas.go @@ -1,6 +1,6 @@ package cohere -import "github.com/EinStack/glide/pkg/api/schemas" +import "github.com/EinStack/glide/pkg/api/schema" // Cohere Chat Response type ChatCompletion struct { @@ -90,25 +90,25 @@ type FinalResponse struct { // ChatRequest is a request to complete a chat completion // Ref: https://docs.cohere.com/reference/chat type ChatRequest struct { - Model string `json:"model"` - Message string `json:"message"` - ChatHistory []schemas.ChatMessage `json:"chat_history"` - Temperature float64 `json:"temperature,omitempty"` - Preamble string `json:"preamble,omitempty"` - PromptTruncation *string `json:"prompt_truncation,omitempty"` - Connectors []string `json:"connectors,omitempty"` - SearchQueriesOnly bool `json:"search_queries_only,omitempty"` - Stream bool `json:"stream,omitempty"` - Seed *int `json:"seed,omitempty"` - MaxTokens *int `json:"max_tokens,omitempty"` - K int `json:"k"` - P float32 `json:"p"` - FrequencyPenalty float32 `json:"frequency_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - StopSequences []string `json:"stop_sequences"` + Model string `json:"model"` + Message string `json:"message"` + ChatHistory []schema.ChatMessage `json:"chat_history"` + Temperature float64 `json:"temperature,omitempty"` + Preamble string `json:"preamble,omitempty"` + PromptTruncation *string `json:"prompt_truncation,omitempty"` + Connectors []string `json:"connectors,omitempty"` + SearchQueriesOnly bool `json:"search_queries_only,omitempty"` + Stream bool `json:"stream,omitempty"` + Seed *int `json:"seed,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + K int `json:"k"` + P float32 `json:"p"` + FrequencyPenalty float32 `json:"frequency_penalty"` + PresencePenalty float32 `json:"presence_penalty"` + StopSequences []string `json:"stop_sequences"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { message := params.Messages[len(params.Messages)-1] messageHistory := params.Messages[:len(params.Messages)-1] diff --git a/pkg/provider/interface.go b/pkg/provider/interface.go index b2e4ffbd..1171f486 100644 --- a/pkg/provider/interface.go +++ b/pkg/provider/interface.go @@ -4,7 +4,8 @@ import ( "context" "errors" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" ) @@ -22,13 +23,13 @@ type ModelProvider interface { type LangProvider interface { ModelProvider SupportChatStream() bool - Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) - ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) + Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) + ChatStream(ctx context.Context, params *schema.ChatParams) (clients.ChatStream, error) } // EmbeddingProvider defines an interface a provider should fulfill to be able to generate embeddings type EmbeddingProvider interface { ModelProvider SupportEmbedding() bool - Embed(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) + Embed(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) } diff --git a/pkg/provider/octoml/chat.go b/pkg/provider/octoml/chat.go index 9b2237f3..3648cd95 100644 --- a/pkg/provider/octoml/chat.go +++ b/pkg/provider/octoml/chat.go @@ -8,27 +8,27 @@ import ( "io" "net/http" - "github.com/EinStack/glide/pkg/provider/openai" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/provider/openai" "go.uber.org/zap" ) // ChatRequest is an octoml-specific request schema type ChatRequest struct { - Model string `json:"model"` - Messages []schemas.ChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` + Model string `json:"model"` + Messages []schema.ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { // TODO(185): set other params r.Messages = params.Messages } @@ -47,7 +47,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified octoml model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -63,7 +63,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -119,21 +119,21 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to UnifiedChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: completion.ID, Created: completion.Created, Provider: providerName, ModelName: completion.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{ "system_fingerprint": completion.SystemFingerprint, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: modelChoice.Message.Role, Content: modelChoice.Message.Content, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: completion.Usage.PromptTokens, ResponseTokens: completion.Usage.CompletionTokens, TotalTokens: completion.Usage.TotalTokens, diff --git a/pkg/provider/octoml/chat_stream.go b/pkg/provider/octoml/chat_stream.go index 7b8a1766..22ead76a 100644 --- a/pkg/provider/octoml/chat_stream.go +++ b/pkg/provider/octoml/chat_stream.go @@ -3,15 +3,15 @@ package octoml import ( "context" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" ) func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { - return nil, clients2.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/provider/octoml/client_test.go b/pkg/provider/octoml/client_test.go index 128fd1f0..fcc266c1 100644 --- a/pkg/provider/octoml/client_test.go +++ b/pkg/provider/octoml/client_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -55,7 +55,7 @@ func TestOctoMLClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -88,7 +88,7 @@ func TestOctoMLClient_Chat_Error(t *testing.T) { require.NoError(t, err) // Create a chat request - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -120,7 +120,7 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) { require.NoError(t, err) // Create a chat request payload - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the dealeo?", }}} diff --git a/pkg/provider/ollama/chat.go b/pkg/provider/ollama/chat.go index 42ee1f99..463899c8 100644 --- a/pkg/provider/ollama/chat.go +++ b/pkg/provider/ollama/chat.go @@ -9,38 +9,38 @@ import ( "net/http" "time" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" "github.com/google/uuid" - "github.com/EinStack/glide/pkg/api/schemas" - "go.uber.org/zap" ) // ChatRequest is an ollama-specific request schema type ChatRequest struct { - Model string `json:"model"` - Messages []schemas.ChatMessage `json:"messages"` - Microstat int `json:"microstat,omitempty"` - MicrostatEta float64 `json:"microstat_eta,omitempty"` - MicrostatTau float64 `json:"microstat_tau,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` - NumGqa int `json:"num_gqa,omitempty"` - NumGpu int `json:"num_gpu,omitempty"` - NumThread int `json:"num_thread,omitempty"` - RepeatLastN int `json:"repeat_last_n,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - Seed int `json:"seed,omitempty"` - StopWords []string `json:"stop,omitempty"` - Tfsz float64 `json:"tfs_z,omitempty"` - NumPredict int `json:"num_predict,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Stream bool `json:"stream"` + Model string `json:"model"` + Messages []schema.ChatMessage `json:"messages"` + Microstat int `json:"microstat,omitempty"` + MicrostatEta float64 `json:"microstat_eta,omitempty"` + MicrostatTau float64 `json:"microstat_tau,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` + NumGqa int `json:"num_gqa,omitempty"` + NumGpu int `json:"num_gpu,omitempty"` + NumThread int `json:"num_thread,omitempty"` + RepeatLastN int `json:"repeat_last_n,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Seed int `json:"seed,omitempty"` + StopWords []string `json:"stop,omitempty"` + Tfsz float64 `json:"tfs_z,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Stream bool `json:"stream"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { // TODO(185): set other params r.Messages = params.Messages } @@ -68,7 +68,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified ollama model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -84,7 +84,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { //nolint:cyclop +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { //nolint:cyclop // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -164,18 +164,18 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to UnifiedChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: uuid.NewString(), Created: int(time.Now().Unix()), Provider: providerName, ModelName: ollamaCompletion.Model, Cached: false, - ModelResponse: schemas.ModelResponse{ - Message: schemas.ChatMessage{ + ModelResponse: schema.ModelResponse{ + Message: schema.ChatMessage{ Role: ollamaCompletion.Message.Role, Content: ollamaCompletion.Message.Content, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: ollamaCompletion.EvalCount, ResponseTokens: ollamaCompletion.EvalCount, TotalTokens: ollamaCompletion.EvalCount, diff --git a/pkg/provider/ollama/chat_stream.go b/pkg/provider/ollama/chat_stream.go index 15d220e9..d43f88c0 100644 --- a/pkg/provider/ollama/chat_stream.go +++ b/pkg/provider/ollama/chat_stream.go @@ -3,15 +3,15 @@ package ollama import ( "context" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" ) func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { - return nil, clients2.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/provider/ollama/client_test.go b/pkg/provider/ollama/client_test.go index 1c9dad49..e371c39d 100644 --- a/pkg/provider/ollama/client_test.go +++ b/pkg/provider/ollama/client_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -56,7 +56,7 @@ func TestOllamaClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the biggest animal?", }}} @@ -84,7 +84,7 @@ func TestOllamaClient_ChatRequest_Non200Response(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -121,7 +121,7 @@ func TestOllamaClient_ChatRequest_SuccessfulResponse(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} diff --git a/pkg/provider/openai/chat.go b/pkg/provider/openai/chat.go index 86bce6f1..be2bcbf3 100644 --- a/pkg/provider/openai/chat.go +++ b/pkg/provider/openai/chat.go @@ -8,9 +8,10 @@ import ( "io" "net/http" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -36,7 +37,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified OpenAI model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -52,7 +53,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -123,21 +124,21 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to ChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: chatCompletion.ID, Created: chatCompletion.Created, Provider: ProviderID, ModelName: chatCompletion.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{ "system_fingerprint": chatCompletion.SystemFingerprint, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: modelChoice.Message.Role, Content: modelChoice.Message.Content, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: chatCompletion.Usage.PromptTokens, ResponseTokens: chatCompletion.Usage.CompletionTokens, TotalTokens: chatCompletion.Usage.TotalTokens, diff --git a/pkg/provider/openai/chat_stream.go b/pkg/provider/openai/chat_stream.go index 0e4a341e..9d50295b 100644 --- a/pkg/provider/openai/chat_stream.go +++ b/pkg/provider/openai/chat_stream.go @@ -8,12 +8,12 @@ import ( "io" "net/http" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" "github.com/r3labs/sse/v2" "go.uber.org/zap" - - "github.com/EinStack/glide/pkg/api/schemas" ) var StreamDoneMarker = []byte("[DONE]") @@ -66,7 +66,7 @@ func (s *ChatStream) Open() error { return nil } -func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { +func (s *ChatStream) Recv() (*schema.ChatStreamChunk, error) { var completionChunk ChatCompletionChunk for { @@ -115,17 +115,17 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { responseChunk := completionChunk.Choices[0] // TODO: use objectpool here - return &schemas.ChatStreamChunk{ + return &schema.ChatStreamChunk{ Cached: false, Provider: ProviderID, ModelName: completionChunk.ModelName, - ModelResponse: schemas.ModelChunkResponse{ - Metadata: &schemas.Metadata{ + ModelResponse: schema.ModelChunkResponse{ + Metadata: &schema.Metadata{ "response_id": completionChunk.ID, "system_fingerprint": completionChunk.SystemFingerprint, "generated_at": completionChunk.Created, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: "assistant", // doesn't present in all chunks Content: responseChunk.Delta.Content, }, @@ -147,7 +147,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schema.ChatParams) (clients.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { @@ -163,7 +163,7 @@ func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (cl ), nil } -func (c *Client) makeStreamReq(ctx context.Context, params *schemas.ChatParams) (*http.Request, error) { +func (c *Client) makeStreamReq(ctx context.Context, params *schema.ChatParams) (*http.Request, error) { // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template chatReq.ApplyParams(params) diff --git a/pkg/provider/openai/chat_stream_test.go b/pkg/provider/openai/chat_stream_test.go index 6928e6f0..2934df3f 100644 --- a/pkg/provider/openai/chat_stream_test.go +++ b/pkg/provider/openai/chat_stream_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -71,7 +71,7 @@ func TestOpenAIClient_ChatStreamRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -139,7 +139,7 @@ func TestOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} diff --git a/pkg/provider/openai/chat_test.go b/pkg/provider/openai/chat_test.go index 0aae4d0e..65dde4f6 100644 --- a/pkg/provider/openai/chat_test.go +++ b/pkg/provider/openai/chat_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -56,7 +56,7 @@ func TestOpenAIClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -85,7 +85,7 @@ func TestOpenAIClient_RateLimit(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} diff --git a/pkg/provider/openai/embed.go b/pkg/provider/openai/embed.go index 69f9aa27..ba054adc 100644 --- a/pkg/provider/openai/embed.go +++ b/pkg/provider/openai/embed.go @@ -3,11 +3,11 @@ package openai import ( "context" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" ) // Embed sends an embedding request to the specified OpenAI model. -func (c *Client) Embed(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Embed(_ context.Context, _ *schema.ChatParams) (*schema.ChatResponse, error) { // TODO: implement return nil, nil } diff --git a/pkg/provider/openai/finish_reasons.go b/pkg/provider/openai/finish_reasons.go index 28b5f675..65196946 100644 --- a/pkg/provider/openai/finish_reasons.go +++ b/pkg/provider/openai/finish_reasons.go @@ -1,9 +1,9 @@ package openai import ( + "github.com/EinStack/glide/pkg/api/schema" "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -25,27 +25,27 @@ type FinishReasonMapper struct { tel *telemetry.Telemetry } -func (m *FinishReasonMapper) Map(finishReason string) *schemas.FinishReason { +func (m *FinishReasonMapper) Map(finishReason string) *schema.FinishReason { if len(finishReason) == 0 { return nil } - var reason *schemas.FinishReason + var reason *schema.FinishReason switch finishReason { case CompleteReason: - reason = &schemas.ReasonComplete + reason = &schema.ReasonComplete case MaxTokensReason: - reason = &schemas.ReasonMaxTokens + reason = &schema.ReasonMaxTokens case FilteredReason: - reason = &schemas.ReasonContentFiltered + reason = &schema.ReasonContentFiltered default: m.tel.Logger.Warn( "Unknown finish reason, other is going to used", zap.String("unknown_reason", finishReason), ) - reason = &schemas.ReasonOther + reason = &schema.ReasonOther } return reason diff --git a/pkg/provider/openai/schemas.go b/pkg/provider/openai/schemas.go index bde0ba81..31af6f9c 100644 --- a/pkg/provider/openai/schemas.go +++ b/pkg/provider/openai/schemas.go @@ -1,28 +1,28 @@ package openai -import "github.com/EinStack/glide/pkg/api/schemas" +import "github.com/EinStack/glide/pkg/api/schema" // ChatRequest is an OpenAI-specific request schema type ChatRequest struct { - Model string `json:"model"` - Messages []schemas.ChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` - LogitBias *map[int]float64 `json:"logit_bias,omitempty"` - User *string `json:"user,omitempty"` - Seed *int `json:"seed,omitempty"` - Tools []string `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` - ResponseFormat interface{} `json:"response_format,omitempty"` + Model string `json:"model"` + Messages []schema.ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Seed *int `json:"seed,omitempty"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { // TODO(185): set other params r.Messages = params.Messages } @@ -40,10 +40,10 @@ type ChatCompletion struct { } type Choice struct { - Index int `json:"index"` - Message schemas.ChatMessage `json:"message"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message schema.ChatMessage `json:"message"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` } type Usage struct { @@ -64,8 +64,8 @@ type ChatCompletionChunk struct { } type StreamChoice struct { - Index int `json:"index"` - Delta schemas.ChatMessage `json:"delta"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Delta schema.ChatMessage `json:"delta"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` } diff --git a/pkg/provider/testing.go b/pkg/provider/testing.go index cef49cdb..f7bc64f0 100644 --- a/pkg/provider/testing.go +++ b/pkg/provider/testing.go @@ -4,7 +4,8 @@ import ( "context" "io" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/config/fields" "github.com/EinStack/glide/pkg/telemetry" @@ -37,24 +38,24 @@ type RespMock struct { Err error } -func (m *RespMock) Resp() *schemas.ChatResponse { - return &schemas.ChatResponse{ +func (m *RespMock) Resp() *schema.ChatResponse { + return &schema.ChatResponse{ ID: "rsp0001", - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{ "ID": "0001", }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Content: m.Msg, }, }, } } -func (m *RespMock) RespChunk() *schemas.ChatStreamChunk { - return &schemas.ChatStreamChunk{ - ModelResponse: schemas.ModelChunkResponse{ - Message: schemas.ChatMessage{ +func (m *RespMock) RespChunk() *schema.ChatStreamChunk { + return &schema.ChatStreamChunk{ + ModelResponse: schema.ModelChunkResponse{ + Message: schema.ChatMessage{ Content: m.Msg, }, }, @@ -97,7 +98,7 @@ func (m *RespStreamMock) Open() error { return nil } -func (m *RespStreamMock) Recv() (*schemas.ChatStreamChunk, error) { +func (m *RespStreamMock) Recv() (*schema.ChatStreamChunk, error) { if m.Chunks != nil && m.idx >= len(*m.Chunks) { return nil, io.EOF } @@ -154,7 +155,7 @@ func (c *Mock) SupportChatStream() bool { return c.supportStreaming } -func (c *Mock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Mock) Chat(_ context.Context, _ *schema.ChatParams) (*schema.ChatResponse, error) { if c.chatResps == nil { return nil, clients.ErrProviderUnavailable } @@ -171,7 +172,7 @@ func (c *Mock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResp return response.Resp(), nil } -func (c *Mock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Mock) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { if c.chatStreams == nil || c.idx >= len(*c.chatStreams) { return nil, clients.ErrProviderUnavailable } diff --git a/pkg/router/embed_router.go b/pkg/router/embed_router.go index 5d62b522..b2ad8a59 100644 --- a/pkg/router/embed_router.go +++ b/pkg/router/embed_router.go @@ -3,7 +3,8 @@ package router import ( "context" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/telemetry" ) @@ -20,6 +21,6 @@ func NewEmbedRouter(_ *EmbedRouterConfig, _ *telemetry.Telemetry) (*EmbedRouter, return &EmbedRouter{}, nil } -func (r *EmbedRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { +func (r *EmbedRouter) Embed(ctx context.Context, req *schema.EmbedRequest) (*schema.EmbedResponse, error) { // TODO: implement } diff --git a/pkg/router/lang_router.go b/pkg/router/lang_router.go index 04c64e9e..ec2d9113 100644 --- a/pkg/router/lang_router.go +++ b/pkg/router/lang_router.go @@ -4,9 +4,10 @@ import ( "context" "errors" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/extmodel" - "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/router/routing" "github.com/EinStack/glide/pkg/telemetry" @@ -59,7 +60,7 @@ func (r *LangRouter) ID() ID { return r.routerID } -func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) { +func (r *LangRouter) Chat(ctx context.Context, req *schema.ChatRequest) (*schema.ChatResponse, error) { if len(r.chatModels) == 0 { return nil, ErrNoModels } @@ -112,22 +113,22 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem // if we reach this part, then we are in trouble r.logger.Error("No model was available to handle chat request") - return nil, &schemas.ErrNoModelAvailable + return nil, &schema.ErrNoModelAvailable } func (r *LangRouter) ChatStream( ctx context.Context, - req *schemas.ChatStreamRequest, - respC chan<- *schemas.ChatStreamMessage, + req *schema.ChatStreamRequest, + respC chan<- *schema.ChatStreamMessage, ) { if len(r.chatStreamModels) == 0 { - respC <- schemas.NewChatStreamError( + respC <- schema.NewChatStreamError( req.ID, r.routerID, - schemas.NoModelConfigured, + schema.NoModelConfigured, ErrNoModels.Error(), req.Metadata, - &schemas.ReasonError, + &schema.ReasonError, ) return @@ -175,10 +176,10 @@ func (r *LangRouter) ChatStream( // It's challenging to hide an error in case of streaming chat as consumer apps // may have already used all chunks we streamed this far (e.g. showed them to their users like OpenAI UI does), // so we cannot easily restart that process from scratch - respC <- schemas.NewChatStreamError( + respC <- schema.NewChatStreamError( req.ID, r.routerID, - schemas.ModelUnavailable, + schema.ModelUnavailable, err.Error(), req.Metadata, nil, @@ -189,7 +190,7 @@ func (r *LangRouter) ChatStream( chunk := chunkResult.Chunk() - respC <- schemas.NewChatStreamChunk( + respC <- schema.NewChatStreamChunk( req.ID, r.routerID, req.Metadata, @@ -207,10 +208,10 @@ func (r *LangRouter) ChatStream( err := retryIterator.WaitNext(ctx) if err != nil { // something has cancelled the context - respC <- schemas.NewChatStreamError( + respC <- schema.NewChatStreamError( req.ID, r.routerID, - schemas.UnknownError, + schema.UnknownError, err.Error(), req.Metadata, nil, @@ -226,12 +227,12 @@ func (r *LangRouter) ChatStream( "Try to configure more fallback models to avoid this", ) - respC <- schemas.NewChatStreamError( + respC <- schema.NewChatStreamError( req.ID, r.routerID, - schemas.ErrNoModelAvailable.Name, - schemas.ErrNoModelAvailable.Message, + schema.ErrNoModelAvailable.Name, + schema.ErrNoModelAvailable.Message, req.Metadata, - &schemas.ReasonError, + &schema.ReasonError, ) } diff --git a/pkg/router/lang_router_test.go b/pkg/router/lang_router_test.go index 92eb5d4b..68c31cab 100644 --- a/pkg/router/lang_router_test.go +++ b/pkg/router/lang_router_test.go @@ -13,7 +13,6 @@ import ( "github.com/EinStack/glide/pkg/resiliency/health" "github.com/EinStack/glide/pkg/resiliency/retry" - "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/router/latency" "github.com/EinStack/glide/pkg/router/routing" "github.com/EinStack/glide/pkg/telemetry" @@ -56,7 +55,7 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { } ctx := context.Background() - req := schemas.NewChatFromStr("tell me a dad joke") + req := schema.NewChatFromStr("tell me a dad joke") for i := 0; i < 2; i++ { resp, err := router.Chat(ctx, req) @@ -73,14 +72,14 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "3"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Msg: "3"}}), budget, *latConfig, 1, ), extmodel.NewLangModel( "second", - provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "4"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Msg: "4"}}), budget, *latConfig, 1, @@ -113,7 +112,7 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { } ctx := context.Background() - req := schemas.NewChatFromStr("tell me a dad joke") + req := schema.NewChatFromStr("tell me a dad joke") for _, modelID := range expectedModels { resp, err := router.Chat(ctx, req) @@ -130,14 +129,14 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "2"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Msg: "2"}}), budget, *latConfig, 1, ), extmodel.NewLangModel( "second", - provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "1"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Msg: "1"}}), budget, *latConfig, 1, @@ -160,7 +159,7 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { logger: telemetry.NewLoggerMock(), } - resp, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) + resp, err := router.Chat(context.Background(), schema.NewChatFromStr("tell me a dad joke")) require.NoError(t, err) require.Equal(t, "first", resp.ModelID) @@ -204,7 +203,7 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { } for i := 0; i < 2; i++ { - resp, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) + resp, err := router.Chat(context.Background(), schema.NewChatFromStr("tell me a dad joke")) require.NoError(t, err) require.Equal(t, "second", resp.ModelID) @@ -218,14 +217,14 @@ func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Err: &schema.ErrNoModelAvailable}}), budget, *latConfig, 1, ), extmodel.NewLangModel( "second", - provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Err: &schema.ErrNoModelAvailable}}), budget, *latConfig, 1, @@ -248,7 +247,7 @@ func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { logger: telemetry.NewLoggerMock(), } - _, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) + _, err := router.Chat(context.Background(), schema.NewChatFromStr("tell me a dad joke")) require.Error(t, err) } @@ -305,8 +304,8 @@ func TestLangRouter_ChatStream(t *testing.T) { } ctx := context.Background() - req := schemas.NewChatStreamFromStr("tell me a dad joke") - respC := make(chan *schemas.ChatStreamMessage) + req := schema.NewChatStreamFromStr("tell me a dad joke") + respC := make(chan *schema.ChatStreamMessage) defer close(respC) @@ -374,8 +373,8 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { } ctx := context.Background() - req := schemas.NewChatStreamFromStr("tell me a dad joke") - respC := make(chan *schemas.ChatStreamMessage) + req := schema.NewChatStreamFromStr("tell me a dad joke") + respC := make(chan *schema.ChatStreamMessage) defer close(respC) @@ -442,10 +441,10 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { logger: telemetry.NewLoggerMock(), } - respC := make(chan *schemas.ChatStreamMessage) + respC := make(chan *schema.ChatStreamMessage) defer close(respC) - go router.ChatStream(context.Background(), schemas.NewChatStreamFromStr("tell me a dad joke"), respC) + go router.ChatStream(context.Background(), schema.NewChatStreamFromStr("tell me a dad joke"), respC) errs := make([]string, 0, 3) @@ -457,5 +456,5 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { errs = append(errs, result.Error.Name) } - require.Equal(t, []string{schemas.ModelUnavailable, schemas.ModelUnavailable, schemas.AllModelsUnavailable}, errs) + require.Equal(t, []string{schema.ModelUnavailable, schema.ModelUnavailable, schema.AllModelsUnavailable}, errs) } diff --git a/pkg/router/manager.go b/pkg/router/manager.go index b30afbcf..b0d2fd69 100644 --- a/pkg/router/manager.go +++ b/pkg/router/manager.go @@ -1,7 +1,7 @@ package router import ( - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" "github.com/EinStack/glide/pkg/telemetry" ) @@ -45,5 +45,5 @@ func (r *Manager) GetLangRouter(routerID string) (*LangRouter, error) { return router, nil } - return nil, &schemas.ErrRouterNotFound + return nil, &schema.ErrRouterNotFound } From e2817c063f519994035f23e92ad618fedecbe41c Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 12 Aug 2024 22:48:58 +0300 Subject: [PATCH 18/18] #67: Renamed the health schema --- pkg/api/http/handlers.go | 2 ++ pkg/api/schema/{health_checks.go => health.go} | 0 2 files changed, 2 insertions(+) rename pkg/api/schema/{health_checks.go => health.go} (100%) diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index b85e07ad..3db789eb 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -4,6 +4,8 @@ import ( "context" "sync" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/router" "github.com/EinStack/glide/pkg/telemetry" diff --git a/pkg/api/schema/health_checks.go b/pkg/api/schema/health.go similarity index 100% rename from pkg/api/schema/health_checks.go rename to pkg/api/schema/health.go