From d0f627644b4e3be74b237400f3c57b8a2bfa925a Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 18 Jun 2025 22:10:11 +0000 Subject: [PATCH 1/6] feat: Add Agent State --- src/strands/agent/agent.py | 8 ++ src/strands/agent/state.py | 96 +++++++++++++++ tests/strands/agent/test_agent_state.py | 111 ++++++++++++++++++ .../strands/mocked_model_provider/__init__.py | 0 .../mocked_model_provider.py | 73 ++++++++++++ .../test_agent_state_updates.py | 29 +++++ 6 files changed, 317 insertions(+) create mode 100644 src/strands/agent/state.py create mode 100644 tests/strands/agent/test_agent_state.py create mode 100644 tests/strands/mocked_model_provider/__init__.py create mode 100644 tests/strands/mocked_model_provider/mocked_model_provider.py create mode 100644 tests/strands/mocked_model_provider/test_agent_state_updates.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index fd857707c..f178ee52a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -41,6 +41,7 @@ ConversationManager, SlidingWindowConversationManager, ) +from .state import AgentState logger = logging.getLogger(__name__) @@ -223,6 +224,7 @@ def __init__( *, name: Optional[str] = None, description: Optional[str] = None, + state: Optional[AgentState] = None, ): """Initialize the Agent with the specified configuration. @@ -259,6 +261,8 @@ def __init__( Defaults to None. description: description of what the Agent does Defaults to None. + state: stateful information for the agent + Defaults to an empty AgentState object. Raises: ValueError: If max_parallel_tools is less than 1. @@ -319,6 +323,10 @@ def __init__( # Initialize tracer instance (no-op if not configured) self.tracer = get_tracer() self.trace_span: Optional[trace.Span] = None + + # Initialize agent state management + self.state = state or AgentState() + self.tool_caller = Agent.ToolCaller(self) self.name = name self.description = description diff --git a/src/strands/agent/state.py b/src/strands/agent/state.py new file mode 100644 index 000000000..e2fbf425a --- /dev/null +++ b/src/strands/agent/state.py @@ -0,0 +1,96 @@ +"""Agent state management.""" + +import json +from typing import Any, Dict, Optional + + +class AgentState: + """Represents an Agent's stateful information outside of context provided to a model. + + Provides a key-value store for agent state with JSON serialization validation and persistence support. + Key features: + - JSON serialization validation on assignment + - Get/set/delete operations + """ + + def __init__(self, initial_state: Optional[Dict[str, Dict[str, Any]]] = None): + """Initialize AgentState with default and SDK namespaces.""" + self._state: Dict[str, Dict[str, Any]] + if initial_state: + self._validate_json_serializable(initial_state) + self._state = initial_state.copy() + else: + self._state = {} + + def set(self, key: str, value: Any) -> None: + """Set a value in the state. + + Args: + key: The key to store the value under + value: The value to store (must be JSON serializable) + + Raises: + ValueError: If key is invalid, or if value is not JSON serializable + """ + self._validate_key(key) + self._validate_json_serializable(value) + + self._state[key] = value + + def get(self, key: Optional[str] = None) -> Any: + """Get a value or entire state. + + Args: + key: The key to retrieve (if None, returns entire state object) + + Returns: + The stored value, entire state dict, or None if not found + """ + if key is None: + return self._state.copy() + else: + # Return specific key + return self._state.get(key) + + def delete(self, key: str) -> None: + """Delete a specific key from the state. + + Args: + key: The key to delete + """ + self._validate_key(key) + + self._state.pop(key, None) + + def _validate_key(self, key: str) -> None: + """Validate that a key is valid. + + Args: + key: The key to validate + + Raises: + ValueError: If key is invalid + """ + if key is None: + raise ValueError("Key cannot be None") + if not isinstance(key, str): + raise ValueError("Key must be a string") + if not key.strip(): + raise ValueError("Key cannot be empty") + + def _validate_json_serializable(self, value: Any) -> None: + """Validate that a value is JSON serializable. + + Args: + value: The value to validate + + Raises: + ValueError: If value is not JSON serializable + """ + try: + json.dumps(value) + except (TypeError, ValueError) as e: + raise ValueError( + f"Value is not JSON serializable: {type(value).__name__}. " + f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed." + ) from e diff --git a/tests/strands/agent/test_agent_state.py b/tests/strands/agent/test_agent_state.py new file mode 100644 index 000000000..1921b0065 --- /dev/null +++ b/tests/strands/agent/test_agent_state.py @@ -0,0 +1,111 @@ +"""Tests for AgentState class.""" + +import pytest + +from strands.agent.state import AgentState + + +def test_set_and_get(): + """Test basic set and get operations.""" + state = AgentState() + state.set("key", "value") + assert state.get("key") == "value" + + +def test_get_nonexistent_key(): + """Test getting nonexistent key returns None.""" + state = AgentState() + assert state.get("nonexistent") is None + + +def test_get_entire_state(): + """Test getting entire state when no key specified.""" + state = AgentState() + state.set("key1", "value1") + state.set("key2", "value2") + + result = state.get() + assert result == {"key1": "value1", "key2": "value2"} + + +def test_initialize_and_get_entire_state(): + """Test getting entire state when no key specified.""" + state = AgentState({"key1": "value1", "key2": "value2"}) + + result = state.get() + assert result == {"key1": "value1", "key2": "value2"} + + +def test_initialize_with_error(): + with pytest.raises(ValueError, match="not JSON serializable"): + AgentState({"object", object()}) + + +def test_delete(): + """Test deleting keys.""" + state = AgentState() + state.set("key1", "value1") + state.set("key2", "value2") + + state.delete("key1") + + assert state.get("key1") is None + assert state.get("key2") == "value2" + + +def test_delete_nonexistent_key(): + """Test deleting nonexistent key doesn't raise error.""" + state = AgentState() + state.delete("nonexistent") # Should not raise + + +def test_json_serializable_values(): + """Test that only JSON-serializable values are accepted.""" + state = AgentState() + + # Valid JSON types + state.set("string", "test") + state.set("int", 42) + state.set("bool", True) + state.set("list", [1, 2, 3]) + state.set("dict", {"nested": "value"}) + state.set("null", None) + + # Invalid JSON types should raise ValueError + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("function", lambda x: x) + + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("object", object()) + + +def test_key_validation(): + """Test key validation for set and delete operations.""" + state = AgentState() + + # Invalid keys for set + with pytest.raises(ValueError, match="Key cannot be None"): + state.set(None, "value") + + with pytest.raises(ValueError, match="Key cannot be empty"): + state.set("", "value") + + with pytest.raises(ValueError, match="Key must be a string"): + state.set(123, "value") + + # Invalid keys for delete + with pytest.raises(ValueError, match="Key cannot be None"): + state.delete(None) + + with pytest.raises(ValueError, match="Key cannot be empty"): + state.delete("") + + +def test_initial_state(): + """Test initialization with initial state.""" + initial = {"key1": "value1", "key2": "value2"} + state = AgentState(initial_state=initial) + + assert state.get("key1") == "value1" + assert state.get("key2") == "value2" + assert state.get() == initial diff --git a/tests/strands/mocked_model_provider/__init__.py b/tests/strands/mocked_model_provider/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/mocked_model_provider/mocked_model_provider.py b/tests/strands/mocked_model_provider/mocked_model_provider.py new file mode 100644 index 000000000..f89d56202 --- /dev/null +++ b/tests/strands/mocked_model_provider/mocked_model_provider.py @@ -0,0 +1,73 @@ +import json +from typing import Any, Callable, Iterable, Optional, Type, TypeVar + +from pydantic import BaseModel + +from strands.types.content import Message, Messages +from strands.types.event_loop import StopReason +from strands.types.models.model import Model +from strands.types.streaming import StreamEvent +from strands.types.tools import ToolSpec + +T = TypeVar("T", bound=BaseModel) + + +class MockedModelProvider(Model): + """A mock implementation of the Model interface for testing purposes. + + This class simulates a model provider by returning pre-defined agent responses + in sequence. It implements the Model interface methods and provides functionality + to stream mock responses as events. + """ + + def __init__(self, agent_responses: Messages): + self.agent_responses = agent_responses + self.index = 0 + + def format_chunk(self, event: Any) -> StreamEvent: + return event + + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> Any: + return None + + def get_config(self) -> Any: + pass + + def update_config(self, **model_config: Any) -> None: + pass + + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + pass + + def stream(self, request: Any) -> Iterable[Any]: + yield from self.map_agent_message_to_events(self.agent_responses[self.index]) + self.index += 1 + + def map_agent_message_to_events(self, agent_message: Message) -> Iterable[dict[str, Any]]: + stop_reason: StopReason = "end_turn" + yield {"messageStart": {"role": "assistant"}} + for content in agent_message["content"]: + if "text" in content: + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": content["text"]}}} + yield {"contentBlockStop": {}} + if "toolUse" in content: + stop_reason = "tool_use" + yield { + "contentBlockStart": { + "start": { + "toolUse": { + "name": content["toolUse"]["name"], + "toolUseId": content["toolUse"]["toolUseId"], + } + } + } + } + yield {"contentBlockDelta": {"delta": {"tool_use": {"input": json.dumps(content["toolUse"]["input"])}}}} + yield {"contentBlockStop": {}} + + yield {"messageStop": {"stopReason": stop_reason}} diff --git a/tests/strands/mocked_model_provider/test_agent_state_updates.py b/tests/strands/mocked_model_provider/test_agent_state_updates.py new file mode 100644 index 000000000..34750db1e --- /dev/null +++ b/tests/strands/mocked_model_provider/test_agent_state_updates.py @@ -0,0 +1,29 @@ +from strands.agent.agent import Agent +from strands.tools.decorator import tool +from strands.types.content import Messages + +from .mocked_model_provider import MockedModelProvider + + +@tool +def update_state(agent: Agent): + agent.state.set("hello", "world") + + +def test_agent_state_update_from_tool(): + agent_messages: Messages = [ + { + "role": "assistant", + "content": [{"toolUse": {"name": "update_state", "toolUseId": "123", "input": {}}}], + }, + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, + ] + mocked_model_provider = MockedModelProvider(agent_messages) + + agent = Agent(model=mocked_model_provider, tools=[update_state]) + + assert agent.state.get("hello") is None + + agent("Invoke Mocked!") + + assert agent.state.get("hello") == "world" From 49f0f49949a3ab61837ce3ecd06ea5c49cb29e58 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 25 Jun 2025 17:14:13 -0400 Subject: [PATCH 2/6] Update state.py --- src/strands/agent/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/agent/state.py b/src/strands/agent/state.py index e2fbf425a..87b11feef 100644 --- a/src/strands/agent/state.py +++ b/src/strands/agent/state.py @@ -14,7 +14,7 @@ class AgentState: """ def __init__(self, initial_state: Optional[Dict[str, Dict[str, Any]]] = None): - """Initialize AgentState with default and SDK namespaces.""" + """Initialize AgentState.""" self._state: Dict[str, Dict[str, Any]] if initial_state: self._validate_json_serializable(initial_state) From 8ac6cd44920e770adf2c522e2acd29cc27a06ac9 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Thu, 26 Jun 2025 23:04:28 +0000 Subject: [PATCH 3/6] Allow dict input for state --- src/strands/agent/agent.py | 16 +++++++++++++--- tests/strands/agent/test_agent.py | 12 ++++++++++++ .../test_agent_state_updates.py | 9 ++++++++- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f178ee52a..1cfa2b731 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -224,7 +224,7 @@ def __init__( *, name: Optional[str] = None, description: Optional[str] = None, - state: Optional[AgentState] = None, + state: Optional[Union[AgentState, dict]] = None, ): """Initialize the Agent with the specified configuration. @@ -261,7 +261,7 @@ def __init__( Defaults to None. description: description of what the Agent does Defaults to None. - state: stateful information for the agent + state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. Defaults to an empty AgentState object. Raises: @@ -325,7 +325,17 @@ def __init__( self.trace_span: Optional[trace.Span] = None # Initialize agent state management - self.state = state or AgentState() + if state is not None: + if isinstance(state, dict): + self.state = AgentState(state) + elif isinstance(state, AgentState): + print("HERE!") + print(type(state)) + self.state = state + else: + raise ValueError("state must be an AgentState object or a dict") + else: + self.state = AgentState() self.tool_caller = Agent.ToolCaller(self) self.name = name diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index c813a1a91..692e7c018 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1314,3 +1314,15 @@ def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_ kwargs = mock_event_loop_cycle.call_args[1] assert "event_loop_parent_span" in kwargs assert kwargs["event_loop_parent_span"] == mock_span + + +def test_non_dict_throws_error(): + with pytest.raises(ValueError, match="state must be an AgentState object or a dict"): + agent = Agent(state={"object", object()}) + print(agent.state) + + +def test_non_json_serializable_state_throws_error(): + with pytest.raises(ValueError, match="Value is not JSON serializable"): + agent = Agent(state={"object": object()}) + print(agent.state) diff --git a/tests/strands/mocked_model_provider/test_agent_state_updates.py b/tests/strands/mocked_model_provider/test_agent_state_updates.py index 34750db1e..c15c61961 100644 --- a/tests/strands/mocked_model_provider/test_agent_state_updates.py +++ b/tests/strands/mocked_model_provider/test_agent_state_updates.py @@ -8,6 +8,7 @@ @tool def update_state(agent: Agent): agent.state.set("hello", "world") + agent.state.set("foo", "baz") def test_agent_state_update_from_tool(): @@ -20,10 +21,16 @@ def test_agent_state_update_from_tool(): ] mocked_model_provider = MockedModelProvider(agent_messages) - agent = Agent(model=mocked_model_provider, tools=[update_state]) + agent = Agent( + model=mocked_model_provider, + tools=[update_state], + state={"foo": "bar"}, + ) assert agent.state.get("hello") is None + assert agent.state.get("foo") == "bar" agent("Invoke Mocked!") assert agent.state.get("hello") == "world" + assert agent.state.get("foo") == "baz" From 9862adb76ac267b7e3b2c124e3849defc1dcc716 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 27 Jun 2025 11:15:56 -0400 Subject: [PATCH 4/6] Update src/strands/agent/agent.py Co-authored-by: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> --- src/strands/agent/agent.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 1cfa2b731..ac5ddf40a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -329,8 +329,6 @@ def __init__( if isinstance(state, dict): self.state = AgentState(state) elif isinstance(state, AgentState): - print("HERE!") - print(type(state)) self.state = state else: raise ValueError("state must be an AgentState object or a dict") From 56f53f37be1ff1a56c4802fd74c9716015b0dcee Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 27 Jun 2025 15:45:01 +0000 Subject: [PATCH 5/6] fix: deepcopy AgentState --- src/strands/agent/state.py | 11 +++++----- tests/strands/agent/test_agent.py | 35 +++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/strands/agent/state.py b/src/strands/agent/state.py index 87b11feef..36120b8ff 100644 --- a/src/strands/agent/state.py +++ b/src/strands/agent/state.py @@ -1,5 +1,6 @@ """Agent state management.""" +import copy import json from typing import Any, Dict, Optional @@ -13,12 +14,12 @@ class AgentState: - Get/set/delete operations """ - def __init__(self, initial_state: Optional[Dict[str, Dict[str, Any]]] = None): + def __init__(self, initial_state: Optional[Dict[str, Any]] = None): """Initialize AgentState.""" self._state: Dict[str, Dict[str, Any]] if initial_state: self._validate_json_serializable(initial_state) - self._state = initial_state.copy() + self._state = copy.deepcopy(initial_state) else: self._state = {} @@ -35,7 +36,7 @@ def set(self, key: str, value: Any) -> None: self._validate_key(key) self._validate_json_serializable(value) - self._state[key] = value + self._state[key] = copy.deepcopy(value) def get(self, key: Optional[str] = None) -> Any: """Get a value or entire state. @@ -47,10 +48,10 @@ def get(self, key: Optional[str] = None) -> Any: The stored value, entire state dict, or None if not found """ if key is None: - return self._state.copy() + return copy.deepcopy(self._state) else: # Return specific key - return self._state.get(key) + return copy.deepcopy(self._state.get(key)) def delete(self, key: str) -> None: """Delete a specific key from the state. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 692e7c018..65c3b2bc7 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1,5 +1,6 @@ import copy import importlib +import json import os import textwrap import threading @@ -1326,3 +1327,37 @@ def test_non_json_serializable_state_throws_error(): with pytest.raises(ValueError, match="Value is not JSON serializable"): agent = Agent(state={"object": object()}) print(agent.state) + + +def test_agent_state_breaks_dict_reference(): + ref_dict = {"hello": "world"} + agent = Agent(state=ref_dict) + ref_dict["hello"] = object() + + json.dumps(agent.state.get()) + + +def test_agent_state_breaks_deep_dict_reference(): + ref_dict = {"world": "!"} + init_dict = {"hello": ref_dict} + agent = Agent(state=init_dict) + ref_dict["world"] = object() + + json.dumps(agent.state.get()) + + +def test_agent_state_set_breaks_dict_reference(): + agent = Agent() + ref_dict = {"hello": "world"} + agent.state.set("hello", ref_dict) + ref_dict["hello"] = object() + + json.dumps(agent.state.get()) + + +def test_agent_state_get_breaks_deep_dict_reference(): + agent = Agent(state={"hello": {"world": "!"}}) + ref_state = agent.state.get() + ref_state["hello"]["world"] = object() + + json.dumps(agent.state.get()) From 4bfddbbe6adce0f5cfac0519a7d6d15fe269a513 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 27 Jun 2025 11:57:53 -0400 Subject: [PATCH 6/6] Update test_agent.py with comments --- tests/strands/agent/test_agent.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 65c3b2bc7..aeb149378 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1332,8 +1332,11 @@ def test_non_json_serializable_state_throws_error(): def test_agent_state_breaks_dict_reference(): ref_dict = {"hello": "world"} agent = Agent(state=ref_dict) + + # Make sure shallow object references do not affect state maintained by AgentState ref_dict["hello"] = object() + # This will fail if AgentState reflects the updated reference json.dumps(agent.state.get()) @@ -1341,23 +1344,29 @@ def test_agent_state_breaks_deep_dict_reference(): ref_dict = {"world": "!"} init_dict = {"hello": ref_dict} agent = Agent(state=init_dict) + # Make sure deep reference changes do not affect state mained by AgentState ref_dict["world"] = object() + # This will fail if AgentState reflects the updated reference json.dumps(agent.state.get()) def test_agent_state_set_breaks_dict_reference(): agent = Agent() ref_dict = {"hello": "world"} + # Set should copy the input, and not maintain the reference to the original object agent.state.set("hello", ref_dict) ref_dict["hello"] = object() + # This will fail if AgentState reflects the updated reference json.dumps(agent.state.get()) def test_agent_state_get_breaks_deep_dict_reference(): agent = Agent(state={"hello": {"world": "!"}}) + # Get should not return a reference to the internal state ref_state = agent.state.get() ref_state["hello"]["world"] = object() + # This will fail if AgentState reflects the updated reference json.dumps(agent.state.get())