diff --git a/frontend/src/services/aiService.ts b/frontend/src/services/aiService.ts index 74c8e41..1b9c2cf 100644 --- a/frontend/src/services/aiService.ts +++ b/frontend/src/services/aiService.ts @@ -57,7 +57,6 @@ export const aiService = { provider: config.provider, endpoint: config.endpoint, model: config.model, - api_key: config.apiKey, max_tokens: config.maxTokens, timeout: formatTimeout(config.timeout) } @@ -67,7 +66,7 @@ export const aiService = { message: string provider: string error?: string - }>('test_connection', payload) + }>('test_connection', payload, { apiKey: config.apiKey }) return { success: toBoolean(result.success), @@ -130,12 +129,11 @@ export const aiService = { include_explanation: request.includeExplanation, provider: request.provider, endpoint: request.endpoint, - api_key: request.apiKey, max_tokens: request.maxTokens, timeout: formatTimeout(request.timeout), database_type: request.databaseDialect }) - }) + }, { apiKey: request.apiKey }) console.log('📥 [aiService] Received backend result', { hasContent: !!result.content, @@ -214,12 +212,11 @@ export const aiService = { provider: config.provider, endpoint: config.endpoint, model: config.model, - api_key: config.apiKey, max_tokens: config.maxTokens, timeout: formatTimeout(config.timeout), database_type: config.databaseDialect } - }) + }, { apiKey: config.apiKey }) } } @@ -239,7 +236,7 @@ function formatTimeout(timeout: number | undefined): string { * is designed for database queries and transforms the request format. * The AI plugin expects: {type: 'ai', key: 'operation', sql: 'params_json'} */ -async function callAPI(key: string, data: any): Promise { +async function callAPI(key: string, data: any, options: { apiKey?: string } = {}): Promise { const requestBody = { type: 'ai', key, @@ -254,12 +251,18 @@ async function callAPI(key: string, data: any): Promise { }) try { + const headers: Record = { + 'Content-Type': 'application/json', + 'X-Store-Name': API_STORE + } + + if (options.apiKey) { + headers['X-Auth'] = `Bearer ${options.apiKey}` + } + const response = await fetch(API_BASE, { method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'X-Store-Name': API_STORE - }, + headers, body: JSON.stringify(requestBody) }) diff --git a/frontend/tests/services/aiService.spec.ts b/frontend/tests/services/aiService.spec.ts index ced0313..ea5ad50 100644 --- a/frontend/tests/services/aiService.spec.ts +++ b/frontend/tests/services/aiService.spec.ts @@ -41,6 +41,7 @@ describe('aiService', () => { }) const payload = JSON.parse(body.sql) expect(payload.config).toContain('timeout') + expect(payload.config).not.toContain('api_key') return createFetchResponse({ data: [ @@ -68,6 +69,38 @@ describe('aiService', () => { expect(response.meta).toEqual({ confidence: 0.9, model: 'demo' }) }) + it('sends API key through authorization header only', async () => { + const apiKey = 'sk-secure' + fetchMock.mockImplementationOnce(async (_url: FetchArgs[0], options: FetchArgs[1]) => { + const headers = options?.headers as Record + expect(headers['X-Auth']).toBe(`Bearer ${apiKey}`) + + const body = JSON.parse(String(options?.body)) + const payload = JSON.parse(body.sql) + expect(payload.config).not.toContain('api_key') + + return createFetchResponse({ + data: [ + { key: 'success', value: true }, + { key: 'content', value: 'sql:SELECT 1;' }, + { key: 'meta', value: '{}' } + ] + }) + }) + + await aiService.generateSQL({ + provider: 'openai', + endpoint: 'https://api.openai.com', + apiKey, + model: 'gpt-5', + prompt: 'SELECT 1;', + timeout: 30, + maxTokens: 256, + includeExplanation: false, + databaseDialect: 'postgresql' + }) + }) + it('parses health check response when backend returns boolean healthy flag', async () => { fetchMock.mockResolvedValueOnce( createFetchResponse({ diff --git a/pkg/ai/engine.go b/pkg/ai/engine.go index c73ce26..9c2dace 100644 --- a/pkg/ai/engine.go +++ b/pkg/ai/engine.go @@ -64,6 +64,7 @@ type GenerateSQLRequest struct { NaturalLanguage string `json:"natural_language"` DatabaseType string `json:"database_type"` Context map[string]string `json:"context,omitempty"` + RuntimeAPIKey string `json:"-"` } // GenerateSQLResponse represents an AI SQL generation response @@ -113,7 +114,11 @@ func NewEngine(cfg config.AIConfig) (Engine, error) { engine, err := newEngineFromManager(manager, cfg) if err != nil { - _ = manager.Close() + if closeErr := manager.Close(); closeErr != nil { + logging.Logger.Warn("Failed to close AI manager after initialization error", + "provider", cfg.DefaultService, + "error", closeErr) + } return nil, err } return engine, nil @@ -163,7 +168,11 @@ func NewEngineWithManager(manager *Manager, cfg config.AIConfig) (Engine, error) engine, err := newEngineFromManager(manager, cfg) if err != nil { - _ = manager.Close() + if closeErr := manager.Close(); closeErr != nil { + logging.Logger.Warn("Failed to close AI manager after initialization error", + "provider", cfg.DefaultService, + "error", closeErr) + } return nil, err } return engine, nil @@ -206,6 +215,10 @@ func (e *aiEngine) GenerateSQL(ctx context.Context, req *GenerateSQLRequest) (*G MaxTokens: defaultMaxTokens, } + if req.RuntimeAPIKey != "" { + options.APIKey = req.RuntimeAPIKey + } + // Add context if provided and extract preferred_model and runtime config var runtimeConfig map[string]interface{} if len(req.Context) > 0 { @@ -236,7 +249,11 @@ func (e *aiEngine) GenerateSQL(ctx context.Context, req *GenerateSQLRequest) (*G options.Provider = provider } if apiKey, ok := runtimeConfig["api_key"].(string); ok && apiKey != "" { - options.APIKey = apiKey + if options.APIKey == "" { + options.APIKey = apiKey + } else if options.APIKey != apiKey { + logging.Logger.Warn("Runtime config API key differs from secured metadata; using secured value") + } } if endpoint, ok := runtimeConfig["endpoint"].(string); ok && endpoint != "" { options.Endpoint = endpoint @@ -313,6 +330,9 @@ func (e *aiEngine) Close() { e.generator.Close() } if e.manager != nil { - _ = e.manager.Close() + if err := e.manager.Close(); err != nil { + logging.Logger.Warn("Failed to close AI manager during engine shutdown", + "error", err) + } } } diff --git a/pkg/ai/generator.go b/pkg/ai/generator.go index 6804a92..e78d47e 100644 --- a/pkg/ai/generator.go +++ b/pkg/ai/generator.go @@ -17,6 +17,7 @@ limitations under the License. package ai import ( + "bytes" "context" "crypto/sha256" "encoding/hex" @@ -38,10 +39,15 @@ type SQLGenerator struct { sqlDialects map[string]SQLDialect config config.AIConfig capabilities *SQLCapabilities - runtimeClients map[string]interfaces.AIClient + runtimeClients map[string]*runtimeClientEntry runtimeMu sync.RWMutex } +type runtimeClientEntry struct { + client interfaces.AIClient + apiKeyFingerprint []byte +} + // Table represents a database table structure type Table struct { Name string `json:"name"` @@ -142,7 +148,7 @@ func NewSQLGenerator(aiClient interfaces.AIClient, config config.AIConfig) (*SQL aiClient: aiClient, config: config, sqlDialects: make(map[string]SQLDialect), - runtimeClients: make(map[string]interfaces.AIClient), + runtimeClients: make(map[string]*runtimeClientEntry), } // Initialize SQL dialects @@ -630,23 +636,38 @@ func runtimeClientKey(options *GenerateOptions) string { hasher.Write([]byte(options.Endpoint)) hasher.Write([]byte("|")) hasher.Write([]byte(options.Model)) - hasher.Write([]byte("|")) - hasher.Write([]byte(options.APIKey)) return hex.EncodeToString(hasher.Sum(nil)) } +func runtimeAPIKeyFingerprint(apiKey string) []byte { + if apiKey == "" { + return nil + } + sum := sha256.Sum256([]byte(apiKey)) + fingerprint := make([]byte, len(sum)) + copy(fingerprint, sum[:]) + return fingerprint +} + func (g *SQLGenerator) getOrCreateRuntimeClient(options *GenerateOptions) (interfaces.AIClient, bool, error) { key := runtimeClientKey(options) + fingerprint := runtimeAPIKeyFingerprint(options.APIKey) + g.runtimeMu.RLock() - if client, ok := g.runtimeClients[key]; ok { - g.runtimeMu.RUnlock() - return client, true, nil + if entry, ok := g.runtimeClients[key]; ok { + if bytes.Equal(entry.apiKeyFingerprint, fingerprint) { + client := entry.client + g.runtimeMu.RUnlock() + return client, true, nil + } } g.runtimeMu.RUnlock() runtimeConfig := map[string]any{ "provider": options.Provider, - "api_key": options.APIKey, + } + if options.APIKey != "" { + runtimeConfig["api_key"] = options.APIKey } if options.Endpoint != "" { runtimeConfig["base_url"] = options.Endpoint @@ -664,14 +685,38 @@ func (g *SQLGenerator) getOrCreateRuntimeClient(options *GenerateOptions) (inter } g.runtimeMu.Lock() - if existing, ok := g.runtimeClients[key]; ok { - g.runtimeMu.Unlock() - _ = client.Close() - return existing, true, nil + var ( + existingEntry *runtimeClientEntry + exists bool + ) + if existingEntry, exists = g.runtimeClients[key]; exists { + if bytes.Equal(existingEntry.apiKeyFingerprint, fingerprint) { + g.runtimeMu.Unlock() + if err := client.Close(); err != nil { + logging.Logger.Warn("Failed to close redundant runtime client", + "provider", options.Provider, + "endpoint", options.Endpoint, + "error", err) + } + return existingEntry.client, true, nil + } + } + + g.runtimeClients[key] = &runtimeClientEntry{ + client: client, + apiKeyFingerprint: fingerprint, } - g.runtimeClients[key] = client g.runtimeMu.Unlock() + if exists && existingEntry != nil && existingEntry.client != nil { + if err := existingEntry.client.Close(); err != nil { + logging.Logger.Warn("Failed to close stale runtime client", + "provider", options.Provider, + "endpoint", options.Endpoint, + "error", err) + } + } + return client, false, nil } @@ -679,8 +724,16 @@ func (g *SQLGenerator) getOrCreateRuntimeClient(options *GenerateOptions) (inter func (g *SQLGenerator) Close() { g.runtimeMu.Lock() defer g.runtimeMu.Unlock() - for key, client := range g.runtimeClients { - _ = client.Close() + for key, entry := range g.runtimeClients { + if entry == nil || entry.client == nil { + delete(g.runtimeClients, key) + continue + } + if err := entry.client.Close(); err != nil { + logging.Logger.Warn("Failed to close runtime client during generator shutdown", + "key", key, + "error", err) + } delete(g.runtimeClients, key) } } diff --git a/pkg/ai/generator_test.go b/pkg/ai/generator_test.go index ab0e129..7e632e4 100644 --- a/pkg/ai/generator_test.go +++ b/pkg/ai/generator_test.go @@ -3,13 +3,12 @@ package ai import ( "testing" - "github.com/linuxsuren/atest-ext-ai/pkg/interfaces" "github.com/stretchr/testify/require" ) func TestRuntimeClientReuseAndClose(t *testing.T) { generator := &SQLGenerator{ - runtimeClients: make(map[string]interfaces.AIClient), + runtimeClients: make(map[string]*runtimeClientEntry), } options := &GenerateOptions{ diff --git a/pkg/ai/manager.go b/pkg/ai/manager.go index da911d7..6f8e1be 100644 --- a/pkg/ai/manager.go +++ b/pkg/ai/manager.go @@ -291,7 +291,11 @@ func (m *Manager) AddClient(ctx context.Context, name string, svc config.AIServi // Close old client if exists if oldClient, exists := m.clients[name]; exists { - _ = oldClient.Close() + if err := oldClient.Close(); err != nil { + logging.Logger.Warn("Failed to close existing AI client", + "client", name, + "error", err) + } } m.clients[name] = client @@ -312,7 +316,11 @@ func (m *Manager) RemoveClient(name string) error { return fmt.Errorf("%w: %s", ErrClientNotFound, name) } - _ = client.Close() + if err := client.Close(); err != nil { + logging.Logger.Warn("Failed to close AI client", + "client", name, + "error", err) + } delete(m.clients, name) return nil } @@ -353,7 +361,11 @@ func (m *Manager) DiscoverProviders(ctx context.Context) ([]*ProviderInfo, error } providers = append(providers, provider) - _ = client.Close() + if err := client.Close(); err != nil { + logging.Logger.Warn("Failed to close discovery client", + "provider", provider.Name, + "error", err) + } } } @@ -407,7 +419,13 @@ func (m *Manager) TestConnection(ctx context.Context, cfg *universal.Config) (*C Error: err.Error(), }, nil } - defer func() { _ = client.Close() }() + defer func() { + if err := client.Close(); err != nil { + logging.Logger.Warn("Failed to close test connection client", + "provider", cfg.Provider, + "error", err) + } + }() health, err := client.HealthCheck(ctx) if err != nil { diff --git a/pkg/ai/providers/universal/client.go b/pkg/ai/providers/universal/client.go index 498b6a7..dcff09a 100644 --- a/pkg/ai/providers/universal/client.go +++ b/pkg/ai/providers/universal/client.go @@ -24,6 +24,7 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "time" "github.com/linuxsuren/atest-ext-ai/pkg/interfaces" @@ -33,19 +34,43 @@ import ( // Global HTTP client pool for connection reuse across providers // Using sync.Map for concurrent-safe access without explicit locking on read var ( - httpClientPool = &sync.Map{} // key: provider name (string), value: *http.Client + httpClientPool = &sync.Map{} // key: provider name (string), value: *pooledHTTPClient httpClientMu sync.Mutex // Mutex for client creation to prevent duplicate creation ) +type pooledHTTPClient struct { + provider string + client *http.Client + transport *http.Transport + refs atomic.Int32 +} + +func (p *pooledHTTPClient) retain() { + p.refs.Add(1) +} + +func (p *pooledHTTPClient) release() { + if p.refs.Add(-1) != 0 { + return + } + + if p.transport != nil { + p.transport.CloseIdleConnections() + } + httpClientPool.Delete(p.provider) +} + // getOrCreateHTTPClient retrieves an existing HTTP client from the pool or creates a new one // This implements connection pooling to improve performance and resource utilization // Based on Go net/http best practices for Transport configuration -func getOrCreateHTTPClient(provider string, timeout time.Duration) *http.Client { +func getOrCreateHTTPClient(provider string, timeout time.Duration) *pooledHTTPClient { // Try to get existing client from pool (fast path, no locking) if client, ok := httpClientPool.Load(provider); ok { + entry := client.(*pooledHTTPClient) + entry.retain() logging.Logger.Debug("Reusing HTTP client from pool", "provider", provider) - return client.(*http.Client) + return entry } // Client not found, need to create (slow path with locking) @@ -54,9 +79,11 @@ func getOrCreateHTTPClient(provider string, timeout time.Duration) *http.Client // Double-check: another goroutine might have created the client while we waited for the lock if client, ok := httpClientPool.Load(provider); ok { + entry := client.(*pooledHTTPClient) + entry.retain() logging.Logger.Debug("HTTP client created by another goroutine", "provider", provider) - return client.(*http.Client) + return entry } // Create new HTTP client with optimized transport settings @@ -65,25 +92,33 @@ func getOrCreateHTTPClient(provider string, timeout time.Duration) *http.Client // - MaxIdleConnsPerHost: Maximum idle connections per host (important for AI APIs) // - IdleConnTimeout: How long idle connections remain in the pool // - DisableCompression: Disabled for better compatibility with AI APIs + transport := &http.Transport{ + MaxIdleConns: 100, // Total pool size across all hosts + MaxIdleConnsPerHost: 10, // Per-host idle connection limit (AI APIs typically use 1 host) + IdleConnTimeout: 90 * time.Second, // Keep idle connections for 90s + DisableCompression: false, // Enable compression for better bandwidth utilization + // Additional recommended settings for production use: + MaxConnsPerHost: 0, // No limit on active connections (0 = unlimited) + ResponseHeaderTimeout: 30 * time.Second, // Timeout for reading response headers + ExpectContinueTimeout: 1 * time.Second, // Timeout for 100-Continue handshake + ForceAttemptHTTP2: true, // Enable HTTP/2 when available + DisableKeepAlives: false, // Enable keep-alives for connection reuse + TLSHandshakeTimeout: 10 * time.Second, // Timeout for TLS handshake + } + client := &http.Client{ - Timeout: timeout, - Transport: &http.Transport{ - MaxIdleConns: 100, // Total pool size across all hosts - MaxIdleConnsPerHost: 10, // Per-host idle connection limit (AI APIs typically use 1 host) - IdleConnTimeout: 90 * time.Second, // Keep idle connections for 90s - DisableCompression: false, // Enable compression for better bandwidth utilization - // Additional recommended settings for production use: - MaxConnsPerHost: 0, // No limit on active connections (0 = unlimited) - ResponseHeaderTimeout: 30 * time.Second, // Timeout for reading response headers - ExpectContinueTimeout: 1 * time.Second, // Timeout for 100-Continue handshake - ForceAttemptHTTP2: true, // Enable HTTP/2 when available - DisableKeepAlives: false, // Enable keep-alives for connection reuse - TLSHandshakeTimeout: 10 * time.Second, // Timeout for TLS handshake - }, + Timeout: timeout, + Transport: transport, } + entry := &pooledHTTPClient{ + client: client, + transport: transport, + } + entry.retain() + // Store in pool for reuse - httpClientPool.Store(provider, client) + httpClientPool.Store(provider, entry) logging.Logger.Info("Created new HTTP client with connection pooling", "provider", provider, @@ -92,13 +127,14 @@ func getOrCreateHTTPClient(provider string, timeout time.Duration) *http.Client "max_idle_conns_per_host", 10, "idle_conn_timeout", "90s") - return client + return entry } // Client implements a universal OpenAI-compatible API client. type Client struct { config *Config httpClient *http.Client + poolEntry *pooledHTTPClient strategy ProviderStrategy // Strategy pattern to handle provider-specific logic } @@ -175,12 +211,13 @@ func NewUniversalClient(config *Config) (*Client, error) { // Create HTTP client using connection pool for better performance // This reuses connections across requests to the same provider - httpClient := getOrCreateHTTPClient(config.Provider, config.Timeout) + pooledClient := getOrCreateHTTPClient(config.Provider, config.Timeout) client := &Client{ config: config, strategy: strategy, - httpClient: httpClient, + httpClient: pooledClient.client, + poolEntry: pooledClient, } logging.Logger.Debug("Universal client created", @@ -339,7 +376,10 @@ func (c *Client) HealthCheck(ctx context.Context) (*interfaces.HealthStatus, err // Close releases any resources held by the client func (c *Client) Close() error { - // No persistent connections to close + if c.poolEntry != nil { + c.poolEntry.release() + c.poolEntry = nil + } return nil } diff --git a/pkg/plugin/service.go b/pkg/plugin/service.go index d82bc58..5571d0d 100644 --- a/pkg/plugin/service.go +++ b/pkg/plugin/service.go @@ -21,6 +21,7 @@ import ( _ "embed" "encoding/json" "fmt" + "sort" "strings" "time" @@ -33,6 +34,7 @@ import ( "github.com/linuxsuren/atest-ext-ai/pkg/logging" "github.com/linuxsuren/atest-ext-ai/pkg/metrics" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -87,6 +89,124 @@ func contextError(ctx context.Context) error { return nil } +type contextKey string + +const apiKeyContextKey contextKey = "ai-plugin-runtime-api-key" + +func withAPIKey(ctx context.Context, apiKey string) context.Context { + if ctx == nil || apiKey == "" { + return ctx + } + return context.WithValue(ctx, apiKeyContextKey, apiKey) +} + +func apiKeyFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + if value, ok := ctx.Value(apiKeyContextKey).(string); ok { + return value + } + return "" +} + +func extractAPIKeyFromMetadata(ctx context.Context) string { + if ctx == nil { + return "" + } + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "" + } + + for key, values := range md { + switch strings.ToLower(key) { + case "auth", "x-auth", "x-ai-api-key", "authorization": + for _, raw := range values { + if normalized := normalizeAPIKeyValue(raw); normalized != "" { + return normalized + } + } + } + } + return "" +} + +func normalizeAPIKeyValue(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "" + } + if strings.HasPrefix(strings.ToLower(trimmed), "bearer ") { + return strings.TrimSpace(trimmed[7:]) + } + return trimmed +} + +func formatInitErrors(filter func(InitializationError) bool) string { + if len(initErrors) == 0 { + return "" + } + + var builder strings.Builder + for _, initErr := range initErrors { + if filter != nil && !filter(initErr) { + continue + } + if builder.Len() == 0 { + builder.WriteString(" Initialization errors:") + } + builder.WriteString(fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason)) + if len(initErr.Details) > 0 { + keys := make([]string, 0, len(initErr.Details)) + for key := range initErr.Details { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + builder.WriteString(fmt.Sprintf("\n %s: %s", key, initErr.Details[key])) + } + } + } + + return builder.String() +} + +func (s *AIPluginService) requireEngineAvailable(operation, baseMessage, fallback string) error { + if s.aiEngine != nil { + return nil + } + + logging.Logger.Error(operation) + errMsg := baseMessage + if details := formatInitErrors(nil); details != "" { + errMsg += details + } else if fallback != "" { + errMsg += " " + fallback + } + return status.Error(codes.FailedPrecondition, errMsg) +} + +const managerFallbackMessage = "Please check AI service configuration." + +func (s *AIPluginService) requireManagerAvailable(operation, baseMessage string) error { + if s.aiManager != nil { + return nil + } + + logging.Logger.Error(operation) + details := formatInitErrors(func(initErr InitializationError) bool { + return initErr.Component == "AI Manager" + }) + errMsg := baseMessage + if details != "" { + errMsg += details + } else { + errMsg += " " + managerFallbackMessage + } + return status.Error(codes.FailedPrecondition, errMsg) +} + func normalizeDatabaseType(value string) string { dbType := strings.ToLower(strings.TrimSpace(value)) switch dbType { @@ -300,6 +420,8 @@ func (s *AIPluginService) Query(ctx context.Context, req *server.DataQuery) (*se "key", req.Key, "sql_length", len(req.Sql)) + ctx = withAPIKey(ctx, extractAPIKeyFromMetadata(ctx)) + // Accept both empty type (for backward compatibility) and explicit "ai" type // The main project doesn't always send the type field if req.Type != "" && req.Type != "ai" { @@ -310,165 +432,51 @@ func (s *AIPluginService) Query(ctx context.Context, req *server.DataQuery) (*se // Handle new AI interface standard switch req.Key { case "generate": - // Check AI engine availability for generation requests - if s.aiEngine == nil { - logging.Logger.Error("AI generation requested but AI engine is not available") - - // Build enhanced error message with initialization details - errMsg := "AI generation service is currently unavailable." - - // Add specific initialization error information if available - if len(initErrors) > 0 { - errMsg += " Initialization errors:" - for _, initErr := range initErrors { - errMsg += fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason) - // Add relevant details - if len(initErr.Details) > 0 { - for key, value := range initErr.Details { - errMsg += fmt.Sprintf("\n %s: %s", key, value) - } - } - } - } else { - errMsg += " Please check AI provider configuration and connectivity." - } - - return nil, status.Error(codes.FailedPrecondition, errMsg) + if err := s.requireEngineAvailable( + "AI generation requested but AI engine is not available", + "AI generation service is currently unavailable.", + "Please check AI provider configuration and connectivity."); err != nil { + return nil, err } return s.handleAIGenerate(ctx, req) case "capabilities": return s.handleAICapabilities(ctx, req) case "providers": - // Check AI manager availability for provider operations - if s.aiManager == nil { - logging.Logger.Error("Provider discovery requested but AI manager is not available") - - // Build enhanced error message with initialization details - errMsg := "AI provider discovery is currently unavailable." - if len(initErrors) > 0 { - errMsg += " Initialization errors:" - for _, initErr := range initErrors { - if initErr.Component == "AI Manager" { - errMsg += fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason) - if len(initErr.Details) > 0 { - for key, value := range initErr.Details { - errMsg += fmt.Sprintf("\n %s: %s", key, value) - } - } - } - } - } else { - errMsg += " Please check AI service configuration." - } - - return nil, status.Error(codes.FailedPrecondition, errMsg) + if err := s.requireManagerAvailable( + "Provider discovery requested but AI manager is not available", + "AI provider discovery is currently unavailable."); err != nil { + return nil, err } return s.handleGetProviders(ctx, req) case "models": - // Check AI manager availability for model operations - if s.aiManager == nil { - logging.Logger.Error("Model listing requested but AI manager is not available") - - // Build enhanced error message with initialization details - errMsg := "AI model listing is currently unavailable." - if len(initErrors) > 0 { - errMsg += " Initialization errors:" - for _, initErr := range initErrors { - if initErr.Component == "AI Manager" { - errMsg += fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason) - if len(initErr.Details) > 0 { - for key, value := range initErr.Details { - errMsg += fmt.Sprintf("\n %s: %s", key, value) - } - } - } - } - } else { - errMsg += " Please check AI service configuration." - } - - return nil, status.Error(codes.FailedPrecondition, errMsg) + if err := s.requireManagerAvailable( + "Model listing requested but AI manager is not available", + "AI model listing is currently unavailable."); err != nil { + return nil, err } return s.handleGetModels(ctx, req) case "test_connection": - // Connection testing can work even without initialized services - if s.aiManager == nil { - logging.Logger.Error("Connection test requested but AI manager is not available") - - // Build enhanced error message with initialization details - errMsg := "AI connection testing is currently unavailable." - if len(initErrors) > 0 { - errMsg += " Initialization errors:" - for _, initErr := range initErrors { - if initErr.Component == "AI Manager" { - errMsg += fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason) - if len(initErr.Details) > 0 { - for key, value := range initErr.Details { - errMsg += fmt.Sprintf("\n %s: %s", key, value) - } - } - } - } - } else { - errMsg += " Please check AI service configuration." - } - - return nil, status.Error(codes.FailedPrecondition, errMsg) + if err := s.requireManagerAvailable( + "Connection test requested but AI manager is not available", + "AI connection testing is currently unavailable."); err != nil { + return nil, err } return s.handleTestConnection(ctx, req) case "health_check": return s.handleHealthCheck(ctx, req) case "update_config": - if s.aiManager == nil { - logging.Logger.Error("Config update requested but AI manager is not available") - - // Build enhanced error message with initialization details - errMsg := "AI configuration update is currently unavailable." - if len(initErrors) > 0 { - errMsg += " Initialization errors:" - for _, initErr := range initErrors { - if initErr.Component == "AI Manager" { - errMsg += fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason) - if len(initErr.Details) > 0 { - for key, value := range initErr.Details { - errMsg += fmt.Sprintf("\n %s: %s", key, value) - } - } - } - } - } else { - errMsg += " Please check AI service configuration." - } - - return nil, status.Error(codes.FailedPrecondition, errMsg) + if err := s.requireManagerAvailable( + "Config update requested but AI manager is not available", + "AI configuration update is currently unavailable."); err != nil { + return nil, err } return s.handleUpdateConfig(ctx, req) default: - // Backward compatibility: support legacy natural language queries - // Check AI engine availability for legacy queries - if s.aiEngine == nil { - logging.Logger.Error("AI query requested but AI engine is not available") - - // Build enhanced error message with initialization details - errMsg := "AI service is currently unavailable." - - // Add specific initialization error information if available - if len(initErrors) > 0 { - errMsg += " Initialization errors:" - for _, initErr := range initErrors { - errMsg += fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason) - // Add relevant details - if len(initErr.Details) > 0 { - for key, value := range initErr.Details { - errMsg += fmt.Sprintf("\n %s: %s", key, value) - } - } - } - } else { - errMsg += " Please check AI provider configuration and connectivity." - } - - return nil, status.Error(codes.FailedPrecondition, errMsg) + if err := s.requireEngineAvailable( + "AI query requested but AI engine is not available", + "AI service is currently unavailable.", + "Please check AI provider configuration and connectivity."); err != nil { + return nil, err } return s.handleLegacyQuery(ctx, req) } @@ -733,6 +741,8 @@ func (s *AIPluginService) handleAIGenerate(ctx context.Context, req *server.Data "prompt_length", len(params.Prompt), "has_config", params.Config != "") + apiKey := apiKeyFromContext(ctx) + // Generate using AI engine context := map[string]string{} if params.Model != "" { @@ -751,6 +761,7 @@ func (s *AIPluginService) handleAIGenerate(ctx context.Context, req *server.Data NaturalLanguage: params.Prompt, DatabaseType: databaseType, Context: context, + RuntimeAPIKey: apiKey, }) if err != nil { metrics.RecordRequest("generate", provider, "error") @@ -1277,6 +1288,12 @@ func (s *AIPluginService) handleTestConnection(ctx context.Context, req *server. config.Provider = "ollama" } + if config.APIKey == "" { + if apiKey := apiKeyFromContext(ctx); apiKey != "" { + config.APIKey = apiKey + } + } + // Log configuration for debugging (mask API key) apiKeyDisplay := "***masked***" if config.APIKey != "" && len(config.APIKey) > 4 { @@ -1308,7 +1325,7 @@ func (s *AIPluginService) handleTestConnection(ctx context.Context, req *server. } // handleUpdateConfig updates the configuration for a provider -func (s *AIPluginService) handleUpdateConfig(_ context.Context, req *server.DataQuery) (*server.DataQueryResult, error) { +func (s *AIPluginService) handleUpdateConfig(ctx context.Context, req *server.DataQuery) (*server.DataQueryResult, error) { logging.Logger.Debug("Handling update config request", "sql_length", len(req.Sql)) // Parse update request from SQL field @@ -1345,6 +1362,12 @@ func (s *AIPluginService) handleUpdateConfig(_ context.Context, req *server.Data return nil, apperrors.ToGRPCError(apperrors.ErrInvalidRequest) } + if updateReq.Config.APIKey == "" { + if apiKey := apiKeyFromContext(ctx); apiKey != "" { + updateReq.Config.APIKey = apiKey + } + } + // Map "local" to "ollama" for backward compatibility if updateReq.Provider == "local" { updateReq.Provider = "ollama" @@ -1395,7 +1418,11 @@ func (s *AIPluginService) handleUpdateConfig(_ context.Context, req *server.Data logging.Logger.Error("Failed to rebuild AI engine", "provider", updateReq.Provider, "error", err) - _ = manager.Close() + if closeErr := manager.Close(); closeErr != nil { + logging.Logger.Warn("Failed to close AI manager after rebuild error", + "provider", updateReq.Provider, + "error", closeErr) + } return nil, apperrors.ToGRPCErrorf(apperrors.ErrInvalidConfig, "failed to rebuild AI engine: %v", err) } diff --git a/pkg/plugin/service_test.go b/pkg/plugin/service_test.go index 86aaee6..7cdfb4b 100644 --- a/pkg/plugin/service_test.go +++ b/pkg/plugin/service_test.go @@ -369,7 +369,6 @@ func TestHandleUpdateConfigRefreshesEngine(t *testing.T) { "provider": "ollama", "endpoint": "http://localhost:11439", "model": "test-model", - "api_key": "", "max_tokens": 1337, }, }