Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,6 @@ def _validate_node_executor(
if executor._session_manager is not None:
raise ValueError("Session persistence is not supported for Graph agents yet.")

# Check for callbacks
if executor.hooks.has_callbacks():
raise ValueError("Agent callbacks are not supported for Graph agents yet.")


class GraphBuilder:
"""Builder pattern for constructing graphs."""
Expand Down
4 changes: 0 additions & 4 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,6 @@ def _validate_swarm(self, nodes: list[Agent]) -> None:
if node._session_manager is not None:
raise ValueError("Session persistence is not supported for Swarm agents yet.")

# Check for callbacks
if node.hooks.has_callbacks():
raise ValueError("Agent callbacks are not supported for Swarm agents yet.")

def _inject_swarm_tools(self) -> None:
"""Add swarm coordination tools to each agent."""
# Create tool functions with proper closures
Expand Down
45 changes: 42 additions & 3 deletions tests/fixtures/mock_hook_provider.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,44 @@
from typing import Iterator, Tuple, Type
from typing import Iterator, Literal, Tuple, Type

from strands.hooks import HookEvent, HookProvider, HookRegistry
from strands import Agent
from strands.experimental.hooks import (
AfterModelInvocationEvent,
AfterToolInvocationEvent,
BeforeModelInvocationEvent,
BeforeToolInvocationEvent,
)
from strands.hooks import (
AfterInvocationEvent,
AgentInitializedEvent,
BeforeInvocationEvent,
HookEvent,
HookProvider,
HookRegistry,
MessageAddedEvent,
)


class MockHookProvider(HookProvider):
def __init__(self, event_types: list[Type]):
def __init__(self, event_types: list[Type] | Literal["all"]):
if event_types == "all":
event_types = [
AgentInitializedEvent,
BeforeInvocationEvent,
AfterInvocationEvent,
AfterToolInvocationEvent,
BeforeToolInvocationEvent,
BeforeModelInvocationEvent,
AfterModelInvocationEvent,
MessageAddedEvent,
]

self.events_received = []
self.events_types = event_types

@property
def event_types_received(self):
return [type(event) for event in self.events_received]

def get_events(self) -> Tuple[int, Iterator[HookEvent]]:
return len(self.events_received), iter(self.events_received)

Expand All @@ -17,3 +48,11 @@ def register_hooks(self, registry: HookRegistry) -> None:

def add_event(self, event: HookEvent) -> None:
self.events_received.append(event)

def extract_for(self, agent: Agent) -> "MockHookProvider":
"""Extracts a hook provider for the given agent, including the events that were fired for that agent.

Convenience method when sharing a hook provider between multiple agents."""
child_provider = MockHookProvider(self.events_types)
child_provider.events_received = [event for event in self.events_received if event.agent == agent]
return child_provider
18 changes: 0 additions & 18 deletions tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,15 +873,6 @@ class TestHookProvider(HookProvider):
def register_hooks(self, registry, **kwargs):
registry.add_callback(AgentInitializedEvent, lambda e: None)

agent_with_hooks = create_mock_agent("agent_with_hooks")
agent_with_hooks._session_manager = None
agent_with_hooks.hooks = HookRegistry()
agent_with_hooks.hooks.add_hook(TestHookProvider())

builder = GraphBuilder()
with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"):
builder.add_node(agent_with_hooks)

# Test validation in Graph constructor (when nodes are passed directly)
# Test with session manager in Graph constructor
node_with_session = GraphNode("node_with_session", agent_with_session)
Expand All @@ -892,15 +883,6 @@ def register_hooks(self, registry, **kwargs):
entry_points=set(),
)

# Test with callbacks in Graph constructor
node_with_hooks = GraphNode("node_with_hooks", agent_with_hooks)
with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"):
Graph(
nodes={"node_with_hooks": node_with_hooks},
edges=set(),
entry_points=set(),
)


@pytest.mark.asyncio
async def test_controlled_cyclic_execution():
Expand Down
16 changes: 1 addition & 15 deletions tests/strands/multiagent/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from strands.agent import Agent, AgentResult
from strands.agent.state import AgentState
from strands.hooks import AgentInitializedEvent
from strands.hooks.registry import HookProvider, HookRegistry
from strands.hooks.registry import HookRegistry
from strands.multiagent.base import Status
from strands.multiagent.swarm import SharedContext, Swarm, SwarmNode, SwarmResult, SwarmState
from strands.session.session_manager import SessionManager
Expand Down Expand Up @@ -470,16 +469,3 @@ def test_swarm_validate_unsupported_features():

with pytest.raises(ValueError, match="Session persistence is not supported for Swarm agents yet"):
Swarm([agent_with_session])

# Test with callbacks (should fail)
class TestHookProvider(HookProvider):
def register_hooks(self, registry, **kwargs):
registry.add_callback(AgentInitializedEvent, lambda e: None)

agent_with_hooks = create_mock_agent("agent_with_hooks")
agent_with_hooks._session_manager = None
agent_with_hooks.hooks = HookRegistry()
agent_with_hooks.hooks.add_hook(TestHookProvider())

with pytest.raises(ValueError, match="Agent callbacks are not supported for Swarm agents yet"):
Swarm([agent_with_hooks])
40 changes: 33 additions & 7 deletions tests_integ/test_multiagent_graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest

from strands import Agent, tool
from strands.experimental.hooks import AfterModelInvocationEvent, BeforeModelInvocationEvent
from strands.hooks import AfterInvocationEvent, AgentInitializedEvent, BeforeInvocationEvent, MessageAddedEvent
from strands.multiagent.graph import GraphBuilder
from strands.types.content import ContentBlock
from tests.fixtures.mock_hook_provider import MockHookProvider


