Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import random
from concurrent.futures import ThreadPoolExecutor
from threading import Thread
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union, cast
from uuid import uuid4

from opentelemetry import trace
Expand Down Expand Up @@ -423,7 +423,12 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
messages.append({"role": "user", "content": [{"text": prompt}]})

# get the structured output from the model
return self.model.structured_output(output_model, messages, self.callback_handler)
events = self.model.structured_output(output_model, messages)
for event in events:
if "callback" in event:
self.callback_handler(**cast(dict, event["callback"]))

return event["output"]

async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.
Expand Down
21 changes: 10 additions & 11 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
import json
import logging
import mimetypes
from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast
from typing import Any, Generator, Iterable, Optional, Type, TypedDict, TypeVar, Union, cast

import anthropic
from pydantic import BaseModel
from typing_extensions import Required, Unpack, override

from ..event_loop.streaming import process_stream
from ..handlers.callback_handler import PrintingCallbackHandler
from ..tools import convert_pydantic_to_tool_spec
from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
Expand Down Expand Up @@ -378,24 +377,24 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
self, output_model: Type[T], prompt: Messages
) -> Generator[dict[str, Union[T, Any]], None, None]:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.

Yields:
Model events with the last being the structured output.
"""
callback_handler = callback_handler or PrintingCallbackHandler()
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.converse(messages=prompt, tool_specs=[tool_spec])
for event in process_stream(response, prompt):
if "callback" in event:
callback_handler(**event["callback"])
else:
stop_reason, messages, _, _ = event["stop"]
yield event

stop_reason, messages, _, _ = event["stop"]

if stop_reason != "tool_use":
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")
Expand All @@ -413,4 +412,4 @@ def structured_output(
if output_response is None:
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")

return output_model(**output_response)
yield {"output": output_model(**output_response)}
21 changes: 10 additions & 11 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import logging
import os
from typing import Any, Callable, Iterable, List, Literal, Optional, Type, TypeVar, cast
from typing import Any, Generator, Iterable, List, Literal, Optional, Type, TypeVar, Union, cast

import boto3
from botocore.config import Config as BotocoreConfig
Expand All @@ -15,7 +15,6 @@
from typing_extensions import TypedDict, Unpack, override

from ..event_loop.streaming import process_stream
from ..handlers.callback_handler import PrintingCallbackHandler
from ..tools import convert_pydantic_to_tool_spec
from ..types.content import Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
Expand Down Expand Up @@ -521,24 +520,24 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool:

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
self, output_model: Type[T], prompt: Messages
) -> Generator[dict[str, Union[T, Any]], None, None]:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.

Yields:
Model events with the last being the structured output.
"""
callback_handler = callback_handler or PrintingCallbackHandler()
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.converse(messages=prompt, tool_specs=[tool_spec])
for event in process_stream(response, prompt):
if "callback" in event:
callback_handler(**event["callback"])
else:
stop_reason, messages, _, _ = event["stop"]
yield event

stop_reason, messages, _, _ = event["stop"]

if stop_reason != "tool_use":
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
Expand All @@ -556,4 +555,4 @@ def structured_output(
if output_response is None:
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")

return output_model(**output_response)
yield {"output": output_model(**output_response)}
12 changes: 7 additions & 5 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import json
import logging
from typing import Any, Callable, Optional, Type, TypedDict, TypeVar, cast
from typing import Any, Generator, Optional, Type, TypedDict, TypeVar, Union, cast

import litellm
from litellm.utils import supports_response_schema
Expand Down Expand Up @@ -105,15 +105,16 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
self, output_model: Type[T], prompt: Messages
) -> Generator[dict[str, Union[T, Any]], None, None]:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.

