diff --git a/optillm/__init__.py b/optillm/__init__.py index 0c622287..4ae9b7ee 100644 --- a/optillm/__init__.py +++ b/optillm/__init__.py @@ -1,5 +1,5 @@ # Version information -__version__ = "0.2.7" +__version__ = "0.2.8" # Import from server module from .server import ( diff --git a/optillm/plugins/proxy/client.py b/optillm/plugins/proxy/client.py index 696ef41a..f26bf151 100644 --- a/optillm/plugins/proxy/client.py +++ b/optillm/plugins/proxy/client.py @@ -48,11 +48,11 @@ def client(self): max_retries=0 # Disable client retries - we handle them ) elif 'generativelanguage.googleapis.com' in self.base_url: - # Google AI client - create custom client to avoid "models/" prefix - from optillm.plugins.proxy.google_client import GoogleAIClient - self._client = GoogleAIClient( + # Google AI with standard OpenAI-compatible client + self._client = OpenAI( api_key=self.api_key, - base_url=self.base_url + base_url=self.base_url, + max_retries=0 # Disable client retries - we handle them ) else: # Standard OpenAI-compatible client @@ -165,6 +165,7 @@ def __init__(self, proxy_client): class _Completions: def __init__(self, proxy_client): self.proxy_client = proxy_client + self._system_message_support_cache = {} def _filter_kwargs(self, kwargs: dict) -> dict: """Filter out OptiLLM-specific parameters that shouldn't be sent to providers""" @@ -175,6 +176,73 @@ def _filter_kwargs(self, kwargs: dict) -> dict: } return {k: v for k, v in kwargs.items() if k not in optillm_params} + def _test_system_message_support(self, provider, model: str) -> bool: + """Test if a model supports system messages""" + cache_key = f"{provider.name}:{model}" + + if cache_key in self._system_message_support_cache: + return self._system_message_support_cache[cache_key] + + try: + test_response = provider.client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "test"}, + {"role": "user", "content": "hi"} + ], + max_tokens=1, + temperature=0 + ) + self._system_message_support_cache[cache_key] = True + return True + except Exception as e: + error_msg = str(e).lower() + if any(pattern in error_msg for pattern in [ + "developer instruction", "system message", "not enabled", "not supported" + ]): + logger.info(f"Provider {provider.name} model {model} does not support system messages") + self._system_message_support_cache[cache_key] = False + return False + # Other errors - assume it supports system messages + self._system_message_support_cache[cache_key] = True + return True + + def _format_messages_for_provider(self, provider, model: str, messages: list) -> list: + """Format messages based on provider's system message support""" + # Check if there's a system message + has_system = any(msg.get("role") == "system" for msg in messages) + + if not has_system: + return messages + + # Test system message support + supports_system = self._test_system_message_support(provider, model) + + if supports_system: + return messages + + # Merge system message into first user message + formatted_messages = [] + system_content = None + + for msg in messages: + if msg.get("role") == "system": + system_content = msg.get("content", "") + elif msg.get("role") == "user": + if system_content: + # Merge system message with user message + formatted_messages.append({ + "role": "user", + "content": f"Instructions: {system_content}\n\nUser: {msg.get('content', '')}" + }) + system_content = None + else: + formatted_messages.append(msg) + else: + formatted_messages.append(msg) + + return formatted_messages + def _make_request_with_timeout(self, provider, request_kwargs): """Make a request with timeout handling""" # The OpenAI client now supports timeout natively @@ -232,7 +300,14 @@ def create(self, **kwargs): try: # Map model name if needed and filter out OptiLLM-specific parameters request_kwargs = self._filter_kwargs(kwargs.copy()) - request_kwargs['model'] = provider.map_model(model) + mapped_model = provider.map_model(model) + request_kwargs['model'] = mapped_model + + # Format messages based on provider's system message support + if 'messages' in request_kwargs: + request_kwargs['messages'] = self._format_messages_for_provider( + provider, mapped_model, request_kwargs['messages'] + ) # Add timeout to client if supported request_kwargs['timeout'] = self.proxy_client.request_timeout @@ -279,7 +354,7 @@ def create(self, **kwargs): if self.proxy_client.fallback_client: logger.warning("All proxy providers failed, using fallback client") try: - fallback_kwargs = self._filter_kwargs(kwargs) + fallback_kwargs = self._filter_kwargs(kwargs.copy()) fallback_kwargs['timeout'] = self.proxy_client.request_timeout return self.proxy_client.fallback_client.chat.completions.create(**fallback_kwargs) except Exception as e: diff --git a/optillm/plugins/proxy/google_client.py b/optillm/plugins/proxy/google_client.py deleted file mode 100644 index 73378b6e..00000000 --- a/optillm/plugins/proxy/google_client.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Custom Google AI client that doesn't add "models/" prefix to model names -""" -import requests -import json -from typing import Dict, List, Any - - -class GoogleAIClient: - """Custom client for Google AI that bypasses OpenAI client's model name prefix behavior""" - - def __init__(self, api_key: str, base_url: str): - self.api_key = api_key - self.base_url = base_url.rstrip('/') - self.chat = self.Chat(self) - self.models = self.Models(self) - - class Chat: - def __init__(self, client): - self.client = client - self.completions = self.Completions(client) - - class Completions: - def __init__(self, client): - self.client = client - - def create(self, model: str, messages: List[Dict[str, str]], **kwargs) -> Any: - """Create chat completion without adding models/ prefix to model name""" - url = f"{self.client.base_url}/chat/completions" - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.client.api_key}" - } - - # Build request data - use model name directly without "models/" prefix - data = { - "model": model, # Use exactly as provided - no prefix! - "messages": messages, - **kwargs - } - - # Make direct HTTP request to bypass OpenAI client behavior - response = requests.post(url, headers=headers, json=data, timeout=kwargs.get('timeout', 30)) - - if response.status_code != 200: - error_text = response.text - raise Exception(f"HTTP {response.status_code}: {error_text}") - - # Parse response and return OpenAI-compatible object - result = response.json() - - # Create a simple object that has the attributes expected by the proxy - class CompletionResponse: - def __init__(self, data): - self._data = data - self.choices = data.get('choices', []) - self.usage = data.get('usage', {}) - self.model = data.get('model', model) - - def model_dump(self): - return self._data - - def __getitem__(self, key): - return self._data[key] - - def get(self, key, default=None): - return self._data.get(key, default) - - return CompletionResponse(result) - - class Models: - def __init__(self, client): - self.client = client - - def list(self): - """Simple models list for health checking""" - url = f"{self.client.base_url}/models" - headers = { - "Authorization": f"Bearer {self.client.api_key}" - } - - try: - response = requests.get(url, headers=headers, timeout=5) - if response.status_code == 200: - return response.json() - else: - # Return a mock response if health check fails - return {"data": [{"id": "gemma-3-4b-it"}]} - except: - # Return a mock response if health check fails - return {"data": [{"id": "gemma-3-4b-it"}]} \ No newline at end of file diff --git a/optillm/plugins/proxy_plugin.py b/optillm/plugins/proxy_plugin.py index 7a472eca..9e8a6c6a 100644 --- a/optillm/plugins/proxy_plugin.py +++ b/optillm/plugins/proxy_plugin.py @@ -5,7 +5,8 @@ with health monitoring, failover, and support for wrapping other approaches. """ import logging -from typing import Tuple, Optional +import threading +from typing import Tuple, Optional, Dict from optillm.plugins.proxy.config import ProxyConfig from optillm.plugins.proxy.client import ProxyClient from optillm.plugins.proxy.approach_handler import ApproachHandler @@ -21,6 +22,78 @@ # Global proxy client cache to maintain state between requests _proxy_client_cache = {} +# Global cache for system message support per provider-model combination +_system_message_support_cache: Dict[str, bool] = {} +_cache_lock = threading.RLock() + +def _test_system_message_support(proxy_client, model: str) -> bool: + """ + Test if a model supports system messages by making a minimal test request. + Returns True if supported, False otherwise. + """ + try: + # Try a minimal system message request + test_response = proxy_client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "test"}, + {"role": "user", "content": "hi"} + ], + max_tokens=1, # Minimal token generation + temperature=0 + ) + return True + except Exception as e: + error_msg = str(e).lower() + # Check for known system message rejection patterns + if any(pattern in error_msg for pattern in [ + "developer instruction", + "system message", + "not enabled", + "not supported" + ]): + logger.info(f"Model {model} does not support system messages: {str(e)[:100]}") + return False + else: + # If it's a different error, assume system messages are supported + # but something else went wrong (rate limit, timeout, etc.) + logger.debug(f"System message test failed for {model}, assuming supported: {str(e)[:100]}") + return True + +def _get_system_message_support(proxy_client, model: str) -> bool: + """ + Get cached system message support status, testing if not cached. + Thread-safe with locking. + """ + # Create a unique cache key based on model and base_url + cache_key = f"{getattr(proxy_client, '_base_identifier', 'default')}:{model}" + + with _cache_lock: + if cache_key not in _system_message_support_cache: + logger.debug(f"Testing system message support for {model}") + _system_message_support_cache[cache_key] = _test_system_message_support(proxy_client, model) + + return _system_message_support_cache[cache_key] + +def _format_messages_for_model(system_prompt: str, initial_query: str, + supports_system_messages: bool) -> list: + """ + Format messages based on whether the model supports system messages. + """ + if supports_system_messages: + return [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ] + else: + # Merge system prompt into user message + if system_prompt.strip(): + combined_message = f"{system_prompt}\n\nUser: {initial_query}" + else: + combined_message = initial_query + + return [{"role": "user", "content": combined_message}] + def run(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None) -> Tuple[str, int]: """ @@ -119,14 +192,21 @@ def run(system_prompt: str, initial_query: str, client, model: str, logger.info(f"Proxy routing approach/plugin: {potential_approach}") return result - # Direct proxy execution + # Direct proxy execution with dynamic system message support detection logger.info(f"Direct proxy routing for model: {model}") + + # Test and cache system message support for this model + supports_system_messages = _get_system_message_support(proxy_client, model) + + # Format messages based on system message support + messages = _format_messages_for_model(system_prompt, initial_query, supports_system_messages) + + if not supports_system_messages: + logger.info(f"Using fallback message formatting for {model} (no system message support)") + response = proxy_client.chat.completions.create( model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": initial_query} - ], + messages=messages, **(request_config or {}) ) diff --git a/pyproject.toml b/pyproject.toml index ed3ca978..23ae88d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "optillm" -version = "0.2.7" +version = "0.2.8" description = "An optimizing inference proxy for LLMs." readme = "README.md" license = "Apache-2.0"