@tool
Expand All @@ -18,49 +21,59 @@ def multiply_numbers(x: int, y: int) -> int:


@pytest.fixture
def math_agent():
def hook_provider():
return MockHookProvider("all")


@pytest.fixture
def math_agent(hook_provider):
"""Create an agent specialized in mathematical operations."""
return Agent(
model="us.amazon.nova-pro-v1:0",
system_prompt="You are a mathematical assistant. Always provide clear, step-by-step calculations.",
hooks=[hook_provider],
tools=[calculate_sum, multiply_numbers],
)


@pytest.fixture
def analysis_agent():
def analysis_agent(hook_provider):
"""Create an agent specialized in data analysis."""
return Agent(
model="us.amazon.nova-pro-v1:0",
hooks=[hook_provider],
system_prompt="You are a data analysis expert. Provide insights and interpretations of numerical results.",
)


@pytest.fixture
def summary_agent():
def summary_agent(hook_provider):
"""Create an agent specialized in summarization."""
return Agent(
model="us.amazon.nova-lite-v1:0",
hooks=[hook_provider],
system_prompt="You are a summarization expert. Create concise, clear summaries of complex information.",
)


@pytest.fixture
def validation_agent():
def validation_agent(hook_provider):
"""Create an agent specialized in validation."""
return Agent(
model="us.amazon.nova-pro-v1:0",
hooks=[hook_provider],
system_prompt="You are a validation expert. Check results for accuracy and completeness.",
)


@pytest.fixture
def image_analysis_agent():
def image_analysis_agent(hook_provider):
"""Create an agent specialized in image analysis."""
return Agent(
hooks=[hook_provider],
system_prompt=(
"You are an image analysis expert. Describe what you see in images and provide detailed analysis."
)
),
)


Expand Down Expand Up @@ -149,7 +162,7 @@ def proceed_to_second_summary(state):


@pytest.mark.asyncio
async def test_graph_execution_with_image(image_analysis_agent, summary_agent, yellow_img):
async def test_graph_execution_with_image(image_analysis_agent, summary_agent, yellow_img, hook_provider):
"""Test graph execution with multi-modal image input."""
builder = GraphBuilder()

Expand Down Expand Up @@ -186,3 +199,16 @@ async def test_graph_execution_with_image(image_analysis_agent, summary_agent, y
# Verify both nodes completed
assert "image_analyzer" in result.results
assert "summarizer" in result.results

expected_hook_events = [
AgentInitializedEvent,
BeforeInvocationEvent,
MessageAddedEvent,
BeforeModelInvocationEvent,
AfterModelInvocationEvent,
MessageAddedEvent,
AfterInvocationEvent,
]

assert hook_provider.extract_for(image_analysis_agent).event_types_received == expected_hook_events
assert hook_provider.extract_for(summary_agent).event_types_received == expected_hook_events
34 changes: 30 additions & 4 deletions tests_integ/test_multiagent_swarm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import pytest

from strands import Agent, tool
from strands.experimental.hooks import (
AfterModelInvocationEvent,
AfterToolInvocationEvent,
BeforeModelInvocationEvent,
BeforeToolInvocationEvent,
)
from strands.hooks import AfterInvocationEvent, BeforeInvocationEvent, MessageAddedEvent
from strands.multiagent.swarm import Swarm
from strands.types.content import ContentBlock
from tests.fixtures.mock_hook_provider import MockHookProvider


@tool
Expand All @@ -22,44 +30,52 @@ def calculate(expression: str) -> str:


@pytest.fixture
def researcher_agent():
def hook_provider():
return MockHookProvider("all")


@pytest.fixture
def researcher_agent(hook_provider):
"""Create an agent specialized in research."""
return Agent(
name="researcher",
system_prompt=(
"You are a research specialist who excels at finding information. When you need to perform calculations or"
" format documents, hand off to the appropriate specialist."
),
hooks=[hook_provider],
tools=[web_search],
)


@pytest.fixture
def analyst_agent():
def analyst_agent(hook_provider):
"""Create an agent specialized in data analysis."""
return Agent(
name="analyst",
system_prompt=(
"You are a data analyst who excels at calculations and numerical analysis. When you need"
" research or document formatting, hand off to the appropriate specialist."
),
hooks=[hook_provider],
tools=[calculate],
)


@pytest.fixture
def writer_agent():
def writer_agent(hook_provider):
"""Create an agent specialized in writing and formatting."""
return Agent(
name="writer",
hooks=[hook_provider],
system_prompt=(
"You are a professional writer who excels at formatting and presenting information. When you need research"
" or calculations, hand off to the appropriate specialist."
),
)


def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent):
def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent, hook_provider):
"""Test swarm execution with string input."""
# Create the swarm
swarm = Swarm([researcher_agent, analyst_agent, writer_agent])
Expand All @@ -82,6 +98,16 @@ def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_age
# Verify agent history - at least one agent should have been used
assert len(result.node_history) > 0

# Just ensure that hooks are emitted; actual content is not verified
researcher_hooks = hook_provider.extract_for(researcher_agent).event_types_received
assert BeforeInvocationEvent in researcher_hooks
assert MessageAddedEvent in researcher_hooks
assert BeforeModelInvocationEvent in researcher_hooks
assert BeforeToolInvocationEvent in researcher_hooks
assert AfterToolInvocationEvent in researcher_hooks
assert AfterModelInvocationEvent in researcher_hooks
assert AfterInvocationEvent in researcher_hooks


@pytest.mark.asyncio
async def test_swarm_execution_with_image(researcher_agent, analyst_agent, writer_agent, yellow_img):
Expand Down
Loading