diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 081193b10..d2838396d 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -469,41 +469,32 @@ async def _execute_graph(self) -> None: ready_nodes.clear() # Execute current batch of ready nodes concurrently - tasks = [ - asyncio.create_task(self._execute_node(node)) - for node in current_batch - if node not in self.state.completed_nodes - ] + tasks = [asyncio.create_task(self._execute_node(node)) for node in current_batch] for task in tasks: await task # Find newly ready nodes after batch execution - ready_nodes.extend(self._find_newly_ready_nodes()) + # We add all nodes in current batch as completed batch, + # because a failure would throw exception and code would not make it here + ready_nodes.extend(self._find_newly_ready_nodes(current_batch)) - def _find_newly_ready_nodes(self) -> list["GraphNode"]: + def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["GraphNode"]: """Find nodes that became ready after the last execution.""" newly_ready = [] for _node_id, node in self.nodes.items(): - if ( - node not in self.state.completed_nodes - and node not in self.state.failed_nodes - and self._is_node_ready_with_conditions(node) - ): + if self._is_node_ready_with_conditions(node, completed_batch): newly_ready.append(node) return newly_ready - def _is_node_ready_with_conditions(self, node: GraphNode) -> bool: + def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list["GraphNode"]) -> bool: """Check if a node is ready considering conditional edges.""" # Get incoming edges to this node incoming_edges = [edge for edge in self.edges if edge.to_node == node] - if not incoming_edges: - return node in self.entry_points - # Check if at least one incoming edge condition is satisfied for edge in incoming_edges: - if edge.from_node in self.state.completed_nodes: + if edge.from_node in completed_batch: if edge.should_traverse(self.state): logger.debug( "from=<%s>, to=<%s> | edge ready via satisfied condition", edge.from_node.node_id, node.node_id diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 9977c54cd..1a598847d 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1,6 +1,6 @@ import asyncio import time -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, call, patch import pytest @@ -318,7 +318,7 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span): @pytest.mark.asyncio async def test_cyclic_graph_execution(mock_strands_tracer, mock_use_span): - """Test execution of a graph with cycles.""" + """Test execution of a graph with cycles and proper exit conditions.""" # Create mock agents with state tracking agent_a = create_mock_agent("agent_a", "Agent A response") agent_b = create_mock_agent("agent_b", "Agent B response") @@ -332,16 +332,33 @@ async def test_cyclic_graph_execution(mock_strands_tracer, mock_use_span): # Create a spy to track reset calls reset_spy = MagicMock() - # Create a graph with a cycle: A -> B -> C -> A + # Create conditions for controlled cycling + def a_to_b_condition(state: GraphState) -> bool: + # A can trigger B if B hasn't been executed yet + b_count = sum(1 for node in state.execution_order if node.node_id == "b") + return b_count == 0 + + def b_to_c_condition(state: GraphState) -> bool: + # B can always trigger C (unconditional) + return True + + def c_to_a_condition(state: GraphState) -> bool: + # C can trigger A only if A has been executed less than 2 times + a_count = sum(1 for node in state.execution_order if node.node_id == "a") + return a_count < 2 + + # Create a graph with conditional cycle: A -> B -> C -> A (with conditions) builder = GraphBuilder() builder.add_node(agent_a, "a") builder.add_node(agent_b, "b") builder.add_node(agent_c, "c") - builder.add_edge("a", "b") - builder.add_edge("b", "c") - builder.add_edge("c", "a") # Creates cycle + builder.add_edge("a", "b", condition=a_to_b_condition) # A -> B only if B not executed + builder.add_edge("b", "c", condition=b_to_c_condition) # B -> C always + builder.add_edge("c", "a", condition=c_to_a_condition) # C -> A only if A executed < 2 times builder.set_entry_point("a") - builder.reset_on_revisit() # Enable state reset on revisit + builder.reset_on_revisit(True) # Enable state reset on revisit + builder.set_max_node_executions(10) # Safety limit + builder.set_execution_timeout(30.0) # Safety timeout # Patch the reset_executor_state method to track calls original_reset = GraphNode.reset_executor_state @@ -353,51 +370,29 @@ def spy_reset(self): with patch.object(GraphNode, "reset_executor_state", spy_reset): graph = builder.build() - # Set a maximum iteration limit to prevent infinite loops - # but ensure we go through the cycle at least twice - # This value is used in the LimitedGraph class below - - # Execute the graph with a task that will cause it to cycle + # Execute the graph with controlled cycling result = await graph.invoke_async("Test cyclic graph execution") # Verify that the graph executed successfully assert result.status == Status.COMPLETED - # Verify that each agent was called at least once - agent_a.invoke_async.assert_called() - agent_b.invoke_async.assert_called() - agent_c.invoke_async.assert_called() - - # Verify that the execution order includes all nodes - assert len(result.execution_order) >= 3 - assert any(node.node_id == "a" for node in result.execution_order) - assert any(node.node_id == "b" for node in result.execution_order) - assert any(node.node_id == "c" for node in result.execution_order) - - # Verify that node state was reset during cyclic execution - # If we have more than 3 nodes in execution_order, at least one node was revisited - if len(result.execution_order) > 3: - # Check that reset_executor_state was called for revisited nodes - reset_spy.assert_called() - - # Count occurrences of each node in execution order - node_counts = {} - for node in result.execution_order: - node_counts[node.node_id] = node_counts.get(node.node_id, 0) + 1 - - # At least one node should appear multiple times - assert any(count > 1 for count in node_counts.values()), "No node was revisited in the cycle" - - # For each node that appears multiple times, verify reset was called - for node_id, count in node_counts.items(): - if count > 1: - # Check that reset was called at least (count-1) times for this node - reset_calls = sum(1 for call in reset_spy.call_args_list if call[0][0] == node_id) - assert reset_calls >= count - 1, ( - f"Node {node_id} appeared {count} times but reset was called {reset_calls} times" - ) - - # Verify all nodes were completed + # Expected execution order: a -> b -> c -> a (4 total executions) + # A executes twice (initial + after c), B executes once, C executes once + assert len(result.execution_order) == 4 + + # Verify execution order + execution_ids = [node.node_id for node in result.execution_order] + assert execution_ids == ["a", "b", "c", "a"] + + # Verify that each agent was called the expected number of times + assert agent_a.invoke_async.call_count == 2 # A executes twice + assert agent_b.invoke_async.call_count == 1 # B executes once + assert agent_c.invoke_async.call_count == 1 # C executes once + + # Verify that node state was reset for the revisited node (A) + assert reset_spy.call_args_list == [call("a")] # Only A should be reset (when revisited) + + # Verify all nodes were completed (final state) assert result.completed_nodes == 3 @@ -423,8 +418,6 @@ def test_graph_builder_validation(): builder.add_node(same_agent, "node2") # Same agent instance, different node_id # Test duplicate node instances in Graph.__init__ - from strands.multiagent.graph import Graph, GraphNode - duplicate_agent = create_mock_agent("duplicate_agent") node1 = GraphNode("node1", duplicate_agent) node2 = GraphNode("node2", duplicate_agent) # Same agent instance @@ -566,7 +559,9 @@ async def test_graph_execution_limits(mock_strands_tracer, mock_use_span): assert result.status == Status.FAILED # Should fail due to limit assert len(result.execution_order) == 2 # Should stop at 2 executions - # Test execution timeout by manipulating start time (like Swarm does) + +@pytest.mark.asyncio +async def test_graph_execution_limits_with_cyclic_graph(mock_strands_tracer, mock_use_span): timeout_agent_a = create_mock_agent("timeout_agent_a", "Response A") timeout_agent_b = create_mock_agent("timeout_agent_b", "Response B") @@ -581,16 +576,28 @@ async def test_graph_execution_limits(mock_strands_tracer, mock_use_span): # Enable reset_on_revisit so the cycle can continue graph = builder.reset_on_revisit(True).set_execution_timeout(5.0).set_max_node_executions(100).build() - # Manipulate the start time to simulate timeout (like Swarm does) - result = await graph.invoke_async("Test execution timeout") - # Manually set start time to simulate timeout condition - graph.state.start_time = time.time() - 10 # Set start time to 10 seconds ago + # Execute the cyclic graph - should hit one of the limits + result = await graph.invoke_async("Test execution limits") - # Check the timeout logic directly - should_continue, reason = graph.state.should_continue(max_node_executions=100, execution_timeout=5.0) + # Should fail due to hitting a limit (either timeout or max executions) + assert result.status == Status.FAILED + # Should have executed many nodes (hitting the limit) + assert len(result.execution_order) >= 50 # Should execute many times before hitting limit + + # Test timeout logic directly (without execution) + test_state = GraphState() + test_state.start_time = time.time() - 10 # Set start time to 10 seconds ago + should_continue, reason = test_state.should_continue(max_node_executions=100, execution_timeout=5.0) assert should_continue is False assert "Execution timed out" in reason + # Test max executions logic directly (without execution) + test_state2 = GraphState() + test_state2.execution_order = [None] * 101 # Simulate 101 executions + should_continue2, reason2 = test_state2.should_continue(max_node_executions=100, execution_timeout=5.0) + assert should_continue2 is False + assert "Max node executions reached" in reason2 + # builder = GraphBuilder() # builder.add_node(slow_agent, "slow") # graph = (builder.set_max_node_executions(1000) # High limit to avoid hitting this @@ -1062,9 +1069,7 @@ async def test_state_reset_only_with_cycles_enabled(): graph = builder.build() # Mock the _execute_node method to test conditional reset logic - import unittest.mock - - with unittest.mock.patch.object(node, "reset_executor_state") as mock_reset: + with patch.object(node, "reset_executor_state") as mock_reset: # Simulate the conditional logic from _execute_node if graph.reset_on_revisit and node in state.completed_nodes: node.reset_executor_state() @@ -1079,7 +1084,7 @@ async def test_state_reset_only_with_cycles_enabled(): builder.reset_on_revisit() graph = builder.build() - with unittest.mock.patch.object(node, "reset_executor_state") as mock_reset: + with patch.object(node, "reset_executor_state") as mock_reset: # Simulate the conditional logic from _execute_node if graph.reset_on_revisit and node in state.completed_nodes: node.reset_executor_state() @@ -1087,3 +1092,196 @@ async def test_state_reset_only_with_cycles_enabled(): # With reset_on_revisit enabled, reset should be called mock_reset.assert_called_once() + + +@pytest.mark.asyncio +async def test_self_loop_functionality(mock_strands_tracer, mock_use_span): + """Test comprehensive self-loop functionality including conditions and reset behavior.""" + # Test basic self-loop with execution counting + self_loop_agent = create_mock_agent("self_loop_agent", "Self loop response") + self_loop_agent.invoke_async = Mock(side_effect=self_loop_agent.invoke_async) + + def loop_condition(state: GraphState) -> bool: + return len(state.execution_order) < 3 + + builder = GraphBuilder() + builder.add_node(self_loop_agent, "self_loop") + builder.add_edge("self_loop", "self_loop", condition=loop_condition) + builder.set_entry_point("self_loop") + builder.reset_on_revisit(True) + builder.set_max_node_executions(10) + builder.set_execution_timeout(30.0) + + graph = builder.build() + result = await graph.invoke_async("Test self loop") + + # Verify basic self-loop functionality + assert result.status == Status.COMPLETED + assert self_loop_agent.invoke_async.call_count == 3 + assert len(result.execution_order) == 3 + assert all(node.node_id == "self_loop" for node in result.execution_order) + + +@pytest.mark.asyncio +async def test_self_loop_functionality_without_reset(mock_strands_tracer, mock_use_span): + loop_agent_no_reset = create_mock_agent("loop_agent", "Loop without reset") + + can_only_be_called_twice: Mock = Mock(side_effect=lambda state: can_only_be_called_twice.call_count <= 2) + + builder = GraphBuilder() + builder.add_node(loop_agent_no_reset, "loop_node") + builder.add_edge("loop_node", "loop_node", condition=can_only_be_called_twice) + builder.set_entry_point("loop_node") + builder.reset_on_revisit(False) # Disable state reset + builder.set_max_node_executions(10) + + graph = builder.build() + result = await graph.invoke_async("Test self loop without reset") + + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 2 + + mock_strands_tracer.start_multiagent_span.assert_called() + mock_use_span.assert_called() + + +@pytest.mark.asyncio +async def test_complex_self_loop(mock_strands_tracer, mock_use_span): + """Test complex self-loop scenarios including multi-node graphs and multiple self-loops.""" + start_agent = create_mock_agent("start_agent", "Start") + loop_agent = create_mock_agent("loop_agent", "Loop") + end_agent = create_mock_agent("end_agent", "End") + + def loop_condition(state: GraphState) -> bool: + loop_count = sum(1 for node in state.execution_order if node.node_id == "loop_node") + return loop_count < 2 + + def end_condition(state: GraphState) -> bool: + loop_count = sum(1 for node in state.execution_order if node.node_id == "loop_node") + return loop_count >= 2 + + builder = GraphBuilder() + builder.add_node(start_agent, "start_node") + builder.add_node(loop_agent, "loop_node") + builder.add_node(end_agent, "end_node") + builder.add_edge("start_node", "loop_node") + builder.add_edge("loop_node", "loop_node", condition=loop_condition) + builder.add_edge("loop_node", "end_node", condition=end_condition) + builder.set_entry_point("start_node") + builder.reset_on_revisit(True) + builder.set_max_node_executions(10) + + graph = builder.build() + result = await graph.invoke_async("Test complex graph with self loops") + + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 4 # start -> loop -> loop -> end + assert [node.node_id for node in result.execution_order] == ["start_node", "loop_node", "loop_node", "end_node"] + assert start_agent.invoke_async.call_count == 1 + assert loop_agent.invoke_async.call_count == 2 + assert end_agent.invoke_async.call_count == 1 + + +@pytest.mark.asyncio +async def test_multiple_nodes_with_self_loops(mock_strands_tracer, mock_use_span): + agent_a = create_mock_agent("agent_a", "Agent A") + agent_b = create_mock_agent("agent_b", "Agent B") + + def condition_a(state: GraphState) -> bool: + return sum(1 for node in state.execution_order if node.node_id == "a") < 2 + + def condition_b(state: GraphState) -> bool: + return sum(1 for node in state.execution_order if node.node_id == "b") < 2 + + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_edge("a", "a", condition=condition_a) + builder.add_edge("b", "b", condition=condition_b) + builder.add_edge("a", "b") + builder.set_entry_point("a") + builder.reset_on_revisit(True) + builder.set_max_node_executions(15) + + graph = builder.build() + result = await graph.invoke_async("Test multiple self loops") + + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 4 # a -> a -> b -> b + assert agent_a.invoke_async.call_count == 2 + assert agent_b.invoke_async.call_count == 2 + + mock_strands_tracer.start_multiagent_span.assert_called() + mock_use_span.assert_called() + + +@pytest.mark.asyncio +async def test_self_loop_state_reset(): + """Test self-loop edge cases including state reset, failure handling, and infinite loop prevention.""" + agent = create_mock_agent("stateful_agent", "Stateful response") + agent.state = AgentState() + + def loop_condition(state: GraphState) -> bool: + return len(state.execution_order) < 3 + + builder = GraphBuilder() + node = builder.add_node(agent, "stateful_node") + builder.add_edge("stateful_node", "stateful_node", condition=loop_condition) + builder.set_entry_point("stateful_node") + builder.reset_on_revisit(True) + builder.set_max_node_executions(10) + + node.reset_executor_state = Mock(wraps=node.reset_executor_state) + + graph = builder.build() + result = await graph.invoke_async("Test state reset") + + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 3 + assert node.reset_executor_state.call_count >= 2 # Reset called for revisits + + +@pytest.mark.asyncio +async def test_infinite_loop_prevention(): + infinite_agent = create_mock_agent("infinite_agent", "Infinite loop") + + def always_true_condition(state: GraphState) -> bool: + return True + + builder = GraphBuilder() + builder.add_node(infinite_agent, "infinite_node") + builder.add_edge("infinite_node", "infinite_node", condition=always_true_condition) + builder.set_entry_point("infinite_node") + builder.reset_on_revisit(True) + builder.set_max_node_executions(5) + + graph = builder.build() + result = await graph.invoke_async("Test infinite loop prevention") + + assert result.status == Status.FAILED + assert len(result.execution_order) == 5 + + +@pytest.mark.asyncio +async def test_infinite_loop_prevention_self_loops(): + multi_agent = create_mock_multi_agent("multi_agent", "Multi-agent response") + loop_count = 0 + + def multi_loop_condition(state: GraphState) -> bool: + nonlocal loop_count + loop_count += 1 + return loop_count <= 2 + + builder = GraphBuilder() + builder.add_node(multi_agent, "multi_node") + builder.add_edge("multi_node", "multi_node", condition=multi_loop_condition) + builder.set_entry_point("multi_node") + builder.reset_on_revisit(True) + builder.set_max_node_executions(10) + + graph = builder.build() + result = await graph.invoke_async("Test multi-agent self loop") + + assert result.status == Status.COMPLETED + assert len(result.execution_order) >= 2 + assert multi_agent.invoke_async.call_count >= 2