Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as

Alternatively, install development dependencies in a manually created virtual environment:
```bash
pip install -e ".[dev]" && pip install -e ".[litellm]"
pip install -e ".[all]"
```


Expand Down
32 changes: 21 additions & 11 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,19 @@
import logging
import random
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Mapping,
Optional,
Type,
TypeAlias,
TypeVar,
Union,
cast,
)

from opentelemetry import trace as trace_api
from pydantic import BaseModel
Expand Down Expand Up @@ -55,6 +67,8 @@
# TypeVar for generic structured output
T = TypeVar("T", bound=BaseModel)

AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None


# Sentinel class and object to distinguish between explicit None and default parameter value
class _DefaultCallbackHandlerSentinel:
Expand Down Expand Up @@ -361,7 +375,7 @@ def tool_names(self) -> list[str]:
all_tools = self.tool_registry.get_all_tools_config()
return list(all_tools.keys())

def __call__(self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any) -> AgentResult:
def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.

This method implements the conversational interface with multiple input patterns:
Expand Down Expand Up @@ -394,9 +408,7 @@ def execute() -> AgentResult:
future = executor.submit(execute)
return future.result()

async def invoke_async(
self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any
) -> AgentResult:
async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.

This method implements the conversational interface with multiple input patterns:
Expand Down Expand Up @@ -427,7 +439,7 @@ async def invoke_async(

return cast(AgentResult, event["result"])

def structured_output(self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None) -> T:
def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T:
"""This method allows you to get structured output from the agent.

If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
Expand Down Expand Up @@ -456,9 +468,7 @@ def execute() -> T:
future = executor.submit(execute)
return future.result()

async def structured_output_async(
self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None
) -> T:
async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T:
"""This method allows you to get structured output from the agent.

If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
Expand Down Expand Up @@ -517,7 +527,7 @@ async def structured_output_async(

async def stream_async(
self,
prompt: str | list[ContentBlock] | Messages | None = None,
prompt: AgentInput = None,
**kwargs: Any,
) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.
Expand Down Expand Up @@ -657,7 +667,7 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A
async for event in events:
yield event

def _convert_prompt_to_messages(self, prompt: str | list[ContentBlock] | Messages | None) -> Messages:
def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
messages: Messages | None = None
if prompt is not None:
if isinstance(prompt, str):
Expand Down
2 changes: 1 addition & 1 deletion src/strands/session/file_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) ->
"""
if not isinstance(message_id, int):
raise ValueError(f"message_id=<{message_id}> | message id must be an integer")

agent_path = self._get_agent_path(session_id, agent_id)
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json")

Expand Down
4 changes: 2 additions & 2 deletions src/strands/session/s3_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) ->

Returns:
The key for the message

Raises:
ValueError: If message_id is not an integer.
"""
if not isinstance(message_id, int):
raise ValueError(f"message_id=<{message_id}> | message id must be an integer")

agent_path = self._get_agent_path(session_id, agent_id)
return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json"

Expand Down
2 changes: 1 addition & 1 deletion tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,6 +1830,7 @@ def test_agent_with_list_of_message_and_content_block():
with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."):
agent([{"role": "user", "content": [{"text": "hello"}]}, {"text", "hello"}])


def test_agent_tool_call_parameter_filtering_integration(mock_randint):
"""Test that tool calls properly filter parameters in message recording."""
mock_randint.return_value = 42
Expand Down Expand Up @@ -1861,4 +1862,3 @@ def test_tool(action: str) -> str:
assert '"action": "test_value"' in tool_call_text
assert '"agent"' not in tool_call_text
assert '"extra_param"' not in tool_call_text

2 changes: 1 addition & 1 deletion tests/strands/session/test_file_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def test__get_agent_path_invalid_agent_id(agent_id, file_manager):
"message_id",
[
"../../../secret",
"../../attack",
"../../attack",
"../escape",
"path/traversal",
"not_an_int",
Expand Down
2 changes: 1 addition & 1 deletion tests/strands/session/test_s3_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test__get_agent_path_invalid_agent_id(agent_id, s3_manager):
"message_id",
[
"../../../secret",
"../../attack",
"../../attack",
"../escape",
"path/traversal",
"not_an_int",
Expand Down
Loading