diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 111509e3a..2022142c6 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -403,8 +403,8 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T: """This method allows you to get structured output from the agent. - If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. - If you don't pass in a prompt, it will use only the conversation history to respond. + If you pass in a prompt, it will be used temporarily without adding it to the conversation history. + If you don't pass in a prompt, it will use only the existing conversation history to respond. For smaller models, you may want to use the optional prompt to add additional instructions to explicitly instruct the model to output the structured data. @@ -412,7 +412,7 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l Args: output_model: The output model (a JSON schema written as a Pydantic BaseModel) that the agent will use when responding. - prompt: The prompt to use for the agent. + prompt: The prompt to use for the agent (will not be added to conversation history). Raises: ValueError: If no conversation history or prompt is provided. @@ -430,8 +430,8 @@ async def structured_output_async( ) -> T: """This method allows you to get structured output from the agent. - If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. - If you don't pass in a prompt, it will use only the conversation history to respond. + If you pass in a prompt, it will be used temporarily without adding it to the conversation history. + If you don't pass in a prompt, it will use only the existing conversation history to respond. For smaller models, you may want to use the optional prompt to add additional instructions to explicitly instruct the model to output the structured data. @@ -439,7 +439,7 @@ async def structured_output_async( Args: output_model: The output model (a JSON schema written as a Pydantic BaseModel) that the agent will use when responding. - prompt: The prompt to use for the agent. + prompt: The prompt to use for the agent (will not be added to conversation history). Raises: ValueError: If no conversation history or prompt is provided. @@ -450,12 +450,14 @@ async def structured_output_async( if not self.messages and not prompt: raise ValueError("No conversation history or prompt provided") - # add the prompt as the last message + # Create temporary messages array if prompt is provided if prompt: content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt - self._append_message({"role": "user", "content": content}) + temp_messages = self.messages + [{"role": "user", "content": content}] + else: + temp_messages = self.messages - events = self.model.structured_output(output_model, self.messages, system_prompt=self.system_prompt) + events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) async for event in events: if "callback" in event: self.callback_handler(**cast(dict, event["callback"])) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 4e310dace..c27243dfe 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -984,10 +984,17 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + # Store initial message count + initial_message_count = len(agent.messages) + tru_result = agent.structured_output(type(user), prompt) exp_result = user assert tru_result == exp_result + # Verify conversation history is not polluted + assert len(agent.messages) == initial_message_count + + # Verify the model was called with temporary messages array agent.model.structured_output.assert_called_once_with( type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt ) @@ -1008,10 +1015,17 @@ def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, a }, ] + # Store initial message count + initial_message_count = len(agent.messages) + tru_result = agent.structured_output(type(user), prompt) exp_result = user assert tru_result == exp_result + # Verify conversation history is not polluted + assert len(agent.messages) == initial_message_count + + # Verify the model was called with temporary messages array agent.model.structured_output.assert_called_once_with( type(user), [{"role": "user", "content": prompt}], system_prompt=system_prompt ) @@ -1023,10 +1037,41 @@ async def test_agent_structured_output_in_async_context(agent, user, agenerator) prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + # Store initial message count + initial_message_count = len(agent.messages) + tru_result = await agent.structured_output_async(type(user), prompt) exp_result = user assert tru_result == exp_result + # Verify conversation history is not polluted + assert len(agent.messages) == initial_message_count + + +def test_agent_structured_output_without_prompt(agent, system_prompt, user, agenerator): + """Test that structured_output works with existing conversation history and no new prompt.""" + agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) + + # Add some existing messages to the agent + existing_messages = [ + {"role": "user", "content": [{"text": "Jane Doe is 30 years old"}]}, + {"role": "assistant", "content": [{"text": "I understand."}]}, + ] + agent.messages.extend(existing_messages) + + initial_message_count = len(agent.messages) + + tru_result = agent.structured_output(type(user)) # No prompt provided + exp_result = user + assert tru_result == exp_result + + # Verify conversation history is unchanged + assert len(agent.messages) == initial_message_count + assert agent.messages == existing_messages + + # Verify the model was called with existing messages only + agent.model.structured_output.assert_called_once_with(type(user), existing_messages, system_prompt=system_prompt) + @pytest.mark.asyncio async def test_agent_structured_output_async(agent, system_prompt, user, agenerator): @@ -1034,10 +1079,17 @@ async def test_agent_structured_output_async(agent, system_prompt, user, agenera prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + # Store initial message count + initial_message_count = len(agent.messages) + tru_result = agent.structured_output(type(user), prompt) exp_result = user assert tru_result == exp_result + # Verify conversation history is not polluted + assert len(agent.messages) == initial_message_count + + # Verify the model was called with temporary messages array agent.model.structured_output.assert_called_once_with( type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt ) diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index cd89fbc7a..9ab008ca2 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -267,13 +267,12 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): length, events = hook_provider.get_events() - assert length == 3 + assert length == 2 assert next(events) == BeforeInvocationEvent(agent=agent) - assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0]) assert next(events) == AfterInvocationEvent(agent=agent) - assert len(agent.messages) == 1 + assert len(agent.messages) == 0 # no new messages added @pytest.mark.asyncio @@ -285,10 +284,9 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a length, events = hook_provider.get_events() - assert length == 3 + assert length == 2 assert next(events) == BeforeInvocationEvent(agent=agent) - assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0]) assert next(events) == AfterInvocationEvent(agent=agent) - assert len(agent.messages) == 1 + assert len(agent.messages) == 0 # no new messages added