diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 6bcc1359e..17ededa14 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -204,6 +204,9 @@ async def structured_output( Yields: Model events with the last being the structured output. """ + if not supports_response_schema(self.get_config()["model_id"]): + raise ValueError("Model does not support response_format") + response = await litellm.acompletion( **self.client_args, model=self.get_config()["model_id"], @@ -211,8 +214,6 @@ async def structured_output( response_format=output_model, ) - if not supports_response_schema(self.get_config()["model_id"]): - raise ValueError("Model does not support response_format") if len(response.choices) > 1: raise ValueError("Multiple choices found in the response.") diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index f345ba003..bc81fc819 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -289,6 +289,18 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c assert tru_result == exp_result +@pytest.mark.asyncio +async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls): + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=False): + with pytest.raises(ValueError, match="Model does not support response_format"): + stream = model.structured_output(test_output_model_cls, messages) + await stream.__anext__() + + litellm_acompletion.assert_not_called() + + def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings): """Test that unknown config keys emit a warning.""" LiteLLMModel(client_args={"api_key": "test"}, model_id="test-model", invalid_param="test") diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index efdd6a5ed..6cfdd3038 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -34,8 +34,8 @@ def weather(): class Weather(pydantic.BaseModel): """Extracts the time and weather from the user's message with the exact strings.""" - time: str - weather: str + time: str = pydantic.Field(description="The time in HH:MM format (e.g., '12:00', '09:30')") + weather: str = pydantic.Field(description="The weather condition (e.g., 'sunny', 'rainy', 'cloudy')") return Weather(time="12:00", weather="sunny") @@ -43,16 +43,22 @@ class Weather(pydantic.BaseModel): @pytest.fixture def yellow_color(): class Color(pydantic.BaseModel): - """Describes a color.""" + """Describes a color with its basic name. - name: str + Used to extract and normalize color names from text or images. + The color name should be a simple, common color like 'red', 'blue', 'yellow', etc. + """ - @pydantic.field_validator("name", mode="after") + simple_color_name: str = pydantic.Field( + description="The basic color name (e.g., 'red', 'blue', 'yellow', 'green', 'orange', 'purple')" + ) + + @pydantic.field_validator("simple_color_name", mode="after") @classmethod def lower(_, value): return value.lower() - return Color(name="yellow") + return Color(simple_color_name="yellow") def test_agent_invoke(agent):