Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,16 @@ async def structured_output(
Yields:
Model events with the last being the structured output.
"""
if not supports_response_schema(self.get_config()["model_id"]):
raise ValueError("Model does not support response_format")

response = await litellm.acompletion(
**self.client_args,
model=self.get_config()["model_id"],
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
response_format=output_model,
)

if not supports_response_schema(self.get_config()["model_id"]):
raise ValueError("Model does not support response_format")
if len(response.choices) > 1:
raise ValueError("Multiple choices found in the response.")

Expand Down
12 changes: 12 additions & 0 deletions tests/strands/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,18 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c
assert tru_result == exp_result


@pytest.mark.asyncio
async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls):
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]

with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=False):
with pytest.raises(ValueError, match="Model does not support response_format"):
stream = model.structured_output(test_output_model_cls, messages)
await stream.__anext__()

litellm_acompletion.assert_not_called()


def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings):
"""Test that unknown config keys emit a warning."""
LiteLLMModel(client_args={"api_key": "test"}, model_id="test-model", invalid_param="test")
Expand Down
18 changes: 12 additions & 6 deletions tests_integ/models/test_model_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,31 @@ def weather():
class Weather(pydantic.BaseModel):
"""Extracts the time and weather from the user's message with the exact strings."""

time: str
weather: str
time: str = pydantic.Field(description="The time in HH:MM format (e.g., '12:00', '09:30')")
weather: str = pydantic.Field(description="The weather condition (e.g., 'sunny', 'rainy', 'cloudy')")

return Weather(time="12:00", weather="sunny")


@pytest.fixture
def yellow_color():
class Color(pydantic.BaseModel):
"""Describes a color."""
"""Describes a color with its basic name.

name: str
Used to extract and normalize color names from text or images.
The color name should be a simple, common color like 'red', 'blue', 'yellow', etc.
"""

@pydantic.field_validator("name", mode="after")
simple_color_name: str = pydantic.Field(
description="The basic color name (e.g., 'red', 'blue', 'yellow', 'green', 'orange', 'purple')"
)

@pydantic.field_validator("simple_color_name", mode="after")
@classmethod
def lower(_, value):
return value.lower()

return Color(name="yellow")
return Color(simple_color_name="yellow")


def test_agent_invoke(agent):
Expand Down
Loading