From b10bf6918bb52e85a22c7c1e8e51cb5ecb2aaaea Mon Sep 17 00:00:00 2001 From: KarielHalling Date: Wed, 5 Nov 2025 22:58:08 +0800 Subject: [PATCH] fix(ai): secure runtime credentials across stack Secure runtime API key handling across front/back-end, fix client leaks Rationale A. Prevent API keys from being exposed in request payloads and logs B. Remove key material from runtime cache hashing and reuse logic C. Ensure runtime clients and HTTP transports are closed when unused Changes A. Frontend now injects API keys via X-Auth header only and extends tests B. Backend propagates metadata API keys, hardens availability errors, and warns on stale clients C. Universal client pool adds reference counting with proper Close housekeeping Impact A. Aligns with security redline by avoiding key exposure and leaking sockets B. Backward compatible for existing UI/API consumers C. Low risk; added logs aid debugging and stale clients close gracefully Test A. go test ./... B. npm run test -- --run Refs A. Security issue 8, Resource leak issue 5, Error swallowing issue 7 --- frontend/src/services/aiService.ts | 25 +- frontend/tests/services/aiService.spec.ts | 33 +++ pkg/ai/engine.go | 28 +- pkg/ai/generator.go | 83 ++++-- pkg/ai/generator_test.go | 3 +- pkg/ai/manager.go | 26 +- pkg/ai/providers/universal/client.go | 86 ++++-- pkg/plugin/service.go | 311 ++++++++++++---------- pkg/plugin/service_test.go | 1 - 9 files changed, 394 insertions(+), 202 deletions(-) diff --git a/frontend/src/services/aiService.ts b/frontend/src/services/aiService.ts index 74c8e41..1b9c2cf 100644 --- a/frontend/src/services/aiService.ts +++ b/frontend/src/services/aiService.ts @@ -57,7 +57,6 @@ export const aiService = { provider: config.provider, endpoint: config.endpoint, model: config.model, - api_key: config.apiKey, max_tokens: config.maxTokens, timeout: formatTimeout(config.timeout) } @@ -67,7 +66,7 @@ export const aiService = { message: string provider: string error?: string - }>('test_connection', payload) + }>('test_connection', payload, { apiKey: config.apiKey }) return { success: toBoolean(result.success), @@ -130,12 +129,11 @@ export const aiService = { include_explanation: request.includeExplanation, provider: request.provider, endpoint: request.endpoint, - api_key: request.apiKey, max_tokens: request.maxTokens, timeout: formatTimeout(request.timeout), database_type: request.databaseDialect }) - }) + }, { apiKey: request.apiKey }) console.log('📥 [aiService] Received backend result', { hasContent: !!result.content, @@ -214,12 +212,11 @@ export const aiService = { provider: config.provider, endpoint: config.endpoint, model: config.model, - api_key: config.apiKey, max_tokens: config.maxTokens, timeout: formatTimeout(config.timeout), database_type: config.databaseDialect } - }) + }, { apiKey: config.apiKey }) } } @@ -239,7 +236,7 @@ function formatTimeout(timeout: number | undefined): string { * is designed for database queries and transforms the request format. * The AI plugin expects: {type: 'ai', key: 'operation', sql: 'params_json'} */ -async function callAPI(key: string, data: any): Promise { +async function callAPI(key: string, data: any, options: { apiKey?: string } = {}): Promise { const requestBody = { type: 'ai', key, @@ -254,12 +251,18 @@ async function callAPI(key: string, data: any): Promise { }) try { + const headers: Record = { + 'Content-Type': 'application/json', + 'X-Store-Name': API_STORE + } + + if (options.apiKey) { + headers['X-Auth'] = `Bearer ${options.apiKey}` + } + const response = await fetch(API_BASE, { method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'X-Store-Name': API_STORE - }, + headers, body: JSON.stringify(requestBody) }) diff --git a/frontend/tests/services/aiService.spec.ts b/frontend/tests/services/aiService.spec.ts index ced0313..ea5ad50 100644 --- a/frontend/tests/services/aiService.spec.ts +++ b/frontend/tests/services/aiService.spec.ts @@ -41,6 +41,7 @@ describe('aiService', () => { }) const payload = JSON.parse(body.sql) expect(payload.config).toContain('timeout') + expect(payload.config).not.toContain('api_key') return createFetchResponse({ data: [ @@ -68,6 +69,38 @@ describe('aiService', () => { expect(response.meta).toEqual({ confidence: 0.9, model: 'demo' }) }) + it('sends API key through authorization header only', async () => { + const apiKey = 'sk-secure' + fetchMock.mockImplementationOnce(async (_url: FetchArgs[0], options: FetchArgs[1]) => { + const headers = options?.headers as Record + expect(headers['X-Auth']).toBe(`Bearer ${apiKey}`) + + const body = JSON.parse(String(options?.body)) + const payload = JSON.parse(body.sql) + expect(payload.config).not.toContain('api_key') + + return createFetchResponse({ + data: [ + { key: 'success', value: true }, + { key: 'content', value: 'sql:SELECT 1;' }, + { key: 'meta', value: '{}' } + ] + }) + }) + + await aiService.generateSQL({ + provider: 'openai', + endpoint: 'https://api.openai.com', + apiKey, + model: 'gpt-5', + prompt: 'SELECT 1;', + timeout: 30, + maxTokens: 256, + includeExplanation: false, + databaseDialect: 'postgresql' + }) + }) + it('parses health check response when backend returns boolean healthy flag', async () => { fetchMock.mockResolvedValueOnce( createFetchResponse({ diff --git a/pkg/ai/engine.go b/pkg/ai/engine.go index c73ce26..9c2dace 100644 --- a/pkg/ai/engine.go +++ b/pkg/ai/engine.go @@ -64,6 +64,7 @@ type GenerateSQLRequest struct { NaturalLanguage string `json:"natural_language"` DatabaseType string `json:"database_type"` Context map[string]string `json:"context,omitempty"` + RuntimeAPIKey string `json:"-"` } // GenerateSQLResponse represents an AI SQL generation response @@ -113,7 +114,11 @@ func NewEngine(cfg config.AIConfig) (Engine, error) { engine, err := newEngineFromManager(manager, cfg) if err != nil { - _ = manager.Close() + if closeErr := manager.Close(); closeErr != nil { + logging.Logger.Warn("Failed to close AI manager after initialization error", + "provider", cfg.DefaultService, + "error", closeErr) + } return nil, err } return engine, nil @@ -163,7 +168,11 @@ func NewEngineWithManager(manager *Manager, cfg config.AIConfig) (Engine, error) engine, err := newEngineFromManager(manager, cfg) if err != nil { - _ = manager.Close() + if closeErr := manager.Close(); closeErr != nil { + logging.Logger.Warn("Failed to close AI manager after initialization error", + "provider", cfg.DefaultService, + "error", closeErr) + } return nil, err } return engine, nil @@ -206,6 +215,10 @@ func (e *aiEngine) GenerateSQL(ctx context.Context, req *GenerateSQLRequest) (*G MaxTokens: defaultMaxTokens, } + if req.RuntimeAPIKey != "" { + options.APIKey = req.RuntimeAPIKey + } + // Add context if provided and extract preferred_model and runtime config var runtimeConfig map[string]interface{} if len(req.Context) > 0 { @@ -236,7 +249,11 @@ func (e *aiEngine) GenerateSQL(ctx context.Context, req *GenerateSQLRequest) (*G options.Provider = provider } if apiKey, ok := runtimeConfig["api_key"].(string); ok && apiKey != "" { - options.APIKey = apiKey + if options.APIKey == "" { + options.APIKey = apiKey + } else if options.APIKey != apiKey { + logging.Logger.Warn("Runtime config API key differs from secured metadata; using secured value") + } } if endpoint, ok := runtimeConfig["endpoint"].(string); ok && endpoint != "" { options.Endpoint = endpoint @@ -313,6 +330,9 @@ func (e *aiEngine) Close() { e.generator.Close() } if e.manager != nil { - _ = e.manager.Close() + if err := e.manager.Close(); err != nil { + logging.Logger.Warn("Failed to close AI manager during engine shutdown", + "error", err) + } } } diff --git a/pkg/ai/generator.go b/pkg/ai/generator.go index 6804a92..e78d47e 100644 --- a/pkg/ai/generator.go +++ b/pkg/ai/generator.go @@ -17,6 +17,7 @@ limitations under the License. package ai import ( + "bytes" "context" "crypto/sha256" "encoding/hex" @@ -38,10 +39,15 @@ type SQLGenerator struct { sqlDialects map[string]SQLDialect config config.AIConfig capabilities *SQLCapabilities - runtimeClients map[string]interfaces.AIClient + runtimeClients map[string]*runtimeClientEntry runtimeMu sync.RWMutex } +type runtimeClientEntry struct { + client interfaces.AIClient + apiKeyFingerprint []byte +} + // Table represents a database table structure type Table struct { Name string `json:"name"` @@ -142,7 +148,7 @@ func NewSQLGenerator(aiClient interfaces.AIClient, config config.AIConfig) (*SQL aiClient: aiClient, config: config, sqlDialects: make(map[string]SQLDialect), - runtimeClients: make(map[string]interfaces.AIClient), + runtimeClients: make(map[string]*runtimeClientEntry), } // Initialize SQL dialects @@ -630,23 +636,38 @@ func runtimeClientKey(options *GenerateOptions) string { hasher.Write([]byte(options.Endpoint)) hasher.Write([]byte("|")) hasher.Write([]byte(options.Model)) - hasher.Write([]byte("|")) - hasher.Write([]byte(options.APIKey)) return hex.EncodeToString(hasher.Sum(nil)) } +func runtimeAPIKeyFingerprint(apiKey string) []byte { + if apiKey == "" { + return nil + } + sum := sha256.Sum256([]byte(apiKey)) + fingerprint := make([]byte, len(sum)) + copy(fingerprint, sum[:]) + return fingerprint +} + func (g *SQLGenerator) getOrCreateRuntimeClient(options *GenerateOptions) (interfaces.AIClient, bool, error) { key := runtimeClientKey(options) + fingerprint := runtimeAPIKeyFingerprint(options.APIKey) + g.runtimeMu.RLock() - if client, ok := g.runtimeClients[key]; ok { - g.runtimeMu.RUnlock() - return client, true, nil + if entry, ok := g.runtimeClients[key]; ok { + if bytes.Equal(entry.apiKeyFingerprint, fingerprint) { + client := entry.client + g.runtimeMu.RUnlock() + return client, true, nil + } } g.runtimeMu.RUnlock() runtimeConfig := map[string]any{ "provider": options.Provider, - "api_key": options.APIKey, + } + if options.APIKey != "" { + runtimeConfig["api_key"] = options.APIKey } if options.Endpoint != "" { runtimeConfig["base_url"] = options.Endpoint @@ -664,14 +685,38 @@ func (g *SQLGenerator) getOrCreateRuntimeClient(options *GenerateOptions) (inter } g.runtimeMu.Lock() - if existing, ok := g.runtimeClients[key]; ok { - g.runtimeMu.Unlock() - _ = client.Close() - return existing, true, nil + var ( + existingEntry *runtimeClientEntry + exists bool + ) + if existingEntry, exists = g.runtimeClients[key]; exists { + if bytes.Equal(existingEntry.apiKeyFingerprint, fingerprint) { + g.runtimeMu.Unlock() + if err := client.Close(); err != nil { + logging.Logger.Warn("Failed to close redundant runtime client", + "provider", options.Provider, + "endpoint", options.Endpoint, + "error", err) + } + return existingEntry.client, true, nil + } + } + + g.runtimeClients[key] = &runtimeClientEntry{ + client: client, + apiKeyFingerprint: fingerprint, } - g.runtimeClients[key] = client g.runtimeMu.Unlock() + if exists && existingEntry != nil && existingEntry.client != nil { + if err := existingEntry.client.Close(); err != nil { + logging.Logger.Warn("Failed to close stale runtime client", + "provider", options.Provider, + "endpoint", options.Endpoint, + "error", err) + } + } + return client, false, nil } @@ -679,8 +724,16 @@ func (g *SQLGenerator) getOrCreateRuntimeClient(options *GenerateOptions) (inter func (g *SQLGenerator) Close() { g.runtimeMu.Lock() defer g.runtimeMu.Unlock() - for key, client := range g.runtimeClients { - _ = client.Close() + for key, entry := range g.runtimeClients { + if entry == nil || entry.client == nil { + delete(g.runtimeClients, key) + continue + } + if err := entry.client.Close(); err != nil { + logging.Logger.Warn("Failed to close runtime client during generator shutdown", + "key", key, + "error", err) + } delete(g.runtimeClients, key) } } diff --git a/pkg/ai/generator_test.go b/pkg/ai/generator_test.go index ab0e129..7e632e4 100644 --- a/pkg/ai/generator_test.go +++ b/pkg/ai/generator_test.go @@ -3,13 +3,12 @@ package ai import ( "testing" - "github.com/linuxsuren/atest-ext-ai/pkg/interfaces" "github.com/stretchr/testify/require" ) func TestRuntimeClientReuseAndClose(t *testing.T) { generator := &SQLGenerator{ - runtimeClients: make(map[string]interfaces.AIClient), + runtimeClients: make(map[string]*runtimeClientEntry), } options := &GenerateOptions{ diff --git a/pkg/ai/manager.go b/pkg/ai/manager.go index da911d7..6f8e1be 100644 --- a/pkg/ai/manager.go +++ b/pkg/ai/manager.go @@ -291,7 +291,11 @@ func (m *Manager) AddClient(ctx context.Context, name string, svc config.AIServi // Close old client if exists if oldClient, exists := m.clients[name]; exists { - _ = oldClient.Close() + if err := oldClient.Close(); err != nil { + logging.Logger.Warn("Failed to close existing AI client", + "client", name, + "error", err) + } } m.clients[name] = client @@ -312,7 +316,11 @@ func (m *Manager) RemoveClient(name string) error { return fmt.Errorf("%w: %s", ErrClientNotFound, name) } - _ = client.Close() + if err := client.Close(); err != nil { + logging.Logger.Warn("Failed to close AI client", + "client", name, + "error", err) + } delete(m.clients, name) return nil } @@ -353,7 +361,11 @@ func (m *Manager) DiscoverProviders(ctx context.Context) ([]*ProviderInfo, error } providers = append(providers, provider) - _ = client.Close() + if err := client.Close(); err != nil { + logging.Logger.Warn("Failed to close discovery client", + "provider", provider.Name, + "error", err) + } } } @@ -407,7 +419,13 @@ func (m *Manager) TestConnection(ctx context.Context, cfg *universal.Config) (*C Error: err.Error(), }, nil } - defer func() { _ = client.Close() }() + defer func() { + if err := client.Close(); err != nil { + logging.Logger.Warn("Failed to close test connection client", + "provider", cfg.Provider, + "error", err) + } + }() health, err := client.HealthCheck(ctx) if err != nil { diff --git a/pkg/ai/providers/universal/client.go b/pkg/ai/providers/universal/client.go index 498b6a7..dcff09a 100644 --- a/pkg/ai/providers/universal/client.go +++ b/pkg/ai/providers/universal/client.go @@ -24,6 +24,7 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "time" "github.com/linuxsuren/atest-ext-ai/pkg/interfaces" @@ -33,19 +34,43 @@ import ( // Global HTTP client pool for connection reuse across providers // Using sync.Map for concurrent-safe access without explicit locking on read var ( - httpClientPool = &sync.Map{} // key: provider name (string), value: *http.Client + httpClientPool = &sync.Map{} // key: provider name (string), value: *pooledHTTPClient httpClientMu sync.Mutex // Mutex for client creation to prevent duplicate creation ) +type pooledHTTPClient struct { + provider string + client *http.Client + transport *http.Transport + refs atomic.Int32 +} + +func (p *pooledHTTPClient) retain() { + p.refs.Add(1) +} + +func (p *pooledHTTPClient) release() { + if p.refs.Add(-1) != 0 { + return + } + + if p.transport != nil { + p.transport.CloseIdleConnections() + } + httpClientPool.Delete(p.provider) +} + // getOrCreateHTTPClient retrieves an existing HTTP client from the pool or creates a new one // This implements connection pooling to improve performance and resource utilization // Based on Go net/http best practices for Transport configuration -func getOrCreateHTTPClient(provider string, timeout time.Duration) *http.Client { +func getOrCreateHTTPClient(provider string, timeout time.Duration) *pooledHTTPClient { // Try to get existing client from pool (fast path, no locking) if client, ok := httpClientPool.Load(provider); ok { + entry := client.(*pooledHTTPClient) + entry.retain() logging.Logger.Debug("Reusing HTTP client from pool", "provider", provider) - return client.(*http.Client) + return entry } // Client not found, need to create (slow path with locking) @@ -54,9 +79,11 @@ func getOrCreateHTTPClient(provider string, timeout time.Duration) *http.Client // Double-check: another goroutine might have created the client while we waited for the lock if client, ok := httpClientPool.Load(provider); ok { + entry := client.(*pooledHTTPClient) + entry.retain() logging.Logger.Debug("HTTP client created by another goroutine", "provider", provider) - return client.(*http.Client) + return entry } // Create new HTTP client with optimized transport settings @@ -65,25 +92,33 @@ func getOrCreateHTTPClient(provider string, timeout time.Duration) *http.Client // - MaxIdleConnsPerHost: Maximum idle connections per host (important for AI APIs) // - IdleConnTimeout: How long idle connections remain in the pool // - DisableCompression: Disabled for better compatibility with AI APIs + transport := &http.Transport{ + MaxIdleConns: 100, // Total pool size across all hosts + MaxIdleConnsPerHost: 10, // Per-host idle connection limit (AI APIs typically use 1 host) + IdleConnTimeout: 90 * time.Second, // Keep idle connections for 90s + DisableCompression: false, // Enable compression for better bandwidth utilization + // Additional recommended settings for production use: + MaxConnsPerHost: 0, // No limit on active connections (0 = unlimited) + ResponseHeaderTimeout: 30 * time.Second, // Timeout for reading response headers + ExpectContinueTimeout: 1 * time.Second, // Timeout for 100-Continue handshake + ForceAttemptHTTP2: true, // Enable HTTP/2 when available + DisableKeepAlives: false, // Enable keep-alives for connection reuse + TLSHandshakeTimeout: 10 * time.Second, // Timeout for TLS handshake + } + client := &http.Client{ - Timeout: timeout, - Transport: &http.Transport{ - MaxIdleConns: 100, // Total pool size across all hosts - MaxIdleConnsPerHost: 10, // Per-host idle connection limit (AI APIs typically use 1 host) - IdleConnTimeout: 90 * time.Second, // Keep idle connections for 90s - DisableCompression: false, // Enable compression for better bandwidth utilization - // Additional recommended settings for production use: - MaxConnsPerHost: 0, // No limit on active connections (0 = unlimited) - ResponseHeaderTimeout: 30 * time.Second, // Timeout for reading response headers - ExpectContinueTimeout: 1 * time.Second, // Timeout for 100-Continue handshake - ForceAttemptHTTP2: true, // Enable HTTP/2 when available - DisableKeepAlives: false, // Enable keep-alives for connection reuse - TLSHandshakeTimeout: 10 * time.Second, // Timeout for TLS handshake - }, + Timeout: timeout, + Transport: transport, } + entry := &pooledHTTPClient{ + client: client, + transport: transport, + } + entry.retain() + // Store in pool for reuse - httpClientPool.Store(provider, client) + httpClientPool.Store(provider, entry) logging.Logger.Info("Created new HTTP client with connection pooling", "provider", provider, @@ -92,13 +127,14 @@ func getOrCreateHTTPClient(provider string, timeout time.Duration) *http.Client "max_idle_conns_per_host", 10, "idle_conn_timeout", "90s") - return client + return entry } // Client implements a universal OpenAI-compatible API client. type Client struct { config *Config httpClient *http.Client + poolEntry *pooledHTTPClient strategy ProviderStrategy // Strategy pattern to handle provider-specific logic } @@ -175,12 +211,13 @@ func NewUniversalClient(config *Config) (*Client, error) { // Create HTTP client using connection pool for better performance // This reuses connections across requests to the same provider - httpClient := getOrCreateHTTPClient(config.Provider, config.Timeout) + pooledClient := getOrCreateHTTPClient(config.Provider, config.Timeout) client := &Client{ config: config, strategy: strategy, - httpClient: httpClient, + httpClient: pooledClient.client, + poolEntry: pooledClient, } logging.Logger.Debug("Universal client created", @@ -339,7 +376,10 @@ func (c *Client) HealthCheck(ctx context.Context) (*interfaces.HealthStatus, err // Close releases any resources held by the client func (c *Client) Close() error { - // No persistent connections to close + if c.poolEntry != nil { + c.poolEntry.release() + c.poolEntry = nil + } return nil } diff --git a/pkg/plugin/service.go b/pkg/plugin/service.go index d82bc58..5571d0d 100644 --- a/pkg/plugin/service.go +++ b/pkg/plugin/service.go @@ -21,6 +21,7 @@ import ( _ "embed" "encoding/json" "fmt" + "sort" "strings" "time" @@ -33,6 +34,7 @@ import ( "github.com/linuxsuren/atest-ext-ai/pkg/logging" "github.com/linuxsuren/atest-ext-ai/pkg/metrics" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -87,6 +89,124 @@ func contextError(ctx context.Context) error { return nil } +type contextKey string + +const apiKeyContextKey contextKey = "ai-plugin-runtime-api-key" + +func withAPIKey(ctx context.Context, apiKey string) context.Context { + if ctx == nil || apiKey == "" { + return ctx + } + return context.WithValue(ctx, apiKeyContextKey, apiKey) +} + +func apiKeyFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + if value, ok := ctx.Value(apiKeyContextKey).(string); ok { + return value + } + return "" +} + +func extractAPIKeyFromMetadata(ctx context.Context) string { + if ctx == nil { + return "" + } + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "" + } + + for key, values := range md { + switch strings.ToLower(key) { + case "auth", "x-auth", "x-ai-api-key", "authorization": + for _, raw := range values { + if normalized := normalizeAPIKeyValue(raw); normalized != "" { + return normalized + } + } + } + } + return "" +} + +func normalizeAPIKeyValue(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "" + } + if strings.HasPrefix(strings.ToLower(trimmed), "bearer ") { + return strings.TrimSpace(trimmed[7:]) + } + return trimmed +} + +func formatInitErrors(filter func(InitializationError) bool) string { + if len(initErrors) == 0 { + return "" + } + + var builder strings.Builder + for _, initErr := range initErrors { + if filter != nil && !filter(initErr) { + continue + } + if builder.Len() == 0 { + builder.WriteString(" Initialization errors:") + } + builder.WriteString(fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason)) + if len(initErr.Details) > 0 { + keys := make([]string, 0, len(initErr.Details)) + for key := range initErr.Details { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + builder.WriteString(fmt.Sprintf("\n %s: %s", key, initErr.Details[key])) + } + } + } + + return builder.String() +} + +func (s *AIPluginService) requireEngineAvailable(operation, baseMessage, fallback string) error { + if s.aiEngine != nil { + return nil + } + + logging.Logger.Error(operation) + errMsg := baseMessage + if details := formatInitErrors(nil); details != "" { + errMsg += details + } else if fallback != "" { + errMsg += " " + fallback + } + return status.Error(codes.FailedPrecondition, errMsg) +} + +const managerFallbackMessage = "Please check AI service configuration." + +func (s *AIPluginService) requireManagerAvailable(operation, baseMessage string) error { + if s.aiManager != nil { + return nil + } + + logging.Logger.Error(operation) + details := formatInitErrors(func(initErr InitializationError) bool { + return initErr.Component == "AI Manager" + }) + errMsg := baseMessage + if details != "" { + errMsg += details + } else { + errMsg += " " + managerFallbackMessage + } + return status.Error(codes.FailedPrecondition, errMsg) +} + func normalizeDatabaseType(value string) string { dbType := strings.ToLower(strings.TrimSpace(value)) switch dbType { @@ -300,6 +420,8 @@ func (s *AIPluginService) Query(ctx context.Context, req *server.DataQuery) (*se "key", req.Key, "sql_length", len(req.Sql)) + ctx = withAPIKey(ctx, extractAPIKeyFromMetadata(ctx)) + // Accept both empty type (for backward compatibility) and explicit "ai" type // The main project doesn't always send the type field if req.Type != "" && req.Type != "ai" { @@ -310,165 +432,51 @@ func (s *AIPluginService) Query(ctx context.Context, req *server.DataQuery) (*se // Handle new AI interface standard switch req.Key { case "generate": - // Check AI engine availability for generation requests - if s.aiEngine == nil { - logging.Logger.Error("AI generation requested but AI engine is not available") - - // Build enhanced error message with initialization details - errMsg := "AI generation service is currently unavailable." - - // Add specific initialization error information if available - if len(initErrors) > 0 { - errMsg += " Initialization errors:" - for _, initErr := range initErrors { - errMsg += fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason) - // Add relevant details - if len(initErr.Details) > 0 { - for key, value := range initErr.Details { - errMsg += fmt.Sprintf("\n %s: %s", key, value) - } - } - } - } else { - errMsg += " Please check AI provider configuration and connectivity." - } - - return nil, status.Error(codes.FailedPrecondition, errMsg) + if err := s.requireEngineAvailable( + "AI generation requested but AI engine is not available", + "AI generation service is currently unavailable.", + "Please check AI provider configuration and connectivity."); err != nil { + return nil, err } return s.handleAIGenerate(ctx, req) case "capabilities": return s.handleAICapabilities(ctx, req) case "providers": - // Check AI manager availability for provider operations - if s.aiManager == nil { - logging.Logger.Error("Provider discovery requested but AI manager is not available") - - // Build enhanced error message with initialization details - errMsg := "AI provider discovery is currently unavailable." - if len(initErrors) > 0 { - errMsg += " Initialization errors:" - for _, initErr := range initErrors { - if initErr.Component == "AI Manager" { - errMsg += fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason) - if len(initErr.Details) > 0 { - for key, value := range initErr.Details { - errMsg += fmt.Sprintf("\n %s: %s", key, value) - } - } - } - } - } else { - errMsg += " Please check AI service configuration." - } - - return nil, status.Error(codes.FailedPrecondition, errMsg) + if err := s.requireManagerAvailable( + "Provider discovery requested but AI manager is not available", + "AI provider discovery is currently unavailable."); err != nil { + return nil, err } return s.handleGetProviders(ctx, req) case "models": - // Check AI manager availability for model operations - if s.aiManager == nil { - logging.Logger.Error("Model listing requested but AI manager is not available") - - // Build enhanced error message with initialization details - errMsg := "AI model listing is currently unavailable." - if len(initErrors) > 0 { - errMsg += " Initialization errors:" - for _, initErr := range initErrors { - if initErr.Component == "AI Manager" { - errMsg += fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason) - if len(initErr.Details) > 0 { - for key, value := range initErr.Details { - errMsg += fmt.Sprintf("\n %s: %s", key, value) - } - } - } - } - } else { - errMsg += " Please check AI service configuration." - } - - return nil, status.Error(codes.FailedPrecondition, errMsg) + if err := s.requireManagerAvailable( + "Model listing requested but AI manager is not available", + "AI model listing is currently unavailable."); err != nil { + return nil, err } return s.handleGetModels(ctx, req) case "test_connection": - // Connection testing can work even without initialized services - if s.aiManager == nil { - logging.Logger.Error("Connection test requested but AI manager is not available") - - // Build enhanced error message with initialization details - errMsg := "AI connection testing is currently unavailable." - if len(initErrors) > 0 { - errMsg += " Initialization errors:" - for _, initErr := range initErrors { - if initErr.Component == "AI Manager" { - errMsg += fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason) - if len(initErr.Details) > 0 { - for key, value := range initErr.Details { - errMsg += fmt.Sprintf("\n %s: %s", key, value) - } - } - } - } - } else { - errMsg += " Please check AI service configuration." - } - - return nil, status.Error(codes.FailedPrecondition, errMsg) + if err := s.requireManagerAvailable( + "Connection test requested but AI manager is not available", + "AI connection testing is currently unavailable."); err != nil { + return nil, err } return s.handleTestConnection(ctx, req) case "health_check": return s.handleHealthCheck(ctx, req) case "update_config": - if s.aiManager == nil { - logging.Logger.Error("Config update requested but AI manager is not available") - - // Build enhanced error message with initialization details - errMsg := "AI configuration update is currently unavailable." - if len(initErrors) > 0 { - errMsg += " Initialization errors:" - for _, initErr := range initErrors { - if initErr.Component == "AI Manager" { - errMsg += fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason) - if len(initErr.Details) > 0 { - for key, value := range initErr.Details { - errMsg += fmt.Sprintf("\n %s: %s", key, value) - } - } - } - } - } else { - errMsg += " Please check AI service configuration." - } - - return nil, status.Error(codes.FailedPrecondition, errMsg) + if err := s.requireManagerAvailable( + "Config update requested but AI manager is not available", + "AI configuration update is currently unavailable."); err != nil { + return nil, err } return s.handleUpdateConfig(ctx, req) default: - // Backward compatibility: support legacy natural language queries - // Check AI engine availability for legacy queries - if s.aiEngine == nil { - logging.Logger.Error("AI query requested but AI engine is not available") - - // Build enhanced error message with initialization details - errMsg := "AI service is currently unavailable." - - // Add specific initialization error information if available - if len(initErrors) > 0 { - errMsg += " Initialization errors:" - for _, initErr := range initErrors { - errMsg += fmt.Sprintf("\n- %s: %s", initErr.Component, initErr.Reason) - // Add relevant details - if len(initErr.Details) > 0 { - for key, value := range initErr.Details { - errMsg += fmt.Sprintf("\n %s: %s", key, value) - } - } - } - } else { - errMsg += " Please check AI provider configuration and connectivity." - } - - return nil, status.Error(codes.FailedPrecondition, errMsg) + if err := s.requireEngineAvailable( + "AI query requested but AI engine is not available", + "AI service is currently unavailable.", + "Please check AI provider configuration and connectivity."); err != nil { + return nil, err } return s.handleLegacyQuery(ctx, req) } @@ -733,6 +741,8 @@ func (s *AIPluginService) handleAIGenerate(ctx context.Context, req *server.Data "prompt_length", len(params.Prompt), "has_config", params.Config != "") + apiKey := apiKeyFromContext(ctx) + // Generate using AI engine context := map[string]string{} if params.Model != "" { @@ -751,6 +761,7 @@ func (s *AIPluginService) handleAIGenerate(ctx context.Context, req *server.Data NaturalLanguage: params.Prompt, DatabaseType: databaseType, Context: context, + RuntimeAPIKey: apiKey, }) if err != nil { metrics.RecordRequest("generate", provider, "error") @@ -1277,6 +1288,12 @@ func (s *AIPluginService) handleTestConnection(ctx context.Context, req *server. config.Provider = "ollama" } + if config.APIKey == "" { + if apiKey := apiKeyFromContext(ctx); apiKey != "" { + config.APIKey = apiKey + } + } + // Log configuration for debugging (mask API key) apiKeyDisplay := "***masked***" if config.APIKey != "" && len(config.APIKey) > 4 { @@ -1308,7 +1325,7 @@ func (s *AIPluginService) handleTestConnection(ctx context.Context, req *server. } // handleUpdateConfig updates the configuration for a provider -func (s *AIPluginService) handleUpdateConfig(_ context.Context, req *server.DataQuery) (*server.DataQueryResult, error) { +func (s *AIPluginService) handleUpdateConfig(ctx context.Context, req *server.DataQuery) (*server.DataQueryResult, error) { logging.Logger.Debug("Handling update config request", "sql_length", len(req.Sql)) // Parse update request from SQL field @@ -1345,6 +1362,12 @@ func (s *AIPluginService) handleUpdateConfig(_ context.Context, req *server.Data return nil, apperrors.ToGRPCError(apperrors.ErrInvalidRequest) } + if updateReq.Config.APIKey == "" { + if apiKey := apiKeyFromContext(ctx); apiKey != "" { + updateReq.Config.APIKey = apiKey + } + } + // Map "local" to "ollama" for backward compatibility if updateReq.Provider == "local" { updateReq.Provider = "ollama" @@ -1395,7 +1418,11 @@ func (s *AIPluginService) handleUpdateConfig(_ context.Context, req *server.Data logging.Logger.Error("Failed to rebuild AI engine", "provider", updateReq.Provider, "error", err) - _ = manager.Close() + if closeErr := manager.Close(); closeErr != nil { + logging.Logger.Warn("Failed to close AI manager after rebuild error", + "provider", updateReq.Provider, + "error", closeErr) + } return nil, apperrors.ToGRPCErrorf(apperrors.ErrInvalidConfig, "failed to rebuild AI engine: %v", err) } diff --git a/pkg/plugin/service_test.go b/pkg/plugin/service_test.go index 86aaee6..7cdfb4b 100644 --- a/pkg/plugin/service_test.go +++ b/pkg/plugin/service_test.go @@ -369,7 +369,6 @@ func TestHandleUpdateConfigRefreshesEngine(t *testing.T) { "provider": "ollama", "endpoint": "http://localhost:11439", "model": "test-model", - "api_key": "", "max_tokens": 1337, }, }