diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output.py index 6f2739d88..2c5922925 100644 --- a/src/strands/tools/structured_output.py +++ b/src/strands/tools/structured_output.py @@ -27,16 +27,16 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: "properties": {}, } - # Add title if present if "title" in schema: flattened["title"] = schema["title"] - # Add description from schema if present, or use model docstring if "description" in schema and schema["description"]: flattened["description"] = schema["description"] # Process properties required_props: list[str] = [] + if "properties" not in schema and "$ref" in schema: + raise ValueError("Circular reference detected and not supported.") if "properties" in schema: required_props = [] for prop_name, prop_value in schema["properties"].items(): @@ -76,9 +76,6 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: if len(required_props) > 0: flattened["required"] = required_props - else: - raise ValueError("Circular reference detected and not supported") - return flattened @@ -325,21 +322,7 @@ def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> continue field_type = field.annotation - - # Handle Optional types - is_optional = False - if ( - field_type is not None - and hasattr(field_type, "__origin__") - and field_type.__origin__ is Union - and hasattr(field_type, "__args__") - ): - # Look for Optional[BaseModel] - for arg in field_type.__args__: - if arg is type(None): - is_optional = True - elif isinstance(arg, type) and issubclass(arg, BaseModel): - field_type = arg + is_optional = not field.is_required() # If this is a BaseModel field, expand its properties with full details if isinstance(field_type, type) and issubclass(field_type, BaseModel): diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index 97b68a34c..fe9b55334 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import List, Literal, Optional import pytest from pydantic import BaseModel, Field @@ -157,6 +157,7 @@ def test_convert_pydantic_to_tool_spec_multiple_same_type(): "user2": { "type": ["object", "null"], "description": "The second user", + "title": "UserWithPlanet", "properties": { "name": {"description": "The name of the user", "title": "Name", "type": "string"}, "age": { @@ -208,6 +209,85 @@ class NodeWithCircularRef(BaseModel): convert_pydantic_to_tool_spec(NodeWithCircularRef) +def test_convert_pydantic_with_circular_required_dependency(): + """Test that the tool handles circular dependencies gracefully.""" + + class NodeWithCircularRef(BaseModel): + """A node with a circular reference to itself.""" + + name: str = Field(description="The name of the node") + parent: "NodeWithCircularRef" + + with pytest.raises(ValueError, match="Circular reference detected and not supported"): + convert_pydantic_to_tool_spec(NodeWithCircularRef) + + +def test_convert_pydantic_with_circular_optional_dependency(): + """Test that the tool handles circular dependencies gracefully.""" + + class NodeWithCircularRef(BaseModel): + """A node with a circular reference to itself.""" + + name: str = Field(description="The name of the node") + parent: Optional["NodeWithCircularRef"] = None + + with pytest.raises(ValueError, match="Circular reference detected and not supported"): + convert_pydantic_to_tool_spec(NodeWithCircularRef) + + +def test_convert_pydantic_with_circular_optional_dependenc_not_using_optional_typing(): + """Test that the tool handles circular dependencies gracefully.""" + + class NodeWithCircularRef(BaseModel): + """A node with a circular reference to itself.""" + + name: str = Field(description="The name of the node") + parent: "NodeWithCircularRef" = None + + with pytest.raises(ValueError, match="Circular reference detected and not supported"): + convert_pydantic_to_tool_spec(NodeWithCircularRef) + + +def test_conversion_works_with_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501 + class Family(BaseModel): + ages: List[str] = Field(default_factory=list) + names: List[str] = Field(default_factory=list) + + converted_output = convert_pydantic_to_tool_spec(Family) + expected_output = { + "name": "Family", + "description": "Family structured output tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "ages": { + "items": {"type": "string"}, + "title": "Ages", + "type": ["array", "null"], + }, + "names": { + "items": {"type": "string"}, + "title": "Names", + "type": ["array", "null"], + }, + }, + "title": "Family", + } + }, + } + assert converted_output == expected_output + + +def test_marks_fields_as_optional_for_model_w_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501 + class Family(BaseModel): + ages: List[str] = Field(default_factory=list) + names: List[str] = Field(default_factory=list) + + converted_output = convert_pydantic_to_tool_spec(Family) + assert "null" in converted_output["inputSchema"]["json"]["properties"]["ages"]["type"] + + def test_convert_pydantic_with_custom_description(): """Test that custom descriptions override model docstrings."""