@@ -90,11 +90,9 @@ def __init__(
9090
9191 logger .debug ("config=<%s> | initializing" , self .config )
9292
93- client_args = client_args or {}
93+ self . client_args = client_args or {}
9494 if api_key :
95- client_args ["api_key" ] = api_key
96-
97- self .client = mistralai .Mistral (** client_args )
95+ self .client_args ["api_key" ] = api_key
9896
9997 @override
10098 def update_config (self , ** model_config : Unpack [MistralConfig ]) -> None : # type: ignore
@@ -421,67 +419,70 @@ async def stream(
421419 logger .debug ("got response from model" )
422420 if not self .config .get ("stream" , True ):
423421 # Use non-streaming API
424- response = await self .client .chat .complete_async (** request )
425- for event in self ._handle_non_streaming_response (response ):
426- yield self .format_chunk (event )
422+ async with mistralai .Mistral (** self .client_args ) as client :
423+ response = await client .chat .complete_async (** request )
424+ for event in self ._handle_non_streaming_response (response ):
425+ yield self .format_chunk (event )
426+
427427 return
428428
429429 # Use the streaming API
430- stream_response = await self .client .chat .stream_async (** request )
430+ async with mistralai .Mistral (** self .client_args ) as client :
431+ stream_response = await client .chat .stream_async (** request )
431432
432- yield self .format_chunk ({"chunk_type" : "message_start" })
433+ yield self .format_chunk ({"chunk_type" : "message_start" })
433434
434- content_started = False
435- tool_calls : dict [str , list [Any ]] = {}
436- accumulated_text = ""
435+ content_started = False
436+ tool_calls : dict [str , list [Any ]] = {}
437+ accumulated_text = ""
437438
438- async for chunk in stream_response :
439- if hasattr (chunk , "data" ) and hasattr (chunk .data , "choices" ) and chunk .data .choices :
440- choice = chunk .data .choices [0 ]
439+ async for chunk in stream_response :
440+ if hasattr (chunk , "data" ) and hasattr (chunk .data , "choices" ) and chunk .data .choices :
441+ choice = chunk .data .choices [0 ]
441442
442- if hasattr (choice , "delta" ):
443- delta = choice .delta
443+ if hasattr (choice , "delta" ):
444+ delta = choice .delta
444445
445- if hasattr (delta , "content" ) and delta .content :
446- if not content_started :
447- yield self .format_chunk ({"chunk_type" : "content_start" , "data_type" : "text" })
448- content_started = True
446+ if hasattr (delta , "content" ) and delta .content :
447+ if not content_started :
448+ yield self .format_chunk ({"chunk_type" : "content_start" , "data_type" : "text" })
449+ content_started = True
449450
450- yield self .format_chunk (
451- {"chunk_type" : "content_delta" , "data_type" : "text" , "data" : delta .content }
452- )
453- accumulated_text += delta .content
451+ yield self .format_chunk (
452+ {"chunk_type" : "content_delta" , "data_type" : "text" , "data" : delta .content }
453+ )
454+ accumulated_text += delta .content
454455
455- if hasattr (delta , "tool_calls" ) and delta .tool_calls :
456- for tool_call in delta .tool_calls :
457- tool_id = tool_call .id
458- tool_calls .setdefault (tool_id , []).append (tool_call )
456+ if hasattr (delta , "tool_calls" ) and delta .tool_calls :
457+ for tool_call in delta .tool_calls :
458+ tool_id = tool_call .id
459+ tool_calls .setdefault (tool_id , []).append (tool_call )
459460
460- if hasattr (choice , "finish_reason" ) and choice .finish_reason :
461- if content_started :
462- yield self .format_chunk ({"chunk_type" : "content_stop" , "data_type" : "text" })
461+ if hasattr (choice , "finish_reason" ) and choice .finish_reason :
462+ if content_started :
463+ yield self .format_chunk ({"chunk_type" : "content_stop" , "data_type" : "text" })
463464
464- for tool_deltas in tool_calls .values ():
465- yield self .format_chunk (
466- {"chunk_type" : "content_start" , "data_type" : "tool" , "data" : tool_deltas [0 ]}
467- )
465+ for tool_deltas in tool_calls .values ():
466+ yield self .format_chunk (
467+ {"chunk_type" : "content_start" , "data_type" : "tool" , "data" : tool_deltas [0 ]}
468+ )
468469
469- for tool_delta in tool_deltas :
470- if hasattr (tool_delta .function , "arguments" ):
471- yield self .format_chunk (
472- {
473- "chunk_type" : "content_delta" ,
474- "data_type" : "tool" ,
475- "data" : tool_delta .function .arguments ,
476- }
477- )
470+ for tool_delta in tool_deltas :
471+ if hasattr (tool_delta .function , "arguments" ):
472+ yield self .format_chunk (
473+ {
474+ "chunk_type" : "content_delta" ,
475+ "data_type" : "tool" ,
476+ "data" : tool_delta .function .arguments ,
477+ }
478+ )
478479
479- yield self .format_chunk ({"chunk_type" : "content_stop" , "data_type" : "tool" })
480+ yield self .format_chunk ({"chunk_type" : "content_stop" , "data_type" : "tool" })
480481
481- yield self .format_chunk ({"chunk_type" : "message_stop" , "data" : choice .finish_reason })
482+ yield self .format_chunk ({"chunk_type" : "message_stop" , "data" : choice .finish_reason })
482483
483- if hasattr (chunk , "usage" ):
484- yield self .format_chunk ({"chunk_type" : "metadata" , "data" : chunk .usage })
484+ if hasattr (chunk , "usage" ):
485+ yield self .format_chunk ({"chunk_type" : "metadata" , "data" : chunk .usage })
485486
486487 except Exception as e :
487488 if "rate" in str (e ).lower () or "429" in str (e ):
@@ -518,7 +519,8 @@ async def structured_output(
518519 formatted_request ["tool_choice" ] = "any"
519520 formatted_request ["parallel_tool_calls" ] = False
520521
521- response = await self .client .chat .complete_async (** formatted_request )
522+ async with mistralai .Mistral (** self .client_args ) as client :
523+ response = await client .chat .complete_async (** formatted_request )
522524
523525 if response .choices and response .choices [0 ].message .tool_calls :
524526 tool_call = response .choices [0 ].message .tool_calls [0 ]
0 commit comments