Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 3 additions & 20 deletions src/strands/tools/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
"properties": {},
}

# Add title if present
if "title" in schema:
flattened["title"] = schema["title"]

# Add description from schema if present, or use model docstring
if "description" in schema and schema["description"]:
flattened["description"] = schema["description"]

# Process properties
required_props: list[str] = []
if "properties" not in schema and "$ref" in schema:
raise ValueError("Circular reference detected and not supported.")
if "properties" in schema:
required_props = []
for prop_name, prop_value in schema["properties"].items():
Expand Down Expand Up @@ -76,9 +76,6 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]:

if len(required_props) > 0:
flattened["required"] = required_props
else:
raise ValueError("Circular reference detected and not supported")

return flattened


Expand Down Expand Up @@ -325,21 +322,7 @@ def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) ->
continue

field_type = field.annotation

# Handle Optional types
is_optional = False
if (
field_type is not None
and hasattr(field_type, "__origin__")
and field_type.__origin__ is Union
and hasattr(field_type, "__args__")
):
# Look for Optional[BaseModel]
for arg in field_type.__args__:
if arg is type(None):
is_optional = True
elif isinstance(arg, type) and issubclass(arg, BaseModel):
field_type = arg
is_optional = not field.is_required()

# If this is a BaseModel field, expand its properties with full details
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
Expand Down
82 changes: 81 additions & 1 deletion tests/strands/tools/test_structured_output.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import List, Literal, Optional

import pytest
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -157,6 +157,7 @@ def test_convert_pydantic_to_tool_spec_multiple_same_type():
"user2": {
"type": ["object", "null"],
"description": "The second user",
"title": "UserWithPlanet",
"properties": {
"name": {"description": "The name of the user", "title": "Name", "type": "string"},
"age": {
Expand Down Expand Up @@ -208,6 +209,85 @@ class NodeWithCircularRef(BaseModel):
convert_pydantic_to_tool_spec(NodeWithCircularRef)


def test_convert_pydantic_with_circular_required_dependency():
"""Test that the tool handles circular dependencies gracefully."""

class NodeWithCircularRef(BaseModel):
"""A node with a circular reference to itself."""

name: str = Field(description="The name of the node")
parent: "NodeWithCircularRef"

with pytest.raises(ValueError, match="Circular reference detected and not supported"):
convert_pydantic_to_tool_spec(NodeWithCircularRef)


def test_convert_pydantic_with_circular_optional_dependency():
"""Test that the tool handles circular dependencies gracefully."""

class NodeWithCircularRef(BaseModel):
"""A node with a circular reference to itself."""

name: str = Field(description="The name of the node")
parent: Optional["NodeWithCircularRef"] = None

with pytest.raises(ValueError, match="Circular reference detected and not supported"):
convert_pydantic_to_tool_spec(NodeWithCircularRef)


def test_convert_pydantic_with_circular_optional_dependenc_not_using_optional_typing():
"""Test that the tool handles circular dependencies gracefully."""

class NodeWithCircularRef(BaseModel):
"""A node with a circular reference to itself."""

name: str = Field(description="The name of the node")
parent: "NodeWithCircularRef" = None

with pytest.raises(ValueError, match="Circular reference detected and not supported"):
convert_pydantic_to_tool_spec(NodeWithCircularRef)


def test_conversion_works_with_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501
class Family(BaseModel):
ages: List[str] = Field(default_factory=list)
names: List[str] = Field(default_factory=list)

converted_output = convert_pydantic_to_tool_spec(Family)
expected_output = {
"name": "Family",
"description": "Family structured output tool",
"inputSchema": {
"json": {
"type": "object",
"properties": {
"ages": {
"items": {"type": "string"},
"title": "Ages",
"type": ["array", "null"],
},
"names": {
"items": {"type": "string"},
"title": "Names",
"type": ["array", "null"],
},
},
"title": "Family",
}
},
}
assert converted_output == expected_output


def test_marks_fields_as_optional_for_model_w_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501
class Family(BaseModel):
ages: List[str] = Field(default_factory=list)
names: List[str] = Field(default_factory=list)

converted_output = convert_pydantic_to_tool_spec(Family)
assert "null" in converted_output["inputSchema"]["json"]["properties"]["ages"]["type"]


def test_convert_pydantic_with_custom_description():
"""Test that custom descriptions override model docstrings."""

Expand Down
Loading