diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 4afc8e3dc..a95b0d027 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -370,6 +370,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + *, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 8c9716a4f..98c5c65b2 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -571,6 +571,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + *, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 17ededa14..005eed3df 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -114,6 +114,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + *, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 4e801026c..013cd2c7d 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -330,6 +330,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + *, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 25d42a6c8..22a3a3873 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -513,6 +513,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + *, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 90cd1b5d8..b6459d63f 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -397,6 +397,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + *, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index 7a8b4d4cc..7f178660a 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -70,6 +70,7 @@ def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + *, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncIterable[StreamEvent]: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index c29772215..574b24200 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -287,6 +287,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + *, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index a41d478ae..7af81be84 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -357,6 +357,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + *, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index f635acce2..d1447732e 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -292,6 +292,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + *, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index 07119a21a..a54fc44c3 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -355,6 +355,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + *, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 175358578..4a9b80364 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -94,10 +94,36 @@ async def test_stream(model, messages, tool_specs, system_prompt, alist): @pytest.mark.asyncio -async def test_structured_output(model, alist): +async def test_structured_output(model, messages, system_prompt, alist): response = model.structured_output(Person, prompt=messages, system_prompt=system_prompt) events = await alist(response) tru_output = events[-1]["output"] exp_output = Person(name="test", age=20) assert tru_output == exp_output + + +@pytest.mark.asyncio +async def test_stream_without_tool_choice_parameter(messages, alist): + """Test that model implementations without tool_choice parameter are still valid.""" + class LegacyModel(SAModel): + def update_config(self, **model_config): + return model_config + + def get_config(self): + return + + async def structured_output(self, output_model, prompt=None, system_prompt=None, **kwargs): + yield {"output": output_model(name="test", age=20)} + + async def stream(self, messages, tool_specs=None, system_prompt=None): + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockDelta": {"delta": {"text": "Legacy model works"}}} + yield {"messageStop": {"stopReason": "end_turn"}} + + model = LegacyModel() + response = model.stream(messages) + events = await alist(response) + + assert len(events) == 3 + assert events[1]["contentBlockDelta"]["delta"]["text"] == "Legacy model works"