From e183907319c01bc37fd13f40f852e93f5cb83bc6 Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Tue, 20 May 2025 10:53:01 -0500 Subject: [PATCH 01/18] feat: add structured output support using Pydantic models - Add method to Agent class for handling structured outputs - Create structured_output.py utility for converting Pydantic models to tool specs - Improve error handling when extracting model_id from configuration - Add integration tests to validate structured output functionality --- src/strands/agent/agent.py | 82 ++++- src/strands/tools/structured_output.py | 419 +++++++++++++++++++++++++ tests-integ/test_with_output.py | 49 +++ 3 files changed, 547 insertions(+), 3 deletions(-) create mode 100644 src/strands/tools/structured_output.py create mode 100644 tests-integ/test_with_output.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 6a9489809..fd6e63c15 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -16,10 +16,11 @@ import random from concurrent.futures import ThreadPoolExecutor from threading import Thread -from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, Union from uuid import uuid4 from opentelemetry import trace +from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler @@ -328,7 +329,15 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None + # Safely get model_id if available + model_id = None + try: + config = getattr(self.model, "config", None) + if isinstance(config, dict): + model_id = config.get("model_id") + except Exception: + # Ignore any errors accessing model configuration + pass self.trace_span = self.tracer.start_agent_span( prompt=prompt, @@ -353,6 +362,73 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: # Re-raise the exception to preserve original behavior raise + def with_output(self, prompt: str, output_model: Type[BaseModel]) -> BaseModel: + """Set the output model for the agent. + + Args: + prompt: The prompt to use for the agent. + output_model: The output model to use for the agent. + + Returns: the loaded basemodel + """ + from ..tools.structured_output import convert_pydantic_to_bedrock_tool + + # Convert the pydantic basemodel to a tool spec + tool_spec = convert_pydantic_to_bedrock_tool(output_model) + + # Create a dynamic tool name to avoid collisions + tool_name = f"generate_{output_model.__name__}" + tool_spec["toolSpec"]["name"] = tool_name + + # Register the tool with the tool registry + # We need a special type of tool that just passes through the input + from ..tools.tools import PythonAgentTool + + # Create a passthrough callback that just returns the input + # with the signature expected by PythonAgentTool + from ..types.tools import ToolResult, ToolUse + + def output_callback( + tool_use: ToolUse, model: Any = None, messages: Optional[dict[str, Any]] = None, **kwargs: Any + ) -> ToolResult: + # Return the ToolResult explicitly typed + result: ToolResult = { + "toolUseId": tool_use["toolUseId"], + "status": "success", + "content": [{"text": "Output generated successfully"}], + } + return result + + # Register the tool + from ..types.tools import ToolResult, ToolUse + + tool = PythonAgentTool(tool_name=tool_name, tool_spec=tool_spec["toolSpec"], callback=output_callback) + self.tool_registry.register_tool(tool) + + # Call the model with the tool and get the response + # This will run the model and invoke the tool + self(prompt) + + # Extract the tool input from the message + # Find the first toolUse in the conversation history + tool_input = None + for message in self.messages: + if message.get("role") == "assistant": + for content in message.get("content", []): + if isinstance(content, dict) and "toolUse" in content: + tool_use = content["toolUse"] + if tool_use.get("name") == tool_name: + tool_input = tool_use.get("input", {}) + break + if tool_input: + break + + # Create the output model from the tool input and return it + if not tool_input: + raise ValueError(f"Model did not generate a valid {output_model.__name__}") + + return output_model(**tool_input) + async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -367,7 +443,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: Returns: An async iterator that yields events. Each event is a dictionary containing - information about the current state of processing, such as: + invocation about the current state of processing, such as: - data: Text content being generated - complete: Whether this is the final chunk - current_tool_use: Information about tools being executed diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output.py new file mode 100644 index 000000000..8a6d97668 --- /dev/null +++ b/src/strands/tools/structured_output.py @@ -0,0 +1,419 @@ +"""Tools for converting Pydantic models to Bedrock tools.""" + +from typing import Any, Dict, Optional, Type, Union + +from pydantic import BaseModel + + +def flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """Flattens a JSON schema by removing $defs and resolving $ref references. + + Handles required vs optional fields properly. + + Args: + schema: The JSON schema to flatten + + Returns: + Flattened JSON schema + """ + # Extract required fields list + required_fields = schema.get("required", []) + + # Initialize the flattened schema with basic properties + flattened = { + "type": schema.get("type", "object"), + "properties": {}, + } + + # Add title if present + if "title" in schema: + flattened["title"] = schema["title"] + + # Add description from schema if present, or use model docstring + if "description" in schema and schema["description"]: + flattened["description"] = schema["description"] + + # Process properties + if "properties" in schema: + required_props = [] + for prop_name, prop_value in schema["properties"].items(): + # Process the property and add to flattened properties + is_required = prop_name in required_fields + + # If the property already has nested properties (expanded), preserve them + if "properties" in prop_value: + # This is an expanded nested schema, preserve its structure + processed_prop = { + "type": prop_value.get("type", "object"), + "description": prop_value.get("description", ""), + "properties": {}, + } + + # Process each nested property + for nested_prop_name, nested_prop_value in prop_value["properties"].items(): + processed_prop["properties"][nested_prop_name] = nested_prop_value + + # Copy required fields if present + if "required" in prop_value: + processed_prop["required"] = prop_value["required"] + else: + # Process as normal + processed_prop = process_property(prop_value, schema.get("$defs", {}), is_required) + + flattened["properties"][prop_name] = processed_prop + + # Track which properties are actually required after processing + if is_required and "null" not in str(processed_prop.get("type", "")): + required_props.append(prop_name) + + # Add required fields if any (only those that are truly required after processing) + if required_props: + flattened["required"] = required_props + + return flattened + + +def process_property( + prop: Dict[str, Any], + defs: Dict[str, Any], + is_required: bool = False, + fully_expand: bool = True, +) -> Dict[str, Any]: + """Process a property in a schema, resolving any references. + + Args: + prop: The property to process + defs: The definitions dictionary for resolving references + is_required: Whether this property is required + fully_expand: Whether to fully expand nested properties + + Returns: + Processed property + """ + result = {} + is_nullable = False + + # Handle anyOf for optional fields (like Optional[Type]) + if "anyOf" in prop: + # Check if this is an Optional[...] case (one null, one type) + null_type = False + non_null_type = None + + for option in prop["anyOf"]: + if option.get("type") == "null": + null_type = True + is_nullable = True + elif "$ref" in option: + ref_path = option["$ref"].split("/")[-1] + if ref_path in defs: + non_null_type = process_schema_object(defs[ref_path], defs, fully_expand) + else: + non_null_type = option + + if null_type and non_null_type: + # For Optional fields, we mark as nullable but copy all properties from the non-null option + result = non_null_type.copy() if isinstance(non_null_type, dict) else {} + + # For type, ensure it includes "null" + if "type" in result and isinstance(result["type"], str): + result["type"] = [result["type"], "null"] + elif "type" in result and isinstance(result["type"], list) and "null" not in result["type"]: + result["type"].append("null") + elif "type" not in result: + # Default to object type if not specified + result["type"] = ["object", "null"] + + # Copy description if available in the property + if "description" in prop: + result["description"] = prop["description"] + + return result + + # Handle direct references + elif "$ref" in prop: + # Resolve reference + ref_path = prop["$ref"].split("/")[-1] + if ref_path in defs: + ref_dict = defs[ref_path] + # Process the referenced object to get a complete schema + result = process_schema_object(ref_dict, defs, fully_expand) + + # Copy description if available in the property (overrides ref description) + if "description" in prop: + result["description"] = prop["description"] + + # If not required, mark as nullable + if not is_required: + if "type" in result and isinstance(result["type"], str): + result["type"] = [result["type"], "null"] + elif "type" in result and isinstance(result["type"], list) and "null" not in result["type"]: + result["type"].append("null") + + return result + + # For regular fields, copy all properties + for key, value in prop.items(): + if key not in ["$ref", "anyOf"]: + if isinstance(value, dict): + result[key] = process_nested_dict(value, defs) + elif key == "type" and not is_required and not is_nullable: + # For non-required fields, ensure type is a list with "null" + if isinstance(value, str): + result[key] = [value, "null"] + elif isinstance(value, list) and "null" not in value: + result[key] = value + ["null"] + else: + result[key] = value + else: + result[key] = value + + return result + + +def process_schema_object( + schema_obj: Dict[str, Any], defs: Dict[str, Any], fully_expand: bool = True +) -> Dict[str, Any]: + """Process a schema object, typically from $defs, to resolve all nested properties. + + Args: + schema_obj: The schema object to process + defs: The definitions dictionary for resolving references + fully_expand: Whether to fully expand nested properties + + Returns: + Processed schema object with all properties resolved + """ + result = {} + + # Copy basic attributes + for key, value in schema_obj.items(): + if key != "properties" and key != "required" and key != "$defs": + result[key] = value + + # Process properties if present + if "properties" in schema_obj: + result["properties"] = {} + required_props = [] + + # Get required fields list + required_fields = schema_obj.get("required", []) + + for prop_name, prop_value in schema_obj["properties"].items(): + # Process each property + is_required = prop_name in required_fields + processed = process_property(prop_value, defs, is_required, fully_expand) + result["properties"][prop_name] = processed + + # Track which properties are actually required after processing + if is_required and "null" not in str(processed.get("type", "")): + required_props.append(prop_name) + + # Add required fields if any + if required_props: + result["required"] = required_props + + return result + + +def process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: + """Recursively processes nested dictionaries and resolves $ref references. + + Args: + d: The dictionary to process + defs: The definitions dictionary for resolving references + + Returns: + Processed dictionary + """ + result: Dict[str, Any] = {} + + # Handle direct reference + if "$ref" in d: + ref_path = d["$ref"].split("/")[-1] + if ref_path in defs: + ref_dict = defs[ref_path] + # Recursively process the referenced object + return process_schema_object(ref_dict, defs) + + # Process each key-value pair + for key, value in d.items(): + if key == "$ref": + # Already handled above + continue + elif isinstance(value, dict): + result[key] = process_nested_dict(value, defs) + elif isinstance(value, list): + # Process lists (like for enum values) + result[key] = [process_nested_dict(item, defs) if isinstance(item, dict) else item for item in value] + else: + result[key] = value + + return result + + +def convert_pydantic_to_bedrock_tool( + model: Type[BaseModel], + description: Optional[str] = None, +) -> Dict[str, Any]: + """Converts a Pydantic model to a tool description for the Amazon Bedrock Converse API. + + Handles optional vs. required fields, resolves $refs, and uses docstrings. + + Args: + model: The Pydantic model class to convert + description: Optional description of the tool's purpose + disable_normalize: If True, skip normalize_schema to preserve nested structure + + Returns: + Dict containing the Bedrock tool specification + """ + name = model.__name__ + + # Get the JSON schema + input_schema = model.model_json_schema() + + # Get model docstring for description if not provided + model_description = description + if not model_description and model.__doc__: + model_description = model.__doc__.strip() + + # Process all referenced models to ensure proper docstrings + # This step is important for gathering descriptions from referenced models + process_referenced_models(input_schema, model) + + # Now, let's fully expand the nested models with all their properties + expand_nested_properties(input_schema, model) + + # Flatten the schema + flattened_schema = flatten_schema(input_schema) + + # Apply normalize_schema to ensure it's in the correct format + # Unless disabled to preserve nested structure + final_schema = flattened_schema + + # Construct the tool specification + tool = { + "toolSpec": { + "name": name, + "description": model_description or f"{name} Tool", + "inputSchema": {"json": final_schema}, + } + } + + return tool + + +def expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + """Expand the properties of nested models in the schema to include their full structure. + + This updates the schema in place. + + Args: + schema: The JSON schema to process + model: The Pydantic model class + """ + # First, process the properties at this level + if "properties" not in schema: + return + + # Create a modified copy of the properties to avoid modifying while iterating + for prop_name, prop_info in list(schema["properties"].items()): + field = model.model_fields.get(prop_name) + if not field: + continue + + field_type = field.annotation + + # Handle Optional types + is_optional = False + if ( + field_type is not None + and hasattr(field_type, "__origin__") + and field_type.__origin__ is Union + and hasattr(field_type, "__args__") + ): + # Look for Optional[BaseModel] + for arg in field_type.__args__: + if arg is type(None): + is_optional = True + elif isinstance(arg, type) and issubclass(arg, BaseModel): + field_type = arg + + # If this is a BaseModel field, expand its properties with full details + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Get the nested model's schema with all its properties + nested_model_schema = field_type.model_json_schema() + + # Create a properly expanded nested object + expanded_object = { + "type": ["object", "null"] if is_optional else "object", + "description": prop_info.get("description", field.description or f"The {prop_name}"), + "properties": {}, + } + + # Copy all properties from the nested schema + if "properties" in nested_model_schema: + expanded_object["properties"] = nested_model_schema["properties"] + + # Copy required fields + if "required" in nested_model_schema: + expanded_object["required"] = nested_model_schema["required"] + + # Replace the original property with this expanded version + schema["properties"][prop_name] = expanded_object + + +def process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + """Process referenced models to ensure their docstrings are included. + + This updates the schema in place. + + Args: + schema: The JSON schema to process + model: The Pydantic model class + """ + # Process $defs to add docstrings from the referenced models + if "$defs" in schema: + # Look through model fields to find referenced models + for _, field in model.model_fields.items(): + field_type = field.annotation + + # Handle Optional types - with null checks + if field_type is not None and hasattr(field_type, "__origin__"): + origin = field_type.__origin__ + if origin is Union and hasattr(field_type, "__args__"): + # Find the non-None type in the Union (for Optional fields) + for arg in field_type.__args__: + if arg is not type(None): + field_type = arg + break + + # Check if this is a BaseModel subclass + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Update $defs with this model's information + ref_name = field_type.__name__ + if ref_name in schema.get("$defs", {}): + ref_def = schema["$defs"][ref_name] + + # Add docstring as description if available + if field_type.__doc__ and not ref_def.get("description"): + ref_def["description"] = field_type.__doc__.strip() + + # Recursively process properties in the referenced model + process_properties(ref_def, field_type) + + +def process_properties(schema_def: Dict[str, Any], model: Type[BaseModel]) -> None: + """Process properties in a schema definition to add descriptions from field metadata. + + Args: + schema_def: The schema definition to update + model: The model class that defines the schema + """ + if "properties" in schema_def: + for prop_name, prop_info in schema_def["properties"].items(): + field = model.model_fields.get(prop_name) + + # Add field description if available and not already set + if field and field.description and not prop_info.get("description"): + prop_info["description"] = field.description diff --git a/tests-integ/test_with_output.py b/tests-integ/test_with_output.py new file mode 100644 index 000000000..606eafc3b --- /dev/null +++ b/tests-integ/test_with_output.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +""" +Test script for function-based tools +""" + +# import logging +from typing import Optional + +from pydantic import BaseModel, Field + +from strands import Agent + +# logging.getLogger("strands").setLevel(logging.DEBUG) +# logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) + +prompt = "Jane Smith is 28 years old and lives at 123 Main St, Boston, MA 02108." + + +class Person(BaseModel): + """A simple model for testing structured data extraction.""" + + name: str = Field(description="The person's full name") + age: int = Field(description="The person's age in years", ge=18) + + +class Address(BaseModel): + """Address information for testing nested data structures.""" + + street: str = Field(description="Street address") + city: str = Field(description="City name") + zip_code: str = Field(description="Postal code", alias="zipCode") + + +class PersonWithAddress(BaseModel): + """A person with an address.""" + + person: Person = Field(description="The person's information") + address: Optional[Address] = Field(description="The person's address") + + +# Initialize agent with function tools +print("\n===== Input Prompt =====\n") +print(prompt) + +print("\n===== Running the agent =====\n") +result = Agent().with_output(prompt, PersonWithAddress) + +print("\n===== Output Result =====\n") +print(result.model_dump_json(indent=2)) From 03942ae2eea6a031ce97be93a74a10400924967d Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Tue, 20 May 2025 11:11:06 -0500 Subject: [PATCH 02/18] fix: import cleanups and unused vars --- src/strands/agent/agent.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index fd6e63c15..5df6ab4e3 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -30,11 +30,12 @@ from ..telemetry.tracer import get_tracer from ..tools.registry import ToolRegistry from ..tools.thread_pool_executor import ThreadPoolExecutorWrapper +from ..tools.tools import PythonAgentTool from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException from ..types.models import Model -from ..types.tools import ToolConfig +from ..types.tools import ToolConfig, ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult from .conversation_manager import ( @@ -382,14 +383,15 @@ def with_output(self, prompt: str, output_model: Type[BaseModel]) -> BaseModel: # Register the tool with the tool registry # We need a special type of tool that just passes through the input - from ..tools.tools import PythonAgentTool # Create a passthrough callback that just returns the input # with the signature expected by PythonAgentTool - from ..types.tools import ToolResult, ToolUse def output_callback( - tool_use: ToolUse, model: Any = None, messages: Optional[dict[str, Any]] = None, **kwargs: Any + tool_use: ToolUse, + model: Any = None, # noqa: ANN401 + messages: Optional[dict[str, Any]] = None, # noqa: ANN401 + **kwargs: Any, ) -> ToolResult: # Return the ToolResult explicitly typed result: ToolResult = { @@ -399,9 +401,6 @@ def output_callback( } return result - # Register the tool - from ..types.tools import ToolResult, ToolUse - tool = PythonAgentTool(tool_name=tool_name, tool_spec=tool_spec["toolSpec"], callback=output_callback) self.tool_registry.register_tool(tool) From 510def66a133727a84ef05ab33c0981b95689bb0 Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Thu, 5 Jun 2025 09:02:28 -0500 Subject: [PATCH 03/18] feat: wip adding `structured_output` methods --- src/strands/agent/agent.py | 139 +++++++++--------- src/strands/models/anthropic.py | 13 +- src/strands/models/bedrock.py | 13 +- src/strands/models/llamaapi.py | 13 +- src/strands/models/ollama.py | 13 +- src/strands/types/models/model.py | 20 ++- src/strands/types/models/openai.py | 13 +- ...th_output.py => test_structured_output.py} | 0 tests/strands/types/models/test_model.py | 17 +++ 9 files changed, 167 insertions(+), 74 deletions(-) rename tests-integ/{test_with_output.py => test_structured_output.py} (100%) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index bc2b5c71e..9f4c994c5 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -16,11 +16,10 @@ import random from concurrent.futures import ThreadPoolExecutor from threading import Thread -from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union from uuid import uuid4 from opentelemetry import trace -from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler @@ -30,12 +29,11 @@ from ..telemetry.tracer import get_tracer from ..tools.registry import ToolRegistry from ..tools.thread_pool_executor import ThreadPoolExecutorWrapper -from ..tools.tools import PythonAgentTool from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException from ..types.models import Model -from ..types.tools import ToolConfig, ToolResult, ToolUse +from ..types.tools import ToolConfig from ..types.traces import AttributeValue from .agent_result import AgentResult from .conversation_manager import ( @@ -368,70 +366,75 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: # Re-raise the exception to preserve original behavior raise - def with_output(self, prompt: str, output_model: Type[BaseModel]) -> BaseModel: - """Set the output model for the agent. - - Args: - prompt: The prompt to use for the agent. - output_model: The output model to use for the agent. - - Returns: the loaded basemodel - """ - from ..tools.structured_output import convert_pydantic_to_bedrock_tool - - # Convert the pydantic basemodel to a tool spec - tool_spec = convert_pydantic_to_bedrock_tool(output_model) - - # Create a dynamic tool name to avoid collisions - tool_name = f"generate_{output_model.__name__}" - tool_spec["toolSpec"]["name"] = tool_name - - # Register the tool with the tool registry - # We need a special type of tool that just passes through the input - - # Create a passthrough callback that just returns the input - # with the signature expected by PythonAgentTool - - def output_callback( - tool_use: ToolUse, - model: Any = None, # noqa: ANN401 - messages: Optional[dict[str, Any]] = None, # noqa: ANN401 - **kwargs: Any, - ) -> ToolResult: - # Return the ToolResult explicitly typed - result: ToolResult = { - "toolUseId": tool_use["toolUseId"], - "status": "success", - "content": [{"text": "Output generated successfully"}], - } - return result - - tool = PythonAgentTool(tool_name=tool_name, tool_spec=tool_spec["toolSpec"], callback=output_callback) - self.tool_registry.register_tool(tool) - - # Call the model with the tool and get the response - # This will run the model and invoke the tool - self(prompt) - - # Extract the tool input from the message - # Find the first toolUse in the conversation history - tool_input = None - for message in self.messages: - if message.get("role") == "assistant": - for content in message.get("content", []): - if isinstance(content, dict) and "toolUse" in content: - tool_use = content["toolUse"] - if tool_use.get("name") == tool_name: - tool_input = tool_use.get("input", {}) - break - if tool_input: - break - - # Create the output model from the tool input and return it - if not tool_input: - raise ValueError(f"Model did not generate a valid {output_model.__name__}") - - return output_model(**tool_input) + # TODO: implement + # def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str]) -> BaseModel: + # """Get structured output from the Agent's current context. + + # Args: + # output_model(Type[BaseModel]): The output model the agent will use when responding. + # prompt(Optional[str]): The prompt to use for the agent. + + # Returns: + # The loaded basemodel. + + # Raises: + # ValidationException: The response format from the large language model does not match the output_model + # """ + # from ..tools.structured_output import convert_pydantic_to_bedrock_tool + + # # Convert the pydantic basemodel to a tool spec + # tool_spec = convert_pydantic_to_bedrock_tool(output_model) + + # # Create a dynamic tool name to avoid collisions + # tool_name = f"generate_{output_model.__name__}" + # tool_spec["toolSpec"]["name"] = tool_name + + # # Register the tool with the tool registry + # # We need a special type of tool that just passes through the input + + # # Create a passthrough callback that just returns the input + # # with the signature expected by PythonAgentTool + + # def output_callback( + # tool_use: ToolUse, + # model: Any = None, # noqa: ANN401 + # messages: Optional[dict[str, Any]] = None, # noqa: ANN401 + # **kwargs: Any, + # ) -> ToolResult: + # # Return the ToolResult explicitly typed + # result: ToolResult = { + # "toolUseId": tool_use["toolUseId"], + # "status": "success", + # "content": [{"text": "Output generated successfully"}], + # } + # return result + + # tool = PythonAgentTool(tool_name=tool_name, tool_spec=tool_spec["toolSpec"], callback=output_callback) + # self.tool_registry.register_tool(tool) + + # # Call the model with the tool and get the response + # # This will run the model and invoke the tool + # self(prompt) + + # # Extract the tool input from the message + # # Find the first toolUse in the conversation history + # tool_input = None + # for message in self.messages: + # if message.get("role") == "assistant": + # for content in message.get("content", []): + # if isinstance(content, dict) and "toolUse" in content: + # tool_use = content["toolUse"] + # if tool_use.get("name") == tool_name: + # tool_input = tool_use.get("input", {}) + # break + # if tool_input: + # break + + # # Create the output model from the tool input and return it + # if not tool_input: + # raise ValueError(f"Model did not generate a valid {output_model.__name__}") + + # return output_model(**tool_input) async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 57394e2c1..9b301f925 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,9 +7,10 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, TypedDict, cast +from typing import Any, Iterable, Optional, Type, TypedDict, cast import anthropic +from pydantic import BaseModel from typing_extensions import Required, Unpack, override from ..types.content import ContentBlock, Messages @@ -369,3 +370,13 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: raise ContextWindowOverflowException(str(error)) from error raise error + + @override + def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + """ + return output_model() diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9bbcca7d0..4def93000 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -6,11 +6,12 @@ import json import logging import os -from typing import Any, Iterable, List, Literal, Optional, cast +from typing import Any, Iterable, List, Literal, Optional, Type, cast import boto3 from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override from ..types.content import Messages @@ -477,3 +478,13 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: return self._find_detected_and_blocked_policy(item) # Otherwise return False return False + + @override + def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + """ + return output_model() diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 583db2f26..668aaa283 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,10 +8,11 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, cast +from typing import Any, Iterable, Optional, Type, cast import llama_api_client from llama_api_client import LlamaAPIClient +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages @@ -384,3 +385,13 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: # we may have a metrics event here if metrics_event: yield {"chunk_type": "metadata", "data": metrics_event} + + @override + def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + """ + return output_model() diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 7ed12216a..28d9b4762 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,9 +5,10 @@ import json import logging -from typing import Any, Iterable, Optional, cast +from typing import Any, Iterable, Optional, Type, cast from ollama import Client as OllamaClient +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages @@ -310,3 +311,13 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "content_stop", "data_type": "text"} yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else event.done_reason} yield {"chunk_type": "metadata", "data": event} + + @override + def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + """ + return output_model() diff --git a/src/strands/types/models/model.py b/src/strands/types/models/model.py index 23e746023..885a0d112 100644 --- a/src/strands/types/models/model.py +++ b/src/strands/types/models/model.py @@ -2,7 +2,9 @@ import abc import logging -from typing import Any, Iterable, Optional +from typing import Any, Iterable, Optional, Type + +from pydantic import BaseModel from ..content import Messages from ..streaming import StreamEvent @@ -38,6 +40,22 @@ def get_config(self) -> Any: """ pass + @abc.abstractmethod + # pragma: no cover + def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + + Returns: + The structured output as a serialized instance of the output model. + + Raises: + ValidationException: The response format from the model does not match the output_model + """ + @abc.abstractmethod # pragma: no cover def format_request( diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 96f758d5c..ff6dff5fd 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -11,8 +11,9 @@ import json import logging import mimetypes -from typing import Any, Optional, cast +from typing import Any, Optional, Type, cast +from pydantic import BaseModel from typing_extensions import override from ..content import ContentBlock, Messages @@ -262,3 +263,13 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + + @override + def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + """ + return output_model() diff --git a/tests-integ/test_with_output.py b/tests-integ/test_structured_output.py similarity index 100% rename from tests-integ/test_with_output.py rename to tests-integ/test_structured_output.py diff --git a/tests/strands/types/models/test_model.py b/tests/strands/types/models/test_model.py index f2797fe5b..03690733a 100644 --- a/tests/strands/types/models/test_model.py +++ b/tests/strands/types/models/test_model.py @@ -1,8 +1,16 @@ +from typing import Type + import pytest +from pydantic import BaseModel from strands.types.models import Model as SAModel +class Person(BaseModel): + name: str + age: int + + class TestModel(SAModel): def update_config(self, **model_config): return model_config @@ -10,6 +18,9 @@ def update_config(self, **model_config): def get_config(self): return + def structured_output(self, output_model: Type[BaseModel]) -> BaseModel: + return output_model(name="test", age=20) + def format_request(self, messages, tool_specs, system_prompt): return { "messages": messages, @@ -79,3 +90,9 @@ def test_converse(model, messages, tool_specs, system_prompt): }, ] assert tru_events == exp_events + + +def test_structured_output(model): + response = model.structured_output(Person) + + assert response == Person(name="test", age=20) From c3ffbceccb064702d1e69cc734c6732a2ea862a3 Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Thu, 5 Jun 2025 13:53:34 -0500 Subject: [PATCH 04/18] feat: wip added structured output to bedrock and anthropic --- src/strands/agent/agent.py | 18 ++++++++++++++- src/strands/event_loop/event_loop.py | 2 +- src/strands/models/anthropic.py | 31 +++++++++++++++++++++----- src/strands/models/bedrock.py | 25 +++++++++++++++++++-- src/strands/models/llamaapi.py | 4 ++-- src/strands/models/ollama.py | 4 ++-- src/strands/tools/__init__.py | 2 ++ src/strands/tools/structured_output.py | 23 ++++++++----------- src/strands/types/models/model.py | 3 ++- src/strands/types/models/openai.py | 4 ++-- tests/strands/models/test_anthropic.py | 7 +++--- 11 files changed, 90 insertions(+), 33 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 9f4c994c5..d780ef89d 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -16,10 +16,11 @@ import random from concurrent.futures import ThreadPoolExecutor from threading import Thread -from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, Union from uuid import uuid4 from opentelemetry import trace +from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler @@ -366,6 +367,21 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: # Re-raise the exception to preserve original behavior raise + def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + """Get structured output from the Agent's current context. + + Args: + output_model(Type[BaseModel]): The output model the agent will use when responding. + prompt(Optional[str]): The prompt to use for the agent. + """ + messages = self.messages + # add the prompt as the last message + if prompt: + messages.append({"role": "user", "content": [{"text": prompt}]}) + + # get the structured output from the model + return self.model.structured_output(output_model, messages) + # TODO: implement # def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str]) -> BaseModel: # """Get structured output from the Agent's current context. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 23d7bd0f3..3819ae4f5 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -13,6 +13,7 @@ from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from ..event_loop.streaming import stream_messages from ..telemetry.metrics import EventLoopMetrics, Trace from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools @@ -24,7 +25,6 @@ from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse from .error_handler import handle_input_too_long_error, handle_throttling_error from .message_processor import clean_orphaned_empty_tool_uses -from .streaming import stream_messages logger = logging.getLogger(__name__) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 9b301f925..e53f2277c 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -13,6 +13,9 @@ from pydantic import BaseModel from typing_extensions import Required, Unpack, override +from ..event_loop.streaming import process_stream +from ..handlers.callback_handler import PrintingCallbackHandler +from ..tools import convert_pydantic_to_bedrock_tool from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -340,7 +343,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: raise RuntimeError(f"event_type=<{event['type']} | unknown type") @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + def stream(self, request: dict[str, Any]) -> Iterable[StreamEvent]: """Send the request to the Anthropic model and get the streaming response. Args: @@ -357,10 +360,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: with self.client.messages.stream(**request) as stream: for event in stream: if event.type in AnthropicModel.EVENT_TYPES: - yield event.dict() + yield self.format_chunk(event.dict()) usage = event.message.usage # type: ignore - yield {"type": "metadata", "usage": usage.dict()} + yield self.format_chunk({"type": "metadata", "usage": usage.dict()}) except anthropic.RateLimitError as error: raise ModelThrottledException(str(error)) from error @@ -372,11 +375,29 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: raise error @override - def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Optional[str]): The prompt to use for the agent. Defaults to None. """ - return output_model() + tool_spec = convert_pydantic_to_bedrock_tool(output_model) + + response = self.stream(self.format_request(messages=prompt, tool_specs=[tool_spec])) + # process the stream and get the tool use input + results = process_stream(response, callback_handler=PrintingCallbackHandler(), messages=prompt) + + # if not stop reason toolUse + if ( + results[0] != "tool_use" + or "toolUse" not in results[1]["content"][0] + or results[1]["content"][0]["toolUse"]["name"] != tool_spec["name"] + ): + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + + # get the tool use input + tool_use_input = results[1]["content"][0]["toolUse"]["input"] + + # return the tool use input as a json object + return output_model(**json.loads(tool_use_input)) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4def93000..4f6be374e 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -14,6 +14,9 @@ from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override +from ..event_loop.streaming import process_stream +from ..handlers.callback_handler import PrintingCallbackHandler +from ..tools import convert_pydantic_to_bedrock_tool from ..types.content import Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -480,11 +483,29 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: return False @override - def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Optional[str]): The prompt to use for the agent. Defaults to None. """ - return output_model() + tool_spec = convert_pydantic_to_bedrock_tool(output_model) + + response = self.stream(self.format_request(messages=prompt, tool_specs=[tool_spec])) + # process the stream and get the tool use input + results = process_stream(response, callback_handler=PrintingCallbackHandler(), messages=prompt) + + # if not stop reason toolUse + if ( + results[0] != "tool_use" + or "toolUse" not in results[1]["content"][0] + or results[1]["content"][0]["toolUse"]["name"] != tool_spec["name"] + ): + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + + # get the tool use input + tool_use_input = results[1]["content"][0]["toolUse"]["input"] + + # return the tool use input as a json object + return output_model(**json.loads(tool_use_input)) diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 668aaa283..8ff072c12 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -387,11 +387,11 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "metadata", "data": metrics_event} @override - def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. - prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + prompt(Messages): The prompt to use for the agent. """ return output_model() diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 28d9b4762..7e275798d 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -313,11 +313,11 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "metadata", "data": event} @override - def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. - prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + prompt(Messages): The prompt to use for the agent. """ return output_model() diff --git a/src/strands/tools/__init__.py b/src/strands/tools/__init__.py index b3ee15669..97b1e8678 100644 --- a/src/strands/tools/__init__.py +++ b/src/strands/tools/__init__.py @@ -4,6 +4,7 @@ """ from .decorator import tool +from .structured_output import convert_pydantic_to_bedrock_tool from .thread_pool_executor import ThreadPoolExecutorWrapper from .tools import FunctionTool, InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec @@ -15,4 +16,5 @@ "normalize_schema", "normalize_tool_spec", "ThreadPoolExecutorWrapper", + "convert_pydantic_to_bedrock_tool", ] diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output.py index 8a6d97668..35cef868c 100644 --- a/src/strands/tools/structured_output.py +++ b/src/strands/tools/structured_output.py @@ -4,6 +4,8 @@ from pydantic import BaseModel +from ..types.tools import ToolSpec + def flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: """Flattens a JSON schema by removing $defs and resolving $ref references. @@ -254,7 +256,7 @@ def process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, An def convert_pydantic_to_bedrock_tool( model: Type[BaseModel], description: Optional[str] = None, -) -> Dict[str, Any]: +) -> ToolSpec: """Converts a Pydantic model to a tool description for the Amazon Bedrock Converse API. Handles optional vs. required fields, resolves $refs, and uses docstrings. @@ -262,10 +264,9 @@ def convert_pydantic_to_bedrock_tool( Args: model: The Pydantic model class to convert description: Optional description of the tool's purpose - disable_normalize: If True, skip normalize_schema to preserve nested structure Returns: - Dict containing the Bedrock tool specification + ToolSpec: Dict containing the Bedrock tool specification """ name = model.__name__ @@ -287,20 +288,14 @@ def convert_pydantic_to_bedrock_tool( # Flatten the schema flattened_schema = flatten_schema(input_schema) - # Apply normalize_schema to ensure it's in the correct format - # Unless disabled to preserve nested structure final_schema = flattened_schema # Construct the tool specification - tool = { - "toolSpec": { - "name": name, - "description": model_description or f"{name} Tool", - "inputSchema": {"json": final_schema}, - } - } - - return tool + return ToolSpec( + name=name, + description=model_description or f"{name} structured output tool", + inputSchema={"json": final_schema}, + ) def expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: diff --git a/src/strands/types/models/model.py b/src/strands/types/models/model.py index 885a0d112..4f962ae92 100644 --- a/src/strands/types/models/model.py +++ b/src/strands/types/models/model.py @@ -42,7 +42,7 @@ def get_config(self) -> Any: @abc.abstractmethod # pragma: no cover - def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: """Get structured output from the model. Args: @@ -55,6 +55,7 @@ def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] Raises: ValidationException: The response format from the model does not match the output_model """ + pass @abc.abstractmethod # pragma: no cover diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index ff6dff5fd..49f27b024 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -265,11 +265,11 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") @override - def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. - prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + prompt(Messages): The prompt to use for the agent. """ return output_model() diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 9421650e8..f88c9e3fa 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -618,7 +618,8 @@ def test_stream(anthropic_client, model): mock_event_1 = unittest.mock.Mock(type="message_start", dict=lambda: {"type": "message_start"}) mock_event_2 = unittest.mock.Mock(type="unknown") mock_event_3 = unittest.mock.Mock( - type="metadata", message=unittest.mock.Mock(usage=unittest.mock.Mock(dict=lambda: {"input_tokens": 1})) + type="metadata", + message=unittest.mock.Mock(usage=unittest.mock.Mock(dict=lambda: {"input_tokens": 1, "output_tokens": 2})), ) mock_stream = unittest.mock.MagicMock() @@ -630,8 +631,8 @@ def test_stream(anthropic_client, model): tru_events = list(response) exp_events = [ - {"type": "message_start"}, - {"type": "metadata", "usage": {"input_tokens": 1}}, + {"messageStart": {"role": "assistant"}}, + {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, ] assert tru_events == exp_events From dce0a81c70b34a5480269098aea881d6fb2ecbe4 Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Sat, 7 Jun 2025 10:12:14 -0500 Subject: [PATCH 05/18] feat: litellm structured output and some integ tests --- src/strands/agent/agent.py | 3 +++ src/strands/models/bedrock.py | 26 +++++++++++++--------- src/strands/models/litellm.py | 37 +++++++++++++++++++++++++++++-- tests-integ/test_model_bedrock.py | 31 ++++++++++++++++++++++++++ tests-integ/test_model_litellm.py | 14 ++++++++++++ 5 files changed, 99 insertions(+), 12 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index d780ef89d..1816ba9f0 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -375,6 +375,9 @@ def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] prompt(Optional[str]): The prompt to use for the agent. """ messages = self.messages + if not messages and not prompt: + raise ValueError("No conversation history or prompt provided") + # add the prompt as the last message if prompt: messages.append({"role": "user", "content": [{"text": prompt}]}) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4f6be374e..928611e1f 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -496,16 +496,22 @@ def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> # process the stream and get the tool use input results = process_stream(response, callback_handler=PrintingCallbackHandler(), messages=prompt) - # if not stop reason toolUse - if ( - results[0] != "tool_use" - or "toolUse" not in results[1]["content"][0] - or results[1]["content"][0]["toolUse"]["name"] != tool_spec["name"] - ): + stop_reason, messages, _, _, _ = results + + if stop_reason != "tool_use": raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") - # get the tool use input - tool_use_input = results[1]["content"][0]["toolUse"]["input"] + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") - # return the tool use input as a json object - return output_model(**json.loads(tool_use_input)) + return output_model(**output_response) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 62f16d319..bc995b19a 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -3,13 +3,15 @@ - Docs: https://docs.litellm.ai/ """ +import json import logging -from typing import Any, Optional, TypedDict, cast +from typing import Any, Optional, Type, TypedDict, cast import litellm +from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock +from ..types.content import ContentBlock, Messages from .openai import OpenAIModel logger = logging.getLogger(__name__) @@ -97,3 +99,34 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] } return super().format_request_message_content(content) + + @override + def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + """ + # The LiteLLM `Client` inits with Chat(). + # Chat() inits with self.completions + # completions() has a method `create()` which wraps the real completion API of Litellm + response = self.client.chat.completions.create( + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + # Find the first choice with tool_calls + for choice in response.choices: + if choice.finish_reason == "tool_calls": + try: + # Parse the tool call content as JSON + tool_call_data = json.loads(choice.message.content) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e + + # If no tool_calls found, raise an error + raise ValueError("No tool_calls found in response") diff --git a/tests-integ/test_model_bedrock.py b/tests-integ/test_model_bedrock.py index a6a29aa94..5378a9b20 100644 --- a/tests-integ/test_model_bedrock.py +++ b/tests-integ/test_model_bedrock.py @@ -1,4 +1,5 @@ import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -118,3 +119,33 @@ def calculator(expression: str) -> float: agent("What is 123 + 456?") assert tool_was_called + + +def test_structured_output_streaming(streaming_model): + """Test structured output with streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" + + +def test_structured_output_non_streaming(non_streaming_model): + """Test structured output with non-streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=non_streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_litellm.py b/tests-integ/test_model_litellm.py index f1afb61fa..94121f3bc 100644 --- a/tests-integ/test_model_litellm.py +++ b/tests-integ/test_model_litellm.py @@ -1,4 +1,5 @@ import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -33,3 +34,16 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent_no_tools = Agent(model=model) + + result = agent_no_tools.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" From 5262dfcb3222517bf6b445008ec229c869bce0af Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Sun, 8 Jun 2025 16:12:48 -0500 Subject: [PATCH 06/18] feat: all structured outputs working, tbd llama api --- src/strands/models/anthropic.py | 34 ++++++++++-------- src/strands/models/bedrock.py | 2 +- src/strands/models/litellm.py | 2 ++ src/strands/models/llamaapi.py | 5 ++- src/strands/models/ollama.py | 12 ++++++- src/strands/models/openai.py | 31 +++++++++++++++- tests-integ/test_model_anthropic.py | 14 ++++++++ tests-integ/test_model_ollama.py | 33 +++++++++++++++++ tests-integ/test_model_openai.py | 16 +++++++++ tests-integ/test_structured_output.py | 49 -------------------------- tests/strands/models/test_anthropic.py | 26 +++++++++++--- 11 files changed, 152 insertions(+), 72 deletions(-) create mode 100644 tests-integ/test_model_ollama.py delete mode 100644 tests-integ/test_structured_output.py diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index e53f2277c..8caa407c4 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -343,7 +343,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: raise RuntimeError(f"event_type=<{event['type']} | unknown type") @override - def stream(self, request: dict[str, Any]) -> Iterable[StreamEvent]: + def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: """Send the request to the Anthropic model and get the streaming response. Args: @@ -360,10 +360,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[StreamEvent]: with self.client.messages.stream(**request) as stream: for event in stream: if event.type in AnthropicModel.EVENT_TYPES: - yield self.format_chunk(event.dict()) + yield event.model_dump() usage = event.message.usage # type: ignore - yield self.format_chunk({"type": "metadata", "usage": usage.dict()}) + yield {"type": "metadata", "usage": usage.model_dump()} except anthropic.RateLimitError as error: raise ModelThrottledException(str(error)) from error @@ -384,20 +384,26 @@ def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> """ tool_spec = convert_pydantic_to_bedrock_tool(output_model) - response = self.stream(self.format_request(messages=prompt, tool_specs=[tool_spec])) + response = self.converse(messages=prompt, tool_specs=[tool_spec]) # process the stream and get the tool use input results = process_stream(response, callback_handler=PrintingCallbackHandler(), messages=prompt) - # if not stop reason toolUse - if ( - results[0] != "tool_use" - or "toolUse" not in results[1]["content"][0] - or results[1]["content"][0]["toolUse"]["name"] != tool_spec["name"] - ): + stop_reason, messages, _, _, _ = results + + if stop_reason != "tool_use": raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") - # get the tool use input - tool_use_input = results[1]["content"][0]["toolUse"]["input"] + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") - # return the tool use input as a json object - return output_model(**json.loads(tool_use_input)) + return output_model(**output_response) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 928611e1f..f2060e1f1 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -492,7 +492,7 @@ def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> """ tool_spec = convert_pydantic_to_bedrock_tool(output_model) - response = self.stream(self.format_request(messages=prompt, tool_specs=[tool_spec])) + response = self.converse(messages=prompt, tool_specs=[tool_spec]) # process the stream and get the tool use input results = process_stream(response, callback_handler=PrintingCallbackHandler(), messages=prompt) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index bc995b19a..073a7b612 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -107,6 +107,8 @@ def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + + TODO: add checks: https://docs.litellm.ai/docs/completion/json_mode#check-model-support """ # The LiteLLM `Client` inits with Chat(). # Chat() inits with self.completions diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 8ff072c12..76b72f6d5 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -393,5 +393,8 @@ def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Messages): The prompt to use for the agent. + + Raises: + NotImplementedError: Structured output is not currently supported for LlamaAPI models. """ - return output_model() + raise NotImplementedError("Structured output is not currently supported for LlamaAPI models") diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 7e275798d..da8331baa 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -320,4 +320,14 @@ def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> output_model(Type[BaseModel]): The output model to use for the agent. prompt(Messages): The prompt to use for the agent. """ - return output_model() + formatted_request = self.format_request(messages=prompt) + formatted_request["format"] = output_model.model_json_schema() + formatted_request["stream"] = False + response = self.client.chat(**formatted_request) + print(response) + + try: + content = response.message.content.strip() + return output_model.model_validate_json(content) + except Exception as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 764cb8519..26db6a62d 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -4,11 +4,14 @@ """ import logging -from typing import Any, Iterable, Optional, Protocol, TypedDict, cast +from typing import Any, Iterable, Optional, Protocol, Type, TypedDict, cast import openai +from openai.types.chat.parsed_chat_completion import ParsedChatCompletion +from pydantic import BaseModel from typing_extensions import Unpack, override +from ..types.content import Messages from ..types.models import OpenAIModel as SAOpenAIModel logger = logging.getLogger(__name__) @@ -122,3 +125,29 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: _ = event yield {"chunk_type": "metadata", "data": event.usage} + + @override + def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt to use for the agent. + """ + response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + parsed: BaseModel | None = None + # Find the first choice with tool_calls + for choice in response.choices: + if isinstance(choice.message.parsed, output_model): + parsed = choice.message.parsed + break + + if parsed: + return parsed + else: + raise ValueError("No valid tool use or tool use input was found in the OpenAI response.") diff --git a/tests-integ/test_model_anthropic.py b/tests-integ/test_model_anthropic.py index 1b0412c94..95bfceb56 100644 --- a/tests-integ/test_model_anthropic.py +++ b/tests-integ/test_model_anthropic.py @@ -1,6 +1,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -47,3 +48,16 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny", "&"]) + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=model) + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_ollama.py b/tests-integ/test_model_ollama.py new file mode 100644 index 000000000..29ece1544 --- /dev/null +++ b/tests-integ/test_model_ollama.py @@ -0,0 +1,33 @@ +import pytest +from pydantic import BaseModel + +from strands import Agent +from strands.models.ollama import OllamaModel + + +@pytest.fixture +def model(): + return OllamaModel(host="http://localhost:11434", model_id="llama3.1:8b") + + +@pytest.fixture +def agent(model): + return Agent(model=model) + + +def test_agent(agent): + result = agent("Say 'hello world' with no other text") + assert result.message["content"][0]["text"].lower() == "hello world" + + +def test_structured_output(agent): + class Weather(BaseModel): + """Extract the time and weather from the response with the exact strings.""" + + time: str + weather: str + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py index c9046ad5d..2ebe3ac86 100644 --- a/tests-integ/test_model_openai.py +++ b/tests-integ/test_model_openai.py @@ -1,6 +1,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -44,3 +45,18 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +def test_structured_output(model): + class Weather(BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + agent = Agent(model=model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_structured_output.py b/tests-integ/test_structured_output.py deleted file mode 100644 index 606eafc3b..000000000 --- a/tests-integ/test_structured_output.py +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for function-based tools -""" - -# import logging -from typing import Optional - -from pydantic import BaseModel, Field - -from strands import Agent - -# logging.getLogger("strands").setLevel(logging.DEBUG) -# logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) - -prompt = "Jane Smith is 28 years old and lives at 123 Main St, Boston, MA 02108." - - -class Person(BaseModel): - """A simple model for testing structured data extraction.""" - - name: str = Field(description="The person's full name") - age: int = Field(description="The person's age in years", ge=18) - - -class Address(BaseModel): - """Address information for testing nested data structures.""" - - street: str = Field(description="Street address") - city: str = Field(description="City name") - zip_code: str = Field(description="Postal code", alias="zipCode") - - -class PersonWithAddress(BaseModel): - """A person with an address.""" - - person: Person = Field(description="The person's information") - address: Optional[Address] = Field(description="The person's address") - - -# Initialize agent with function tools -print("\n===== Input Prompt =====\n") -print(prompt) - -print("\n===== Running the agent =====\n") -result = Agent().with_output(prompt, PersonWithAddress) - -print("\n===== Output Result =====\n") -print(result.model_dump_json(indent=2)) diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index f88c9e3fa..a0cfc4d4a 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -615,11 +615,24 @@ def test_format_chunk_unknown(model): def test_stream(anthropic_client, model): - mock_event_1 = unittest.mock.Mock(type="message_start", dict=lambda: {"type": "message_start"}) - mock_event_2 = unittest.mock.Mock(type="unknown") + mock_event_1 = unittest.mock.Mock( + type="message_start", + dict=lambda: {"type": "message_start"}, + model_dump=lambda: {"type": "message_start"}, + ) + mock_event_2 = unittest.mock.Mock( + type="unknown", + dict=lambda: {"type": "unknown"}, + model_dump=lambda: {"type": "unknown"}, + ) mock_event_3 = unittest.mock.Mock( type="metadata", - message=unittest.mock.Mock(usage=unittest.mock.Mock(dict=lambda: {"input_tokens": 1, "output_tokens": 2})), + message=unittest.mock.Mock( + usage=unittest.mock.Mock( + dict=lambda: {"input_tokens": 1, "output_tokens": 2}, + model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, + ) + ), ) mock_stream = unittest.mock.MagicMock() @@ -631,8 +644,11 @@ def test_stream(anthropic_client, model): tru_events = list(response) exp_events = [ - {"messageStart": {"role": "assistant"}}, - {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, + {"type": "message_start"}, + { + "type": "metadata", + "usage": {"input_tokens": 1, "output_tokens": 2}, + }, ] assert tru_events == exp_events From 23df2c6a191b37b0f951f0f8570d5c572ab971f0 Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Sun, 8 Jun 2025 16:28:06 -0500 Subject: [PATCH 07/18] feat: updated docstring --- src/strands/agent/agent.py | 12 ++++++++++-- src/strands/telemetry/tracer.py | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 1816ba9f0..d98090ce1 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -368,10 +368,18 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: raise def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: - """Get structured output from the Agent's current context. + """This method allows you to get structured output from the agent. + + If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. + If you don't pass in a prompt, it will use only the conversation history to respond. + If no conversation history exists and no prompt is provided, an error will be raised. + + For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly + instruct the model to output the structured data. Args: - output_model(Type[BaseModel]): The output model the agent will use when responding. + output_model(Type[BaseModel]): The output model (a JSON schema written as a Pydantic BaseModel) + that the agent will use when responding. prompt(Optional[str]): The prompt to use for the agent. """ messages = self.messages diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 9f731996e..ead2e49cc 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -15,7 +15,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter # See https://github.com/open-telemetry/opentelemetry-python/issues/4615 for the type ignore -from opentelemetry.sdk.resources import Resource # type: ignore[attr-defined] +from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor from opentelemetry.trace import StatusCode From cc78b6f8b6ac5f8f51442ae52c4e874961a143a2 Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Sun, 8 Jun 2025 17:22:04 -0500 Subject: [PATCH 08/18] fix: otel ci dep issue --- src/strands/telemetry/tracer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index ead2e49cc..9f731996e 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -15,7 +15,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter # See https://github.com/open-telemetry/opentelemetry-python/issues/4615 for the type ignore -from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.resources import Resource # type: ignore[attr-defined] from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor from opentelemetry.trace import StatusCode From e8ef60056d5ef5dc9de5f4dfc840bbd6b9535186 Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Mon, 9 Jun 2025 08:55:38 -0500 Subject: [PATCH 09/18] fix: remove unnecessary changes and comments --- src/strands/agent/agent.py | 72 +------------------ src/strands/event_loop/event_loop.py | 2 +- tests/strands/tools/test_structured_output.py | 0 3 files changed, 2 insertions(+), 72 deletions(-) create mode 100644 tests/strands/tools/test_structured_output.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index d98090ce1..170fec39a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -393,76 +393,6 @@ def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] # get the structured output from the model return self.model.structured_output(output_model, messages) - # TODO: implement - # def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str]) -> BaseModel: - # """Get structured output from the Agent's current context. - - # Args: - # output_model(Type[BaseModel]): The output model the agent will use when responding. - # prompt(Optional[str]): The prompt to use for the agent. - - # Returns: - # The loaded basemodel. - - # Raises: - # ValidationException: The response format from the large language model does not match the output_model - # """ - # from ..tools.structured_output import convert_pydantic_to_bedrock_tool - - # # Convert the pydantic basemodel to a tool spec - # tool_spec = convert_pydantic_to_bedrock_tool(output_model) - - # # Create a dynamic tool name to avoid collisions - # tool_name = f"generate_{output_model.__name__}" - # tool_spec["toolSpec"]["name"] = tool_name - - # # Register the tool with the tool registry - # # We need a special type of tool that just passes through the input - - # # Create a passthrough callback that just returns the input - # # with the signature expected by PythonAgentTool - - # def output_callback( - # tool_use: ToolUse, - # model: Any = None, # noqa: ANN401 - # messages: Optional[dict[str, Any]] = None, # noqa: ANN401 - # **kwargs: Any, - # ) -> ToolResult: - # # Return the ToolResult explicitly typed - # result: ToolResult = { - # "toolUseId": tool_use["toolUseId"], - # "status": "success", - # "content": [{"text": "Output generated successfully"}], - # } - # return result - - # tool = PythonAgentTool(tool_name=tool_name, tool_spec=tool_spec["toolSpec"], callback=output_callback) - # self.tool_registry.register_tool(tool) - - # # Call the model with the tool and get the response - # # This will run the model and invoke the tool - # self(prompt) - - # # Extract the tool input from the message - # # Find the first toolUse in the conversation history - # tool_input = None - # for message in self.messages: - # if message.get("role") == "assistant": - # for content in message.get("content", []): - # if isinstance(content, dict) and "toolUse" in content: - # tool_use = content["toolUse"] - # if tool_use.get("name") == tool_name: - # tool_input = tool_use.get("input", {}) - # break - # if tool_input: - # break - - # # Create the output model from the tool input and return it - # if not tool_input: - # raise ValueError(f"Model did not generate a valid {output_model.__name__}") - - # return output_model(**tool_input) - async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -477,7 +407,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: Returns: An async iterator that yields events. Each event is a dictionary containing - invocation about the current state of processing, such as: + information about the current state of processing, such as: - data: Text content being generated - complete: Whether this is the final chunk - current_tool_use: Information about tools being executed diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 3819ae4f5..23d7bd0f3 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -13,7 +13,6 @@ from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, cast -from ..event_loop.streaming import stream_messages from ..telemetry.metrics import EventLoopMetrics, Trace from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools @@ -25,6 +24,7 @@ from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse from .error_handler import handle_input_too_long_error, handle_throttling_error from .message_processor import clean_orphaned_empty_tool_uses +from .streaming import stream_messages logger = logging.getLogger(__name__) diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py new file mode 100644 index 000000000..e69de29bb From 6eeeaa80e938d8bde5ac7627691f9443d1cea83e Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Mon, 9 Jun 2025 09:00:41 -0500 Subject: [PATCH 10/18] feat: basic test WIP --- tests/strands/tools/test_structured_output.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index e69de29bb..5ede0340b 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -0,0 +1,24 @@ +import pytest +from pydantic import BaseModel, Field + +from strands.tools.structured_output import convert_pydantic_to_bedrock_tool + + +class User(BaseModel): + """A user of the system.""" + + name: str = Field(description="The name of the user") + age: int = Field(description="The age of the user") + email: str = Field(description="The email of the user", default="") + + +@pytest.fixture +def user_model(): + return User + + +def test_convert_pydantic_to_bedrock_tool(user_model): + tool_spec = convert_pydantic_to_bedrock_tool(user_model) + + assert tool_spec is not None + print(tool_spec) From 51f1f1d8ecbeae2b22570ee06363f54857690eb8 Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Mon, 9 Jun 2025 09:28:17 -0500 Subject: [PATCH 11/18] feat: better test coverage --- tests/strands/agent/test_agent.py | 24 ++++ tests/strands/tools/test_structured_output.py | 126 +++++++++++++++++- 2 files changed, 147 insertions(+), 3 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 0ea20b642..254cdf4ca 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -7,6 +7,7 @@ from time import sleep import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -717,6 +718,29 @@ def test_agent_callback_handler_custom_handler_used(): assert agent.callback_handler is custom_handler +# mock the User(name='Jane Doe', age=30, email='jane@doe.com') +class User(BaseModel): + """A user of the system.""" + + name: str + age: int + email: str + + +def test_agent_method_structured_output(agent): + # Mock the structured_output method on the model + expected_user = User(name="Jane Doe", age=30, email="jane@doe.com") + agent.model.structured_output = unittest.mock.Mock(return_value=expected_user) + + prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + + result = agent.structured_output(User, prompt) + assert result == expected_user + + # Verify the model's structured_output was called with correct arguments + agent.model.structured_output.assert_called_once_with(User, [{"role": "user", "content": [{"text": prompt}]}]) + + @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): agent = Agent() diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index 5ede0340b..8520866c7 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -1,7 +1,10 @@ +from typing import Literal + import pytest from pydantic import BaseModel, Field from strands.tools.structured_output import convert_pydantic_to_bedrock_tool +from strands.types.tools import ToolSpec class User(BaseModel): @@ -12,13 +15,130 @@ class User(BaseModel): email: str = Field(description="The email of the user", default="") +class Employment(BaseModel): + """An employment of the user.""" + + company: str = Field(description="The company of the user") + title: Literal[ + "CEO", + "CTO", + "CFO", + "CMO", + "COO", + "VP", + "Director", + "Manager", + "Other", + ] = Field(description="The title of the user", default="Other") + + +# for a nested, more complex test +class UserWithEmployment(User): + """A user of the system with employment.""" + + employment: Employment = Field(description="The employment of the user") + + @pytest.fixture def user_model(): return User -def test_convert_pydantic_to_bedrock_tool(user_model): +@pytest.fixture +def user_with_employment_model(): + return UserWithEmployment + + +@pytest.fixture +def basic_user_tool_spec(): + return { + "name": "User", + "description": "A user of the system.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": {"description": "The age of the user", "title": "Age", "type": "integer"}, + "email": { + "default": "", + "description": "The email of the user", + "title": "Email", + "type": ["string", "null"], + }, + }, + "title": "User", + "description": "A user of the system.", + "required": ["name", "age"], + } + }, + } + + +@pytest.fixture +def complex_user_tool_spec_json(): + return { + "name": "UserWithEmployment", + "description": "A user of the system with employment.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": {"description": "The age of the user", "title": "Age", "type": "integer"}, + "email": { + "default": "", + "description": "The email of the user", + "title": "Email", + "type": ["string", "null"], + }, + "employment": { + "type": "object", + "description": "The employment of the user", + "properties": { + "company": {"description": "The company of the user", "title": "Company", "type": "string"}, + "title": { + "default": "Other", + "description": "The title of the user", + "enum": ["CEO", "CTO", "CFO", "CMO", "COO", "VP", "Director", "Manager", "Other"], + "title": "Title", + "type": "string", + }, + }, + "required": ["company"], + }, + }, + "title": "UserWithEmployment", + "description": "A user of the system with employment.", + "required": ["name", "age", "employment"], + } + }, + } + + +@pytest.fixture +def user_with_employment_tool_spec(): + return { + "name": "UserWithEmployment", + "description": "A user of the system with employment.", + } + + +def test_convert_pydantic_to_bedrock_tool_basic( + user_model, + basic_user_tool_spec, +): tool_spec = convert_pydantic_to_bedrock_tool(user_model) - assert tool_spec is not None - print(tool_spec) + assert tool_spec == basic_user_tool_spec + assert ToolSpec(**tool_spec) == ToolSpec(**basic_user_tool_spec) + + +def test_convert_pydantic_to_bedrock_tool_complex( + user_with_employment_model, + complex_user_tool_spec_json, +): + tool_spec = convert_pydantic_to_bedrock_tool(user_with_employment_model) + + assert tool_spec == complex_user_tool_spec_json + assert ToolSpec(**tool_spec) == ToolSpec(**complex_user_tool_spec_json) From d5bef961bc3542311f59d4a77a83347515197c78 Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Mon, 9 Jun 2025 09:30:27 -0500 Subject: [PATCH 12/18] fix: remove unused fixture --- tests/strands/tools/test_structured_output.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index 8520866c7..d3814223e 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -116,14 +116,6 @@ def complex_user_tool_spec_json(): } -@pytest.fixture -def user_with_employment_tool_spec(): - return { - "name": "UserWithEmployment", - "description": "A user of the system with employment.", - } - - def test_convert_pydantic_to_bedrock_tool_basic( user_model, basic_user_tool_spec, From c66fa326cd809ef67331dd347544b30171fa1dac Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Fri, 13 Jun 2025 17:07:18 -0500 Subject: [PATCH 13/18] fix: resolve some comments --- src/strands/agent/agent.py | 9 +++++--- src/strands/models/anthropic.py | 23 ++++++++++++------- src/strands/models/bedrock.py | 19 ++++++++++----- src/strands/models/litellm.py | 11 ++++++--- src/strands/models/llamaapi.py | 11 ++++++--- src/strands/models/ollama.py | 12 ++++++---- src/strands/models/openai.py | 13 +++++++---- src/strands/tools/__init__.py | 4 ++-- src/strands/tools/structured_output.py | 6 ++--- src/strands/types/models/model.py | 11 ++++++--- src/strands/types/models/openai.py | 9 ++++++-- tests/strands/agent/test_agent.py | 4 +++- tests/strands/tools/test_structured_output.py | 10 ++++---- 13 files changed, 95 insertions(+), 47 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 170fec39a..389d0559e 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -16,7 +16,7 @@ import random from concurrent.futures import ThreadPoolExecutor from threading import Thread -from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union from uuid import uuid4 from opentelemetry import trace @@ -44,6 +44,9 @@ logger = logging.getLogger(__name__) +# TypeVar for generic structured output +T = TypeVar("T", bound=BaseModel) + # Sentinel class and object to distinguish between explicit None and default parameter value class _DefaultCallbackHandlerSentinel: @@ -367,7 +370,7 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: # Re-raise the exception to preserve original behavior raise - def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel: + def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. @@ -391,7 +394,7 @@ def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] messages.append({"role": "user", "content": [{"text": prompt}]}) # get the structured output from the model - return self.model.structured_output(output_model, messages) + return self.model.structured_output(output_model, messages, self.callback_handler) async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 8caa407c4..ab427e53d 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,7 +7,7 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, Type, TypedDict, cast +from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast import anthropic from pydantic import BaseModel @@ -15,7 +15,7 @@ from ..event_loop.streaming import process_stream from ..handlers.callback_handler import PrintingCallbackHandler -from ..tools import convert_pydantic_to_bedrock_tool +from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -24,6 +24,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class AnthropicModel(Model): """Anthropic model provider implementation.""" @@ -375,23 +377,28 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: raise error @override - def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. - prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. """ - tool_spec = convert_pydantic_to_bedrock_tool(output_model) + tool_spec = convert_pydantic_to_tool_spec(output_model) response = self.converse(messages=prompt, tool_specs=[tool_spec]) # process the stream and get the tool use input - results = process_stream(response, callback_handler=PrintingCallbackHandler(), messages=prompt) + results = process_stream( + response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt + ) stop_reason, messages, _, _, _ = results if stop_reason != "tool_use": - raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") content = messages["content"] output_response: dict[str, Any] | None = None @@ -404,6 +411,6 @@ def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> continue if output_response is None: - raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") return output_model(**output_response) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index f2060e1f1..ac1c4a381 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -6,7 +6,7 @@ import json import logging import os -from typing import Any, Iterable, List, Literal, Optional, Type, cast +from typing import Any, Callable, Iterable, List, Literal, Optional, Type, TypeVar, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -16,7 +16,7 @@ from ..event_loop.streaming import process_stream from ..handlers.callback_handler import PrintingCallbackHandler -from ..tools import convert_pydantic_to_bedrock_tool +from ..tools import convert_pydantic_to_tool_spec from ..types.content import Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -33,6 +33,8 @@ "too many total text bytes", ] +T = TypeVar("T", bound=BaseModel) + class BedrockModel(Model): """AWS Bedrock model provider implementation. @@ -483,18 +485,23 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: return False @override - def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. - prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. """ - tool_spec = convert_pydantic_to_bedrock_tool(output_model) + tool_spec = convert_pydantic_to_tool_spec(output_model) response = self.converse(messages=prompt, tool_specs=[tool_spec]) # process the stream and get the tool use input - results = process_stream(response, callback_handler=PrintingCallbackHandler(), messages=prompt) + results = process_stream( + response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt + ) stop_reason, messages, _, _, _ = results diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 073a7b612..b91a7cbea 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -5,7 +5,7 @@ import json import logging -from typing import Any, Optional, Type, TypedDict, cast +from typing import Any, Callable, Optional, Type, TypedDict, TypeVar, cast import litellm from pydantic import BaseModel @@ -16,6 +16,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class LiteLLMModel(OpenAIModel): """LiteLLM model provider implementation.""" @@ -101,12 +103,15 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] return super().format_request_message_content(content) @override - def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. - prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. TODO: add checks: https://docs.litellm.ai/docs/completion/json_mode#check-model-support """ diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 76b72f6d5..46e5ec696 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,7 +8,7 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, Type, cast +from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast import llama_api_client from llama_api_client import LlamaAPIClient @@ -23,6 +23,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class LlamaAPIModel(Model): """Llama API model provider implementation.""" @@ -387,12 +389,15 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "metadata", "data": metrics_event} @override - def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. - prompt(Messages): The prompt to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. Raises: NotImplementedError: Structured output is not currently supported for LlamaAPI models. diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index da8331baa..b062fe14d 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,7 +5,7 @@ import json import logging -from typing import Any, Iterable, Optional, Type, cast +from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast from ollama import Client as OllamaClient from pydantic import BaseModel @@ -18,6 +18,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class OllamaModel(Model): """Ollama model provider implementation. @@ -313,18 +315,20 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "metadata", "data": event} @override - def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. - prompt(Messages): The prompt to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. """ formatted_request = self.format_request(messages=prompt) formatted_request["format"] = output_model.model_json_schema() formatted_request["stream"] = False response = self.client.chat(**formatted_request) - print(response) try: content = response.message.content.strip() diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index c888cd806..6c0a76d39 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -4,7 +4,7 @@ """ import logging -from typing import Any, Iterable, Optional, Protocol, Type, TypedDict, cast +from typing import Any, Callable, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, cast import openai from openai.types.chat.parsed_chat_completion import ParsedChatCompletion @@ -16,6 +16,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class Client(Protocol): """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" @@ -130,12 +132,15 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "metadata", "data": event.usage} @override - def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. - prompt(Messages): The prompt to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. """ response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore model=self.get_config()["model_id"], @@ -143,7 +148,7 @@ def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> response_format=output_model, ) - parsed: BaseModel | None = None + parsed: T | None = None # Find the first choice with tool_calls for choice in response.choices: if isinstance(choice.message.parsed, output_model): diff --git a/src/strands/tools/__init__.py b/src/strands/tools/__init__.py index 97b1e8678..12979015e 100644 --- a/src/strands/tools/__init__.py +++ b/src/strands/tools/__init__.py @@ -4,7 +4,7 @@ """ from .decorator import tool -from .structured_output import convert_pydantic_to_bedrock_tool +from .structured_output import convert_pydantic_to_tool_spec from .thread_pool_executor import ThreadPoolExecutorWrapper from .tools import FunctionTool, InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec @@ -16,5 +16,5 @@ "normalize_schema", "normalize_tool_spec", "ThreadPoolExecutorWrapper", - "convert_pydantic_to_bedrock_tool", + "convert_pydantic_to_tool_spec", ] diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output.py index 35cef868c..f19cf6a47 100644 --- a/src/strands/tools/structured_output.py +++ b/src/strands/tools/structured_output.py @@ -7,7 +7,7 @@ from ..types.tools import ToolSpec -def flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: +def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: """Flattens a JSON schema by removing $defs and resolving $ref references. Handles required vs optional fields properly. @@ -253,7 +253,7 @@ def process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, An return result -def convert_pydantic_to_bedrock_tool( +def convert_pydantic_to_tool_spec( model: Type[BaseModel], description: Optional[str] = None, ) -> ToolSpec: @@ -286,7 +286,7 @@ def convert_pydantic_to_bedrock_tool( expand_nested_properties(input_schema, model) # Flatten the schema - flattened_schema = flatten_schema(input_schema) + flattened_schema = _flatten_schema(input_schema) final_schema = flattened_schema diff --git a/src/strands/types/models/model.py b/src/strands/types/models/model.py index 4f962ae92..071c8a511 100644 --- a/src/strands/types/models/model.py +++ b/src/strands/types/models/model.py @@ -2,7 +2,7 @@ import abc import logging -from typing import Any, Iterable, Optional, Type +from typing import Any, Callable, Iterable, Optional, Type, TypeVar from pydantic import BaseModel @@ -12,6 +12,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class Model(abc.ABC): """Abstract base class for AI model implementations. @@ -42,12 +44,15 @@ def get_config(self) -> Any: @abc.abstractmethod # pragma: no cover - def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. - prompt(Optional[str]): The prompt to use for the agent. Defaults to None. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. Returns: The structured output as a serialized instance of the output model. diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 49f27b024..4f182c0d9 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -11,7 +11,7 @@ import json import logging import mimetypes -from typing import Any, Optional, Type, cast +from typing import Any, Callable, Optional, Type, TypeVar, cast from pydantic import BaseModel from typing_extensions import override @@ -23,6 +23,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class OpenAIModel(Model, abc.ABC): """Base OpenAI model provider implementation. @@ -265,11 +267,14 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") @override - def structured_output(self, output_model: Type[BaseModel], prompt: Messages) -> BaseModel: + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Messages): The prompt to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. """ return output_model() diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 254cdf4ca..06fe28be4 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -738,7 +738,9 @@ def test_agent_method_structured_output(agent): assert result == expected_user # Verify the model's structured_output was called with correct arguments - agent.model.structured_output.assert_called_once_with(User, [{"role": "user", "content": [{"text": prompt}]}]) + agent.model.structured_output.assert_called_once_with( + User, [{"role": "user", "content": [{"text": prompt}]}], agent.callback_handler + ) @pytest.mark.asyncio diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index d3814223e..de655a141 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -3,7 +3,7 @@ import pytest from pydantic import BaseModel, Field -from strands.tools.structured_output import convert_pydantic_to_bedrock_tool +from strands.tools.structured_output import convert_pydantic_to_tool_spec from strands.types.tools import ToolSpec @@ -116,21 +116,21 @@ def complex_user_tool_spec_json(): } -def test_convert_pydantic_to_bedrock_tool_basic( +def test_convert_pydantic_to_tool_spec_basic( user_model, basic_user_tool_spec, ): - tool_spec = convert_pydantic_to_bedrock_tool(user_model) + tool_spec = convert_pydantic_to_tool_spec(user_model) assert tool_spec == basic_user_tool_spec assert ToolSpec(**tool_spec) == ToolSpec(**basic_user_tool_spec) -def test_convert_pydantic_to_bedrock_tool_complex( +def test_convert_pydantic_to_tool_spec_complex( user_with_employment_model, complex_user_tool_spec_json, ): - tool_spec = convert_pydantic_to_bedrock_tool(user_with_employment_model) + tool_spec = convert_pydantic_to_tool_spec(user_with_employment_model) assert tool_spec == complex_user_tool_spec_json assert ToolSpec(**tool_spec) == ToolSpec(**complex_user_tool_spec_json) From 422bc25285e95241a91ad16dc054185fb382cd53 Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Fri, 13 Jun 2025 17:36:25 -0500 Subject: [PATCH 14/18] fix: inline basemodel classes --- tests/strands/tools/test_structured_output.py | 124 +++++++++--------- 1 file changed, 59 insertions(+), 65 deletions(-) diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index de655a141..c686ee2cc 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -1,57 +1,20 @@ -from typing import Literal +from typing import Literal, Optional -import pytest from pydantic import BaseModel, Field from strands.tools.structured_output import convert_pydantic_to_tool_spec from strands.types.tools import ToolSpec -class User(BaseModel): - """A user of the system.""" +def test_convert_pydantic_to_tool_spec_basic(): + class User(BaseModel): + """A user of the system.""" - name: str = Field(description="The name of the user") - age: int = Field(description="The age of the user") - email: str = Field(description="The email of the user", default="") + name: str = Field(description="The name of the user") + age: int = Field(description="The age of the user") + email: str = Field(description="The email of the user", default="") - -class Employment(BaseModel): - """An employment of the user.""" - - company: str = Field(description="The company of the user") - title: Literal[ - "CEO", - "CTO", - "CFO", - "CMO", - "COO", - "VP", - "Director", - "Manager", - "Other", - ] = Field(description="The title of the user", default="Other") - - -# for a nested, more complex test -class UserWithEmployment(User): - """A user of the system with employment.""" - - employment: Employment = Field(description="The employment of the user") - - -@pytest.fixture -def user_model(): - return User - - -@pytest.fixture -def user_with_employment_model(): - return UserWithEmployment - - -@pytest.fixture -def basic_user_tool_spec(): - return { + basic_user_tool_spec = { "name": "User", "description": "A user of the system.", "inputSchema": { @@ -73,11 +36,44 @@ def basic_user_tool_spec(): } }, } + tool_spec = convert_pydantic_to_tool_spec(User) + + assert tool_spec == basic_user_tool_spec + assert ToolSpec(**tool_spec) == ToolSpec(**basic_user_tool_spec) + + +def test_convert_pydantic_to_tool_spec_complex(): + class User(BaseModel): + """A user of the system.""" + + name: str = Field(description="The name of the user") + age: int = Field(description="The age of the user") + email: str = Field(description="The email of the user", default="") + + class Employment(BaseModel): + """An employment of the user.""" + + company: str = Field(description="The company of the user") + title: Literal[ + "CEO", + "CTO", + "CFO", + "CMO", + "COO", + "VP", + "Director", + "Manager", + "Other", + ] = Field(description="The title of the user", default="Other") + + class UserWithEmployment(User): + """A user of the system with employment.""" + employment: Employment = Field(description="The employment of the user") + employment_2: Optional[Employment] = Field(description="A part time employment of the user") -@pytest.fixture -def complex_user_tool_spec_json(): - return { + tool_spec = convert_pydantic_to_tool_spec(UserWithEmployment) + complex_user_tool_spec_json = { "name": "UserWithEmployment", "description": "A user of the system with employment.", "inputSchema": { @@ -107,6 +103,21 @@ def complex_user_tool_spec_json(): }, "required": ["company"], }, + "employment_2": { + "type": ["object", "null"], + "description": "A part time employment of the user", + "properties": { + "company": {"description": "The company of the user", "title": "Company", "type": "string"}, + "title": { + "default": "Other", + "description": "The title of the user", + "enum": ["CEO", "CTO", "CFO", "CMO", "COO", "VP", "Director", "Manager", "Other"], + "title": "Title", + "type": "string", + }, + }, + "required": ["company"], + }, }, "title": "UserWithEmployment", "description": "A user of the system with employment.", @@ -115,22 +126,5 @@ def complex_user_tool_spec_json(): }, } - -def test_convert_pydantic_to_tool_spec_basic( - user_model, - basic_user_tool_spec, -): - tool_spec = convert_pydantic_to_tool_spec(user_model) - - assert tool_spec == basic_user_tool_spec - assert ToolSpec(**tool_spec) == ToolSpec(**basic_user_tool_spec) - - -def test_convert_pydantic_to_tool_spec_complex( - user_with_employment_model, - complex_user_tool_spec_json, -): - tool_spec = convert_pydantic_to_tool_spec(user_with_employment_model) - assert tool_spec == complex_user_tool_spec_json assert ToolSpec(**tool_spec) == ToolSpec(**complex_user_tool_spec_json) From eabf07585671a45e64db3f1c0fbd1d575447455a Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Tue, 17 Jun 2025 12:09:43 -0500 Subject: [PATCH 15/18] feat: update litellm, add checks --- pyproject.toml | 167 +++++++++++++++---------------- src/strands/models/litellm.py | 7 +- src/strands/models/llamaapi.py | 14 ++- src/strands/models/openai.py | 3 + tests-integ/test_model_ollama.py | 13 ++- tests-integ/test_model_openai.py | 4 + 6 files changed, 116 insertions(+), 92 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bd3097327..b0a7f1f88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,10 +8,8 @@ dynamic = ["version"] description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" -license = {text = "Apache-2.0"} -authors = [ - {name = "AWS", email = "opensource@amazon.com"}, -] +license = { text = "Apache-2.0" } +authors = [{ name = "AWS", email = "opensource@amazon.com" }] classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -47,9 +45,7 @@ Documentation = "https://strandsagents.com" packages = ["src/strands"] [project.optional-dependencies] -anthropic = [ - "anthropic>=0.21.0,<1.0.0", -] +anthropic = ["anthropic>=0.21.0,<1.0.0"] dev = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", @@ -66,18 +62,10 @@ docs = [ "sphinx-rtd-theme>=1.0.0,<2.0.0", "sphinx-autodoc-typehints>=1.12.0,<2.0.0", ] -litellm = [ - "litellm>=1.69.0,<2.0.0", -] -llamaapi = [ - "llama-api-client>=0.1.0,<1.0.0", -] -ollama = [ - "ollama>=0.4.8,<1.0.0", -] -openai = [ - "openai>=1.68.0,<2.0.0", -] +litellm = ["litellm>=1.72.6,<2.0.0"] +llamaapi = ["llama-api-client>=0.1.0,<1.0.0"] +ollama = ["ollama>=0.4.8,<1.0.0"] +openai = ["openai>=1.68.0,<2.0.0"] [tool.hatch.version] # Tells Hatch to use your version control system (git) to determine the version. @@ -86,25 +74,16 @@ source = "vcs" [tool.hatch.envs.hatch-static-analysis] features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] dependencies = [ - "mypy>=1.15.0,<2.0.0", - "ruff>=0.11.6,<0.12.0", - "strands-agents @ {root:uri}" + "mypy>=1.15.0,<2.0.0", + "ruff>=0.11.6,<0.12.0", + "strands-agents @ {root:uri}", ] [tool.hatch.envs.hatch-static-analysis.scripts] -format-check = [ - "ruff format --check" -] -format-fix = [ - "ruff format" -] -lint-check = [ - "ruff check", - "mypy -p src" -] -lint-fix = [ - "ruff check --fix" -] +format-check = ["ruff format --check"] +format-fix = ["ruff format"] +lint-check = ["ruff check", "mypy -p src"] +lint-fix = ["ruff check --fix"] [tool.hatch.envs.hatch-test] features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] @@ -115,28 +94,21 @@ extra-dependencies = [ "pytest-cov>=4.1.0,<5.0.0", "pytest-xdist>=3.0.0,<4.0.0", ] -extra-args = [ - "-n", - "auto", - "-vv", -] +extra-args = ["-n", "auto", "-vv"] [tool.hatch.envs.dev] dev-mode = true features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"] - [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] [tool.hatch.envs.hatch-test.scripts] -run = [ - "pytest{env:HATCH_TEST_ARGS:} {args}" -] +run = ["pytest{env:HATCH_TEST_ARGS:} {args}"] run-cov = [ - "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}" + "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}", ] cov-combine = [] @@ -145,26 +117,14 @@ cov-report = [] [tool.hatch.envs.default.scripts] list = [ - "echo 'Scripts commands available for default env:'; hatch env show --json | jq --raw-output '.default.scripts | keys[]'" -] -format = [ - "hatch fmt --formatter", -] -test-format = [ - "hatch fmt --formatter --check", -] -lint = [ - "hatch fmt --linter" -] -test-lint = [ - "hatch fmt --linter --check" -] -test = [ - "hatch test --cover --cov-report html --cov-report xml {args}" -] -test-integ = [ - "hatch test tests-integ {args}" + "echo 'Scripts commands available for default env:'; hatch env show --json | jq --raw-output '.default.scripts | keys[]'", ] +format = ["hatch fmt --formatter"] +test-format = ["hatch fmt --formatter --check"] +lint = ["hatch fmt --linter"] +test-lint = ["hatch fmt --linter --check"] +test = ["hatch test --cover --cov-report html --cov-report xml {args}"] +test-integ = ["hatch test tests-integ {args}"] [tool.mypy] @@ -189,17 +149,22 @@ ignore_missing_imports = true [tool.ruff] line-length = 120 -include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"] +include = [ + "examples/**/*.py", + "src/**/*.py", + "tests/**/*.py", + "tests-integ/**/*.py", +] [tool.ruff.lint] select = [ - "B", # flake8-bugbear - "D", # pydocstyle - "E", # pycodestyle - "F", # pyflakes - "G", # logging format - "I", # isort - "LOG", # logging + "B", # flake8-bugbear + "D", # pydocstyle + "E", # pycodestyle + "F", # pyflakes + "G", # logging format + "I", # isort + "LOG", # logging ] [tool.ruff.lint.per-file-ignores] @@ -209,9 +174,7 @@ select = [ convention = "google" [tool.pytest.ini_options] -testpaths = [ - "tests" -] +testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" [tool.coverage.run] @@ -234,19 +197,47 @@ output = "build/coverage/coverage.xml" name = "cz_conventional_commits" tag_format = "v$version" bump_message = "chore(release): bump version $current_version -> $new_version" -version_files = [ - "pyproject.toml:version", -] +version_files = ["pyproject.toml:version"] update_changelog_on_bump = true style = [ - ["qmark", "fg:#ff9d00 bold"], - ["question", "bold"], - ["answer", "fg:#ff9d00 bold"], - ["pointer", "fg:#ff9d00 bold"], - ["highlighted", "fg:#ff9d00 bold"], - ["selected", "fg:#cc5454"], - ["separator", "fg:#cc5454"], - ["instruction", ""], - ["text", ""], - ["disabled", "fg:#858585 italic"] + [ + "qmark", + "fg:#ff9d00 bold", + ], + [ + "question", + "bold", + ], + [ + "answer", + "fg:#ff9d00 bold", + ], + [ + "pointer", + "fg:#ff9d00 bold", + ], + [ + "highlighted", + "fg:#ff9d00 bold", + ], + [ + "selected", + "fg:#cc5454", + ], + [ + "separator", + "fg:#cc5454", + ], + [ + "instruction", + "", + ], + [ + "text", + "", + ], + [ + "disabled", + "fg:#858585 italic", + ], ] diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index b91a7cbea..661381863 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Optional, Type, TypedDict, TypeVar, cast import litellm +from litellm.utils import supports_response_schema from pydantic import BaseModel from typing_extensions import Unpack, override @@ -113,7 +114,6 @@ def structured_output( prompt(Messages): The prompt messages to use for the agent. callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. - TODO: add checks: https://docs.litellm.ai/docs/completion/json_mode#check-model-support """ # The LiteLLM `Client` inits with Chat(). # Chat() inits with self.completions @@ -124,6 +124,11 @@ def structured_output( response_format=output_model, ) + if not supports_response_schema(self.get_config()["model_id"]): + raise ValueError("Model does not support response_format") + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the response.") + # Find the first choice with tool_calls for choice in response.choices: if choice.finish_reason == "tool_calls": diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 46e5ec696..755e07ad9 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -402,4 +402,16 @@ def structured_output( Raises: NotImplementedError: Structured output is not currently supported for LlamaAPI models. """ - raise NotImplementedError("Structured output is not currently supported for LlamaAPI models") + # response_format: ResponseFormat = { + # "type": "json_schema", + # "json_schema": { + # "name": output_model.__name__, + # "schema": output_model.model_json_schema(), + # }, + # } + # response = self.client.chat.completions.create( + # model=self.config["model_id"], + # messages=self.format_request(prompt)["messages"], + # response_format=response_format, + # ) + raise NotImplementedError("Strands sdk-python does not implement this in the Llama API Preview.") diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 6c0a76d39..783ce3794 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -150,6 +150,9 @@ def structured_output( parsed: T | None = None # Find the first choice with tool_calls + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the OpenAI response.") + for choice in response.choices: if isinstance(choice.message.parsed, output_model): parsed = choice.message.parsed diff --git a/tests-integ/test_model_ollama.py b/tests-integ/test_model_ollama.py index 29ece1544..74000943d 100644 --- a/tests-integ/test_model_ollama.py +++ b/tests-integ/test_model_ollama.py @@ -1,4 +1,5 @@ import pytest +import requests from pydantic import BaseModel from strands import Agent @@ -7,7 +8,7 @@ @pytest.fixture def model(): - return OllamaModel(host="http://localhost:11434", model_id="llama3.1:8b") + return OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b") @pytest.fixture @@ -15,11 +16,19 @@ def agent(model): return Agent(model=model) +@pytest.mark.skipif( + not requests.get("http://localhost:11434/api/health").ok, + reason="Local Ollama endpoint not available at localhost:11434", +) def test_agent(agent): result = agent("Say 'hello world' with no other text") - assert result.message["content"][0]["text"].lower() == "hello world" + assert isinstance(result, str) +@pytest.mark.skipif( + not requests.get("http://localhost:11434/api/health").ok, + reason="Local Ollama endpoint not available at localhost:11434", +) def test_structured_output(agent): class Weather(BaseModel): """Extract the time and weather from the response with the exact strings.""" diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py index 2ebe3ac86..b0790ba01 100644 --- a/tests-integ/test_model_openai.py +++ b/tests-integ/test_model_openai.py @@ -47,6 +47,10 @@ def test_agent(agent): assert all(string in text for string in ["12:00", "sunny"]) +@pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, + reason="OPENAI_API_KEY environment variable missing", +) def test_structured_output(model): class Weather(BaseModel): """Extracts the time and weather from the user's message with the exact strings.""" From 885d3acc7e40f37b50056d23ea7b537cb2036974 Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Tue, 17 Jun 2025 12:17:24 -0500 Subject: [PATCH 16/18] fix: autoformatting issue --- pyproject.toml | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7ca8fc0f6..33cd3b06b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,12 +60,6 @@ docs = [ "sphinx-rtd-theme>=1.0.0,<2.0.0", "sphinx-autodoc-typehints>=1.12.0,<2.0.0", ] -<<<<<<< feature/structured-output -litellm = ["litellm>=1.72.6,<2.0.0"] -llamaapi = ["llama-api-client>=0.1.0,<1.0.0"] -ollama = ["ollama>=0.4.8,<1.0.0"] -openai = ["openai>=1.68.0,<2.0.0"] -======= litellm = [ "litellm>=1.69.0,<2.0.0", ] @@ -88,7 +82,6 @@ a2a = [ "fastapi>=0.115.12", "starlette>=0.46.2", ] ->>>>>>> main [tool.hatch.version] # Tells Hatch to use your version control system (git) to determine the version. From 73084912f812020b789b00bf80158e8582bbef6d Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Tue, 17 Jun 2025 18:49:45 -0500 Subject: [PATCH 17/18] feat: resolves comments --- src/strands/tools/structured_output.py | 63 ++-- tests/strands/tools/test_structured_output.py | 269 ++++++++++-------- 2 files changed, 187 insertions(+), 145 deletions(-) diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output.py index f19cf6a47..5421cdc69 100644 --- a/src/strands/tools/structured_output.py +++ b/src/strands/tools/structured_output.py @@ -36,6 +36,7 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: flattened["description"] = schema["description"] # Process properties + required_props: list[str] = [] if "properties" in schema: required_props = [] for prop_name, prop_value in schema["properties"].items(): @@ -60,7 +61,7 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: processed_prop["required"] = prop_value["required"] else: # Process as normal - processed_prop = process_property(prop_value, schema.get("$defs", {}), is_required) + processed_prop = _process_property(prop_value, schema.get("$defs", {}), is_required) flattened["properties"][prop_name] = processed_prop @@ -69,13 +70,17 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: required_props.append(prop_name) # Add required fields if any (only those that are truly required after processing) - if required_props: + # Check if required props are empty, if so, raise an error because it means there is a circular reference + + if len(required_props) > 0: flattened["required"] = required_props + else: + raise ValueError("Circular reference detected and not supported") return flattened -def process_property( +def _process_property( prop: Dict[str, Any], defs: Dict[str, Any], is_required: bool = False, @@ -108,7 +113,10 @@ def process_property( elif "$ref" in option: ref_path = option["$ref"].split("/")[-1] if ref_path in defs: - non_null_type = process_schema_object(defs[ref_path], defs, fully_expand) + non_null_type = _process_schema_object(defs[ref_path], defs, fully_expand) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") else: non_null_type = option @@ -138,26 +146,16 @@ def process_property( if ref_path in defs: ref_dict = defs[ref_path] # Process the referenced object to get a complete schema - result = process_schema_object(ref_dict, defs, fully_expand) - - # Copy description if available in the property (overrides ref description) - if "description" in prop: - result["description"] = prop["description"] - - # If not required, mark as nullable - if not is_required: - if "type" in result and isinstance(result["type"], str): - result["type"] = [result["type"], "null"] - elif "type" in result and isinstance(result["type"], list) and "null" not in result["type"]: - result["type"].append("null") - - return result + result = _process_schema_object(ref_dict, defs, fully_expand) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") # For regular fields, copy all properties for key, value in prop.items(): if key not in ["$ref", "anyOf"]: if isinstance(value, dict): - result[key] = process_nested_dict(value, defs) + result[key] = _process_nested_dict(value, defs) elif key == "type" and not is_required and not is_nullable: # For non-required fields, ensure type is a list with "null" if isinstance(value, str): @@ -172,7 +170,7 @@ def process_property( return result -def process_schema_object( +def _process_schema_object( schema_obj: Dict[str, Any], defs: Dict[str, Any], fully_expand: bool = True ) -> Dict[str, Any]: """Process a schema object, typically from $defs, to resolve all nested properties. @@ -203,7 +201,7 @@ def process_schema_object( for prop_name, prop_value in schema_obj["properties"].items(): # Process each property is_required = prop_name in required_fields - processed = process_property(prop_value, defs, is_required, fully_expand) + processed = _process_property(prop_value, defs, is_required, fully_expand) result["properties"][prop_name] = processed # Track which properties are actually required after processing @@ -217,7 +215,7 @@ def process_schema_object( return result -def process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: +def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: """Recursively processes nested dictionaries and resolves $ref references. Args: @@ -235,7 +233,10 @@ def process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, An if ref_path in defs: ref_dict = defs[ref_path] # Recursively process the referenced object - return process_schema_object(ref_dict, defs) + return _process_schema_object(ref_dict, defs) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") # Process each key-value pair for key, value in d.items(): @@ -243,10 +244,10 @@ def process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, An # Already handled above continue elif isinstance(value, dict): - result[key] = process_nested_dict(value, defs) + result[key] = _process_nested_dict(value, defs) elif isinstance(value, list): # Process lists (like for enum values) - result[key] = [process_nested_dict(item, defs) if isinstance(item, dict) else item for item in value] + result[key] = [_process_nested_dict(item, defs) if isinstance(item, dict) else item for item in value] else: result[key] = value @@ -280,10 +281,10 @@ def convert_pydantic_to_tool_spec( # Process all referenced models to ensure proper docstrings # This step is important for gathering descriptions from referenced models - process_referenced_models(input_schema, model) + _process_referenced_models(input_schema, model) # Now, let's fully expand the nested models with all their properties - expand_nested_properties(input_schema, model) + _expand_nested_properties(input_schema, model) # Flatten the schema flattened_schema = _flatten_schema(input_schema) @@ -298,7 +299,7 @@ def convert_pydantic_to_tool_spec( ) -def expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: +def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: """Expand the properties of nested models in the schema to include their full structure. This updates the schema in place. @@ -358,7 +359,7 @@ def expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> schema["properties"][prop_name] = expanded_object -def process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> None: +def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> None: """Process referenced models to ensure their docstrings are included. This updates the schema in place. @@ -395,10 +396,10 @@ def process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> ref_def["description"] = field_type.__doc__.strip() # Recursively process properties in the referenced model - process_properties(ref_def, field_type) + _process_properties(ref_def, field_type) -def process_properties(schema_def: Dict[str, Any], model: Type[BaseModel]) -> None: +def _process_properties(schema_def: Dict[str, Any], model: Type[BaseModel]) -> None: """Process properties in a schema definition to add descriptions from field metadata. Args: diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index c686ee2cc..e5b6a7d6c 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -1,130 +1,171 @@ from typing import Literal, Optional +import pytest from pydantic import BaseModel, Field from strands.tools.structured_output import convert_pydantic_to_tool_spec from strands.types.tools import ToolSpec -def test_convert_pydantic_to_tool_spec_basic(): - class User(BaseModel): - """A user of the system.""" +# Basic test model +class User(BaseModel): + """User model with name and age.""" - name: str = Field(description="The name of the user") - age: int = Field(description="The age of the user") - email: str = Field(description="The email of the user", default="") - - basic_user_tool_spec = { - "name": "User", - "description": "A user of the system.", - "inputSchema": { - "json": { - "type": "object", - "properties": { - "name": {"description": "The name of the user", "title": "Name", "type": "string"}, - "age": {"description": "The age of the user", "title": "Age", "type": "integer"}, - "email": { - "default": "", - "description": "The email of the user", - "title": "Email", - "type": ["string", "null"], - }, - }, - "title": "User", - "description": "A user of the system.", - "required": ["name", "age"], - } - }, - } + name: str = Field(description="The name of the user") + age: int = Field(description="The age of the user", ge=18, le=100) + + +# Test model with inheritance and literals +class UserWithPlanet(User): + """User with planet.""" + + planet: Literal["Earth", "Mars"] = Field(description="The planet") + + +# Test model with multiple same type fields and optional field +class TwoUsersWithPlanet(BaseModel): + """Two users model with planet.""" + + user1: UserWithPlanet = Field(description="The first user") + user2: Optional[UserWithPlanet] = Field(description="The second user", default=None) + + +# Test model with list of same type fields +class ListOfUsersWithPlanet(BaseModel): + """List of users model with planet.""" + + users: list[UserWithPlanet] = Field(description="The users", min_length=2, max_length=3) + + +def test_convert_pydantic_to_tool_spec_basic(): tool_spec = convert_pydantic_to_tool_spec(User) - assert tool_spec == basic_user_tool_spec - assert ToolSpec(**tool_spec) == ToolSpec(**basic_user_tool_spec) + # Check basic structure + assert tool_spec["name"] == "User" + assert tool_spec["description"] == "User model with name and age." + assert "inputSchema" in tool_spec + assert "json" in tool_spec["inputSchema"] + + # Check schema properties + schema = tool_spec["inputSchema"]["json"] + assert schema["type"] == "object" + assert schema["title"] == "User" + + # Check field properties + assert "name" in schema["properties"] + assert schema["properties"]["name"]["description"] == "The name of the user" + assert schema["properties"]["name"]["type"] == "string" + + assert "age" in schema["properties"] + assert schema["properties"]["age"]["description"] == "The age of the user" + assert schema["properties"]["age"]["type"] == "integer" + + # Check required fields + assert "required" in schema + assert "name" in schema["required"] + assert "age" in schema["required"] + + # check validation + assert schema["properties"]["age"]["minimum"] == 18 + assert schema["properties"]["age"]["maximum"] == 100 + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None def test_convert_pydantic_to_tool_spec_complex(): - class User(BaseModel): - """A user of the system.""" + tool_spec = convert_pydantic_to_tool_spec(ListOfUsersWithPlanet) + + # Assert expected properties are present in the tool spec + assert tool_spec["name"] == "ListOfUsersWithPlanet" + assert tool_spec["description"] == "List of users model with planet." + assert "inputSchema" in tool_spec + assert "json" in tool_spec["inputSchema"] + + # Check the schema properties + schema = tool_spec["inputSchema"]["json"] + assert schema["type"] == "object" + assert "users" in schema["properties"] + assert schema["properties"]["users"]["type"] == "array" + assert schema["properties"]["users"]["items"]["type"] == "object" + assert schema["properties"]["users"]["items"]["properties"]["name"]["type"] == "string" + assert schema["properties"]["users"]["items"]["properties"]["age"]["type"] == "integer" + assert schema["properties"]["users"]["items"]["properties"]["planet"]["type"] == "string" + assert schema["properties"]["users"]["items"]["properties"]["planet"]["enum"] == ["Earth", "Mars"] + # Check the list field properties + assert schema["properties"]["users"]["minItems"] == 2 + assert schema["properties"]["users"]["maxItems"] == 3 + + # Verify the required fields + assert "required" in schema + assert "users" in schema["required"] + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + + +def test_convert_pydantic_to_tool_spec_multiple_same_type(): + tool_spec = convert_pydantic_to_tool_spec(TwoUsersWithPlanet) + + # Verify the schema structure + assert tool_spec["name"] == "TwoUsersWithPlanet" + assert "user1" in tool_spec["inputSchema"]["json"]["properties"] + assert "user2" in tool_spec["inputSchema"]["json"]["properties"] + + # Verify both employment fields have the same structure + primary = tool_spec["inputSchema"]["json"]["properties"]["user1"] + secondary = tool_spec["inputSchema"]["json"]["properties"]["user2"] + + assert primary["type"] == "object" + assert secondary["type"] == ["object", "null"] + + assert "properties" in primary + assert "name" in primary["properties"] + assert "age" in primary["properties"] + assert "planet" in primary["properties"] + + assert "properties" in secondary + assert "name" in secondary["properties"] + assert "age" in secondary["properties"] + assert "planet" in secondary["properties"] + + +def test_convert_pydantic_with_missing_refs(): + """Test that the tool handles missing $refs gracefully.""" + # This test checks that our error handling for missing $refs works correctly + # by testing with a model that has circular references + + class NodeWithCircularRef(BaseModel): + """A node with a circular reference to itself.""" + + name: str = Field(description="The name of the node") + parent: Optional["NodeWithCircularRef"] = Field(None, description="Parent node") + children: list["NodeWithCircularRef"] = Field(default_factory=list, description="Child nodes") + + # This forward reference normally causes issues with schema generation + # but our error handling should prevent errors + with pytest.raises(ValueError, match="Circular reference detected and not supported"): + convert_pydantic_to_tool_spec(NodeWithCircularRef) + + +def test_convert_pydantic_with_custom_description(): + """Test that custom descriptions override model docstrings.""" + + # Test with custom description + custom_description = "Custom tool description for user model" + tool_spec = convert_pydantic_to_tool_spec(User, description=custom_description) + + assert tool_spec["description"] == custom_description + + +def test_convert_pydantic_with_empty_docstring(): + """Test that empty docstrings use default description.""" + + class EmptyDocUser(BaseModel): name: str = Field(description="The name of the user") - age: int = Field(description="The age of the user") - email: str = Field(description="The email of the user", default="") - - class Employment(BaseModel): - """An employment of the user.""" - - company: str = Field(description="The company of the user") - title: Literal[ - "CEO", - "CTO", - "CFO", - "CMO", - "COO", - "VP", - "Director", - "Manager", - "Other", - ] = Field(description="The title of the user", default="Other") - - class UserWithEmployment(User): - """A user of the system with employment.""" - - employment: Employment = Field(description="The employment of the user") - employment_2: Optional[Employment] = Field(description="A part time employment of the user") - - tool_spec = convert_pydantic_to_tool_spec(UserWithEmployment) - complex_user_tool_spec_json = { - "name": "UserWithEmployment", - "description": "A user of the system with employment.", - "inputSchema": { - "json": { - "type": "object", - "properties": { - "name": {"description": "The name of the user", "title": "Name", "type": "string"}, - "age": {"description": "The age of the user", "title": "Age", "type": "integer"}, - "email": { - "default": "", - "description": "The email of the user", - "title": "Email", - "type": ["string", "null"], - }, - "employment": { - "type": "object", - "description": "The employment of the user", - "properties": { - "company": {"description": "The company of the user", "title": "Company", "type": "string"}, - "title": { - "default": "Other", - "description": "The title of the user", - "enum": ["CEO", "CTO", "CFO", "CMO", "COO", "VP", "Director", "Manager", "Other"], - "title": "Title", - "type": "string", - }, - }, - "required": ["company"], - }, - "employment_2": { - "type": ["object", "null"], - "description": "A part time employment of the user", - "properties": { - "company": {"description": "The company of the user", "title": "Company", "type": "string"}, - "title": { - "default": "Other", - "description": "The title of the user", - "enum": ["CEO", "CTO", "CFO", "CMO", "COO", "VP", "Director", "Manager", "Other"], - "title": "Title", - "type": "string", - }, - }, - "required": ["company"], - }, - }, - "title": "UserWithEmployment", - "description": "A user of the system with employment.", - "required": ["name", "age", "employment"], - } - }, - } - - assert tool_spec == complex_user_tool_spec_json - assert ToolSpec(**tool_spec) == ToolSpec(**complex_user_tool_spec_json) + + tool_spec = convert_pydantic_to_tool_spec(EmptyDocUser) + assert tool_spec["description"] == "EmptyDocUser structured output tool" From 0216bccb20993d359fe28c0d773f4f97d75799c0 Mon Sep 17 00:00:00 2001 From: laithalsaadoon Date: Wed, 18 Jun 2025 14:51:49 -0500 Subject: [PATCH 18/18] fix: ollama skip tests, pyproject whitespace diffs --- pyproject.toml | 131 ++++++------ tests-integ/test_model_ollama.py | 25 ++- tests/strands/tools/test_structured_output.py | 201 +++++++++++------- 3 files changed, 204 insertions(+), 153 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 06bd9fdbc..e0cc25785 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,8 +8,10 @@ dynamic = ["version"] description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" -license = { text = "Apache-2.0" } -authors = [{ name = "AWS", email = "opensource@amazon.com" }] +license = {text = "Apache-2.0"} +authors = [ + {name = "AWS", email = "opensource@amazon.com"}, +] classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -26,7 +28,7 @@ classifiers = [ dependencies = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", - "docstring_parser>=0.15,<0.16.0", + "docstring_parser>=0.15,<1.0", "mcp>=1.8.0,<2.0.0", "pydantic>=2.0.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", @@ -44,7 +46,9 @@ Documentation = "https://strandsagents.com" packages = ["src/strands"] [project.optional-dependencies] -anthropic = ["anthropic>=0.21.0,<1.0.0"] +anthropic = [ + "anthropic>=0.21.0,<1.0.0", +] dev = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", @@ -61,7 +65,7 @@ docs = [ "sphinx-autodoc-typehints>=1.12.0,<2.0.0", ] litellm = [ - "litellm>=1.69.0,<2.0.0", + "litellm>=1.72.6,<2.0.0", ] llamaapi = [ "llama-api-client>=0.1.0,<1.0.0", @@ -90,16 +94,25 @@ source = "vcs" [tool.hatch.envs.hatch-static-analysis] features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] dependencies = [ - "mypy>=1.15.0,<2.0.0", - "ruff>=0.11.6,<0.12.0", - "strands-agents @ {root:uri}", + "mypy>=1.15.0,<2.0.0", + "ruff>=0.11.6,<0.12.0", + "strands-agents @ {root:uri}" ] [tool.hatch.envs.hatch-static-analysis.scripts] -format-check = ["ruff format --check"] -format-fix = ["ruff format"] -lint-check = ["ruff check", "mypy -p src"] -lint-fix = ["ruff check --fix"] +format-check = [ + "ruff format --check" +] +format-fix = [ + "ruff format" +] +lint-check = [ + "ruff check", + "mypy -p src" +] +lint-fix = [ + "ruff check --fix" +] [tool.hatch.envs.hatch-test] features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] @@ -110,7 +123,11 @@ extra-dependencies = [ "pytest-cov>=4.1.0,<5.0.0", "pytest-xdist>=3.0.0,<4.0.0", ] -extra-args = ["-n", "auto", "-vv"] +extra-args = [ + "-n", + "auto", + "-vv", +] [tool.hatch.envs.dev] dev-mode = true @@ -120,14 +137,17 @@ features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel"] dev-mode = true features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"] + [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] [tool.hatch.envs.hatch-test.scripts] -run = ["pytest{env:HATCH_TEST_ARGS:} {args}"] +run = [ + "pytest{env:HATCH_TEST_ARGS:} {args}" +] run-cov = [ - "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}", + "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}" ] cov-combine = [] @@ -184,22 +204,17 @@ ignore_missing_imports = true [tool.ruff] line-length = 120 -include = [ - "examples/**/*.py", - "src/**/*.py", - "tests/**/*.py", - "tests-integ/**/*.py", -] +include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"] [tool.ruff.lint] select = [ - "B", # flake8-bugbear - "D", # pydocstyle - "E", # pycodestyle - "F", # pyflakes - "G", # logging format - "I", # isort - "LOG", # logging + "B", # flake8-bugbear + "D", # pydocstyle + "E", # pycodestyle + "F", # pyflakes + "G", # logging format + "I", # isort + "LOG", # logging ] [tool.ruff.lint.per-file-ignores] @@ -209,7 +224,9 @@ select = [ convention = "google" [tool.pytest.ini_options] -testpaths = ["tests"] +testpaths = [ + "tests" +] asyncio_default_fixture_loop_scope = "function" [tool.coverage.run] @@ -232,47 +249,19 @@ output = "build/coverage/coverage.xml" name = "cz_conventional_commits" tag_format = "v$version" bump_message = "chore(release): bump version $current_version -> $new_version" -version_files = ["pyproject.toml:version"] +version_files = [ + "pyproject.toml:version", +] update_changelog_on_bump = true style = [ - [ - "qmark", - "fg:#ff9d00 bold", - ], - [ - "question", - "bold", - ], - [ - "answer", - "fg:#ff9d00 bold", - ], - [ - "pointer", - "fg:#ff9d00 bold", - ], - [ - "highlighted", - "fg:#ff9d00 bold", - ], - [ - "selected", - "fg:#cc5454", - ], - [ - "separator", - "fg:#cc5454", - ], - [ - "instruction", - "", - ], - [ - "text", - "", - ], - [ - "disabled", - "fg:#858585 italic", - ], -] + ["qmark", "fg:#ff9d00 bold"], + ["question", "bold"], + ["answer", "fg:#ff9d00 bold"], + ["pointer", "fg:#ff9d00 bold"], + ["highlighted", "fg:#ff9d00 bold"], + ["selected", "fg:#cc5454"], + ["separator", "fg:#cc5454"], + ["instruction", ""], + ["text", ""], + ["disabled", "fg:#858585 italic"] +] \ No newline at end of file diff --git a/tests-integ/test_model_ollama.py b/tests-integ/test_model_ollama.py index 74000943d..38b46821d 100644 --- a/tests-integ/test_model_ollama.py +++ b/tests-integ/test_model_ollama.py @@ -6,6 +6,13 @@ from strands.models.ollama import OllamaModel +def is_server_available() -> bool: + try: + return requests.get("http://localhost:11434").ok + except requests.exceptions.ConnectionError: + return False + + @pytest.fixture def model(): return OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b") @@ -16,22 +23,20 @@ def agent(model): return Agent(model=model) -@pytest.mark.skipif( - not requests.get("http://localhost:11434/api/health").ok, - reason="Local Ollama endpoint not available at localhost:11434", -) +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") def test_agent(agent): result = agent("Say 'hello world' with no other text") - assert isinstance(result, str) + assert isinstance(result.message["content"][0]["text"], str) -@pytest.mark.skipif( - not requests.get("http://localhost:11434/api/health").ok, - reason="Local Ollama endpoint not available at localhost:11434", -) +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") def test_structured_output(agent): class Weather(BaseModel): - """Extract the time and weather from the response with the exact strings.""" + """Extract the time and weather. + + Time format: HH:MM + Weather: sunny, cloudy, rainy, etc. + """ time: str weather: str diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index e5b6a7d6c..2e354b831 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -40,67 +40,83 @@ class ListOfUsersWithPlanet(BaseModel): def test_convert_pydantic_to_tool_spec_basic(): tool_spec = convert_pydantic_to_tool_spec(User) - # Check basic structure - assert tool_spec["name"] == "User" - assert tool_spec["description"] == "User model with name and age." - assert "inputSchema" in tool_spec - assert "json" in tool_spec["inputSchema"] - - # Check schema properties - schema = tool_spec["inputSchema"]["json"] - assert schema["type"] == "object" - assert schema["title"] == "User" - - # Check field properties - assert "name" in schema["properties"] - assert schema["properties"]["name"]["description"] == "The name of the user" - assert schema["properties"]["name"]["type"] == "string" - - assert "age" in schema["properties"] - assert schema["properties"]["age"]["description"] == "The age of the user" - assert schema["properties"]["age"]["type"] == "integer" - - # Check required fields - assert "required" in schema - assert "name" in schema["required"] - assert "age" in schema["required"] - - # check validation - assert schema["properties"]["age"]["minimum"] == 18 - assert schema["properties"]["age"]["maximum"] == 100 + expected_spec = { + "name": "User", + "description": "User model with name and age.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + }, + "title": "User", + "description": "User model with name and age.", + "required": ["name", "age"], + } + }, + } # Verify we can construct a valid ToolSpec tool_spec_obj = ToolSpec(**tool_spec) assert tool_spec_obj is not None + assert tool_spec == expected_spec def test_convert_pydantic_to_tool_spec_complex(): tool_spec = convert_pydantic_to_tool_spec(ListOfUsersWithPlanet) - # Assert expected properties are present in the tool spec - assert tool_spec["name"] == "ListOfUsersWithPlanet" - assert tool_spec["description"] == "List of users model with planet." - assert "inputSchema" in tool_spec - assert "json" in tool_spec["inputSchema"] - - # Check the schema properties - schema = tool_spec["inputSchema"]["json"] - assert schema["type"] == "object" - assert "users" in schema["properties"] - assert schema["properties"]["users"]["type"] == "array" - assert schema["properties"]["users"]["items"]["type"] == "object" - assert schema["properties"]["users"]["items"]["properties"]["name"]["type"] == "string" - assert schema["properties"]["users"]["items"]["properties"]["age"]["type"] == "integer" - assert schema["properties"]["users"]["items"]["properties"]["planet"]["type"] == "string" - assert schema["properties"]["users"]["items"]["properties"]["planet"]["enum"] == ["Earth", "Mars"] - - # Check the list field properties - assert schema["properties"]["users"]["minItems"] == 2 - assert schema["properties"]["users"]["maxItems"] == 3 - - # Verify the required fields - assert "required" in schema - assert "users" in schema["required"] + expected_spec = { + "name": "ListOfUsersWithPlanet", + "description": "List of users model with planet.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "users": { + "description": "The users", + "items": { + "description": "User with planet.", + "title": "UserWithPlanet", + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + "maxItems": 3, + "minItems": 2, + "title": "Users", + "type": "array", + } + }, + "title": "ListOfUsersWithPlanet", + "description": "List of users model with planet.", + "required": ["users"], + } + }, + } + + assert tool_spec == expected_spec # Verify we can construct a valid ToolSpec tool_spec_obj = ToolSpec(**tool_spec) @@ -110,27 +126,68 @@ def test_convert_pydantic_to_tool_spec_complex(): def test_convert_pydantic_to_tool_spec_multiple_same_type(): tool_spec = convert_pydantic_to_tool_spec(TwoUsersWithPlanet) - # Verify the schema structure - assert tool_spec["name"] == "TwoUsersWithPlanet" - assert "user1" in tool_spec["inputSchema"]["json"]["properties"] - assert "user2" in tool_spec["inputSchema"]["json"]["properties"] + expected_spec = { + "name": "TwoUsersWithPlanet", + "description": "Two users model with planet.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "user1": { + "type": "object", + "description": "The first user", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + "user2": { + "type": ["object", "null"], + "description": "The second user", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + }, + "title": "TwoUsersWithPlanet", + "description": "Two users model with planet.", + "required": ["user1"], + } + }, + } + + assert tool_spec == expected_spec - # Verify both employment fields have the same structure - primary = tool_spec["inputSchema"]["json"]["properties"]["user1"] - secondary = tool_spec["inputSchema"]["json"]["properties"]["user2"] - - assert primary["type"] == "object" - assert secondary["type"] == ["object", "null"] - - assert "properties" in primary - assert "name" in primary["properties"] - assert "age" in primary["properties"] - assert "planet" in primary["properties"] - - assert "properties" in secondary - assert "name" in secondary["properties"] - assert "age" in secondary["properties"] - assert "planet" in secondary["properties"] + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None def test_convert_pydantic_with_missing_refs():