Skip to content

Commit b0407a0

Browse files
committed
Update client.py
1 parent e958509 commit b0407a0

File tree

1 file changed

+77
-2
lines changed

1 file changed

+77
-2
lines changed

optillm/plugins/proxy/client.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def __init__(self, proxy_client):
165165
class _Completions:
166166
def __init__(self, proxy_client):
167167
self.proxy_client = proxy_client
168+
self._system_message_support_cache = {}
168169

169170
def _filter_kwargs(self, kwargs: dict) -> dict:
170171
"""Filter out OptiLLM-specific parameters that shouldn't be sent to providers"""
@@ -175,6 +176,73 @@ def _filter_kwargs(self, kwargs: dict) -> dict:
175176
}
176177
return {k: v for k, v in kwargs.items() if k not in optillm_params}
177178

179+
def _test_system_message_support(self, provider, model: str) -> bool:
180+
"""Test if a model supports system messages"""
181+
cache_key = f"{provider.name}:{model}"
182+
183+
if cache_key in self._system_message_support_cache:
184+
return self._system_message_support_cache[cache_key]
185+
186+
try:
187+
test_response = provider.client.chat.completions.create(
188+
model=model,
189+
messages=[
190+
{"role": "system", "content": "test"},
191+
{"role": "user", "content": "hi"}
192+
],
193+
max_tokens=1,
194+
temperature=0
195+
)
196+
self._system_message_support_cache[cache_key] = True
197+
return True
198+
except Exception as e:
199+
error_msg = str(e).lower()
200+
if any(pattern in error_msg for pattern in [
201+
"developer instruction", "system message", "not enabled", "not supported"
202+
]):
203+
logger.info(f"Provider {provider.name} model {model} does not support system messages")
204+
self._system_message_support_cache[cache_key] = False
205+
return False
206+
# Other errors - assume it supports system messages
207+
self._system_message_support_cache[cache_key] = True
208+
return True
209+
210+
def _format_messages_for_provider(self, provider, model: str, messages: list) -> list:
211+
"""Format messages based on provider's system message support"""
212+
# Check if there's a system message
213+
has_system = any(msg.get("role") == "system" for msg in messages)
214+
215+
if not has_system:
216+
return messages
217+
218+
# Test system message support
219+
supports_system = self._test_system_message_support(provider, model)
220+
221+
if supports_system:
222+
return messages
223+
224+
# Merge system message into first user message
225+
formatted_messages = []
226+
system_content = None
227+
228+
for msg in messages:
229+
if msg.get("role") == "system":
230+
system_content = msg.get("content", "")
231+
elif msg.get("role") == "user":
232+
if system_content:
233+
# Merge system message with user message
234+
formatted_messages.append({
235+
"role": "user",
236+
"content": f"Instructions: {system_content}\n\nUser: {msg.get('content', '')}"
237+
})
238+
system_content = None
239+
else:
240+
formatted_messages.append(msg)
241+
else:
242+
formatted_messages.append(msg)
243+
244+
return formatted_messages
245+
178246
def _make_request_with_timeout(self, provider, request_kwargs):
179247
"""Make a request with timeout handling"""
180248
# The OpenAI client now supports timeout natively
@@ -232,7 +300,14 @@ def create(self, **kwargs):
232300
try:
233301
# Map model name if needed and filter out OptiLLM-specific parameters
234302
request_kwargs = self._filter_kwargs(kwargs.copy())
235-
request_kwargs['model'] = provider.map_model(model)
303+
mapped_model = provider.map_model(model)
304+
request_kwargs['model'] = mapped_model
305+
306+
# Format messages based on provider's system message support
307+
if 'messages' in request_kwargs:
308+
request_kwargs['messages'] = self._format_messages_for_provider(
309+
provider, mapped_model, request_kwargs['messages']
310+
)
236311

237312
# Add timeout to client if supported
238313
request_kwargs['timeout'] = self.proxy_client.request_timeout
@@ -279,7 +354,7 @@ def create(self, **kwargs):
279354
if self.proxy_client.fallback_client:
280355
logger.warning("All proxy providers failed, using fallback client")
281356
try:
282-
fallback_kwargs = self._filter_kwargs(kwargs)
357+
fallback_kwargs = self._filter_kwargs(kwargs.copy())
283358
fallback_kwargs['timeout'] = self.proxy_client.request_timeout
284359
return self.proxy_client.fallback_client.chat.completions.create(**fallback_kwargs)
285360
except Exception as e:

0 commit comments

Comments
 (0)