diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index fd857707c..ac5ddf40a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -41,6 +41,7 @@ ConversationManager, SlidingWindowConversationManager, ) +from .state import AgentState logger = logging.getLogger(__name__) @@ -223,6 +224,7 @@ def __init__( *, name: Optional[str] = None, description: Optional[str] = None, + state: Optional[Union[AgentState, dict]] = None, ): """Initialize the Agent with the specified configuration. @@ -259,6 +261,8 @@ def __init__( Defaults to None. description: description of what the Agent does Defaults to None. + state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. + Defaults to an empty AgentState object. Raises: ValueError: If max_parallel_tools is less than 1. @@ -319,6 +323,18 @@ def __init__( # Initialize tracer instance (no-op if not configured) self.tracer = get_tracer() self.trace_span: Optional[trace.Span] = None + + # Initialize agent state management + if state is not None: + if isinstance(state, dict): + self.state = AgentState(state) + elif isinstance(state, AgentState): + self.state = state + else: + raise ValueError("state must be an AgentState object or a dict") + else: + self.state = AgentState() + self.tool_caller = Agent.ToolCaller(self) self.name = name self.description = description diff --git a/src/strands/agent/state.py b/src/strands/agent/state.py new file mode 100644 index 000000000..36120b8ff --- /dev/null +++ b/src/strands/agent/state.py @@ -0,0 +1,97 @@ +"""Agent state management.""" + +import copy +import json +from typing import Any, Dict, Optional + + +class AgentState: + """Represents an Agent's stateful information outside of context provided to a model. + + Provides a key-value store for agent state with JSON serialization validation and persistence support. + Key features: + - JSON serialization validation on assignment + - Get/set/delete operations + """ + + def __init__(self, initial_state: Optional[Dict[str, Any]] = None): + """Initialize AgentState.""" + self._state: Dict[str, Dict[str, Any]] + if initial_state: + self._validate_json_serializable(initial_state) + self._state = copy.deepcopy(initial_state) + else: + self._state = {} + + def set(self, key: str, value: Any) -> None: + """Set a value in the state. + + Args: + key: The key to store the value under + value: The value to store (must be JSON serializable) + + Raises: + ValueError: If key is invalid, or if value is not JSON serializable + """ + self._validate_key(key) + self._validate_json_serializable(value) + + self._state[key] = copy.deepcopy(value) + + def get(self, key: Optional[str] = None) -> Any: + """Get a value or entire state. + + Args: + key: The key to retrieve (if None, returns entire state object) + + Returns: + The stored value, entire state dict, or None if not found + """ + if key is None: + return copy.deepcopy(self._state) + else: + # Return specific key + return copy.deepcopy(self._state.get(key)) + + def delete(self, key: str) -> None: + """Delete a specific key from the state. + + Args: + key: The key to delete + """ + self._validate_key(key) + + self._state.pop(key, None) + + def _validate_key(self, key: str) -> None: + """Validate that a key is valid. + + Args: + key: The key to validate + + Raises: + ValueError: If key is invalid + """ + if key is None: + raise ValueError("Key cannot be None") + if not isinstance(key, str): + raise ValueError("Key must be a string") + if not key.strip(): + raise ValueError("Key cannot be empty") + + def _validate_json_serializable(self, value: Any) -> None: + """Validate that a value is JSON serializable. + + Args: + value: The value to validate + + Raises: + ValueError: If value is not JSON serializable + """ + try: + json.dumps(value) + except (TypeError, ValueError) as e: + raise ValueError( + f"Value is not JSON serializable: {type(value).__name__}. " + f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed." + ) from e diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index c813a1a91..aeb149378 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1,5 +1,6 @@ import copy import importlib +import json import os import textwrap import threading @@ -1314,3 +1315,58 @@ def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_ kwargs = mock_event_loop_cycle.call_args[1] assert "event_loop_parent_span" in kwargs assert kwargs["event_loop_parent_span"] == mock_span + + +def test_non_dict_throws_error(): + with pytest.raises(ValueError, match="state must be an AgentState object or a dict"): + agent = Agent(state={"object", object()}) + print(agent.state) + + +def test_non_json_serializable_state_throws_error(): + with pytest.raises(ValueError, match="Value is not JSON serializable"): + agent = Agent(state={"object": object()}) + print(agent.state) + + +def test_agent_state_breaks_dict_reference(): + ref_dict = {"hello": "world"} + agent = Agent(state=ref_dict) + + # Make sure shallow object references do not affect state maintained by AgentState + ref_dict["hello"] = object() + + # This will fail if AgentState reflects the updated reference + json.dumps(agent.state.get()) + + +def test_agent_state_breaks_deep_dict_reference(): + ref_dict = {"world": "!"} + init_dict = {"hello": ref_dict} + agent = Agent(state=init_dict) + # Make sure deep reference changes do not affect state mained by AgentState + ref_dict["world"] = object() + + # This will fail if AgentState reflects the updated reference + json.dumps(agent.state.get()) + + +def test_agent_state_set_breaks_dict_reference(): + agent = Agent() + ref_dict = {"hello": "world"} + # Set should copy the input, and not maintain the reference to the original object + agent.state.set("hello", ref_dict) + ref_dict["hello"] = object() + + # This will fail if AgentState reflects the updated reference + json.dumps(agent.state.get()) + + +def test_agent_state_get_breaks_deep_dict_reference(): + agent = Agent(state={"hello": {"world": "!"}}) + # Get should not return a reference to the internal state + ref_state = agent.state.get() + ref_state["hello"]["world"] = object() + + # This will fail if AgentState reflects the updated reference + json.dumps(agent.state.get()) diff --git a/tests/strands/agent/test_agent_state.py b/tests/strands/agent/test_agent_state.py new file mode 100644 index 000000000..1921b0065 --- /dev/null +++ b/tests/strands/agent/test_agent_state.py @@ -0,0 +1,111 @@ +"""Tests for AgentState class.""" + +import pytest + +from strands.agent.state import AgentState + + +def test_set_and_get(): + """Test basic set and get operations.""" + state = AgentState() + state.set("key", "value") + assert state.get("key") == "value" + + +def test_get_nonexistent_key(): + """Test getting nonexistent key returns None.""" + state = AgentState() + assert state.get("nonexistent") is None + + +def test_get_entire_state(): + """Test getting entire state when no key specified.""" + state = AgentState() + state.set("key1", "value1") + state.set("key2", "value2") + + result = state.get() + assert result == {"key1": "value1", "key2": "value2"} + + +def test_initialize_and_get_entire_state(): + """Test getting entire state when no key specified.""" + state = AgentState({"key1": "value1", "key2": "value2"}) + + result = state.get() + assert result == {"key1": "value1", "key2": "value2"} + + +def test_initialize_with_error(): + with pytest.raises(ValueError, match="not JSON serializable"): + AgentState({"object", object()}) + + +def test_delete(): + """Test deleting keys.""" + state = AgentState() + state.set("key1", "value1") + state.set("key2", "value2") + + state.delete("key1") + + assert state.get("key1") is None + assert state.get("key2") == "value2" + + +def test_delete_nonexistent_key(): + """Test deleting nonexistent key doesn't raise error.""" + state = AgentState() + state.delete("nonexistent") # Should not raise + + +def test_json_serializable_values(): + """Test that only JSON-serializable values are accepted.""" + state = AgentState() + + # Valid JSON types + state.set("string", "test") + state.set("int", 42) + state.set("bool", True) + state.set("list", [1, 2, 3]) + state.set("dict", {"nested": "value"}) + state.set("null", None) + + # Invalid JSON types should raise ValueError + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("function", lambda x: x) + + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("object", object()) + + +def test_key_validation(): + """Test key validation for set and delete operations.""" + state = AgentState() + + # Invalid keys for set + with pytest.raises(ValueError, match="Key cannot be None"): + state.set(None, "value") + + with pytest.raises(ValueError, match="Key cannot be empty"): + state.set("", "value") + + with pytest.raises(ValueError, match="Key must be a string"): + state.set(123, "value") + + # Invalid keys for delete + with pytest.raises(ValueError, match="Key cannot be None"): + state.delete(None) + + with pytest.raises(ValueError, match="Key cannot be empty"): + state.delete("") + + +def test_initial_state(): + """Test initialization with initial state.""" + initial = {"key1": "value1", "key2": "value2"} + state = AgentState(initial_state=initial) + + assert state.get("key1") == "value1" + assert state.get("key2") == "value2" + assert state.get() == initial diff --git a/tests/strands/mocked_model_provider/__init__.py b/tests/strands/mocked_model_provider/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/mocked_model_provider/mocked_model_provider.py b/tests/strands/mocked_model_provider/mocked_model_provider.py new file mode 100644 index 000000000..f89d56202 --- /dev/null +++ b/tests/strands/mocked_model_provider/mocked_model_provider.py @@ -0,0 +1,73 @@ +import json +from typing import Any, Callable, Iterable, Optional, Type, TypeVar + +from pydantic import BaseModel + +from strands.types.content import Message, Messages +from strands.types.event_loop import StopReason +from strands.types.models.model import Model +from strands.types.streaming import StreamEvent +from strands.types.tools import ToolSpec + +T = TypeVar("T", bound=BaseModel) + + +class MockedModelProvider(Model): + """A mock implementation of the Model interface for testing purposes. + + This class simulates a model provider by returning pre-defined agent responses + in sequence. It implements the Model interface methods and provides functionality + to stream mock responses as events. + """ + + def __init__(self, agent_responses: Messages): + self.agent_responses = agent_responses + self.index = 0 + + def format_chunk(self, event: Any) -> StreamEvent: + return event + + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> Any: + return None + + def get_config(self) -> Any: + pass + + def update_config(self, **model_config: Any) -> None: + pass + + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + pass + + def stream(self, request: Any) -> Iterable[Any]: + yield from self.map_agent_message_to_events(self.agent_responses[self.index]) + self.index += 1 + + def map_agent_message_to_events(self, agent_message: Message) -> Iterable[dict[str, Any]]: + stop_reason: StopReason = "end_turn" + yield {"messageStart": {"role": "assistant"}} + for content in agent_message["content"]: + if "text" in content: + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": content["text"]}}} + yield {"contentBlockStop": {}} + if "toolUse" in content: + stop_reason = "tool_use" + yield { + "contentBlockStart": { + "start": { + "toolUse": { + "name": content["toolUse"]["name"], + "toolUseId": content["toolUse"]["toolUseId"], + } + } + } + } + yield {"contentBlockDelta": {"delta": {"tool_use": {"input": json.dumps(content["toolUse"]["input"])}}}} + yield {"contentBlockStop": {}} + + yield {"messageStop": {"stopReason": stop_reason}} diff --git a/tests/strands/mocked_model_provider/test_agent_state_updates.py b/tests/strands/mocked_model_provider/test_agent_state_updates.py new file mode 100644 index 000000000..c15c61961 --- /dev/null +++ b/tests/strands/mocked_model_provider/test_agent_state_updates.py @@ -0,0 +1,36 @@ +from strands.agent.agent import Agent +from strands.tools.decorator import tool +from strands.types.content import Messages + +from .mocked_model_provider import MockedModelProvider + + +@tool +def update_state(agent: Agent): + agent.state.set("hello", "world") + agent.state.set("foo", "baz") + + +def test_agent_state_update_from_tool(): + agent_messages: Messages = [ + { + "role": "assistant", + "content": [{"toolUse": {"name": "update_state", "toolUseId": "123", "input": {}}}], + }, + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, + ] + mocked_model_provider = MockedModelProvider(agent_messages) + + agent = Agent( + model=mocked_model_provider, + tools=[update_state], + state={"foo": "bar"}, + ) + + assert agent.state.get("hello") is None + assert agent.state.get("foo") == "bar" + + agent("Invoke Mocked!") + + assert agent.state.get("hello") == "world" + assert agent.state.get("foo") == "baz"