Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 52 additions & 50 deletions src/strands/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,9 @@ def __init__(

logger.debug("config=<%s> | initializing", self.config)

client_args = client_args or {}
self.client_args = client_args or {}
if api_key:
client_args["api_key"] = api_key

self.client = mistralai.Mistral(**client_args)
self.client_args["api_key"] = api_key

@override
def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: ignore
Expand Down Expand Up @@ -421,67 +419,70 @@ async def stream(
logger.debug("got response from model")
if not self.config.get("stream", True):
# Use non-streaming API
response = await self.client.chat.complete_async(**request)
for event in self._handle_non_streaming_response(response):
yield self.format_chunk(event)
async with mistralai.Mistral(**self.client_args) as client:
response = await client.chat.complete_async(**request)
for event in self._handle_non_streaming_response(response):
yield self.format_chunk(event)

return

# Use the streaming API
stream_response = await self.client.chat.stream_async(**request)
async with mistralai.Mistral(**self.client_args) as client:
stream_response = await client.chat.stream_async(**request)

yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "message_start"})

content_started = False
tool_calls: dict[str, list[Any]] = {}
accumulated_text = ""
content_started = False
tool_calls: dict[str, list[Any]] = {}
accumulated_text = ""

async for chunk in stream_response:
if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices:
choice = chunk.data.choices[0]
async for chunk in stream_response:
if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices:
choice = chunk.data.choices[0]

if hasattr(choice, "delta"):
delta = choice.delta
if hasattr(choice, "delta"):
delta = choice.delta

if hasattr(delta, "content") and delta.content:
if not content_started:
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
content_started = True
if hasattr(delta, "content") and delta.content:
if not content_started:
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
content_started = True

yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": delta.content}
)
accumulated_text += delta.content
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": delta.content}
)
accumulated_text += delta.content

if hasattr(delta, "tool_calls") and delta.tool_calls:
for tool_call in delta.tool_calls:
tool_id = tool_call.id
tool_calls.setdefault(tool_id, []).append(tool_call)
if hasattr(delta, "tool_calls") and delta.tool_calls:
for tool_call in delta.tool_calls:
tool_id = tool_call.id
tool_calls.setdefault(tool_id, []).append(tool_call)

if hasattr(choice, "finish_reason") and choice.finish_reason:
if content_started:
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
if hasattr(choice, "finish_reason") and choice.finish_reason:
if content_started:
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})

for tool_deltas in tool_calls.values():
yield self.format_chunk(
{"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
)
for tool_deltas in tool_calls.values():
yield self.format_chunk(
{"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
)

for tool_delta in tool_deltas:
if hasattr(tool_delta.function, "arguments"):
yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": "tool",
"data": tool_delta.function.arguments,
}
)
for tool_delta in tool_deltas:
if hasattr(tool_delta.function, "arguments"):
yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": "tool",
"data": tool_delta.function.arguments,
}
)

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})

yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})

if hasattr(chunk, "usage"):
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage})
if hasattr(chunk, "usage"):
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage})

except Exception as e:
if "rate" in str(e).lower() or "429" in str(e):
Expand Down Expand Up @@ -518,7 +519,8 @@ async def structured_output(
formatted_request["tool_choice"] = "any"
formatted_request["parallel_tool_calls"] = False

response = await self.client.chat.complete_async(**formatted_request)
async with mistralai.Mistral(**self.client_args) as client:
response = await client.chat.complete_async(**formatted_request)

if response.choices and response.choices[0].message.tool_calls:
tool_call = response.choices[0].message.tool_calls[0]
Expand Down
8 changes: 4 additions & 4 deletions tests/strands/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
@pytest.fixture
def mistral_client():
with unittest.mock.patch.object(strands.models.mistral.mistralai, "Mistral") as mock_client_cls:
yield mock_client_cls.return_value
mock_client = unittest.mock.AsyncMock()
mock_client_cls.return_value.__aenter__.return_value = mock_client
yield mock_client


@pytest.fixture
Expand All @@ -25,9 +27,7 @@ def max_tokens():


@pytest.fixture
def model(mistral_client, model_id, max_tokens):
_ = mistral_client

def model(model_id, max_tokens):
return MistralModel(model_id=model_id, max_tokens=max_tokens)


Expand Down
Loading