From f8d4750a54b682c19a2cc220aa735bf3fae610b1 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 8 Sep 2025 11:59:35 -0400 Subject: [PATCH 1/7] feat(multiagent): allow callers of swarm and graph to pass kwargs to executors --- src/strands/multiagent/graph.py | 18 +++--- src/strands/multiagent/swarm.py | 14 +++-- tests/strands/multiagent/test_graph.py | 86 ++++++++++++++++++++++++++ tests/strands/multiagent/test_swarm.py | 51 +++++++++++++++ 4 files changed, 154 insertions(+), 15 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index d2838396d..cc2cc435b 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -389,7 +389,7 @@ def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult """Invoke the graph synchronously.""" def execute() -> GraphResult: - return asyncio.run(self.invoke_async(task)) + return asyncio.run(self.invoke_async(task, **kwargs)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) @@ -420,7 +420,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> G self.node_timeout or "None", ) - await self._execute_graph() + await self._execute_graph(kwargs) # Set final status based on execution results if self.state.failed_nodes: @@ -450,7 +450,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - async def _execute_graph(self) -> None: + async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: """Unified execution flow with conditional routing.""" ready_nodes = list(self.entry_points) @@ -469,7 +469,7 @@ async def _execute_graph(self) -> None: ready_nodes.clear() # Execute current batch of ready nodes concurrently - tasks = [asyncio.create_task(self._execute_node(node)) for node in current_batch] + tasks = [asyncio.create_task(self._execute_node(node, invocation_state)) for node in current_batch] for task in tasks: await task @@ -506,7 +506,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ ) return False - async def _execute_node(self, node: GraphNode) -> None: + async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> None: """Execute a single node with error handling and timeout protection.""" # Reset the node's state if reset_on_revisit is enabled and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: @@ -529,11 +529,11 @@ async def _execute_node(self, node: GraphNode) -> None: if isinstance(node.executor, MultiAgentBase): if self.node_timeout is not None: multi_agent_result = await asyncio.wait_for( - node.executor.invoke_async(node_input), + node.executor.invoke_async(node_input, **invocation_state), timeout=self.node_timeout, ) else: - multi_agent_result = await node.executor.invoke_async(node_input) + multi_agent_result = await node.executor.invoke_async(node_input, **invocation_state) # Create NodeResult with MultiAgentResult directly node_result = NodeResult( @@ -548,11 +548,11 @@ async def _execute_node(self, node: GraphNode) -> None: elif isinstance(node.executor, Agent): if self.node_timeout is not None: agent_response = await asyncio.wait_for( - node.executor.invoke_async(node_input), + node.executor.invoke_async(node_input, **invocation_state), timeout=self.node_timeout, ) else: - agent_response = await node.executor.invoke_async(node_input) + agent_response = await node.executor.invoke_async(node_input, **invocation_state) # Extract metrics from agent response usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index d730d5156..6e518215f 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -241,7 +241,7 @@ def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult """Invoke the swarm synchronously.""" def execute() -> SwarmResult: - return asyncio.run(self.invoke_async(task)) + return asyncio.run(self.invoke_async(task, **kwargs)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) @@ -272,7 +272,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> S self.execution_timeout, ) - await self._execute_swarm() + await self._execute_swarm(kwargs) except Exception: logger.exception("swarm execution failed") self.state.completion_status = Status.FAILED @@ -483,7 +483,7 @@ def _build_node_input(self, target_node: SwarmNode) -> str: return context_text - async def _execute_swarm(self) -> None: + async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: """Shared execution logic used by execute_async.""" try: # Main execution loop @@ -522,7 +522,7 @@ async def _execute_swarm(self) -> None: # TODO: Implement cancellation token to stop _execute_node from continuing try: await asyncio.wait_for( - self._execute_node(current_node, self.state.task), + self._execute_node(current_node, self.state.task, invocation_state), timeout=self.node_timeout, ) @@ -563,7 +563,9 @@ async def _execute_swarm(self) -> None: f"{elapsed_time:.2f}", ) - async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -> AgentResult: + async def _execute_node( + self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] + ) -> AgentResult: """Execute swarm node.""" start_time = time.time() node_name = node.node_id @@ -583,7 +585,7 @@ async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) - # Execute node result = None node.reset_executor_state() - result = await node.executor.invoke_async(node_input) + result = await node.executor.invoke_async(node_input, **invocation_state) execution_time = round((time.time() - start_time) * 1000) diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 1a598847d..71235b3a0 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1285,3 +1285,89 @@ def multi_loop_condition(state: GraphState) -> bool: assert result.status == Status.COMPLETED assert len(result.execution_order) >= 2 assert multi_agent.invoke_async.call_count >= 2 + + +@pytest.mark.asyncio +async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying Agent nodes.""" + # Create a mock agent that captures kwargs + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + + async def capture_kwargs(*args, **kwargs): + # Store kwargs for verification + capture_kwargs.captured_kwargs = kwargs + return kwargs_agent.return_value + + kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs) + + # Create graph + builder = GraphBuilder() + builder.add_node(kwargs_agent, "kwargs_node") + graph = builder.build() + + # Execute with custom kwargs + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = await graph.invoke_async("Test kwargs passing", **test_kwargs) + + # Verify kwargs were passed to agent + assert hasattr(capture_kwargs, "captured_kwargs") + assert capture_kwargs.captured_kwargs == test_kwargs + assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_graph_kwargs_passing_multiagent(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying MultiAgentBase nodes.""" + # Create a mock MultiAgentBase that captures kwargs + kwargs_multiagent = create_mock_multi_agent("kwargs_multiagent", "MultiAgent response with kwargs") + + # Store the original return value + original_result = kwargs_multiagent.invoke_async.return_value + + async def capture_kwargs(*args, **kwargs): + # Store kwargs for verification + capture_kwargs.captured_kwargs = kwargs + return original_result + + kwargs_multiagent.invoke_async = AsyncMock(side_effect=capture_kwargs) + + # Create graph + builder = GraphBuilder() + builder.add_node(kwargs_multiagent, "multiagent_node") + graph = builder.build() + + # Execute with custom kwargs + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = await graph.invoke_async("Test kwargs passing to multiagent", **test_kwargs) + + # Verify kwargs were passed to multiagent + assert hasattr(capture_kwargs, "captured_kwargs") + assert capture_kwargs.captured_kwargs == test_kwargs + assert result.status == Status.COMPLETED + + +def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying nodes in sync execution.""" + # Create a mock agent that captures kwargs + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + + async def capture_kwargs(*args, **kwargs): + # Store kwargs for verification + capture_kwargs.captured_kwargs = kwargs + return kwargs_agent.return_value + + kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs) + + # Create graph + builder = GraphBuilder() + builder.add_node(kwargs_agent, "kwargs_node") + graph = builder.build() + + # Execute with custom kwargs + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = graph("Test kwargs passing sync", **test_kwargs) + + # Verify kwargs were passed to agent + assert hasattr(capture_kwargs, "captured_kwargs") + assert capture_kwargs.captured_kwargs == test_kwargs + assert result.status == Status.COMPLETED diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 74f89241f..d4653b1e2 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -469,3 +469,54 @@ def test_swarm_validate_unsupported_features(): with pytest.raises(ValueError, match="Session persistence is not supported for Swarm agents yet"): Swarm([agent_with_session]) + + +@pytest.mark.asyncio +async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying agents.""" + # Create a mock agent that captures kwargs + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + + async def capture_kwargs(*args, **kwargs): + # Store kwargs for verification + capture_kwargs.captured_kwargs = kwargs + return kwargs_agent.return_value + + kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs) + + # Create swarm + swarm = Swarm(nodes=[kwargs_agent]) + + # Execute with custom kwargs + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = await swarm.invoke_async("Test kwargs passing", **test_kwargs) + + # Verify kwargs were passed to agent + assert hasattr(capture_kwargs, "captured_kwargs") + assert capture_kwargs.captured_kwargs == test_kwargs + assert result.status == Status.COMPLETED + + +def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying agents in sync execution.""" + # Create a mock agent that captures kwargs + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + + async def capture_kwargs(*args, **kwargs): + # Store kwargs for verification + capture_kwargs.captured_kwargs = kwargs + return kwargs_agent.return_value + + kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs) + + # Create swarm + swarm = Swarm(nodes=[kwargs_agent]) + + # Execute with custom kwargs + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = swarm("Test kwargs passing sync", **test_kwargs) + + # Verify kwargs were passed to agent + assert hasattr(capture_kwargs, "captured_kwargs") + assert capture_kwargs.captured_kwargs == test_kwargs + assert result.status == Status.COMPLETED From a4941bc13cfffbab116162ac8ad44b66cec66950 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 8 Sep 2025 14:16:19 -0400 Subject: [PATCH 2/7] fix: adjust unit tests --- tests/strands/multiagent/test_graph.py | 50 +++++--------------------- tests/strands/multiagent/test_swarm.py | 30 +++------------- 2 files changed, 12 insertions(+), 68 deletions(-) diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 71235b3a0..0116cf84e 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1290,84 +1290,50 @@ def multi_loop_condition(state: GraphState) -> bool: @pytest.mark.asyncio async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span): """Test that kwargs are passed through to underlying Agent nodes.""" - # Create a mock agent that captures kwargs kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) - async def capture_kwargs(*args, **kwargs): - # Store kwargs for verification - capture_kwargs.captured_kwargs = kwargs - return kwargs_agent.return_value - - kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs) - - # Create graph builder = GraphBuilder() builder.add_node(kwargs_agent, "kwargs_node") graph = builder.build() - # Execute with custom kwargs test_kwargs = {"custom_param": "test_value", "another_param": 42} result = await graph.invoke_async("Test kwargs passing", **test_kwargs) - # Verify kwargs were passed to agent - assert hasattr(capture_kwargs, "captured_kwargs") - assert capture_kwargs.captured_kwargs == test_kwargs + kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing"}], **test_kwargs) assert result.status == Status.COMPLETED @pytest.mark.asyncio async def test_graph_kwargs_passing_multiagent(mock_strands_tracer, mock_use_span): """Test that kwargs are passed through to underlying MultiAgentBase nodes.""" - # Create a mock MultiAgentBase that captures kwargs kwargs_multiagent = create_mock_multi_agent("kwargs_multiagent", "MultiAgent response with kwargs") + kwargs_multiagent.invoke_async = Mock(side_effect=kwargs_multiagent.invoke_async) - # Store the original return value - original_result = kwargs_multiagent.invoke_async.return_value - - async def capture_kwargs(*args, **kwargs): - # Store kwargs for verification - capture_kwargs.captured_kwargs = kwargs - return original_result - - kwargs_multiagent.invoke_async = AsyncMock(side_effect=capture_kwargs) - - # Create graph builder = GraphBuilder() builder.add_node(kwargs_multiagent, "multiagent_node") graph = builder.build() - # Execute with custom kwargs test_kwargs = {"custom_param": "test_value", "another_param": 42} result = await graph.invoke_async("Test kwargs passing to multiagent", **test_kwargs) - # Verify kwargs were passed to multiagent - assert hasattr(capture_kwargs, "captured_kwargs") - assert capture_kwargs.captured_kwargs == test_kwargs + kwargs_multiagent.invoke_async.assert_called_once_with( + [{"text": "Test kwargs passing to multiagent"}], **test_kwargs + ) assert result.status == Status.COMPLETED def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): """Test that kwargs are passed through to underlying nodes in sync execution.""" - # Create a mock agent that captures kwargs kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) - async def capture_kwargs(*args, **kwargs): - # Store kwargs for verification - capture_kwargs.captured_kwargs = kwargs - return kwargs_agent.return_value - - kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs) - - # Create graph builder = GraphBuilder() builder.add_node(kwargs_agent, "kwargs_node") graph = builder.build() - # Execute with custom kwargs test_kwargs = {"custom_param": "test_value", "another_param": 42} result = graph("Test kwargs passing sync", **test_kwargs) - # Verify kwargs were passed to agent - assert hasattr(capture_kwargs, "captured_kwargs") - assert capture_kwargs.captured_kwargs == test_kwargs + kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_kwargs) assert result.status == Status.COMPLETED diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index d4653b1e2..506d45e1a 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -474,49 +474,27 @@ def test_swarm_validate_unsupported_features(): @pytest.mark.asyncio async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span): """Test that kwargs are passed through to underlying agents.""" - # Create a mock agent that captures kwargs kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) - async def capture_kwargs(*args, **kwargs): - # Store kwargs for verification - capture_kwargs.captured_kwargs = kwargs - return kwargs_agent.return_value - - kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs) - - # Create swarm swarm = Swarm(nodes=[kwargs_agent]) - # Execute with custom kwargs test_kwargs = {"custom_param": "test_value", "another_param": 42} result = await swarm.invoke_async("Test kwargs passing", **test_kwargs) - # Verify kwargs were passed to agent - assert hasattr(capture_kwargs, "captured_kwargs") - assert capture_kwargs.captured_kwargs == test_kwargs + assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs assert result.status == Status.COMPLETED def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): """Test that kwargs are passed through to underlying agents in sync execution.""" - # Create a mock agent that captures kwargs kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) - async def capture_kwargs(*args, **kwargs): - # Store kwargs for verification - capture_kwargs.captured_kwargs = kwargs - return kwargs_agent.return_value - - kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs) - - # Create swarm swarm = Swarm(nodes=[kwargs_agent]) - # Execute with custom kwargs test_kwargs = {"custom_param": "test_value", "another_param": 42} result = swarm("Test kwargs passing sync", **test_kwargs) - # Verify kwargs were passed to agent - assert hasattr(capture_kwargs, "captured_kwargs") - assert capture_kwargs.captured_kwargs == test_kwargs + assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs assert result.status == Status.COMPLETED From f9846172d57831cf87c3e394a62a4a6570ef16b1 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 8 Sep 2025 15:09:21 -0400 Subject: [PATCH 3/7] fix: refactor to pass invocation_state rather thn passing kwargs --- src/strands/multiagent/base.py | 30 +++++++++++++++--- src/strands/multiagent/graph.py | 42 +++++++++++++++++++++----- src/strands/multiagent/swarm.py | 38 +++++++++++++++++++---- tests/strands/multiagent/test_base.py | 2 +- tests/strands/multiagent/test_graph.py | 20 ++++++------ tests/strands/multiagent/test_swarm.py | 4 +-- 6 files changed, 103 insertions(+), 33 deletions(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 69578cb5d..03d7de9b4 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -84,15 +84,35 @@ class MultiAgentBase(ABC): """ @abstractmethod - async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult: - """Invoke asynchronously.""" + async def invoke_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> MultiAgentResult: + """Invoke asynchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Additional keyword arguments passed to underlying agents. + """ raise NotImplementedError("invoke_async not implemented") - def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult: - """Invoke synchronously.""" + def __call__( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> MultiAgentResult: + """Invoke synchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Additional keyword arguments passed to underlying agents. + """ + if invocation_state is None: + invocation_state = {} def execute() -> MultiAgentResult: - return asyncio.run(self.invoke_async(task, **kwargs)) + return asyncio.run(self.invoke_async(task, invocation_state, **kwargs)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index cc2cc435b..ff6ee2b3c 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -385,18 +385,42 @@ def __init__( self.state = GraphState() self.tracer = get_tracer() - def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult: - """Invoke the graph synchronously.""" + def __call__( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> GraphResult: + """Invoke the graph synchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Additional keyword arguments passed to underlying agents. + """ + if invocation_state is None: + invocation_state = {} def execute() -> GraphResult: - return asyncio.run(self.invoke_async(task, **kwargs)) + return asyncio.run(self.invoke_async(task, invocation_state, **kwargs)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() - async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult: - """Invoke the graph asynchronously.""" + async def invoke_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> GraphResult: + """Invoke the graph asynchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues - a new empty dict + is created if None is provided. + **kwargs: Additional keyword arguments passed to underlying agents. + """ + if invocation_state is None: + invocation_state = {} + logger.debug("task=<%s> | starting graph execution", task) # Initialize state @@ -420,7 +444,9 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> G self.node_timeout or "None", ) - await self._execute_graph(kwargs) + # Merge kwargs into invocation_state for internal execution + merged_state = {**invocation_state, **kwargs} + await self._execute_graph(merged_state) # Set final status based on execution results if self.state.failed_nodes: @@ -529,11 +555,11 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if isinstance(node.executor, MultiAgentBase): if self.node_timeout is not None: multi_agent_result = await asyncio.wait_for( - node.executor.invoke_async(node_input, **invocation_state), + node.executor.invoke_async(node_input, invocation_state), timeout=self.node_timeout, ) else: - multi_agent_result = await node.executor.invoke_async(node_input, **invocation_state) + multi_agent_result = await node.executor.invoke_async(node_input, invocation_state) # Create NodeResult with MultiAgentResult directly node_result = NodeResult( diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 6e518215f..d8cf4a7a1 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -237,18 +237,42 @@ def __init__( self._setup_swarm(nodes) self._inject_swarm_tools() - def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult: - """Invoke the swarm synchronously.""" + def __call__( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> SwarmResult: + """Invoke the swarm synchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Additional keyword arguments passed to underlying agents. + """ + if invocation_state is None: + invocation_state = {} def execute() -> SwarmResult: - return asyncio.run(self.invoke_async(task, **kwargs)) + return asyncio.run(self.invoke_async(task, invocation_state, **kwargs)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() - async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult: - """Invoke the swarm asynchronously.""" + async def invoke_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> SwarmResult: + """Invoke the swarm asynchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues - a new empty dict + is created if None is provided. + **kwargs: Additional keyword arguments passed to underlying agents. + """ + if invocation_state is None: + invocation_state = {} + logger.debug("starting swarm execution") # Initialize swarm state with configuration @@ -272,7 +296,9 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> S self.execution_timeout, ) - await self._execute_swarm(kwargs) + # Merge kwargs into invocation_state for internal execution + merged_state = {**invocation_state, **kwargs} + await self._execute_swarm(merged_state) except Exception: logger.exception("swarm execution failed") self.state.completion_status = Status.FAILED diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 395d9275c..d21aa6e14 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -155,7 +155,7 @@ def __init__(self): self.received_task = None self.received_kwargs = None - async def invoke_async(self, task, **kwargs): + async def invoke_async(self, task, invocation_state, **kwargs): self.invoke_async_called = True self.received_task = task self.received_kwargs = kwargs diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 0116cf84e..3ac4a8ff0 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1297,10 +1297,10 @@ async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span): builder.add_node(kwargs_agent, "kwargs_node") graph = builder.build() - test_kwargs = {"custom_param": "test_value", "another_param": 42} - result = await graph.invoke_async("Test kwargs passing", **test_kwargs) + test_invocation_state = {"custom_param": "test_value", "another_param": 42} + result = await graph.invoke_async("Test kwargs passing", test_invocation_state) - kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing"}], **test_kwargs) + kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing"}], **test_invocation_state) assert result.status == Status.COMPLETED @@ -1314,12 +1314,10 @@ async def test_graph_kwargs_passing_multiagent(mock_strands_tracer, mock_use_spa builder.add_node(kwargs_multiagent, "multiagent_node") graph = builder.build() - test_kwargs = {"custom_param": "test_value", "another_param": 42} - result = await graph.invoke_async("Test kwargs passing to multiagent", **test_kwargs) + test_invocation_state = {"custom_param": "test_value", "another_param": 42} + result = await graph.invoke_async("Test kwargs passing to multiagent", test_invocation_state) - kwargs_multiagent.invoke_async.assert_called_once_with( - [{"text": "Test kwargs passing to multiagent"}], **test_kwargs - ) + kwargs_multiagent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing to multiagent"}], test_invocation_state) assert result.status == Status.COMPLETED @@ -1332,8 +1330,8 @@ def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): builder.add_node(kwargs_agent, "kwargs_node") graph = builder.build() - test_kwargs = {"custom_param": "test_value", "another_param": 42} - result = graph("Test kwargs passing sync", **test_kwargs) + test_invocation_state = {"custom_param": "test_value", "another_param": 42} + result = graph("Test kwargs passing sync", test_invocation_state) - kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_kwargs) + kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_invocation_state) assert result.status == Status.COMPLETED diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 506d45e1a..be463c7fd 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -480,7 +480,7 @@ async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span): swarm = Swarm(nodes=[kwargs_agent]) test_kwargs = {"custom_param": "test_value", "another_param": 42} - result = await swarm.invoke_async("Test kwargs passing", **test_kwargs) + result = await swarm.invoke_async("Test kwargs passing", test_kwargs) assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs assert result.status == Status.COMPLETED @@ -494,7 +494,7 @@ def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): swarm = Swarm(nodes=[kwargs_agent]) test_kwargs = {"custom_param": "test_value", "another_param": 42} - result = swarm("Test kwargs passing sync", **test_kwargs) + result = swarm("Test kwargs passing sync", test_kwargs) assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs assert result.status == Status.COMPLETED From 937324207d1386ddbf16ee985a22767c2570fa91 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 8 Sep 2025 15:17:49 -0400 Subject: [PATCH 4/7] fix: update comments --- src/strands/multiagent/swarm.py | 8 +++----- tests/strands/multiagent/test_graph.py | 4 +++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index d8cf4a7a1..441833ea5 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -246,7 +246,7 @@ def __call__( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. - **kwargs: Additional keyword arguments passed to underlying agents. + **kwargs: Keyword arguments allowing backward compatible future changes. """ if invocation_state is None: invocation_state = {} @@ -268,7 +268,7 @@ async def invoke_async( invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues - a new empty dict is created if None is provided. - **kwargs: Additional keyword arguments passed to underlying agents. + **kwargs: Keyword arguments allowing backward compatible future changes. """ if invocation_state is None: invocation_state = {} @@ -296,9 +296,7 @@ async def invoke_async( self.execution_timeout, ) - # Merge kwargs into invocation_state for internal execution - merged_state = {**invocation_state, **kwargs} - await self._execute_swarm(merged_state) + await self._execute_swarm(invocation_state) except Exception: logger.exception("swarm execution failed") self.state.completion_status = Status.FAILED diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 3ac4a8ff0..8097d944e 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1317,7 +1317,9 @@ async def test_graph_kwargs_passing_multiagent(mock_strands_tracer, mock_use_spa test_invocation_state = {"custom_param": "test_value", "another_param": 42} result = await graph.invoke_async("Test kwargs passing to multiagent", test_invocation_state) - kwargs_multiagent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing to multiagent"}], test_invocation_state) + kwargs_multiagent.invoke_async.assert_called_once_with( + [{"text": "Test kwargs passing to multiagent"}], test_invocation_state + ) assert result.status == Status.COMPLETED From 81a7906c9acb86b3e145deb361b921783616ddba Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 8 Sep 2025 15:57:08 -0400 Subject: [PATCH 5/7] fix: removal of merged_state --- src/strands/multiagent/graph.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index ff6ee2b3c..62c354d4f 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -394,7 +394,7 @@ def __call__( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. - **kwargs: Additional keyword arguments passed to underlying agents. + **kwargs: Keyword arguments allowing backward compatible future changes. """ if invocation_state is None: invocation_state = {} @@ -416,7 +416,7 @@ async def invoke_async( invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues - a new empty dict is created if None is provided. - **kwargs: Additional keyword arguments passed to underlying agents. + **kwargs: Keyword arguments allowing backward compatible future changes. """ if invocation_state is None: invocation_state = {} @@ -444,9 +444,7 @@ async def invoke_async( self.node_timeout or "None", ) - # Merge kwargs into invocation_state for internal execution - merged_state = {**invocation_state, **kwargs} - await self._execute_graph(merged_state) + await self._execute_graph(invocation_state) # Set final status based on execution results if self.state.failed_nodes: From b71e689475852029939570ee1aff29dafe6dd228 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 8 Sep 2025 16:08:35 -0400 Subject: [PATCH 6/7] Update src/strands/multiagent/swarm.py Co-authored-by: Nick Clegg --- src/strands/multiagent/swarm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 441833ea5..f649b75e9 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -609,6 +609,7 @@ async def _execute_node( # Execute node result = None node.reset_executor_state() + # Unpacking since this is the agent class. Other executors should not unpack result = await node.executor.invoke_async(node_input, **invocation_state) execution_time = round((time.time() - start_time) * 1000) From 3503bd7e4229fbc4fe2b75de6b61ad8b78665217 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 8 Sep 2025 16:09:58 -0400 Subject: [PATCH 7/7] remove kwargs passing --- src/strands/multiagent/graph.py | 2 +- src/strands/multiagent/swarm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 62c354d4f..738dc4d4c 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -400,7 +400,7 @@ def __call__( invocation_state = {} def execute() -> GraphResult: - return asyncio.run(self.invoke_async(task, invocation_state, **kwargs)) + return asyncio.run(self.invoke_async(task, invocation_state)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index f649b75e9..1c2302c28 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -252,7 +252,7 @@ def __call__( invocation_state = {} def execute() -> SwarmResult: - return asyncio.run(self.invoke_async(task, invocation_state, **kwargs)) + return asyncio.run(self.invoke_async(task, invocation_state)) with ThreadPoolExecutor() as executor: future = executor.submit(execute)