From 35f72addd45839a4e681f46e911ea05a3f0c21e7 Mon Sep 17 00:00:00 2001 From: siddhantwaghjale Date: Wed, 25 Jun 2025 01:00:49 -0700 Subject: [PATCH 1/3] feat: add Mistral model support to strands --- pyproject.toml | 9 +- src/strands/models/mistral.py | 498 +++++++++++++++++++++++++++ tests-integ/test_model_mistral.py | 157 +++++++++ tests/strands/models/test_mistral.py | 475 +++++++++++++++++++++++++ 4 files changed, 1136 insertions(+), 3 deletions(-) create mode 100644 src/strands/models/mistral.py create mode 100644 tests-integ/test_model_mistral.py create mode 100644 tests/strands/models/test_mistral.py diff --git a/pyproject.toml b/pyproject.toml index b17dcfb21..6244b89bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,9 @@ litellm = [ llamaapi = [ "llama-api-client>=0.1.0,<1.0.0", ] +mistral = [ + "mistralai>=1.8.2", +] ollama = [ "ollama>=0.4.8,<1.0.0", ] @@ -92,7 +95,7 @@ a2a = [ source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel","mistral"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -116,7 +119,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel","mistral"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -132,7 +135,7 @@ extra-args = [ [tool.hatch.envs.dev] dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel"] +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel","mistral"] [tool.hatch.envs.a2a] dev-mode = true diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py new file mode 100644 index 000000000..024a39f24 --- /dev/null +++ b/src/strands/models/mistral.py @@ -0,0 +1,498 @@ +"""Mistral API model provider. + +- Docs: https://docs.mistral.ai/ +""" + +import base64 +import json +import logging +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union + +from mistralai import Mistral +from pydantic import BaseModel +from typing_extensions import TypedDict, Unpack, override + +from ..types.content import ContentBlock, Messages +from ..types.exceptions import ModelThrottledException +from ..types.models import Model +from ..types.streaming import StopReason, StreamEvent +from ..types.tools import ToolResult, ToolSpec, ToolUse + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class MistralModel(Model): + """Mistral API model provider implementation. + + The implementation handles Mistral-specific features such as: + + - Chat and text completions + - Streaming responses + - Tool/function calling + - System prompts + """ + + class MistralConfig(TypedDict, total=False): + """Configuration parameters for Mistral models. + + Attributes: + model_id: Mistral model ID (e.g., "mistral-large-latest", "mistral-medium-latest"). + max_tokens: Maximum number of tokens to generate in the response. + temperature: Controls randomness in generation (0.0 to 1.0). + top_p: Controls diversity via nucleus sampling. + streaming: Whether to enable streaming responses. + """ + + model_id: str + max_tokens: Optional[int] + temperature: Optional[float] + top_p: Optional[float] + streaming: Optional[bool] + + def __init__( + self, + api_key: Optional[str] = None, + *, + client_args: Optional[dict[str, Any]] = None, + **model_config: Unpack[MistralConfig], + ) -> None: + """Initialize provider instance. + + Args: + api_key: Mistral API key. If not provided, will use MISTRAL_API_KEY env var. + client_args: Additional arguments for the Mistral client. + **model_config: Configuration options for the Mistral model. + """ + self.config = MistralModel.MistralConfig(**model_config) + + # Set default streaming to True if not specified + if "streaming" not in self.config: + self.config["streaming"] = True + + logger.debug("config=<%s> | initializing", self.config) + + client_args = client_args or {} + if api_key: + client_args["api_key"] = api_key + + self.client = Mistral(**client_args) + + @override + def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: ignore + """Update the Mistral Model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> MistralConfig: + """Get the Mistral model configuration. + + Returns: + The Mistral model configuration. + """ + return self.config + + def _format_request_message_content(self, content: ContentBlock) -> Union[str, Dict[str, Any]]: + """Format a Mistral content block. + + Args: + content: Message content. + + Returns: + Mistral formatted content. + + Raises: + TypeError: If the content block type cannot be converted to a Mistral-compatible format. + """ + if "text" in content: + return content["text"] + + if "image" in content: + image_data = content["image"] + + if "source" in image_data: + image_bytes = image_data["source"]["bytes"] + base64_data = base64.b64encode(image_bytes).decode("utf-8") + format_value = image_data.get("format", "jpeg") + media_type = f"image/{format_value}" + return {"type": "image_url", "image_url": f"data:{media_type};base64,{base64_data}"} + + # if "url" in image_data: + # return {"type": "image_url", "image_url": image_data["url"]} + + raise TypeError("content_type= | unsupported image format") + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: + """Format a Mistral tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + Mistral formatted tool call. + """ + return { + "function": { + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]: + """Format a Mistral tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + Mistral formatted tool message. + """ + content_parts: List[str] = [] + for content in tool_result["content"]: + if "json" in content: + content_parts.append(json.dumps(content["json"])) + elif "text" in content: + content_parts.append(content["text"]) + + return { + "role": "tool", + "name": tool_result["toolUseId"].split("_")[0] + if "_" in tool_result["toolUseId"] + else tool_result["toolUseId"], + "content": "\n".join(content_parts), + "tool_call_id": tool_result["toolUseId"], + } + + def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format a Mistral compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + A Mistral compatible messages array. + """ + formatted_messages: list[dict[str, Any]] = [] + + if system_prompt: + formatted_messages.append({"role": "system", "content": system_prompt}) + + for message in messages: + role = message["role"] + contents = message["content"] + + text_contents: List[str] = [] + tool_calls: List[Dict[str, Any]] = [] + tool_messages: List[Dict[str, Any]] = [] + + for content in contents: + if "text" in content: + formatted_content = self._format_request_message_content(content) + if isinstance(formatted_content, str): + text_contents.append(formatted_content) + elif "toolUse" in content: + tool_calls.append(self._format_request_message_tool_call(content["toolUse"])) + elif "toolResult" in content: + tool_messages.append(self._format_request_tool_message(content["toolResult"])) + + if text_contents or tool_calls: + formatted_message: Dict[str, Any] = { + "role": role, + "content": " ".join(text_contents) if text_contents else "", + } + + if tool_calls: + formatted_message["tool_calls"] = tool_calls + + formatted_messages.append(formatted_message) + + formatted_messages.extend(tool_messages) + + return formatted_messages + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format a Mistral chat streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + A Mistral chat streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to a Mistral-compatible + format. + """ + request: Dict[str, Any] = { + "model": self.config["model_id"], + "messages": self._format_request_messages(messages, system_prompt), + } + + if "max_tokens" in self.config: + request["max_tokens"] = self.config["max_tokens"] + if "temperature" in self.config: + request["temperature"] = self.config["temperature"] + if "top_p" in self.config: + request["top_p"] = self.config["top_p"] + + if tool_specs: + request["tools"] = [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs + ] + + return request + + @override + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format the Mistral response events into standardized message chunks. + + Args: + event: A response event from the Mistral model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "text": + return {"contentBlockStart": {"start": {}}} + + tool_call = event["data"] + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": tool_call.function.name, + "toolUseId": tool_call.id, + } + } + } + } + + case "content_delta": + if event["data_type"] == "text": + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"]}}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + reason: StopReason + if event["data"] == "tool_calls": + reason = "tool_use" + elif event["data"] == "length": + reason = "max_tokens" + else: + reason = "end_turn" + + return {"messageStop": {"stopReason": reason}} + + case "metadata": + usage = event["data"] + return { + "metadata": { + "usage": { + "inputTokens": usage.prompt_tokens, + "outputTokens": usage.completion_tokens, + "totalTokens": usage.total_tokens, + }, + "metrics": { + "latencyMs": event.get("latency_ms", 0), + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") + + def _handle_non_streaming_response(self, response: Any) -> Iterable[dict[str, Any]]: + """Handle non-streaming response from Mistral API. + + Args: + response: The non-streaming response from Mistral. + + Yields: + Formatted events that match the streaming format. + """ + yield {"chunk_type": "message_start"} + + content_started = False + + if response.choices and response.choices[0].message: + message = response.choices[0].message + + if hasattr(message, "content") and message.content: + if not content_started: + yield {"chunk_type": "content_start", "data_type": "text"} + content_started = True + + yield {"chunk_type": "content_delta", "data_type": "text", "data": message.content} + + yield {"chunk_type": "content_stop"} + + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_call} + + if hasattr(tool_call.function, "arguments"): + yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_call.function.arguments} + + yield {"chunk_type": "content_stop"} + + finish_reason = response.choices[0].finish_reason if response.choices[0].finish_reason else "stop" + yield {"chunk_type": "message_stop", "data": finish_reason} + + if hasattr(response, "usage") and response.usage: + yield {"chunk_type": "metadata", "data": response.usage} + + @override + def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + """Send the request to the Mistral model and get the streaming response. + + Args: + request: The formatted request to send to the Mistral model. + + Returns: + An iterable of response events from the Mistral model. + + Raises: + ModelThrottledException: When the model service is throttling requests. + """ + try: + if self.config.get("streaming", True) is False: + # Use non-streaming API + response = self.client.chat.complete(**request) + yield from self._handle_non_streaming_response(response) + return + + # Use the streaming API + stream_response = self.client.chat.stream(**request) + + yield {"chunk_type": "message_start"} + + content_started = False + current_tool_calls: Dict[str, Dict[str, str]] = {} + accumulated_text = "" + + for chunk in stream_response: + if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices: + choice = chunk.data.choices[0] + + if hasattr(choice, "delta"): + delta = choice.delta + + if hasattr(delta, "content") and delta.content: + if not content_started: + yield {"chunk_type": "content_start", "data_type": "text"} + content_started = True + + yield {"chunk_type": "content_delta", "data_type": "text", "data": delta.content} + accumulated_text += delta.content + + if hasattr(delta, "tool_calls") and delta.tool_calls: + for tool_call in delta.tool_calls: + tool_id = tool_call.id + + if tool_id not in current_tool_calls: + yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_call} + current_tool_calls[tool_id] = {"name": tool_call.function.name, "arguments": ""} + + if hasattr(tool_call.function, "arguments"): + current_tool_calls[tool_id]["arguments"] += tool_call.function.arguments + yield { + "chunk_type": "content_delta", + "data_type": "tool", + "data": tool_call.function.arguments, + } + + if hasattr(choice, "finish_reason") and choice.finish_reason: + if content_started: + yield {"chunk_type": "content_stop", "data_type": "text"} + + for _ in current_tool_calls: + yield {"chunk_type": "content_stop", "data_type": "tool"} + + yield {"chunk_type": "message_stop", "data": choice.finish_reason} + + if hasattr(chunk, "usage"): + yield {"chunk_type": "metadata", "data": chunk.usage} + + except Exception as e: + if "rate" in str(e).lower() or "429" in str(e): + raise ModelThrottledException(str(e)) from e + raise + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model: The output model to use for the agent. + prompt: The prompt messages to use for the agent. + callback_handler: Optional callback handler for processing events. + + Returns: + An instance of the output model with the generated data. + + Raises: + ValueError: If the response cannot be parsed into the output model. + """ + tool_spec: ToolSpec = { + "name": f"extract_{output_model.__name__.lower()}", + "description": f"Extract structured data in the format of {output_model.__name__}", + "inputSchema": {"json": output_model.model_json_schema()}, + } + + formatted_request = self.format_request(messages=prompt, tool_specs=[tool_spec]) + + formatted_request["tool_choice"] = "any" + formatted_request["parallel_tool_calls"] = False + + response = self.client.chat.complete(**formatted_request) + + if response.choices and response.choices[0].message.tool_calls: + tool_call = response.choices[0].message.tool_calls[0] + try: + # Handle both string and dict arguments + if isinstance(tool_call.function.arguments, str): + arguments = json.loads(tool_call.function.arguments) + else: + arguments = tool_call.function.arguments + return output_model(**arguments) + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse tool call arguments into model: {e}") from e + + raise ValueError("No tool calls found in response") diff --git a/tests-integ/test_model_mistral.py b/tests-integ/test_model_mistral.py new file mode 100644 index 000000000..d52afb896 --- /dev/null +++ b/tests-integ/test_model_mistral.py @@ -0,0 +1,157 @@ +import os + +import pytest +from pydantic import BaseModel + +import strands +from strands import Agent +from strands.models.mistral import MistralModel + + +@pytest.fixture +def streaming_model(): + return MistralModel( + model_id="mistral-medium-latest", + api_key=os.getenv("MISTRAL_API_KEY"), + streaming=True, + temperature=0.7, + max_tokens=1000, + top_p=0.9, + ) + + +@pytest.fixture +def non_streaming_model(): + return MistralModel( + model_id="mistral-medium-latest", + api_key=os.getenv("MISTRAL_API_KEY"), + streaming=False, + temperature=0.7, + max_tokens=1000, + top_p=0.9, + ) + + +@pytest.fixture +def system_prompt(): + return "You are an AI assistant that provides helpful and accurate information." + + +@pytest.fixture +def calculator_tool(): + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + return eval(expression) + + return calculator + + +@pytest.fixture +def weather_tools(): + @strands.tool + def tool_time() -> str: + """Get the current time.""" + return "12:00" + + @strands.tool + def tool_weather() -> str: + """Get the current weather.""" + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def streaming_agent(streaming_model): + return Agent(model=streaming_model) + + +@pytest.fixture +def non_streaming_agent(non_streaming_model): + return Agent(model=non_streaming_model) + + +@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") +def test_streaming_agent_basic(streaming_agent): + """Test basic streaming agent functionality.""" + result = streaming_agent("Tell me about Agentic AI in one sentence.") + + assert len(str(result)) > 0 + assert hasattr(result, "message") + assert "content" in result.message + + +@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") +def test_non_streaming_agent_basic(non_streaming_agent): + """Test basic non-streaming agent functionality.""" + result = non_streaming_agent("Tell me about Agentic AI in one sentence.") + + assert len(str(result)) > 0 + assert hasattr(result, "message") + assert "content" in result.message + + +@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") +def test_tool_use_streaming(streaming_model): + """Test tool use with streaming model.""" + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + return eval(expression) + + agent = Agent(model=streaming_model, tools=[calculator]) + result = agent("What is the square root of 1764") + + # Verify the result contains the calculation + text_content = str(result).lower() + assert "42" in text_content + + +@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") +def test_tool_use_non_streaming(non_streaming_model): + """Test tool use with non-streaming model.""" + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + return eval(expression) + + agent = Agent(model=non_streaming_model, tools=[calculator], load_tools_from_directory=False) + result = agent("What is the square root of 1764") + + text_content = str(result).lower() + assert "42" in text_content + + +@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") +def test_structured_output_streaming(streaming_model): + """Test structured output with streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=streaming_model) + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" + + +@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") +def test_structured_output_non_streaming(non_streaming_model): + """Test structured output with non-streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=non_streaming_model) + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py new file mode 100644 index 000000000..b194f0cce --- /dev/null +++ b/tests/strands/models/test_mistral.py @@ -0,0 +1,475 @@ +import unittest.mock + +import pytest + +import strands +from strands.models.mistral import MistralModel +from strands.types.exceptions import ModelThrottledException + + +@pytest.fixture +def mistral_client(): + with unittest.mock.patch.object(strands.models.mistral, "Mistral") as mock_client_cls: + yield mock_client_cls.return_value + + +@pytest.fixture +def model_id(): + return "mistral-large-latest" + + +@pytest.fixture +def max_tokens(): + return 100 + + +@pytest.fixture +def model(mistral_client, model_id, max_tokens): + _ = mistral_client + + return MistralModel(model_id=model_id, max_tokens=max_tokens) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant" + + +def test__init__model_configs(mistral_client, model_id, max_tokens): + _ = mistral_client + + model = MistralModel(model_id=model_id, max_tokens=max_tokens, temperature=0.7) + + tru_temperature = model.get_config().get("temperature") + exp_temperature = 0.7 + + assert tru_temperature == exp_temperature + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +def test_format_request_default(model, messages, model_id): + tru_request = model.format_request(messages) + exp_request = { + "model": model_id, + "messages": [{"role": "user", "content": "test"}], + "max_tokens": 100, + } + + assert tru_request == exp_request + + +def test_format_request_with_temperature(model, messages, model_id): + model.update_config(temperature=0.8) + + tru_request = model.format_request(messages) + exp_request = { + "model": model_id, + "messages": [{"role": "user", "content": "test"}], + "max_tokens": 100, + "temperature": 0.8, + } + + assert tru_request == exp_request + + +def test_format_request_with_system_prompt(model, messages, model_id, system_prompt): + tru_request = model.format_request(messages, system_prompt=system_prompt) + exp_request = { + "model": model_id, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "test"}, + ], + "max_tokens": 100, + } + + assert tru_request == exp_request + + +def test_format_request_with_tool_use(model, model_id): + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "calc_123", + "name": "calculator", + "input": {"expression": "2+2"}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "model": model_id, + "messages": [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "calculator", + "arguments": '{"expression": "2+2"}', + }, + "id": "calc_123", + "type": "function", + } + ], + } + ], + "max_tokens": 100, + } + + assert tru_request == exp_request + + +def test_format_request_with_tool_result(model, model_id): + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "calc_123", + "status": "success", + "content": [{"text": "4"}, {"json": {"result": 4}}], + } + } + ], + } + ] + + tru_request = model.format_request(messages) + exp_request = { + "model": model_id, + "messages": [ + { + "role": "tool", + "name": "calc", + "content": '4\n{"result": 4}', + "tool_call_id": "calc_123", + } + ], + "max_tokens": 100, + } + + assert tru_request == exp_request + + +def test_format_request_with_tool_specs(model, messages, model_id): + tool_specs = [ + { + "name": "calculator", + "description": "Calculate mathematical expressions", + "inputSchema": { + "json": { + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + } + }, + } + ] + + tru_request = model.format_request(messages, tool_specs) + exp_request = { + "model": model_id, + "messages": [{"role": "user", "content": "test"}], + "max_tokens": 100, + "tools": [ + { + "type": "function", + "function": { + "name": "calculator", + "description": "Calculate mathematical expressions", + "parameters": { + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + }, + }, + } + ], + } + + assert tru_request == exp_request + + +def test_format_request_with_all_optional_params(model, messages, model_id): + model.update_config( + temperature=0.7, + top_p=0.9, + ) + + tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object"}}, + } + ] + + tru_request = model.format_request(messages, tool_specs) + exp_request = { + "model": model_id, + "messages": [{"role": "user", "content": "test"}], + "max_tokens": 100, + "temperature": 0.7, + "top_p": 0.9, + "tools": [ + { + "type": "function", + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": {"type": "object"}, + }, + } + ], + } + + assert tru_request == exp_request + + +def test_format_chunk_message_start(model): + event = {"chunk_type": "message_start"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStart": {"role": "assistant"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_start_text(model): + event = {"chunk_type": "content_start", "data_type": "text"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockStart": {"start": {}}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_start_tool(model): + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function.name = "calculator" + mock_tool_call.id = "calc_123" + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calc_123"}}}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_delta_text(model): + event = {"chunk_type": "content_delta", "data_type": "text", "data": "Hello"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_delta_tool(model): + event = { + "chunk_type": "content_delta", + "data_type": "tool", + "data": '{"expression": "2+2"}', + } + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_stop(model): + event = {"chunk_type": "content_stop"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockStop": {}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_message_stop_end_turn(model): + event = {"chunk_type": "message_stop", "data": "stop"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStop": {"stopReason": "end_turn"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_message_stop_tool_use(model): + event = {"chunk_type": "message_stop", "data": "tool_calls"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStop": {"stopReason": "tool_use"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_message_stop_max_tokens(model): + event = {"chunk_type": "message_stop", "data": "length"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStop": {"stopReason": "max_tokens"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_metadata(model): + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + + event = { + "chunk_type": "metadata", + "data": mock_usage, + "latency_ms": 250, + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "metadata": { + "usage": { + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + }, + "metrics": { + "latencyMs": 250, + }, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_metadata_no_latency(model): + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + + event = { + "chunk_type": "metadata", + "data": mock_usage, + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "metadata": { + "usage": { + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + }, + "metrics": { + "latencyMs": 0, + }, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_unknown(model): + event = {"chunk_type": "unknown"} + + with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): + model.format_chunk(event) + + +def test_stream_rate_limit_error(mistral_client, model): + mistral_client.chat.stream.side_effect = Exception("rate limit exceeded (429)") + + with pytest.raises(ModelThrottledException, match="rate limit exceeded"): + list(model.stream({})) + + +def test_stream_other_error(mistral_client, model): + mistral_client.chat.stream.side_effect = Exception("some other error") + + with pytest.raises(Exception, match="some other error"): + list(model.stream({})) + + +def test_structured_output_success(mistral_client, model): + from pydantic import BaseModel + + class TestModel(BaseModel): + name: str + age: int + + # Mock successful response + mock_response = unittest.mock.Mock() + mock_response.choices = [unittest.mock.Mock()] + mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()] + mock_response.choices[0].message.tool_calls[0].function.arguments = '{"name": "John", "age": 30}' + + mistral_client.chat.complete.return_value = mock_response + + prompt = [{"role": "user", "content": [{"text": "Extract data"}]}] + result = model.structured_output(TestModel, prompt) + + assert isinstance(result, TestModel) + assert result.name == "John" + assert result.age == 30 + + +def test_structured_output_no_tool_calls(mistral_client, model): + from pydantic import BaseModel + + class TestModel(BaseModel): + name: str + + mock_response = unittest.mock.Mock() + mock_response.choices = [unittest.mock.Mock()] + mock_response.choices[0].message.tool_calls = None + + mistral_client.chat.complete.return_value = mock_response + + prompt = [{"role": "user", "content": [{"text": "Extract data"}]}] + + with pytest.raises(ValueError, match="No tool calls found in response"): + model.structured_output(TestModel, prompt) + + +def test_structured_output_invalid_json(mistral_client, model): + from pydantic import BaseModel + + class TestModel(BaseModel): + name: str + + mock_response = unittest.mock.Mock() + mock_response.choices = [unittest.mock.Mock()] + mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()] + mock_response.choices[0].message.tool_calls[0].function.arguments = "invalid json" + + mistral_client.chat.complete.return_value = mock_response + + prompt = [{"role": "user", "content": [{"text": "Extract data"}]}] + + with pytest.raises(ValueError, match="Failed to parse tool call arguments into model"): + model.structured_output(TestModel, prompt) From 89b3b2eaa72bd3db118ad4af945e075e073acb1f Mon Sep 17 00:00:00 2001 From: siddhantwaghjale Date: Thu, 26 Jun 2025 21:33:31 -0700 Subject: [PATCH 2/3] fix: testcase and better error handling --- src/strands/models/mistral.py | 34 ++++++--- tests-integ/test_model_mistral.py | 4 +- tests/strands/models/test_mistral.py | 105 +++++++++++++++++---------- 3 files changed, 92 insertions(+), 51 deletions(-) diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 024a39f24..627bc6ad4 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -42,14 +42,14 @@ class MistralConfig(TypedDict, total=False): max_tokens: Maximum number of tokens to generate in the response. temperature: Controls randomness in generation (0.0 to 1.0). top_p: Controls diversity via nucleus sampling. - streaming: Whether to enable streaming responses. + stream: Whether to enable streaming responses. """ model_id: str max_tokens: Optional[int] temperature: Optional[float] top_p: Optional[float] - streaming: Optional[bool] + stream: Optional[bool] def __init__( self, @@ -65,11 +65,28 @@ def __init__( client_args: Additional arguments for the Mistral client. **model_config: Configuration options for the Mistral model. """ + if "temperature" in model_config and model_config["temperature"] is not None: + temp = model_config["temperature"] + if not 0.0 <= temp <= 1.0: + raise ValueError(f"temperature must be between 0.0 and 1.0, got {temp}") + # Warn if temperature is above recommended range + if temp > 0.7: + logger.warning( + "temperature=%s is above the recommended range (0.0-0.7). " + "High values may produce unpredictable results.", + temp, + ) + + if "top_p" in model_config and model_config["top_p"] is not None: + top_p = model_config["top_p"] + if not 0.0 <= top_p <= 1.0: + raise ValueError(f"top_p must be between 0.0 and 1.0, got {top_p}") + self.config = MistralModel.MistralConfig(**model_config) - # Set default streaming to True if not specified - if "streaming" not in self.config: - self.config["streaming"] = True + # Set default stream to True if not specified + if "stream" not in self.config: + self.config["stream"] = True logger.debug("config=<%s> | initializing", self.config) @@ -122,9 +139,6 @@ def _format_request_message_content(self, content: ContentBlock) -> Union[str, D media_type = f"image/{format_value}" return {"type": "image_url", "image_url": f"data:{media_type};base64,{base64_data}"} - # if "url" in image_data: - # return {"type": "image_url", "image_url": image_data["url"]} - raise TypeError("content_type= | unsupported image format") raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @@ -249,6 +263,8 @@ def format_request( request["temperature"] = self.config["temperature"] if "top_p" in self.config: request["top_p"] = self.config["top_p"] + if "stream" in self.config: + request["stream"] = self.config["stream"] if tool_specs: request["tools"] = [ @@ -390,7 +406,7 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: ModelThrottledException: When the model service is throttling requests. """ try: - if self.config.get("streaming", True) is False: + if self.config.get("stream", True) is False: # Use non-streaming API response = self.client.chat.complete(**request) yield from self._handle_non_streaming_response(response) diff --git a/tests-integ/test_model_mistral.py b/tests-integ/test_model_mistral.py index d52afb896..f2664f7fd 100644 --- a/tests-integ/test_model_mistral.py +++ b/tests-integ/test_model_mistral.py @@ -13,7 +13,7 @@ def streaming_model(): return MistralModel( model_id="mistral-medium-latest", api_key=os.getenv("MISTRAL_API_KEY"), - streaming=True, + stream=True, temperature=0.7, max_tokens=1000, top_p=0.9, @@ -25,7 +25,7 @@ def non_streaming_model(): return MistralModel( model_id="mistral-medium-latest", api_key=os.getenv("MISTRAL_API_KEY"), - streaming=False, + stream=False, temperature=0.7, max_tokens=1000, top_p=0.9, diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index b194f0cce..d52b6eb6c 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -35,6 +35,24 @@ def messages(): return [{"role": "user", "content": [{"text": "test"}]}] +@pytest.fixture +def tool_use_messages(): + return [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "calc_123", + "name": "calculator", + "input": {"expression": "2+2"}, + }, + }, + ], + } + ] + + @pytest.fixture def system_prompt(): return "You are a helpful assistant" @@ -45,48 +63,50 @@ def test__init__model_configs(mistral_client, model_id, max_tokens): model = MistralModel(model_id=model_id, max_tokens=max_tokens, temperature=0.7) - tru_temperature = model.get_config().get("temperature") + actual_temperature = model.get_config().get("temperature") exp_temperature = 0.7 - assert tru_temperature == exp_temperature + assert actual_temperature == exp_temperature def test_update_config(model, model_id): model.update_config(model_id=model_id) - tru_model_id = model.get_config().get("model_id") + actual_model_id = model.get_config().get("model_id") exp_model_id = model_id - assert tru_model_id == exp_model_id + assert actual_model_id == exp_model_id def test_format_request_default(model, messages, model_id): - tru_request = model.format_request(messages) + actual_request = model.format_request(messages) exp_request = { "model": model_id, "messages": [{"role": "user", "content": "test"}], "max_tokens": 100, + "stream": True, } - assert tru_request == exp_request + assert actual_request == exp_request def test_format_request_with_temperature(model, messages, model_id): model.update_config(temperature=0.8) - tru_request = model.format_request(messages) + actual_request = model.format_request(messages) exp_request = { "model": model_id, "messages": [{"role": "user", "content": "test"}], "max_tokens": 100, "temperature": 0.8, + "stream": True, } - assert tru_request == exp_request + assert actual_request == exp_request def test_format_request_with_system_prompt(model, messages, model_id, system_prompt): - tru_request = model.format_request(messages, system_prompt=system_prompt) + actual_request = model.format_request(messages, system_prompt=system_prompt) exp_request = { "model": model_id, "messages": [ @@ -94,9 +114,10 @@ def test_format_request_with_system_prompt(model, messages, model_id, system_pro {"role": "user", "content": "test"}, ], "max_tokens": 100, + "stream": True, } - assert tru_request == exp_request + assert actual_request == exp_request def test_format_request_with_tool_use(model, model_id): @@ -115,7 +136,7 @@ def test_format_request_with_tool_use(model, model_id): }, ] - tru_request = model.format_request(messages) + actual_request = model.format_request(messages) exp_request = { "model": model_id, "messages": [ @@ -135,9 +156,10 @@ def test_format_request_with_tool_use(model, model_id): } ], "max_tokens": 100, + "stream": True, } - assert tru_request == exp_request + assert actual_request == exp_request def test_format_request_with_tool_result(model, model_id): @@ -156,7 +178,7 @@ def test_format_request_with_tool_result(model, model_id): } ] - tru_request = model.format_request(messages) + actual_request = model.format_request(messages) exp_request = { "model": model_id, "messages": [ @@ -168,9 +190,10 @@ def test_format_request_with_tool_result(model, model_id): } ], "max_tokens": 100, + "stream": True, } - assert tru_request == exp_request + assert actual_request == exp_request def test_format_request_with_tool_specs(model, messages, model_id): @@ -188,11 +211,12 @@ def test_format_request_with_tool_specs(model, messages, model_id): } ] - tru_request = model.format_request(messages, tool_specs) + actual_request = model.format_request(messages, tool_specs) exp_request = { "model": model_id, "messages": [{"role": "user", "content": "test"}], "max_tokens": 100, + "stream": True, "tools": [ { "type": "function", @@ -209,7 +233,7 @@ def test_format_request_with_tool_specs(model, messages, model_id): ], } - assert tru_request == exp_request + assert actual_request == exp_request def test_format_request_with_all_optional_params(model, messages, model_id): @@ -226,13 +250,14 @@ def test_format_request_with_all_optional_params(model, messages, model_id): } ] - tru_request = model.format_request(messages, tool_specs) + actual_request = model.format_request(messages, tool_specs) exp_request = { "model": model_id, "messages": [{"role": "user", "content": "test"}], "max_tokens": 100, "temperature": 0.7, "top_p": 0.9, + "stream": True, "tools": [ { "type": "function", @@ -245,25 +270,25 @@ def test_format_request_with_all_optional_params(model, messages, model_id): ], } - assert tru_request == exp_request + assert actual_request == exp_request def test_format_chunk_message_start(model): event = {"chunk_type": "message_start"} - tru_chunk = model.format_chunk(event) + actual_chunk = model.format_chunk(event) exp_chunk = {"messageStart": {"role": "assistant"}} - assert tru_chunk == exp_chunk + assert actual_chunk == exp_chunk def test_format_chunk_content_start_text(model): event = {"chunk_type": "content_start", "data_type": "text"} - tru_chunk = model.format_chunk(event) + actual_chunk = model.format_chunk(event) exp_chunk = {"contentBlockStart": {"start": {}}} - assert tru_chunk == exp_chunk + assert actual_chunk == exp_chunk def test_format_chunk_content_start_tool(model): @@ -273,19 +298,19 @@ def test_format_chunk_content_start_tool(model): event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call} - tru_chunk = model.format_chunk(event) + actual_chunk = model.format_chunk(event) exp_chunk = {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calc_123"}}}} - assert tru_chunk == exp_chunk + assert actual_chunk == exp_chunk def test_format_chunk_content_delta_text(model): event = {"chunk_type": "content_delta", "data_type": "text", "data": "Hello"} - tru_chunk = model.format_chunk(event) + actual_chunk = model.format_chunk(event) exp_chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} - assert tru_chunk == exp_chunk + assert actual_chunk == exp_chunk def test_format_chunk_content_delta_tool(model): @@ -295,46 +320,46 @@ def test_format_chunk_content_delta_tool(model): "data": '{"expression": "2+2"}', } - tru_chunk = model.format_chunk(event) + actual_chunk = model.format_chunk(event) exp_chunk = {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}} - assert tru_chunk == exp_chunk + assert actual_chunk == exp_chunk def test_format_chunk_content_stop(model): event = {"chunk_type": "content_stop"} - tru_chunk = model.format_chunk(event) + actual_chunk = model.format_chunk(event) exp_chunk = {"contentBlockStop": {}} - assert tru_chunk == exp_chunk + assert actual_chunk == exp_chunk def test_format_chunk_message_stop_end_turn(model): event = {"chunk_type": "message_stop", "data": "stop"} - tru_chunk = model.format_chunk(event) + actual_chunk = model.format_chunk(event) exp_chunk = {"messageStop": {"stopReason": "end_turn"}} - assert tru_chunk == exp_chunk + assert actual_chunk == exp_chunk def test_format_chunk_message_stop_tool_use(model): event = {"chunk_type": "message_stop", "data": "tool_calls"} - tru_chunk = model.format_chunk(event) + actual_chunk = model.format_chunk(event) exp_chunk = {"messageStop": {"stopReason": "tool_use"}} - assert tru_chunk == exp_chunk + assert actual_chunk == exp_chunk def test_format_chunk_message_stop_max_tokens(model): event = {"chunk_type": "message_stop", "data": "length"} - tru_chunk = model.format_chunk(event) + actual_chunk = model.format_chunk(event) exp_chunk = {"messageStop": {"stopReason": "max_tokens"}} - assert tru_chunk == exp_chunk + assert actual_chunk == exp_chunk def test_format_chunk_metadata(model): @@ -349,7 +374,7 @@ def test_format_chunk_metadata(model): "latency_ms": 250, } - tru_chunk = model.format_chunk(event) + actual_chunk = model.format_chunk(event) exp_chunk = { "metadata": { "usage": { @@ -363,7 +388,7 @@ def test_format_chunk_metadata(model): }, } - assert tru_chunk == exp_chunk + assert actual_chunk == exp_chunk def test_format_chunk_metadata_no_latency(model): @@ -377,7 +402,7 @@ def test_format_chunk_metadata_no_latency(model): "data": mock_usage, } - tru_chunk = model.format_chunk(event) + actual_chunk = model.format_chunk(event) exp_chunk = { "metadata": { "usage": { @@ -391,7 +416,7 @@ def test_format_chunk_metadata_no_latency(model): }, } - assert tru_chunk == exp_chunk + assert actual_chunk == exp_chunk def test_format_chunk_unknown(model): From fdefc7104c7cc82ee7d932b9970d5026b5897f6a Mon Sep 17 00:00:00 2001 From: siddhantwaghjale Date: Thu, 26 Jun 2025 22:13:09 -0700 Subject: [PATCH 3/3] fix: Minor doc update --- src/strands/models/mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 627bc6ad4..0997637fd 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -1,4 +1,4 @@ -"""Mistral API model provider. +"""Mistral AI model provider. - Docs: https://docs.mistral.ai/ """