diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 9aee260b1..081193b10 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -201,10 +201,6 @@ def _validate_node_executor( if executor._session_manager is not None: raise ValueError("Session persistence is not supported for Graph agents yet.") - # Check for callbacks - if executor.hooks.has_callbacks(): - raise ValueError("Agent callbacks are not supported for Graph agents yet.") - class GraphBuilder: """Builder pattern for constructing graphs.""" diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index a96c92de8..d730d5156 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -318,10 +318,6 @@ def _validate_swarm(self, nodes: list[Agent]) -> None: if node._session_manager is not None: raise ValueError("Session persistence is not supported for Swarm agents yet.") - # Check for callbacks - if node.hooks.has_callbacks(): - raise ValueError("Agent callbacks are not supported for Swarm agents yet.") - def _inject_swarm_tools(self) -> None: """Add swarm coordination tools to each agent.""" # Create tool functions with proper closures diff --git a/tests/fixtures/mock_hook_provider.py b/tests/fixtures/mock_hook_provider.py index 8d7e93253..6bf7b8c77 100644 --- a/tests/fixtures/mock_hook_provider.py +++ b/tests/fixtures/mock_hook_provider.py @@ -1,13 +1,44 @@ -from typing import Iterator, Tuple, Type +from typing import Iterator, Literal, Tuple, Type -from strands.hooks import HookEvent, HookProvider, HookRegistry +from strands import Agent +from strands.experimental.hooks import ( + AfterModelInvocationEvent, + AfterToolInvocationEvent, + BeforeModelInvocationEvent, + BeforeToolInvocationEvent, +) +from strands.hooks import ( + AfterInvocationEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + HookEvent, + HookProvider, + HookRegistry, + MessageAddedEvent, +) class MockHookProvider(HookProvider): - def __init__(self, event_types: list[Type]): + def __init__(self, event_types: list[Type] | Literal["all"]): + if event_types == "all": + event_types = [ + AgentInitializedEvent, + BeforeInvocationEvent, + AfterInvocationEvent, + AfterToolInvocationEvent, + BeforeToolInvocationEvent, + BeforeModelInvocationEvent, + AfterModelInvocationEvent, + MessageAddedEvent, + ] + self.events_received = [] self.events_types = event_types + @property + def event_types_received(self): + return [type(event) for event in self.events_received] + def get_events(self) -> Tuple[int, Iterator[HookEvent]]: return len(self.events_received), iter(self.events_received) @@ -17,3 +48,11 @@ def register_hooks(self, registry: HookRegistry) -> None: def add_event(self, event: HookEvent) -> None: self.events_received.append(event) + + def extract_for(self, agent: Agent) -> "MockHookProvider": + """Extracts a hook provider for the given agent, including the events that were fired for that agent. + + Convenience method when sharing a hook provider between multiple agents.""" + child_provider = MockHookProvider(self.events_types) + child_provider.events_received = [event for event in self.events_received if event.agent == agent] + return child_provider diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index c60361da8..9977c54cd 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -873,15 +873,6 @@ class TestHookProvider(HookProvider): def register_hooks(self, registry, **kwargs): registry.add_callback(AgentInitializedEvent, lambda e: None) - agent_with_hooks = create_mock_agent("agent_with_hooks") - agent_with_hooks._session_manager = None - agent_with_hooks.hooks = HookRegistry() - agent_with_hooks.hooks.add_hook(TestHookProvider()) - - builder = GraphBuilder() - with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"): - builder.add_node(agent_with_hooks) - # Test validation in Graph constructor (when nodes are passed directly) # Test with session manager in Graph constructor node_with_session = GraphNode("node_with_session", agent_with_session) @@ -892,15 +883,6 @@ def register_hooks(self, registry, **kwargs): entry_points=set(), ) - # Test with callbacks in Graph constructor - node_with_hooks = GraphNode("node_with_hooks", agent_with_hooks) - with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"): - Graph( - nodes={"node_with_hooks": node_with_hooks}, - edges=set(), - entry_points=set(), - ) - @pytest.mark.asyncio async def test_controlled_cyclic_execution(): diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 91b677fa4..74f89241f 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -5,8 +5,7 @@ from strands.agent import Agent, AgentResult from strands.agent.state import AgentState -from strands.hooks import AgentInitializedEvent -from strands.hooks.registry import HookProvider, HookRegistry +from strands.hooks.registry import HookRegistry from strands.multiagent.base import Status from strands.multiagent.swarm import SharedContext, Swarm, SwarmNode, SwarmResult, SwarmState from strands.session.session_manager import SessionManager @@ -470,16 +469,3 @@ def test_swarm_validate_unsupported_features(): with pytest.raises(ValueError, match="Session persistence is not supported for Swarm agents yet"): Swarm([agent_with_session]) - - # Test with callbacks (should fail) - class TestHookProvider(HookProvider): - def register_hooks(self, registry, **kwargs): - registry.add_callback(AgentInitializedEvent, lambda e: None) - - agent_with_hooks = create_mock_agent("agent_with_hooks") - agent_with_hooks._session_manager = None - agent_with_hooks.hooks = HookRegistry() - agent_with_hooks.hooks.add_hook(TestHookProvider()) - - with pytest.raises(ValueError, match="Agent callbacks are not supported for Swarm agents yet"): - Swarm([agent_with_hooks]) diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index e1f3a2f3f..bc9b0ea8b 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,8 +1,11 @@ import pytest from strands import Agent, tool +from strands.experimental.hooks import AfterModelInvocationEvent, BeforeModelInvocationEvent +from strands.hooks import AfterInvocationEvent, AgentInitializedEvent, BeforeInvocationEvent, MessageAddedEvent from strands.multiagent.graph import GraphBuilder from strands.types.content import ContentBlock +from tests.fixtures.mock_hook_provider import MockHookProvider @tool @@ -18,49 +21,59 @@ def multiply_numbers(x: int, y: int) -> int: @pytest.fixture -def math_agent(): +def hook_provider(): + return MockHookProvider("all") + + +@pytest.fixture +def math_agent(hook_provider): """Create an agent specialized in mathematical operations.""" return Agent( model="us.amazon.nova-pro-v1:0", system_prompt="You are a mathematical assistant. Always provide clear, step-by-step calculations.", + hooks=[hook_provider], tools=[calculate_sum, multiply_numbers], ) @pytest.fixture -def analysis_agent(): +def analysis_agent(hook_provider): """Create an agent specialized in data analysis.""" return Agent( model="us.amazon.nova-pro-v1:0", + hooks=[hook_provider], system_prompt="You are a data analysis expert. Provide insights and interpretations of numerical results.", ) @pytest.fixture -def summary_agent(): +def summary_agent(hook_provider): """Create an agent specialized in summarization.""" return Agent( model="us.amazon.nova-lite-v1:0", + hooks=[hook_provider], system_prompt="You are a summarization expert. Create concise, clear summaries of complex information.", ) @pytest.fixture -def validation_agent(): +def validation_agent(hook_provider): """Create an agent specialized in validation.""" return Agent( model="us.amazon.nova-pro-v1:0", + hooks=[hook_provider], system_prompt="You are a validation expert. Check results for accuracy and completeness.", ) @pytest.fixture -def image_analysis_agent(): +def image_analysis_agent(hook_provider): """Create an agent specialized in image analysis.""" return Agent( + hooks=[hook_provider], system_prompt=( "You are an image analysis expert. Describe what you see in images and provide detailed analysis." - ) + ), ) @@ -149,7 +162,7 @@ def proceed_to_second_summary(state): @pytest.mark.asyncio -async def test_graph_execution_with_image(image_analysis_agent, summary_agent, yellow_img): +async def test_graph_execution_with_image(image_analysis_agent, summary_agent, yellow_img, hook_provider): """Test graph execution with multi-modal image input.""" builder = GraphBuilder() @@ -186,3 +199,16 @@ async def test_graph_execution_with_image(image_analysis_agent, summary_agent, y # Verify both nodes completed assert "image_analyzer" in result.results assert "summarizer" in result.results + + expected_hook_events = [ + AgentInitializedEvent, + BeforeInvocationEvent, + MessageAddedEvent, + BeforeModelInvocationEvent, + AfterModelInvocationEvent, + MessageAddedEvent, + AfterInvocationEvent, + ] + + assert hook_provider.extract_for(image_analysis_agent).event_types_received == expected_hook_events + assert hook_provider.extract_for(summary_agent).event_types_received == expected_hook_events diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index 6fe5700aa..76860f687 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -1,8 +1,16 @@ import pytest from strands import Agent, tool +from strands.experimental.hooks import ( + AfterModelInvocationEvent, + AfterToolInvocationEvent, + BeforeModelInvocationEvent, + BeforeToolInvocationEvent, +) +from strands.hooks import AfterInvocationEvent, BeforeInvocationEvent, MessageAddedEvent from strands.multiagent.swarm import Swarm from strands.types.content import ContentBlock +from tests.fixtures.mock_hook_provider import MockHookProvider @tool @@ -22,7 +30,12 @@ def calculate(expression: str) -> str: @pytest.fixture -def researcher_agent(): +def hook_provider(): + return MockHookProvider("all") + + +@pytest.fixture +def researcher_agent(hook_provider): """Create an agent specialized in research.""" return Agent( name="researcher", @@ -30,12 +43,13 @@ def researcher_agent(): "You are a research specialist who excels at finding information. When you need to perform calculations or" " format documents, hand off to the appropriate specialist." ), + hooks=[hook_provider], tools=[web_search], ) @pytest.fixture -def analyst_agent(): +def analyst_agent(hook_provider): """Create an agent specialized in data analysis.""" return Agent( name="analyst", @@ -43,15 +57,17 @@ def analyst_agent(): "You are a data analyst who excels at calculations and numerical analysis. When you need" " research or document formatting, hand off to the appropriate specialist." ), + hooks=[hook_provider], tools=[calculate], ) @pytest.fixture -def writer_agent(): +def writer_agent(hook_provider): """Create an agent specialized in writing and formatting.""" return Agent( name="writer", + hooks=[hook_provider], system_prompt=( "You are a professional writer who excels at formatting and presenting information. When you need research" " or calculations, hand off to the appropriate specialist." @@ -59,7 +75,7 @@ def writer_agent(): ) -def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent): +def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent, hook_provider): """Test swarm execution with string input.""" # Create the swarm swarm = Swarm([researcher_agent, analyst_agent, writer_agent]) @@ -82,6 +98,16 @@ def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_age # Verify agent history - at least one agent should have been used assert len(result.node_history) > 0 + # Just ensure that hooks are emitted; actual content is not verified + researcher_hooks = hook_provider.extract_for(researcher_agent).event_types_received + assert BeforeInvocationEvent in researcher_hooks + assert MessageAddedEvent in researcher_hooks + assert BeforeModelInvocationEvent in researcher_hooks + assert BeforeToolInvocationEvent in researcher_hooks + assert AfterToolInvocationEvent in researcher_hooks + assert AfterModelInvocationEvent in researcher_hooks + assert AfterInvocationEvent in researcher_hooks + @pytest.mark.asyncio async def test_swarm_execution_with_image(researcher_agent, analyst_agent, writer_agent, yellow_img):