Skip to content

Commit 492b853

Browse files
authored
Merge pull request #246 from codelion/fix-proxy-clients
Fix proxy clients
2 parents 0465bf5 + b0407a0 commit 492b853

File tree

5 files changed

+169
-106
lines changed

5 files changed

+169
-106
lines changed

optillm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Version information
2-
__version__ = "0.2.7"
2+
__version__ = "0.2.8"
33

44
# Import from server module
55
from .server import (

optillm/plugins/proxy/client.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ def client(self):
4848
max_retries=0 # Disable client retries - we handle them
4949
)
5050
elif 'generativelanguage.googleapis.com' in self.base_url:
51-
# Google AI client - create custom client to avoid "models/" prefix
52-
from optillm.plugins.proxy.google_client import GoogleAIClient
53-
self._client = GoogleAIClient(
51+
# Google AI with standard OpenAI-compatible client
52+
self._client = OpenAI(
5453
api_key=self.api_key,
55-
base_url=self.base_url
54+
base_url=self.base_url,
55+
max_retries=0 # Disable client retries - we handle them
5656
)
5757
else:
5858
# Standard OpenAI-compatible client
@@ -165,6 +165,7 @@ def __init__(self, proxy_client):
165165
class _Completions:
166166
def __init__(self, proxy_client):
167167
self.proxy_client = proxy_client
168+
self._system_message_support_cache = {}
168169

169170
def _filter_kwargs(self, kwargs: dict) -> dict:
170171
"""Filter out OptiLLM-specific parameters that shouldn't be sent to providers"""
@@ -175,6 +176,73 @@ def _filter_kwargs(self, kwargs: dict) -> dict:
175176
}
176177
return {k: v for k, v in kwargs.items() if k not in optillm_params}
177178

179+
def _test_system_message_support(self, provider, model: str) -> bool:
180+
"""Test if a model supports system messages"""
181+
cache_key = f"{provider.name}:{model}"
182+
183+
if cache_key in self._system_message_support_cache:
184+
return self._system_message_support_cache[cache_key]
185+
186+
try:
187+
test_response = provider.client.chat.completions.create(
188+
model=model,
189+
messages=[
190+
{"role": "system", "content": "test"},
191+
{"role": "user", "content": "hi"}
192+
],
193+
max_tokens=1,
194+
temperature=0
195+
)
196+
self._system_message_support_cache[cache_key] = True
197+
return True
198+
except Exception as e:
199+
error_msg = str(e).lower()
200+
if any(pattern in error_msg for pattern in [
201+
"developer instruction", "system message", "not enabled", "not supported"
202+
]):
203+
logger.info(f"Provider {provider.name} model {model} does not support system messages")
204+
self._system_message_support_cache[cache_key] = False
205+
return False
206+
# Other errors - assume it supports system messages
207+
self._system_message_support_cache[cache_key] = True
208+
return True
209+
210+
def _format_messages_for_provider(self, provider, model: str, messages: list) -> list:
211+
"""Format messages based on provider's system message support"""
212+
# Check if there's a system message
213+
has_system = any(msg.get("role") == "system" for msg in messages)
214+
215+
if not has_system:
216+
return messages
217+
218+
# Test system message support
219+
supports_system = self._test_system_message_support(provider, model)
220+
221+
if supports_system:
222+
return messages
223+
224+
# Merge system message into first user message
225+
formatted_messages = []
226+
system_content = None
227+
228+
for msg in messages:
229+
if msg.get("role") == "system":
230+
system_content = msg.get("content", "")
231+
elif msg.get("role") == "user":
232+
if system_content:
233+
# Merge system message with user message
234+
formatted_messages.append({
235+
"role": "user",
236+
"content": f"Instructions: {system_content}\n\nUser: {msg.get('content', '')}"
237+
})
238+
system_content = None
239+
else:
240+
formatted_messages.append(msg)
241+
else:
242+
formatted_messages.append(msg)
243+
244+
return formatted_messages
245+
178246
def _make_request_with_timeout(self, provider, request_kwargs):
179247
"""Make a request with timeout handling"""
180248
# The OpenAI client now supports timeout natively
@@ -232,7 +300,14 @@ def create(self, **kwargs):
232300
try:
233301
# Map model name if needed and filter out OptiLLM-specific parameters
234302
request_kwargs = self._filter_kwargs(kwargs.copy())
235-
request_kwargs['model'] = provider.map_model(model)
303+
mapped_model = provider.map_model(model)
304+
request_kwargs['model'] = mapped_model
305+
306+
# Format messages based on provider's system message support
307+
if 'messages' in request_kwargs:
308+
request_kwargs['messages'] = self._format_messages_for_provider(
309+
provider, mapped_model, request_kwargs['messages']
310+
)
236311

237312
# Add timeout to client if supported
238313
request_kwargs['timeout'] = self.proxy_client.request_timeout
@@ -279,7 +354,7 @@ def create(self, **kwargs):
279354
if self.proxy_client.fallback_client:
280355
logger.warning("All proxy providers failed, using fallback client")
281356
try:
282-
fallback_kwargs = self._filter_kwargs(kwargs)
357+
fallback_kwargs = self._filter_kwargs(kwargs.copy())
283358
fallback_kwargs['timeout'] = self.proxy_client.request_timeout
284359
return self.proxy_client.fallback_client.chat.completions.create(**fallback_kwargs)
285360
except Exception as e:

