From 9b39d736e43ba8be1a3a575b9bb3610c9107f63e Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Mon, 25 Aug 2025 13:20:03 -0400 Subject: [PATCH] fix: Add AgentInput TypeAlias --- CONTRIBUTING.md | 2 +- src/strands/agent/agent.py | 32 ++++++++++++------- src/strands/session/file_session_manager.py | 2 +- src/strands/session/s3_session_manager.py | 4 +-- tests/strands/agent/test_agent.py | 2 +- .../session/test_file_session_manager.py | 2 +- .../session/test_s3_session_manager.py | 2 +- 7 files changed, 28 insertions(+), 18 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index add4825fd..93970ed64 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -49,7 +49,7 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as Alternatively, install development dependencies in a manually created virtual environment: ```bash - pip install -e ".[dev]" && pip install -e ".[litellm]" + pip install -e ".[all]" ``` diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 654b8edce..66099cb1d 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -14,7 +14,19 @@ import logging import random from concurrent.futures import ThreadPoolExecutor -from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Callable, + Mapping, + Optional, + Type, + TypeAlias, + TypeVar, + Union, + cast, +) from opentelemetry import trace as trace_api from pydantic import BaseModel @@ -55,6 +67,8 @@ # TypeVar for generic structured output T = TypeVar("T", bound=BaseModel) +AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None + # Sentinel class and object to distinguish between explicit None and default parameter value class _DefaultCallbackHandlerSentinel: @@ -361,7 +375,7 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) - def __call__(self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any) -> AgentResult: + def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface with multiple input patterns: @@ -394,9 +408,7 @@ def execute() -> AgentResult: future = executor.submit(execute) return future.result() - async def invoke_async( - self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any - ) -> AgentResult: + async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface with multiple input patterns: @@ -427,7 +439,7 @@ async def invoke_async( return cast(AgentResult, event["result"]) - def structured_output(self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None) -> T: + def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -456,9 +468,7 @@ def execute() -> T: future = executor.submit(execute) return future.result() - async def structured_output_async( - self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None - ) -> T: + async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -517,7 +527,7 @@ async def structured_output_async( async def stream_async( self, - prompt: str | list[ContentBlock] | Messages | None = None, + prompt: AgentInput = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -657,7 +667,7 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A async for event in events: yield event - def _convert_prompt_to_messages(self, prompt: str | list[ContentBlock] | Messages | None) -> Messages: + def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: messages: Messages | None = None if prompt is not None: if isinstance(prompt, str): diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 14e71d07c..491f7ad60 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -92,7 +92,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> """ if not isinstance(message_id, int): raise ValueError(f"message_id=<{message_id}> | message id must be an integer") - + agent_path = self._get_agent_path(session_id, agent_id) return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json") diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index da1735e35..c6ce28d80 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -116,13 +116,13 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> Returns: The key for the message - + Raises: ValueError: If message_id is not an integer. """ if not isinstance(message_id, int): raise ValueError(f"message_id=<{message_id}> | message id must be an integer") - + agent_path = self._get_agent_path(session_id, agent_id) return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 01d8f977e..67ea5940a 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1830,6 +1830,7 @@ def test_agent_with_list_of_message_and_content_block(): with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."): agent([{"role": "user", "content": [{"text": "hello"}]}, {"text", "hello"}]) + def test_agent_tool_call_parameter_filtering_integration(mock_randint): """Test that tool calls properly filter parameters in message recording.""" mock_randint.return_value = 42 @@ -1861,4 +1862,3 @@ def test_tool(action: str) -> str: assert '"action": "test_value"' in tool_call_text assert '"agent"' not in tool_call_text assert '"extra_param"' not in tool_call_text - diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index 036591924..f124ddf58 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -396,7 +396,7 @@ def test__get_agent_path_invalid_agent_id(agent_id, file_manager): "message_id", [ "../../../secret", - "../../attack", + "../../attack", "../escape", "path/traversal", "not_an_int", diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 50fb303f7..c4d6a0154 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -362,7 +362,7 @@ def test__get_agent_path_invalid_agent_id(agent_id, s3_manager): "message_id", [ "../../../secret", - "../../attack", + "../../attack", "../escape", "path/traversal", "not_an_int",