diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 69578cb5d..03d7de9b4 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -84,15 +84,35 @@ class MultiAgentBase(ABC): """ @abstractmethod - async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult: - """Invoke asynchronously.""" + async def invoke_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> MultiAgentResult: + """Invoke asynchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Additional keyword arguments passed to underlying agents. + """ raise NotImplementedError("invoke_async not implemented") - def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult: - """Invoke synchronously.""" + def __call__( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> MultiAgentResult: + """Invoke synchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Additional keyword arguments passed to underlying agents. + """ + if invocation_state is None: + invocation_state = {} def execute() -> MultiAgentResult: - return asyncio.run(self.invoke_async(task, **kwargs)) + return asyncio.run(self.invoke_async(task, invocation_state, **kwargs)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index d2838396d..738dc4d4c 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -385,18 +385,42 @@ def __init__( self.state = GraphState() self.tracer = get_tracer() - def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult: - """Invoke the graph synchronously.""" + def __call__( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> GraphResult: + """Invoke the graph synchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Keyword arguments allowing backward compatible future changes. + """ + if invocation_state is None: + invocation_state = {} def execute() -> GraphResult: - return asyncio.run(self.invoke_async(task)) + return asyncio.run(self.invoke_async(task, invocation_state)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() - async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult: - """Invoke the graph asynchronously.""" + async def invoke_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> GraphResult: + """Invoke the graph asynchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues - a new empty dict + is created if None is provided. + **kwargs: Keyword arguments allowing backward compatible future changes. + """ + if invocation_state is None: + invocation_state = {} + logger.debug("task=<%s> | starting graph execution", task) # Initialize state @@ -420,7 +444,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> G self.node_timeout or "None", ) - await self._execute_graph() + await self._execute_graph(invocation_state) # Set final status based on execution results if self.state.failed_nodes: @@ -450,7 +474,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - async def _execute_graph(self) -> None: + async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: """Unified execution flow with conditional routing.""" ready_nodes = list(self.entry_points) @@ -469,7 +493,7 @@ async def _execute_graph(self) -> None: ready_nodes.clear() # Execute current batch of ready nodes concurrently - tasks = [asyncio.create_task(self._execute_node(node)) for node in current_batch] + tasks = [asyncio.create_task(self._execute_node(node, invocation_state)) for node in current_batch] for task in tasks: await task @@ -506,7 +530,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ ) return False - async def _execute_node(self, node: GraphNode) -> None: + async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> None: """Execute a single node with error handling and timeout protection.""" # Reset the node's state if reset_on_revisit is enabled and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: @@ -529,11 +553,11 @@ async def _execute_node(self, node: GraphNode) -> None: if isinstance(node.executor, MultiAgentBase): if self.node_timeout is not None: multi_agent_result = await asyncio.wait_for( - node.executor.invoke_async(node_input), + node.executor.invoke_async(node_input, invocation_state), timeout=self.node_timeout, ) else: - multi_agent_result = await node.executor.invoke_async(node_input) + multi_agent_result = await node.executor.invoke_async(node_input, invocation_state) # Create NodeResult with MultiAgentResult directly node_result = NodeResult( @@ -548,11 +572,11 @@ async def _execute_node(self, node: GraphNode) -> None: elif isinstance(node.executor, Agent): if self.node_timeout is not None: agent_response = await asyncio.wait_for( - node.executor.invoke_async(node_input), + node.executor.invoke_async(node_input, **invocation_state), timeout=self.node_timeout, ) else: - agent_response = await node.executor.invoke_async(node_input) + agent_response = await node.executor.invoke_async(node_input, **invocation_state) # Extract metrics from agent response usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index d730d5156..1c2302c28 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -237,18 +237,42 @@ def __init__( self._setup_swarm(nodes) self._inject_swarm_tools() - def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult: - """Invoke the swarm synchronously.""" + def __call__( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> SwarmResult: + """Invoke the swarm synchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Keyword arguments allowing backward compatible future changes. + """ + if invocation_state is None: + invocation_state = {} def execute() -> SwarmResult: - return asyncio.run(self.invoke_async(task)) + return asyncio.run(self.invoke_async(task, invocation_state)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() - async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult: - """Invoke the swarm asynchronously.""" + async def invoke_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> SwarmResult: + """Invoke the swarm asynchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues - a new empty dict + is created if None is provided. + **kwargs: Keyword arguments allowing backward compatible future changes. + """ + if invocation_state is None: + invocation_state = {} + logger.debug("starting swarm execution") # Initialize swarm state with configuration @@ -272,7 +296,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> S self.execution_timeout, ) - await self._execute_swarm() + await self._execute_swarm(invocation_state) except Exception: logger.exception("swarm execution failed") self.state.completion_status = Status.FAILED @@ -483,7 +507,7 @@ def _build_node_input(self, target_node: SwarmNode) -> str: return context_text - async def _execute_swarm(self) -> None: + async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: """Shared execution logic used by execute_async.""" try: # Main execution loop @@ -522,7 +546,7 @@ async def _execute_swarm(self) -> None: # TODO: Implement cancellation token to stop _execute_node from continuing try: await asyncio.wait_for( - self._execute_node(current_node, self.state.task), + self._execute_node(current_node, self.state.task, invocation_state), timeout=self.node_timeout, ) @@ -563,7 +587,9 @@ async def _execute_swarm(self) -> None: f"{elapsed_time:.2f}", ) - async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -> AgentResult: + async def _execute_node( + self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] + ) -> AgentResult: """Execute swarm node.""" start_time = time.time() node_name = node.node_id @@ -583,7 +609,8 @@ async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) - # Execute node result = None node.reset_executor_state() - result = await node.executor.invoke_async(node_input) + # Unpacking since this is the agent class. Other executors should not unpack + result = await node.executor.invoke_async(node_input, **invocation_state) execution_time = round((time.time() - start_time) * 1000) diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 395d9275c..d21aa6e14 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -155,7 +155,7 @@ def __init__(self): self.received_task = None self.received_kwargs = None - async def invoke_async(self, task, **kwargs): + async def invoke_async(self, task, invocation_state, **kwargs): self.invoke_async_called = True self.received_task = task self.received_kwargs = kwargs diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 1a598847d..8097d944e 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1285,3 +1285,55 @@ def multi_loop_condition(state: GraphState) -> bool: assert result.status == Status.COMPLETED assert len(result.execution_order) >= 2 assert multi_agent.invoke_async.call_count >= 2 + + +@pytest.mark.asyncio +async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying Agent nodes.""" + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) + + builder = GraphBuilder() + builder.add_node(kwargs_agent, "kwargs_node") + graph = builder.build() + + test_invocation_state = {"custom_param": "test_value", "another_param": 42} + result = await graph.invoke_async("Test kwargs passing", test_invocation_state) + + kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing"}], **test_invocation_state) + assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_graph_kwargs_passing_multiagent(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying MultiAgentBase nodes.""" + kwargs_multiagent = create_mock_multi_agent("kwargs_multiagent", "MultiAgent response with kwargs") + kwargs_multiagent.invoke_async = Mock(side_effect=kwargs_multiagent.invoke_async) + + builder = GraphBuilder() + builder.add_node(kwargs_multiagent, "multiagent_node") + graph = builder.build() + + test_invocation_state = {"custom_param": "test_value", "another_param": 42} + result = await graph.invoke_async("Test kwargs passing to multiagent", test_invocation_state) + + kwargs_multiagent.invoke_async.assert_called_once_with( + [{"text": "Test kwargs passing to multiagent"}], test_invocation_state + ) + assert result.status == Status.COMPLETED + + +def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying nodes in sync execution.""" + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) + + builder = GraphBuilder() + builder.add_node(kwargs_agent, "kwargs_node") + graph = builder.build() + + test_invocation_state = {"custom_param": "test_value", "another_param": 42} + result = graph("Test kwargs passing sync", test_invocation_state) + + kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_invocation_state) + assert result.status == Status.COMPLETED diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 74f89241f..be463c7fd 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -469,3 +469,32 @@ def test_swarm_validate_unsupported_features(): with pytest.raises(ValueError, match="Session persistence is not supported for Swarm agents yet"): Swarm([agent_with_session]) + + +@pytest.mark.asyncio +async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying agents.""" + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) + + swarm = Swarm(nodes=[kwargs_agent]) + + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = await swarm.invoke_async("Test kwargs passing", test_kwargs) + + assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs + assert result.status == Status.COMPLETED + + +def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying agents in sync execution.""" + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) + + swarm = Swarm(nodes=[kwargs_agent]) + + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = swarm("Test kwargs passing sync", test_kwargs) + + assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs + assert result.status == Status.COMPLETED