Yields:
Model events with the last being the structured output.
"""
# The LiteLLM `Client` inits with Chat().
# Chat() inits with self.completions
Expand All @@ -136,7 +137,8 @@ def structured_output(
# Parse the tool call content as JSON
tool_call_data = json.loads(choice.message.content)
# Instantiate the output model with the parsed data
return output_model(**tool_call_data)
yield {"output": output_model(**tool_call_data)}
return
except (json.JSONDecodeError, TypeError, ValueError) as e:
raise ValueError(f"Failed to parse or load content into model: {e}") from e

Expand Down
10 changes: 6 additions & 4 deletions src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import logging
import mimetypes
from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast
from typing import Any, Generator, Iterable, Optional, Type, TypeVar, Union, cast

import llama_api_client
from llama_api_client import LlamaAPIClient
Expand Down Expand Up @@ -390,14 +390,16 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
self, output_model: Type[T], prompt: Messages
) -> Generator[dict[str, Union[T, Any]], None, None]:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.

Yields:
Model events with the last being the structured output.

Raises:
NotImplementedError: Structured output is not currently supported for LlamaAPI models.
Expand Down
12 changes: 7 additions & 5 deletions src/strands/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import json
import logging
from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast
from typing import Any, Generator, Iterable, Optional, Type, TypeVar, Union, cast

from ollama import Client as OllamaClient
from pydantic import BaseModel
Expand Down Expand Up @@ -316,14 +316,16 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
self, output_model: Type[T], prompt: Messages
) -> Generator[dict[str, Union[T, Any]], None, None]:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.

Yields:
Model events with the last being the structured output.
"""
formatted_request = self.format_request(messages=prompt)
formatted_request["format"] = output_model.model_json_schema()
Expand All @@ -332,6 +334,6 @@ def structured_output(

try:
content = response.message.content.strip()
return output_model.model_validate_json(content)
yield {"output": output_model.model_validate_json(content)}
except Exception as e:
raise ValueError(f"Failed to parse or load content into model: {e}") from e
12 changes: 7 additions & 5 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import logging
from typing import Any, Callable, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, cast
from typing import Any, Generator, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast

import openai
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
Expand Down Expand Up @@ -133,14 +133,16 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
self, output_model: Type[T], prompt: Messages
) -> Generator[dict[str, Union[T, Any]], None, None]:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.

Yields:
Model events with the last being the structured output.
"""
response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore
model=self.get_config()["model_id"],
Expand All @@ -159,6 +161,6 @@ def structured_output(
break

if parsed:
return parsed
yield {"output": parsed}
else:
raise ValueError("No valid tool use or tool use input was found in the OpenAI response.")
11 changes: 5 additions & 6 deletions src/strands/types/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import abc
import logging
from typing import Any, Callable, Iterable, Optional, Type, TypeVar
from typing import Any, Generator, Iterable, Optional, Type, TypeVar, Union

from pydantic import BaseModel

Expand Down Expand Up @@ -45,17 +45,16 @@ def get_config(self) -> Any:
@abc.abstractmethod
# pragma: no cover
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
self, output_model: Type[T], prompt: Messages
) -> Generator[dict[str, Union[T, Any]], None, None]:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.

Returns:
The structured output as a serialized instance of the output model.
Yields:
Model events with the last being the structured output.

Raises:
ValidationException: The response format from the model does not match the output_model
Expand Down
12 changes: 7 additions & 5 deletions src/strands/types/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import json
import logging
import mimetypes
from typing import Any, Callable, Optional, Type, TypeVar, cast
from typing import Any, Generator, Optional, Type, TypeVar, Union, cast

from pydantic import BaseModel
from typing_extensions import override
Expand Down Expand Up @@ -295,13 +295,15 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
self, output_model: Type[T], prompt: Messages
) -> Generator[dict[str, Union[T, Any]], None, None]:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.

Yields:
Model events with the last being the structured output.
"""
return output_model()
yield {"output": output_model()}
6 changes: 2 additions & 4 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,17 +898,15 @@ class User(BaseModel):
def test_agent_method_structured_output(agent):
# Mock the structured_output method on the model
expected_user = User(name="Jane Doe", age=30, email="[email protected]")
agent.model.structured_output = unittest.mock.Mock(return_value=expected_user)
agent.model.structured_output = unittest.mock.Mock(return_value=[{"output": expected_user}])

prompt = "Jane Doe is 30 years old and her email is [email protected]"

result = agent.structured_output(User, prompt)
assert result == expected_user

# Verify the model's structured_output was called with correct arguments
agent.model.structured_output.assert_called_once_with(
User, [{"role": "user", "content": [{"text": prompt}]}], agent.callback_handler
)
agent.model.structured_output.assert_called_once_with(User, [{"role": "user", "content": [{"text": prompt}]}])


@pytest.mark.asyncio
Expand Down
Loading