|
1 | | -from unittest.mock import patch |
2 | 1 | from uuid import uuid4 |
3 | 2 |
|
4 | 3 | import pytest |
5 | 4 |
|
6 | 5 | from strands import Agent, tool |
| 6 | +from strands.experimental.hooks.multiagent import BeforeNodeCallEvent |
7 | 7 | from strands.hooks import ( |
8 | 8 | AfterInvocationEvent, |
9 | 9 | AfterModelCallEvent, |
|
13 | 13 | BeforeToolCallEvent, |
14 | 14 | MessageAddedEvent, |
15 | 15 | ) |
16 | | -from strands.multiagent.base import Status |
17 | 16 | from strands.multiagent.swarm import Swarm |
18 | 17 | from strands.session.file_session_manager import FileSessionManager |
19 | 18 | from strands.types.content import ContentBlock |
@@ -82,6 +81,38 @@ def writer_agent(hook_provider): |
82 | 81 | ) |
83 | 82 |
|
84 | 83 |
|
| 84 | +@pytest.fixture |
| 85 | +def exit_hook(): |
| 86 | + class ExitHook: |
| 87 | + def __init__(self): |
| 88 | + self.should_exit = True |
| 89 | + |
| 90 | + def register_hooks(self, registry): |
| 91 | + registry.add_callback(BeforeNodeCallEvent, self.exit_before_analyst) |
| 92 | + |
| 93 | + def exit_before_analyst(self, event): |
| 94 | + if event.node_id == "analyst" and self.should_exit: |
| 95 | + raise SystemExit("Controlled exit before analyst") |
| 96 | + |
| 97 | + return ExitHook() |
| 98 | + |
| 99 | + |
| 100 | +@pytest.fixture |
| 101 | +def verify_hook(): |
| 102 | + class VerifyHook: |
| 103 | + def __init__(self): |
| 104 | + self.first_node = None |
| 105 | + |
| 106 | + def register_hooks(self, registry): |
| 107 | + registry.add_callback(BeforeNodeCallEvent, self.capture_first_node) |
| 108 | + |
| 109 | + def capture_first_node(self, event): |
| 110 | + if self.first_node is None: |
| 111 | + self.first_node = event.node_id |
| 112 | + |
| 113 | + return VerifyHook() |
| 114 | + |
| 115 | + |
85 | 116 | def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent, hook_provider): |
86 | 117 | """Test swarm execution with string input.""" |
87 | 118 | # Create the swarm |
@@ -326,53 +357,39 @@ async def test_swarm_get_agent_results_flattening(): |
326 | 357 | assert agent_results[0].message is not None |
327 | 358 |
|
328 | 359 |
|
329 | | -@pytest.mark.asyncio |
330 | | -async def test_swarm_interrupt_and_resume(researcher_agent, analyst_agent, writer_agent): |
331 | | - """Test swarm interruption after analyst_agent and resume functionality.""" |
332 | | - session_id = str(uuid4()) |
333 | | - |
334 | | - # Create session manager |
335 | | - session_manager = FileSessionManager(session_id=session_id) |
336 | | - |
337 | | - # Create swarm with session manager |
338 | | - swarm = Swarm([researcher_agent, analyst_agent, writer_agent], session_manager=session_manager) |
339 | | - |
340 | | - # Mock analyst_agent's _invoke method to fail |
341 | | - async def failing_invoke(*args, **kwargs): |
342 | | - raise Exception("Simulated failure in analyst") |
343 | | - yield # This line is never reached, but makes it an async generator |
344 | | - |
345 | | - with patch.object(analyst_agent, "stream_async", side_effect=failing_invoke): |
346 | | - # First execution - should fail at analyst |
347 | | - result = await swarm.invoke_async("Research AI trends and create a brief report") |
348 | | - try: |
349 | | - assert result.status == Status.FAILED |
350 | | - except Exception as e: |
351 | | - assert "Simulated failure in analyst" in str(e) |
352 | | - |
353 | | - # Verify partial execution was persisted |
354 | | - persisted_state = session_manager.read_multi_agent(session_id, swarm.id) |
355 | | - assert persisted_state is not None |
356 | | - assert persisted_state["type"] == "swarm" |
357 | | - assert persisted_state["status"] == "failed" |
358 | | - assert len(persisted_state["node_history"]) == 1 # At least researcher executed |
359 | | - |
360 | | - # Track execution count before resume |
361 | | - initial_execution_count = len(persisted_state["node_history"]) |
362 | | - |
363 | | - # Execute swarm again - should automatically resume from saved state |
364 | | - result = await swarm.invoke_async("Research AI trends and create a brief report") |
| 360 | +def test_swarm_resume_from_executing_state(tmpdir, exit_hook, verify_hook): |
| 361 | + """Test swarm resuming from EXECUTING state using BeforeNodeCallEvent hook.""" |
| 362 | + session_id = f"swarm_resume_{uuid4()}" |
365 | 363 |
|
366 | | - # Verify successful completion |
367 | | - assert result.status == Status.COMPLETED |
368 | | - assert len(result.results) > 0 |
| 364 | + # First execution - exit before second node |
| 365 | + session_manager = FileSessionManager(session_id=session_id, storage_dir=tmpdir) |
| 366 | + researcher = Agent(name="researcher", system_prompt="you are a researcher.") |
| 367 | + analyst = Agent(name="analyst", system_prompt="you are an analyst.") |
| 368 | + writer = Agent(name="writer", system_prompt="you are a writer.") |
369 | 369 |
|
370 | | - assert len(result.node_history) >= initial_execution_count + 1 |
| 370 | + swarm = Swarm([researcher, analyst, writer], session_manager=session_manager, hooks=[exit_hook]) |
371 | 371 |
|
372 | | - node_names = [node.node_id for node in result.node_history] |
373 | | - assert "researcher" in node_names |
374 | | - # Either analyst or writer (or both) should have executed to complete the task |
375 | | - assert "analyst" in node_names or "writer" in node_names |
| 372 | + try: |
| 373 | + swarm("write AI trends and calculate growth in 100 words") |
| 374 | + except SystemExit as e: |
| 375 | + assert "Controlled exit before analyst" in str(e) |
376 | 376 |
|
377 | | - # Clean up |
378 | | - session_manager.delete_session(session_id) |
| 377 | + # Verify state was persisted with EXECUTING status and next node |
| 378 | + persisted_state = session_manager.read_multi_agent(session_id, swarm.id) |
| 379 | + assert persisted_state["status"] == "executing" |
| 380 | + assert len(persisted_state["node_history"]) == 1 |
| 381 | + assert persisted_state["node_history"][0] == "researcher" |
| 382 | + assert persisted_state["next_nodes_to_execute"] == ["analyst"] |
| 383 | + |
| 384 | + exit_hook.should_exit = False |
| 385 | + researcher2 = Agent(name="researcher", system_prompt="you are a researcher.") |
| 386 | + analyst2 = Agent(name="analyst", system_prompt="you are an analyst.") |
| 387 | + writer2 = Agent(name="writer", system_prompt="you are a writer.") |
| 388 | + new_swarm = Swarm([researcher2, analyst2, writer2], session_manager=session_manager, hooks=[verify_hook]) |
| 389 | + result = new_swarm("write AI trends and calculate growth in 100 words") |
| 390 | + |
| 391 | + # Verify swarm behavior - should resume from analyst, not restart |
| 392 | + assert result.status.value == "completed" |
| 393 | + assert verify_hook.first_node == "analyst" |
| 394 | + node_ids = [n.node_id for n in result.node_history] |
| 395 | + assert "analyst" in node_ids |
0 commit comments