optillm/plugins/proxy/google_client.py

Lines changed: 0 additions & 92 deletions
This file was deleted.

optillm/plugins/proxy_plugin.py

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
with health monitoring, failover, and support for wrapping other approaches.
66
"""
77
import logging
8-
from typing import Tuple, Optional
8+
import threading
9+
from typing import Tuple, Optional, Dict
910
from optillm.plugins.proxy.config import ProxyConfig
1011
from optillm.plugins.proxy.client import ProxyClient
1112
from optillm.plugins.proxy.approach_handler import ApproachHandler
@@ -21,6 +22,78 @@
2122
# Global proxy client cache to maintain state between requests
2223
_proxy_client_cache = {}
2324

25+
# Global cache for system message support per provider-model combination
26+
_system_message_support_cache: Dict[str, bool] = {}
27+
_cache_lock = threading.RLock()
28+
29+
def _test_system_message_support(proxy_client, model: str) -> bool:
30+
"""
31+
Test if a model supports system messages by making a minimal test request.
32+
Returns True if supported, False otherwise.
33+
"""
34+
try:
35+
# Try a minimal system message request
36+
test_response = proxy_client.chat.completions.create(
37+
model=model,
38+
messages=[
39+
{"role": "system", "content": "test"},
40+
{"role": "user", "content": "hi"}
41+
],
42+
max_tokens=1, # Minimal token generation
43+
temperature=0
44+
)
45+
return True
46+
except Exception as e:
47+
error_msg = str(e).lower()
48+
# Check for known system message rejection patterns
49+
if any(pattern in error_msg for pattern in [
50+
"developer instruction",
51+
"system message",
52+
"not enabled",
53+
"not supported"
54+
]):
55+
logger.info(f"Model {model} does not support system messages: {str(e)[:100]}")
56+
return False
57+
else:
58+
# If it's a different error, assume system messages are supported
59+
# but something else went wrong (rate limit, timeout, etc.)
60+
logger.debug(f"System message test failed for {model}, assuming supported: {str(e)[:100]}")
61+
return True
62+
63+
def _get_system_message_support(proxy_client, model: str) -> bool:
64+
"""
65+
Get cached system message support status, testing if not cached.
66+
Thread-safe with locking.
67+
"""
68+
# Create a unique cache key based on model and base_url
69+
cache_key = f"{getattr(proxy_client, '_base_identifier', 'default')}:{model}"
70+
71+
with _cache_lock:
72+
if cache_key not in _system_message_support_cache:
73+
logger.debug(f"Testing system message support for {model}")
74+
_system_message_support_cache[cache_key] = _test_system_message_support(proxy_client, model)
75+
76+
return _system_message_support_cache[cache_key]
77+
78+
def _format_messages_for_model(system_prompt: str, initial_query: str,
79+
supports_system_messages: bool) -> list:
80+
"""
81+
Format messages based on whether the model supports system messages.
82+
"""
83+
if supports_system_messages:
84+
return [
85+
{"role": "system", "content": system_prompt},
86+
{"role": "user", "content": initial_query}
87+
]
88+
else:
89+
# Merge system prompt into user message
90+
if system_prompt.strip():
91+
combined_message = f"{system_prompt}\n\nUser: {initial_query}"
92+
else:
93+
combined_message = initial_query
94+
95+
return [{"role": "user", "content": combined_message}]
96+
2497
def run(system_prompt: str, initial_query: str, client, model: str,
2598
request_config: dict = None) -> Tuple[str, int]:
2699
"""
@@ -119,14 +192,21 @@ def run(system_prompt: str, initial_query: str, client, model: str,
119192
logger.info(f"Proxy routing approach/plugin: {potential_approach}")
120193
return result
121194

122-
# Direct proxy execution
195+
# Direct proxy execution with dynamic system message support detection
123196
logger.info(f"Direct proxy routing for model: {model}")
197+
198+
# Test and cache system message support for this model
199+
supports_system_messages = _get_system_message_support(proxy_client, model)
200+
201+
# Format messages based on system message support
202+
messages = _format_messages_for_model(system_prompt, initial_query, supports_system_messages)
203+
204+
if not supports_system_messages:
205+
logger.info(f"Using fallback message formatting for {model} (no system message support)")
206+
124207
response = proxy_client.chat.completions.create(
125208
model=model,
126-
messages=[
127-
{"role": "system", "content": system_prompt},
128-
{"role": "user", "content": initial_query}
129-
],
209+
messages=messages,
130210
**(request_config or {})
131211
)
132212

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "optillm"
7-
version = "0.2.7"
7+
version = "0.2.8"
88
description = "An optimizing inference proxy for LLMs."
99
readme = "README.md"
1010
license = "Apache-2.0"

0 commit comments

Comments
 (0)