diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 0997637fd..2726dd348 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -6,7 +6,7 @@ import base64 import json import logging -from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, TypeVar, Union from mistralai import Mistral from pydantic import BaseModel @@ -472,7 +472,7 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: @override def structured_output( self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None - ) -> T: + ) -> Generator[dict[str, Union[T, Any]], None, None]: """Get structured output from the model. Args: @@ -507,7 +507,8 @@ def structured_output( arguments = json.loads(tool_call.function.arguments) else: arguments = tool_call.function.arguments - return output_model(**arguments) + yield {"output": output_model(**arguments)} + return except (json.JSONDecodeError, TypeError, ValueError) as e: raise ValueError(f"Failed to parse tool call arguments into model: {e}") from e diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index d52b6eb6c..1b1f02764 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -1,5 +1,6 @@ import unittest.mock +import pydantic import pytest import strands @@ -58,6 +59,15 @@ def system_prompt(): return "You are a helpful assistant" +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + def test__init__model_configs(mistral_client, model_id, max_tokens): _ = mistral_client @@ -440,14 +450,9 @@ def test_stream_other_error(mistral_client, model): list(model.stream({})) -def test_structured_output_success(mistral_client, model): - from pydantic import BaseModel - - class TestModel(BaseModel): - name: str - age: int +def test_structured_output_success(mistral_client, model, test_output_model_cls): + messages = [{"role": "user", "content": [{"text": "Extract data"}]}] - # Mock successful response mock_response = unittest.mock.Mock() mock_response.choices = [unittest.mock.Mock()] mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()] @@ -455,20 +460,14 @@ class TestModel(BaseModel): mistral_client.chat.complete.return_value = mock_response - prompt = [{"role": "user", "content": [{"text": "Extract data"}]}] - result = model.structured_output(TestModel, prompt) - - assert isinstance(result, TestModel) - assert result.name == "John" - assert result.age == 30 - + stream = model.structured_output(test_output_model_cls, messages) -def test_structured_output_no_tool_calls(mistral_client, model): - from pydantic import BaseModel + tru_result = list(stream)[-1] + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result - class TestModel(BaseModel): - name: str +def test_structured_output_no_tool_calls(mistral_client, model, test_output_model_cls): mock_response = unittest.mock.Mock() mock_response.choices = [unittest.mock.Mock()] mock_response.choices[0].message.tool_calls = None @@ -478,15 +477,11 @@ class TestModel(BaseModel): prompt = [{"role": "user", "content": [{"text": "Extract data"}]}] with pytest.raises(ValueError, match="No tool calls found in response"): - model.structured_output(TestModel, prompt) - + stream = model.structured_output(test_output_model_cls, prompt) + next(stream) -def test_structured_output_invalid_json(mistral_client, model): - from pydantic import BaseModel - - class TestModel(BaseModel): - name: str +def test_structured_output_invalid_json(mistral_client, model, test_output_model_cls): mock_response = unittest.mock.Mock() mock_response.choices = [unittest.mock.Mock()] mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()] @@ -497,4 +492,5 @@ class TestModel(BaseModel): prompt = [{"role": "user", "content": [{"text": "Extract data"}]}] with pytest.raises(ValueError, match="Failed to parse tool call arguments into model"): - model.structured_output(TestModel, prompt) + stream = model.structured_output(test_output_model_cls, prompt) + next(stream)