diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 300600a4e..151b423d1 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -90,11 +90,9 @@ def __init__( logger.debug("config=<%s> | initializing", self.config) - client_args = client_args or {} + self.client_args = client_args or {} if api_key: - client_args["api_key"] = api_key - - self.client = mistralai.Mistral(**client_args) + self.client_args["api_key"] = api_key @override def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: ignore @@ -421,67 +419,70 @@ async def stream( logger.debug("got response from model") if not self.config.get("stream", True): # Use non-streaming API - response = await self.client.chat.complete_async(**request) - for event in self._handle_non_streaming_response(response): - yield self.format_chunk(event) + async with mistralai.Mistral(**self.client_args) as client: + response = await client.chat.complete_async(**request) + for event in self._handle_non_streaming_response(response): + yield self.format_chunk(event) + return # Use the streaming API - stream_response = await self.client.chat.stream_async(**request) + async with mistralai.Mistral(**self.client_args) as client: + stream_response = await client.chat.stream_async(**request) - yield self.format_chunk({"chunk_type": "message_start"}) + yield self.format_chunk({"chunk_type": "message_start"}) - content_started = False - tool_calls: dict[str, list[Any]] = {} - accumulated_text = "" + content_started = False + tool_calls: dict[str, list[Any]] = {} + accumulated_text = "" - async for chunk in stream_response: - if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices: - choice = chunk.data.choices[0] + async for chunk in stream_response: + if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices: + choice = chunk.data.choices[0] - if hasattr(choice, "delta"): - delta = choice.delta + if hasattr(choice, "delta"): + delta = choice.delta - if hasattr(delta, "content") and delta.content: - if not content_started: - yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) - content_started = True + if hasattr(delta, "content") and delta.content: + if not content_started: + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + content_started = True - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": "text", "data": delta.content} - ) - accumulated_text += delta.content + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": delta.content} + ) + accumulated_text += delta.content - if hasattr(delta, "tool_calls") and delta.tool_calls: - for tool_call in delta.tool_calls: - tool_id = tool_call.id - tool_calls.setdefault(tool_id, []).append(tool_call) + if hasattr(delta, "tool_calls") and delta.tool_calls: + for tool_call in delta.tool_calls: + tool_id = tool_call.id + tool_calls.setdefault(tool_id, []).append(tool_call) - if hasattr(choice, "finish_reason") and choice.finish_reason: - if content_started: - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + if hasattr(choice, "finish_reason") and choice.finish_reason: + if content_started: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) - for tool_deltas in tool_calls.values(): - yield self.format_chunk( - {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} - ) + for tool_deltas in tool_calls.values(): + yield self.format_chunk( + {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} + ) - for tool_delta in tool_deltas: - if hasattr(tool_delta.function, "arguments"): - yield self.format_chunk( - { - "chunk_type": "content_delta", - "data_type": "tool", - "data": tool_delta.function.arguments, - } - ) + for tool_delta in tool_deltas: + if hasattr(tool_delta.function, "arguments"): + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": tool_delta.function.arguments, + } + ) - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) - yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) + yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) - if hasattr(chunk, "usage"): - yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage}) + if hasattr(chunk, "usage"): + yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage}) except Exception as e: if "rate" in str(e).lower() or "429" in str(e): @@ -518,7 +519,8 @@ async def structured_output( formatted_request["tool_choice"] = "any" formatted_request["parallel_tool_calls"] = False - response = await self.client.chat.complete_async(**formatted_request) + async with mistralai.Mistral(**self.client_args) as client: + response = await client.chat.complete_async(**formatted_request) if response.choices and response.choices[0].message.tool_calls: tool_call = response.choices[0].message.tool_calls[0] diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 06ea32d2b..2a78024f2 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -11,7 +11,9 @@ @pytest.fixture def mistral_client(): with unittest.mock.patch.object(strands.models.mistral.mistralai, "Mistral") as mock_client_cls: - yield mock_client_cls.return_value + mock_client = unittest.mock.AsyncMock() + mock_client_cls.return_value.__aenter__.return_value = mock_client + yield mock_client @pytest.fixture @@ -25,9 +27,7 @@ def max_tokens(): @pytest.fixture -def model(mistral_client, model_id, max_tokens): - _ = mistral_client - +def model(model_id, max_tokens): return MistralModel(model_id=model_id, max_tokens=max_tokens)