@@ -48,11 +48,11 @@ def client(self):
4848 max_retries = 0 # Disable client retries - we handle them
4949 )
5050 elif 'generativelanguage.googleapis.com' in self .base_url :
51- # Google AI client - create custom client to avoid "models/" prefix
52- from optillm .plugins .proxy .google_client import GoogleAIClient
53- self ._client = GoogleAIClient (
51+ # Google AI with standard OpenAI-compatible client
52+ self ._client = OpenAI (
5453 api_key = self .api_key ,
55- base_url = self .base_url
54+ base_url = self .base_url ,
55+ max_retries = 0 # Disable client retries - we handle them
5656 )
5757 else :
5858 # Standard OpenAI-compatible client
@@ -165,6 +165,7 @@ def __init__(self, proxy_client):
165165 class _Completions :
166166 def __init__ (self , proxy_client ):
167167 self .proxy_client = proxy_client
168+ self ._system_message_support_cache = {}
168169
169170 def _filter_kwargs (self , kwargs : dict ) -> dict :
170171 """Filter out OptiLLM-specific parameters that shouldn't be sent to providers"""
@@ -175,6 +176,73 @@ def _filter_kwargs(self, kwargs: dict) -> dict:
175176 }
176177 return {k : v for k , v in kwargs .items () if k not in optillm_params }
177178
179+ def _test_system_message_support (self , provider , model : str ) -> bool :
180+ """Test if a model supports system messages"""
181+ cache_key = f"{ provider .name } :{ model } "
182+
183+ if cache_key in self ._system_message_support_cache :
184+ return self ._system_message_support_cache [cache_key ]
185+
186+ try :
187+ test_response = provider .client .chat .completions .create (
188+ model = model ,
189+ messages = [
190+ {"role" : "system" , "content" : "test" },
191+ {"role" : "user" , "content" : "hi" }
192+ ],
193+ max_tokens = 1 ,
194+ temperature = 0
195+ )
196+ self ._system_message_support_cache [cache_key ] = True
197+ return True
198+ except Exception as e :
199+ error_msg = str (e ).lower ()
200+ if any (pattern in error_msg for pattern in [
201+ "developer instruction" , "system message" , "not enabled" , "not supported"
202+ ]):
203+ logger .info (f"Provider { provider .name } model { model } does not support system messages" )
204+ self ._system_message_support_cache [cache_key ] = False
205+ return False
206+ # Other errors - assume it supports system messages
207+ self ._system_message_support_cache [cache_key ] = True
208+ return True
209+
210+ def _format_messages_for_provider (self , provider , model : str , messages : list ) -> list :
211+ """Format messages based on provider's system message support"""
212+ # Check if there's a system message
213+ has_system = any (msg .get ("role" ) == "system" for msg in messages )
214+
215+ if not has_system :
216+ return messages
217+
218+ # Test system message support
219+ supports_system = self ._test_system_message_support (provider , model )
220+
221+ if supports_system :
222+ return messages
223+
224+ # Merge system message into first user message
225+ formatted_messages = []
226+ system_content = None
227+
228+ for msg in messages :
229+ if msg .get ("role" ) == "system" :
230+ system_content = msg .get ("content" , "" )
231+ elif msg .get ("role" ) == "user" :
232+ if system_content :
233+ # Merge system message with user message
234+ formatted_messages .append ({
235+ "role" : "user" ,
236+ "content" : f"Instructions: { system_content } \n \n User: { msg .get ('content' , '' )} "
237+ })
238+ system_content = None
239+ else :
240+ formatted_messages .append (msg )
241+ else :
242+ formatted_messages .append (msg )
243+
244+ return formatted_messages
245+
178246 def _make_request_with_timeout (self , provider , request_kwargs ):
179247 """Make a request with timeout handling"""
180248 # The OpenAI client now supports timeout natively
@@ -232,7 +300,14 @@ def create(self, **kwargs):
232300 try :
233301 # Map model name if needed and filter out OptiLLM-specific parameters
234302 request_kwargs = self ._filter_kwargs (kwargs .copy ())
235- request_kwargs ['model' ] = provider .map_model (model )
303+ mapped_model = provider .map_model (model )
304+ request_kwargs ['model' ] = mapped_model
305+
306+ # Format messages based on provider's system message support
307+ if 'messages' in request_kwargs :
308+ request_kwargs ['messages' ] = self ._format_messages_for_provider (
309+ provider , mapped_model , request_kwargs ['messages' ]
310+ )
236311
237312 # Add timeout to client if supported
238313 request_kwargs ['timeout' ] = self .proxy_client .request_timeout
@@ -279,7 +354,7 @@ def create(self, **kwargs):
279354 if self .proxy_client .fallback_client :
280355 logger .warning ("All proxy providers failed, using fallback client" )
281356 try :
282- fallback_kwargs = self ._filter_kwargs (kwargs )
357+ fallback_kwargs = self ._filter_kwargs (kwargs . copy () )
283358 fallback_kwargs ['timeout' ] = self .proxy_client .request_timeout
284359 return self .proxy_client .fallback_client .chat .completions .create (** fallback_kwargs )
285360 except Exception as e :
0 commit comments