From adc33bf94ae084779a7d249cd325d821413fa659 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 2 Jul 2025 13:24:21 +0000 Subject: [PATCH] refactor: Consolidate agent state unit tests --- .../mocked_model_provider.py | 0 tests/strands/agent/test_agent_state.py | 34 ++++++++++++++++++ .../strands/mocked_model_provider/__init__.py | 0 .../test_agent_state_updates.py | 36 ------------------- 4 files changed, 34 insertions(+), 36 deletions(-) rename tests/{strands/mocked_model_provider => fixtures}/mocked_model_provider.py (100%) delete mode 100644 tests/strands/mocked_model_provider/__init__.py delete mode 100644 tests/strands/mocked_model_provider/test_agent_state_updates.py diff --git a/tests/strands/mocked_model_provider/mocked_model_provider.py b/tests/fixtures/mocked_model_provider.py similarity index 100% rename from tests/strands/mocked_model_provider/mocked_model_provider.py rename to tests/fixtures/mocked_model_provider.py diff --git a/tests/strands/agent/test_agent_state.py b/tests/strands/agent/test_agent_state.py index 1921b0065..bc2321a56 100644 --- a/tests/strands/agent/test_agent_state.py +++ b/tests/strands/agent/test_agent_state.py @@ -2,7 +2,11 @@ import pytest +from strands import Agent, tool from strands.agent.state import AgentState +from strands.types.content import Messages + +from ...fixtures.mocked_model_provider import MockedModelProvider def test_set_and_get(): @@ -109,3 +113,33 @@ def test_initial_state(): assert state.get("key1") == "value1" assert state.get("key2") == "value2" assert state.get() == initial + + +def test_agent_state_update_from_tool(): + @tool + def update_state(agent: Agent): + agent.state.set("hello", "world") + agent.state.set("foo", "baz") + + agent_messages: Messages = [ + { + "role": "assistant", + "content": [{"toolUse": {"name": "update_state", "toolUseId": "123", "input": {}}}], + }, + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, + ] + mocked_model_provider = MockedModelProvider(agent_messages) + + agent = Agent( + model=mocked_model_provider, + tools=[update_state], + state={"foo": "bar"}, + ) + + assert agent.state.get("hello") is None + assert agent.state.get("foo") == "bar" + + agent("Invoke Mocked!") + + assert agent.state.get("hello") == "world" + assert agent.state.get("foo") == "baz" diff --git a/tests/strands/mocked_model_provider/__init__.py b/tests/strands/mocked_model_provider/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/strands/mocked_model_provider/test_agent_state_updates.py b/tests/strands/mocked_model_provider/test_agent_state_updates.py deleted file mode 100644 index c15c61961..000000000 --- a/tests/strands/mocked_model_provider/test_agent_state_updates.py +++ /dev/null @@ -1,36 +0,0 @@ -from strands.agent.agent import Agent -from strands.tools.decorator import tool -from strands.types.content import Messages - -from .mocked_model_provider import MockedModelProvider - - -@tool -def update_state(agent: Agent): - agent.state.set("hello", "world") - agent.state.set("foo", "baz") - - -def test_agent_state_update_from_tool(): - agent_messages: Messages = [ - { - "role": "assistant", - "content": [{"toolUse": {"name": "update_state", "toolUseId": "123", "input": {}}}], - }, - {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, - ] - mocked_model_provider = MockedModelProvider(agent_messages) - - agent = Agent( - model=mocked_model_provider, - tools=[update_state], - state={"foo": "bar"}, - ) - - assert agent.state.get("hello") is None - assert agent.state.get("foo") == "bar" - - agent("Invoke Mocked!") - - assert agent.state.get("hello") == "world" - assert agent.state.get("foo") == "baz"