diff --git a/pyproject.toml b/pyproject.toml index a0be0ddc6..f55e2fa47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ dev = [ "pytest-cov>=6.0.0,<7.0.0", "pytest-asyncio>=1.0.0,<1.2.0", "pytest-xdist>=3.0.0,<4.0.0", - "ruff>=0.12.0,<0.13.0", + "ruff>=0.4.4,<0.5.0", ] docs = [ "sphinx>=5.0.0,<6.0.0", diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 1e64f5adb..5c647b4ed 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -441,7 +441,7 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. - If you pass in a prompt, it will be used temporarily without adding it to the conversation history. + If you pass in a prompt, it will be added to the conversation history along with the structured output result. If you don't pass in a prompt, it will use only the existing conversation history to respond. For smaller models, you may want to use the optional prompt to add additional instructions to explicitly @@ -470,7 +470,7 @@ def execute() -> T: async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. - If you pass in a prompt, it will be used temporarily without adding it to the conversation history. + If you pass in a prompt, it will be added to the conversation history along with the structured output result. If you don't pass in a prompt, it will use only the existing conversation history to respond. For smaller models, you may want to use the optional prompt to add additional instructions to explicitly @@ -479,7 +479,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu Args: output_model: The output model (a JSON schema written as a Pydantic BaseModel) that the agent will use when responding. - prompt: The prompt to use for the agent (will not be added to conversation history). + prompt: The prompt to use for the agent (will be added to conversation history). Raises: ValueError: If no conversation history or prompt is provided. @@ -492,7 +492,13 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu if not self.messages and not prompt: raise ValueError("No conversation history or prompt provided") - temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt) + # Add prompt to conversation history if provided + if prompt: + prompt_messages = self._convert_prompt_to_messages(prompt) + for message in prompt_messages: + self._append_message(message) + + temp_messages: Messages = self.messages structured_output_span.set_attributes( { @@ -502,16 +508,16 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu "gen_ai.operation.name": "execute_structured_output", } ) - if self.system_prompt: - structured_output_span.add_event( - "gen_ai.system.message", - attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])}, - ) for message in temp_messages: structured_output_span.add_event( f"gen_ai.{message['role']}.message", attributes={"role": message["role"], "content": serialize(message["content"])}, ) + if self.system_prompt: + structured_output_span.add_event( + "gen_ai.system.message", + attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])}, + ) events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) async for event in events: if "callback" in event: @@ -519,7 +525,16 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu structured_output_span.add_event( "gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())} ) - return event["output"] + + # Add structured output result to conversation history + result = event["output"] + assistant_message = { + "role": "assistant", + "content": [{"text": f"Structured output ({output_model.__name__}): {result.model_dump_json()}"}] + } + self._append_message(assistant_message) + + return result finally: self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 081193b10..5fbee7133 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -389,7 +389,7 @@ def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult """Invoke the graph synchronously.""" def execute() -> GraphResult: - return asyncio.run(self.invoke_async(task)) + return asyncio.run(self.invoke_async(task, **kwargs)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) @@ -420,7 +420,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> G self.node_timeout or "None", ) - await self._execute_graph() + await self._execute_graph(kwargs) # Set final status based on execution results if self.state.failed_nodes: @@ -450,7 +450,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - async def _execute_graph(self) -> None: + async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: """Unified execution flow with conditional routing.""" ready_nodes = list(self.entry_points) @@ -470,7 +470,7 @@ async def _execute_graph(self) -> None: # Execute current batch of ready nodes concurrently tasks = [ - asyncio.create_task(self._execute_node(node)) + asyncio.create_task(self._execute_node(node, invocation_state)) for node in current_batch if node not in self.state.completed_nodes ] @@ -515,7 +515,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode) -> bool: ) return False - async def _execute_node(self, node: GraphNode) -> None: + async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> None: """Execute a single node with error handling and timeout protection.""" # Reset the node's state if reset_on_revisit is enabled and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: @@ -538,11 +538,11 @@ async def _execute_node(self, node: GraphNode) -> None: if isinstance(node.executor, MultiAgentBase): if self.node_timeout is not None: multi_agent_result = await asyncio.wait_for( - node.executor.invoke_async(node_input), + node.executor.invoke_async(node_input, **invocation_state), timeout=self.node_timeout, ) else: - multi_agent_result = await node.executor.invoke_async(node_input) + multi_agent_result = await node.executor.invoke_async(node_input, **invocation_state) # Create NodeResult with MultiAgentResult directly node_result = NodeResult( @@ -557,11 +557,11 @@ async def _execute_node(self, node: GraphNode) -> None: elif isinstance(node.executor, Agent): if self.node_timeout is not None: agent_response = await asyncio.wait_for( - node.executor.invoke_async(node_input), + node.executor.invoke_async(node_input, **invocation_state), timeout=self.node_timeout, ) else: - agent_response = await node.executor.invoke_async(node_input) + agent_response = await node.executor.invoke_async(node_input, **invocation_state) # Extract metrics from agent response usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index d730d5156..6e518215f 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -241,7 +241,7 @@ def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult """Invoke the swarm synchronously.""" def execute() -> SwarmResult: - return asyncio.run(self.invoke_async(task)) + return asyncio.run(self.invoke_async(task, **kwargs)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) @@ -272,7 +272,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> S self.execution_timeout, ) - await self._execute_swarm() + await self._execute_swarm(kwargs) except Exception: logger.exception("swarm execution failed") self.state.completion_status = Status.FAILED @@ -483,7 +483,7 @@ def _build_node_input(self, target_node: SwarmNode) -> str: return context_text - async def _execute_swarm(self) -> None: + async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: """Shared execution logic used by execute_async.""" try: # Main execution loop @@ -522,7 +522,7 @@ async def _execute_swarm(self) -> None: # TODO: Implement cancellation token to stop _execute_node from continuing try: await asyncio.wait_for( - self._execute_node(current_node, self.state.task), + self._execute_node(current_node, self.state.task, invocation_state), timeout=self.node_timeout, ) @@ -563,7 +563,9 @@ async def _execute_swarm(self) -> None: f"{elapsed_time:.2f}", ) - async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -> AgentResult: + async def _execute_node( + self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] + ) -> AgentResult: """Execute swarm node.""" start_time = time.time() node_name = node.node_id @@ -583,7 +585,7 @@ async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) - # Execute node result = None node.reset_executor_state() - result = await node.executor.invoke_async(node_input) + result = await node.executor.invoke_async(node_input, **invocation_state) execution_time = round((time.time() - start_time) * 1000) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index a8561abe4..30245dc23 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -18,7 +18,6 @@ from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager -from strands.telemetry.tracer import serialize from strands.types._events import EventLoopStopEvent, ModelStreamEvent from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException, EventLoopException @@ -986,6 +985,9 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): agent.tracer = mock_strands_tracer agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) + agent.hooks = unittest.mock.MagicMock() + agent.hooks.invoke_callbacks = unittest.mock.Mock() + agent.callback_handler = unittest.mock.Mock() prompt = "Jane Doe is 30 years old and her email is jane@doe.com" @@ -996,12 +998,31 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): exp_result = user assert tru_result == exp_result - # Verify conversation history is not polluted - assert len(agent.messages) == initial_message_count + # Verify conversation history is updated with prompt and structured output + assert len(agent.messages) == initial_message_count + 2 + + # Verify the prompt was added to conversation history + user_message_added = any( + msg['role'] == 'user' and prompt in msg['content'][0]['text'] + for msg in agent.messages + ) + assert user_message_added, "User prompt should be added to conversation history" + + # Verify the structured output was added to conversation history + assistant_message_added = any( + msg['role'] == 'assistant' and 'Structured output (User):' in msg['content'][0]['text'] + for msg in agent.messages + ) + assert assistant_message_added, "Structured output should be added to conversation history" - # Verify the model was called with temporary messages array + # Verify the model was called with all messages (including the added prompt) agent.model.structured_output.assert_called_once_with( - type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt + type(user), + [ + {"role": "user", "content": [{"text": prompt}]}, + {"role": "assistant", "content": [{"text": f"Structured output (User): {user.model_dump_json()}"}]} + ], + system_prompt=system_prompt ) mock_span.set_attributes.assert_called_once_with( @@ -1013,23 +1034,15 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): } ) - # ensure correct otel event messages are emitted - act_event_names = mock_span.add_event.call_args_list - exp_event_names = [ - unittest.mock.call( - "gen_ai.system.message", attributes={"role": "system", "content": serialize([{"text": system_prompt}])} - ), - unittest.mock.call( - "gen_ai.user.message", - attributes={ - "role": "user", - "content": '[{"text": "Jane Doe is 30 years old and her email is jane@doe.com"}]', - }, - ), - unittest.mock.call("gen_ai.choice", attributes={"message": json.dumps(user.model_dump())}), - ] + mock_span.add_event.assert_any_call( + "gen_ai.user.message", + attributes={"role": "user", "content": '[{"text": "Jane Doe is 30 years old and her email is jane@doe.com"}]'}, + ) - assert act_event_names == exp_event_names + mock_span.add_event.assert_called_with( + "gen_ai.choice", + attributes={"message": json.dumps(user.model_dump())}, + ) def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, agenerator): @@ -1061,12 +1074,31 @@ def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, a exp_result = user assert tru_result == exp_result - # Verify conversation history is not polluted - assert len(agent.messages) == initial_message_count + # Verify conversation history is updated with prompt and structured output + assert len(agent.messages) == initial_message_count + 2 + + # Verify the multi-modal prompt was added to conversation history + user_message_added = any( + msg['role'] == 'user' and 'Please describe the user in this image' in msg['content'][0]['text'] + for msg in agent.messages + ) + assert user_message_added, "Multi-modal user prompt should be added to conversation history" + + # Verify the structured output was added to conversation history + assistant_message_added = any( + msg['role'] == 'assistant' and 'Structured output (User):' in msg['content'][0]['text'] + for msg in agent.messages + ) + assert assistant_message_added, "Structured output should be added to conversation history" - # Verify the model was called with temporary messages array + # Verify the model was called with all messages (including the added prompt) agent.model.structured_output.assert_called_once_with( - type(user), [{"role": "user", "content": prompt}], system_prompt=system_prompt + type(user), + [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": [{"text": f"Structured output (User): {user.model_dump_json()}"}]} + ], + system_prompt=system_prompt ) mock_span.add_event.assert_called_with( @@ -1078,6 +1110,9 @@ def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, a @pytest.mark.asyncio async def test_agent_structured_output_in_async_context(agent, user, agenerator): agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) + agent.hooks = unittest.mock.MagicMock() + agent.hooks.invoke_callbacks = unittest.mock.Mock() + agent.callback_handler = unittest.mock.Mock() prompt = "Jane Doe is 30 years old and her email is jane@doe.com" @@ -1088,13 +1123,30 @@ async def test_agent_structured_output_in_async_context(agent, user, agenerator) exp_result = user assert tru_result == exp_result - # Verify conversation history is not polluted - assert len(agent.messages) == initial_message_count + # Verify conversation history is updated with prompt and structured output + assert len(agent.messages) == initial_message_count + 2 + + # Verify the prompt was added to conversation history + user_message_added = any( + msg['role'] == 'user' and prompt in msg['content'][0]['text'] + for msg in agent.messages + ) + assert user_message_added, "User prompt should be added to conversation history" + + # Verify the structured output was added to conversation history + assistant_message_added = any( + msg['role'] == 'assistant' and 'Structured output (User):' in msg['content'][0]['text'] + for msg in agent.messages + ) + assert assistant_message_added, "Structured output should be added to conversation history" def test_agent_structured_output_without_prompt(agent, system_prompt, user, agenerator): """Test that structured_output works with existing conversation history and no new prompt.""" agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) + agent.hooks = unittest.mock.MagicMock() + agent.hooks.invoke_callbacks = unittest.mock.Mock() + agent.callback_handler = unittest.mock.Mock() # Add some existing messages to the agent existing_messages = [ @@ -1109,17 +1161,27 @@ def test_agent_structured_output_without_prompt(agent, system_prompt, user, agen exp_result = user assert tru_result == exp_result - # Verify conversation history is unchanged - assert len(agent.messages) == initial_message_count - assert agent.messages == existing_messages + # Verify conversation history is updated with structured output only (no prompt added) + assert len(agent.messages) == initial_message_count + 1 + + # Verify the structured output was added to conversation history + assistant_message_added = any( + msg['role'] == 'assistant' and 'Structured output (User):' in msg['content'][0]['text'] + for msg in agent.messages + ) + assert assistant_message_added, "Structured output should be added to conversation history" - # Verify the model was called with existing messages only - agent.model.structured_output.assert_called_once_with(type(user), existing_messages, system_prompt=system_prompt) + # Verify the model was called with existing messages plus the added structured output + expected_messages = existing_messages + [{"role": "assistant", "content": [{"text": f"Structured output (User): {user.model_dump_json()}"}]}] + agent.model.structured_output.assert_called_once_with(type(user), expected_messages, system_prompt=system_prompt) @pytest.mark.asyncio async def test_agent_structured_output_async(agent, system_prompt, user, agenerator): agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) + agent.hooks = unittest.mock.MagicMock() + agent.hooks.invoke_callbacks = unittest.mock.Mock() + agent.callback_handler = unittest.mock.Mock() prompt = "Jane Doe is 30 years old and her email is jane@doe.com" @@ -1130,12 +1192,31 @@ async def test_agent_structured_output_async(agent, system_prompt, user, agenera exp_result = user assert tru_result == exp_result - # Verify conversation history is not polluted - assert len(agent.messages) == initial_message_count + # Verify conversation history is updated with prompt and structured output + assert len(agent.messages) == initial_message_count + 2 + + # Verify the prompt was added to conversation history + user_message_added = any( + msg['role'] == 'user' and prompt in msg['content'][0]['text'] + for msg in agent.messages + ) + assert user_message_added, "User prompt should be added to conversation history" + + # Verify the structured output was added to conversation history + assistant_message_added = any( + msg['role'] == 'assistant' and 'Structured output (User):' in msg['content'][0]['text'] + for msg in agent.messages + ) + assert assistant_message_added, "Structured output should be added to conversation history" - # Verify the model was called with temporary messages array + # Verify the model was called with all messages (including the added prompt) agent.model.structured_output.assert_called_once_with( - type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt + type(user), + [ + {"role": "user", "content": [{"text": prompt}]}, + {"role": "assistant", "content": [{"text": f"Structured output (User): {user.model_dump_json()}"}]} + ], + system_prompt=system_prompt ) diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 9ab008ca2..c203078a8 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -267,12 +267,14 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): length, events = hook_provider.get_events() - assert length == 2 + assert length == 4 # BeforeInvocationEvent, MessageAddedEvent (prompt), MessageAddedEvent (output), AfterInvocationEvent assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0]) # Prompt added + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) # Output added assert next(events) == AfterInvocationEvent(agent=agent) - assert len(agent.messages) == 0 # no new messages added + assert len(agent.messages) == 2 # prompt and structured output added @pytest.mark.asyncio @@ -284,9 +286,11 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a length, events = hook_provider.get_events() - assert length == 2 + assert length == 4 # BeforeInvocationEvent, MessageAddedEvent (prompt), MessageAddedEvent (output), AfterInvocationEvent assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0]) # Prompt added + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) # Output added assert next(events) == AfterInvocationEvent(agent=agent) - assert len(agent.messages) == 0 # no new messages added + assert len(agent.messages) == 2 # prompt and structured output added diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 9977c54cd..e00048b49 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1087,3 +1087,89 @@ async def test_state_reset_only_with_cycles_enabled(): # With reset_on_revisit enabled, reset should be called mock_reset.assert_called_once() + + +@pytest.mark.asyncio +async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying Agent nodes.""" + # Create a mock agent that captures kwargs + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + + async def capture_kwargs(*args, **kwargs): + # Store kwargs for verification + capture_kwargs.captured_kwargs = kwargs + return kwargs_agent.return_value + + kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs) + + # Create graph + builder = GraphBuilder() + builder.add_node(kwargs_agent, "kwargs_node") + graph = builder.build() + + # Execute with custom kwargs + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = await graph.invoke_async("Test kwargs passing", **test_kwargs) + + # Verify kwargs were passed to agent + assert hasattr(capture_kwargs, "captured_kwargs") + assert capture_kwargs.captured_kwargs == test_kwargs + assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_graph_kwargs_passing_multiagent(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying MultiAgentBase nodes.""" + # Create a mock MultiAgentBase that captures kwargs + kwargs_multiagent = create_mock_multi_agent("kwargs_multiagent", "MultiAgent response with kwargs") + + # Store the original return value + original_result = kwargs_multiagent.invoke_async.return_value + + async def capture_kwargs(*args, **kwargs): + # Store kwargs for verification + capture_kwargs.captured_kwargs = kwargs + return original_result + + kwargs_multiagent.invoke_async = AsyncMock(side_effect=capture_kwargs) + + # Create graph + builder = GraphBuilder() + builder.add_node(kwargs_multiagent, "multiagent_node") + graph = builder.build() + + # Execute with custom kwargs + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = await graph.invoke_async("Test kwargs passing to multiagent", **test_kwargs) + + # Verify kwargs were passed to multiagent + assert hasattr(capture_kwargs, "captured_kwargs") + assert capture_kwargs.captured_kwargs == test_kwargs + assert result.status == Status.COMPLETED + + +def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying nodes in sync execution.""" + # Create a mock agent that captures kwargs + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + + async def capture_kwargs(*args, **kwargs): + # Store kwargs for verification + capture_kwargs.captured_kwargs = kwargs + return kwargs_agent.return_value + + kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs) + + # Create graph + builder = GraphBuilder() + builder.add_node(kwargs_agent, "kwargs_node") + graph = builder.build() + + # Execute with custom kwargs + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = graph("Test kwargs passing sync", **test_kwargs) + + # Verify kwargs were passed to agent + assert hasattr(capture_kwargs, "captured_kwargs") + assert capture_kwargs.captured_kwargs == test_kwargs + assert result.status == Status.COMPLETED diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 74f89241f..d4653b1e2 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -469,3 +469,54 @@ def test_swarm_validate_unsupported_features(): with pytest.raises(ValueError, match="Session persistence is not supported for Swarm agents yet"): Swarm([agent_with_session]) + + +@pytest.mark.asyncio +async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying agents.""" + # Create a mock agent that captures kwargs + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + + async def capture_kwargs(*args, **kwargs): + # Store kwargs for verification + capture_kwargs.captured_kwargs = kwargs + return kwargs_agent.return_value + + kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs) + + # Create swarm + swarm = Swarm(nodes=[kwargs_agent]) + + # Execute with custom kwargs + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = await swarm.invoke_async("Test kwargs passing", **test_kwargs) + + # Verify kwargs were passed to agent + assert hasattr(capture_kwargs, "captured_kwargs") + assert capture_kwargs.captured_kwargs == test_kwargs + assert result.status == Status.COMPLETED + + +def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying agents in sync execution.""" + # Create a mock agent that captures kwargs + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + + async def capture_kwargs(*args, **kwargs): + # Store kwargs for verification + capture_kwargs.captured_kwargs = kwargs + return kwargs_agent.return_value + + kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs) + + # Create swarm + swarm = Swarm(nodes=[kwargs_agent]) + + # Execute with custom kwargs + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = swarm("Test kwargs passing sync", **test_kwargs) + + # Verify kwargs were passed to agent + assert hasattr(capture_kwargs, "captured_kwargs") + assert capture_kwargs.captured_kwargs == test_kwargs + assert result.status == Status.COMPLETED