Skip to content

Commit 8805021

Browse files
fix(test): litellm structured_output test with more descriptive model (#871)
1 parent 4b29edc commit 8805021

File tree

3 files changed

+27
-8
lines changed

3 files changed

+27
-8
lines changed

src/strands/models/litellm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,16 @@ async def structured_output(
204204
Yields:
205205
Model events with the last being the structured output.
206206
"""
207+
if not supports_response_schema(self.get_config()["model_id"]):
208+
raise ValueError("Model does not support response_format")
209+
207210
response = await litellm.acompletion(
208211
**self.client_args,
209212
model=self.get_config()["model_id"],
210213
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
211214
response_format=output_model,
212215
)
213216

214-
if not supports_response_schema(self.get_config()["model_id"]):
215-
raise ValueError("Model does not support response_format")
216217
if len(response.choices) > 1:
217218
raise ValueError("Multiple choices found in the response.")
218219

tests/strands/models/test_litellm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,18 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c
289289
assert tru_result == exp_result
290290

291291

292+
@pytest.mark.asyncio
293+
async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls):
294+
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]
295+
296+
with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=False):
297+
with pytest.raises(ValueError, match="Model does not support response_format"):
298+
stream = model.structured_output(test_output_model_cls, messages)
299+
await stream.__anext__()
300+
301+
litellm_acompletion.assert_not_called()
302+
303+
292304
def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings):
293305
"""Test that unknown config keys emit a warning."""
294306
LiteLLMModel(client_args={"api_key": "test"}, model_id="test-model", invalid_param="test")

tests_integ/models/test_model_litellm.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,31 @@ def weather():
3434
class Weather(pydantic.BaseModel):
3535
"""Extracts the time and weather from the user's message with the exact strings."""
3636

37-
time: str
38-
weather: str
37+
time: str = pydantic.Field(description="The time in HH:MM format (e.g., '12:00', '09:30')")
38+
weather: str = pydantic.Field(description="The weather condition (e.g., 'sunny', 'rainy', 'cloudy')")
3939

4040
return Weather(time="12:00", weather="sunny")
4141

4242

4343
@pytest.fixture
4444
def yellow_color():
4545
class Color(pydantic.BaseModel):
46-
"""Describes a color."""
46+
"""Describes a color with its basic name.
4747
48-
name: str
48+
Used to extract and normalize color names from text or images.
49+
The color name should be a simple, common color like 'red', 'blue', 'yellow', etc.
50+
"""
4951

50-
@pydantic.field_validator("name", mode="after")
52+
simple_color_name: str = pydantic.Field(
53+
description="The basic color name (e.g., 'red', 'blue', 'yellow', 'green', 'orange', 'purple')"
54+
)
55+
56+
@pydantic.field_validator("simple_color_name", mode="after")
5157
@classmethod
5258
def lower(_, value):
5359
return value.lower()
5460

55-
return Color(name="yellow")
61+
return Color(simple_color_name="yellow")
5662

5763

5864
def test_agent_invoke(agent):

0 commit comments

Comments
 (0)