diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 54e7a58ec..9c31ec4d5 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -543,6 +543,19 @@ async def _run_loop( # Execute the event loop cycle with retry logic for context limits events = self._execute_event_loop_cycle(invocation_state) async for event in events: + # Signal from the model provider that the message sent by the user should be redacted, + # likely due to a guardrail. + if ( + event.get("callback") + and event["callback"].get("event") + and event["callback"]["event"].get("redactContent") + and event["callback"]["event"]["redactContent"].get("redactUserContentMessage") + ): + self.messages[-1]["content"] = [ + {"text": event["callback"]["event"]["redactContent"]["redactUserContentMessage"]} + ] + if self._session_manager: + self._session_manager.redact_latest_message(self.messages[-1], self) yield event finally: diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index fff0fd6f4..f9a2686ef 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -221,17 +221,13 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason: return event["stopReason"] -def handle_redact_content(event: RedactContentEvent, messages: Messages, state: dict[str, Any]) -> None: +def handle_redact_content(event: RedactContentEvent, state: dict[str, Any]) -> None: """Handles redacting content from the input or output. Args: event: Redact Content Event. - messages: Agent messages. state: The current state of message processing. """ - if event.get("redactUserContentMessage") is not None: - messages[-1]["content"] = [{"text": event["redactUserContentMessage"]}] # type: ignore - if event.get("redactAssistantContentMessage") is not None: state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}] @@ -251,15 +247,11 @@ def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]: return usage, metrics -async def process_stream( - chunks: AsyncIterable[StreamEvent], - messages: Messages, -) -> AsyncGenerator[dict[str, Any], None]: +async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[dict[str, Any], None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: chunks: The chunks of the response stream from the model. - messages: The agents messages. Returns: The reason for stopping, the constructed message, and the usage metrics. @@ -295,7 +287,7 @@ async def process_stream( elif "metadata" in chunk: usage, metrics = extract_usage_metrics(chunk["metadata"]) elif "redactContent" in chunk: - handle_redact_content(chunk["redactContent"], messages, state) + handle_redact_content(chunk["redactContent"], state) yield {"stop": (stop_reason, state["message"], usage, metrics)} @@ -323,5 +315,5 @@ async def stream_messages( chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt) - async for event in process_stream(chunks, messages): + async for event in process_stream(chunks): yield event diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index dae05394e..936f799d7 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -407,7 +407,7 @@ async def structured_output( tool_spec = convert_pydantic_to_tool_spec(output_model) response = self.stream(messages=prompt, tool_specs=[tool_spec], **kwargs) - async for event in process_stream(response, prompt): + async for event in process_stream(response): yield event stop_reason, messages, _, _ = event["stop"] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 1463b280b..ce76a246a 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -577,7 +577,7 @@ async def structured_output( tool_spec = convert_pydantic_to_tool_spec(output_model) response = self.stream(messages=prompt, tool_specs=[tool_spec], **kwargs) - async for event in streaming.process_stream(response, prompt): + async for event in streaming.process_stream(response): yield event stop_reason, messages, _, _ = event["stop"] diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index fd31d9671..534afab34 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -1,6 +1,7 @@ """Repository session manager implementation.""" import logging +from typing import Optional from ..agent.agent import Agent from ..agent.state import AgentState @@ -50,6 +51,9 @@ def __init__( # Keep track of the initialized agent id's so that two agents in a session cannot share an id self._initialized_agent_ids: set[str] = set() + # Keep track of the latest message stored in the session in case we need to redact its content. + self._latest_message: Optional[SessionMessage] = None + def append_message(self, message: Message, agent: Agent) -> None: """Append a message to the agent's session. @@ -57,8 +61,20 @@ def append_message(self, message: Message, agent: Agent) -> None: message: Message to add to the agent in the session agent: Agent to append the message to """ - session_message = SessionMessage.from_message(message) - self.session_repository.create_message(self.session_id, agent.agent_id, session_message) + self._latest_message = SessionMessage.from_message(message) + self.session_repository.create_message(self.session_id, agent.agent_id, self._latest_message) + + def redact_latest_message(self, redact_message: Message, agent: Agent) -> None: + """Redact the latest message appended to the session. + + Args: + redact_message: New message to use that contains the redact content + agent: Agent to apply the message redaction to + """ + if self._latest_message is None: + raise SessionException("No message to redact.") + self._latest_message.redact_message = redact_message + return self.session_repository.update_message(self.session_id, agent.agent_id, self._latest_message) def sync_agent(self, agent: Agent) -> None: """Serialize and update the agent into the session repository. diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 6f071f929..3e1d986de 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -26,6 +26,15 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent)) registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent)) + @abstractmethod + def redact_latest_message(self, redact_message: Message, agent: "Agent") -> None: + """Redact the message most recently appended to the agent in the session. + + Args: + redact_message: New message to use that contains the redact content + agent: Agent to apply the message redaction to + """ + @abstractmethod def append_message(self, message: Message, agent: "Agent") -> None: """Append a message to the agent's session. diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 50d82b368..dc0761f48 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import Any, Dict, cast +from typing import Any, Dict, Optional, cast from uuid import uuid4 from ..agent.agent import Agent @@ -54,9 +54,18 @@ def decode_bytes_values(obj: Any) -> Any: @dataclass class SessionMessage: - """Message within a SessionAgent.""" + """Message within a SessionAgent. + + Attributes: + message: Message content + redact_message: If the original message is redacted, this is the new content to use + message_id: Unique id for a message + created_at: ISO format timestamp for when this message was created + updated_at: ISO format timestamp for when this message was last updated + """ message: Message + redact_message: Optional[Message] = None message_id: str = field(default_factory=lambda: str(uuid4())) created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) @@ -73,8 +82,14 @@ def from_message(cls, message: Message) -> "SessionMessage": ) def to_message(self) -> Message: - """Convert SessionMessage back to a Message, decoding any bytes values.""" - return cast(Message, decode_bytes_values(self.message)) + """Convert SessionMessage back to a Message, decoding any bytes values. + + If the message was redacted, return the redact content instead. + """ + if self.redact_message is not None: + return cast(Message, decode_bytes_values(self.redact_message)) + else: + return cast(Message, decode_bytes_values(self.message)) @classmethod def from_dict(cls, env: dict[str, Any]) -> "SessionMessage": diff --git a/tests/fixtures/mocked_model_provider.py b/tests/fixtures/mocked_model_provider.py index b951d3abe..e4cb5fe93 100644 --- a/tests/fixtures/mocked_model_provider.py +++ b/tests/fixtures/mocked_model_provider.py @@ -1,5 +1,5 @@ import json -from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar +from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypedDict, TypeVar, Union from pydantic import BaseModel @@ -12,6 +12,11 @@ T = TypeVar("T", bound=BaseModel) +class RedactionMessage(TypedDict): + redactedUserContent: str + redactedAssistantContent: str + + class MockedModelProvider(Model): """A mock implementation of the Model interface for testing purposes. @@ -20,7 +25,7 @@ class MockedModelProvider(Model): to stream mock responses as events. """ - def __init__(self, agent_responses: Messages): + def __init__(self, agent_responses: list[Union[Message, RedactionMessage]]): self.agent_responses = agent_responses self.index = 0 @@ -54,27 +59,36 @@ async def stream( self.index += 1 - def map_agent_message_to_events(self, agent_message: Message) -> Iterable[dict[str, Any]]: + def map_agent_message_to_events(self, agent_message: Union[Message, RedactionMessage]) -> Iterable[dict[str, Any]]: stop_reason: StopReason = "end_turn" yield {"messageStart": {"role": "assistant"}} - for content in agent_message["content"]: - if "text" in content: - yield {"contentBlockStart": {"start": {}}} - yield {"contentBlockDelta": {"delta": {"text": content["text"]}}} - yield {"contentBlockStop": {}} - if "toolUse" in content: - stop_reason = "tool_use" - yield { - "contentBlockStart": { - "start": { - "toolUse": { - "name": content["toolUse"]["name"], - "toolUseId": content["toolUse"]["toolUseId"], + if agent_message.get("redactedAssistantContent"): + yield {"redactContent": {"redactUserContentMessage": agent_message["redactedUserContent"]}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": agent_message["redactedAssistantContent"]}}} + yield {"contentBlockStop": {}} + stop_reason = "guardrail_intervened" + else: + for content in agent_message["content"]: + if "text" in content: + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": content["text"]}}} + yield {"contentBlockStop": {}} + if "toolUse" in content: + stop_reason = "tool_use" + yield { + "contentBlockStart": { + "start": { + "toolUse": { + "name": content["toolUse"]["name"], + "toolUseId": content["toolUse"]["toolUseId"], + } } } } - } - yield {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(content["toolUse"]["input"])}}}} - yield {"contentBlockStop": {}} + yield { + "contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(content["toolUse"]["input"])}}} + } + yield {"contentBlockStop": {}} yield {"messageStop": {"stopReason": stop_reason}} diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index c5453c5fe..c8d60a345 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -4,6 +4,7 @@ import os import textwrap import unittest.mock +from uuid import uuid4 import pytest from pydantic import BaseModel @@ -1425,3 +1426,61 @@ def test_agent_restored_from_session_management(): agent = Agent(session_manager=session_manager) assert agent.state.get("foo") == "bar" + + +def test_agent_redacts_input_on_triggered_guardrail(): + mocked_model = MockedModelProvider( + [{"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}] + ) + + agent = Agent( + model=mocked_model, + system_prompt="You are a helpful assistant.", + callback_handler=None, + ) + + response1 = agent("CACTUS") + + assert response1.stop_reason == "guardrail_intervened" + assert agent.messages[0]["content"][0]["text"] == "BLOCKED!" + + +def test_agent_restored_from_session_management_with_redacted_input(): + mocked_model = MockedModelProvider( + [{"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}] + ) + + test_session_id = str(uuid4()) + mocked_session_repository = MockedSessionRepository() + session_manager = RepositorySessionManager(session_id=test_session_id, session_repository=mocked_session_repository) + + agent = Agent( + model=mocked_model, + system_prompt="You are a helpful assistant.", + callback_handler=None, + session_manager=session_manager, + ) + + assert mocked_session_repository.read_agent(test_session_id, agent.agent_id) is not None + + response1 = agent("CACTUS") + + assert response1.stop_reason == "guardrail_intervened" + assert agent.messages[0]["content"][0]["text"] == "BLOCKED!" + user_input_session_message = mocked_session_repository.list_messages(test_session_id, agent.agent_id)[0] + # Assert persisted message is equal to the redacted message in the agent + assert user_input_session_message.to_message() == agent.messages[0] + + # Restore an agent from the session, confirm input is still redacted + session_manager_2 = RepositorySessionManager( + session_id=test_session_id, session_repository=mocked_session_repository + ) + agent_2 = Agent( + model=mocked_model, + system_prompt="You are a helpful assistant.", + callback_handler=None, + session_manager=session_manager_2, + ) + + # Assert that the restored agent redacted message is equal to the original agent + assert agent.messages[0] == agent_2.messages[0] diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 80d6a5ef6..921fd91de 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -528,8 +528,7 @@ def test_extract_usage_metrics(): ) @pytest.mark.asyncio async def test_process_stream(response, exp_events, agenerator, alist): - messages = [{"role": "user", "content": [{"text": "Some input!"}]}] - stream = strands.event_loop.streaming.process_stream(agenerator(response), messages) + stream = strands.event_loop.streaming.process_stream(agenerator(response)) tru_events = await alist(stream) assert tru_events == exp_events diff --git a/tests_integ/test_bedrock_guardrails.py b/tests_integ/test_bedrock_guardrails.py index bf0be7068..4683918cb 100644 --- a/tests_integ/test_bedrock_guardrails.py +++ b/tests_integ/test_bedrock_guardrails.py @@ -1,15 +1,25 @@ +import tempfile import time +from uuid import uuid4 import boto3 import pytest from strands import Agent from strands.models.bedrock import BedrockModel +from strands.session.file_session_manager import FileSessionManager BLOCKED_INPUT = "BLOCKED_INPUT" BLOCKED_OUTPUT = "BLOCKED_OUTPUT" +@pytest.fixture +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + @pytest.fixture(scope="module") def boto_session(): return boto3.Session(region_name="us-east-1") @@ -158,3 +168,44 @@ def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processi assert REDACT_MESSAGE in str(response1) assert response2.stop_reason != "guardrail_intervened" assert REDACT_MESSAGE not in str(response2) + + +def test_guardrail_input_intervention_properly_redacts_in_session(boto_session, bedrock_guardrail, temp_dir): + bedrock_model = BedrockModel( + guardrail_id=bedrock_guardrail, + guardrail_version="DRAFT", + boto_session=boto_session, + guardrail_redact_input_message="BLOCKED!", + ) + + test_session_id = str(uuid4()) + session_manager = FileSessionManager(session_id=test_session_id) + + agent = Agent( + model=bedrock_model, + system_prompt="You are a helpful assistant.", + callback_handler=None, + session_manager=session_manager, + ) + + assert session_manager.read_agent(test_session_id, agent.agent_id) is not None + + response1 = agent("CACTUS") + + assert response1.stop_reason == "guardrail_intervened" + assert agent.messages[0]["content"][0]["text"] == "BLOCKED!" + user_input_session_message = session_manager.list_messages(test_session_id, agent.agent_id)[0] + # Assert persisted message is equal to the redacted message in the agent + assert user_input_session_message.to_message() == agent.messages[0] + + # Restore an agent from the session, confirm input is still redacted + session_manager_2 = FileSessionManager(session_id=test_session_id) + agent_2 = Agent( + model=bedrock_model, + system_prompt="You are a helpful assistant.", + callback_handler=None, + session_manager=session_manager_2, + ) + + # Assert that the restored agent redacted message is equal to the original agent + assert agent.messages[0] == agent_2.messages[0] diff --git a/tests_integ/test_session.py b/tests_integ/test_session.py index fbfd54384..6efbc2c8f 100644 --- a/tests_integ/test_session.py +++ b/tests_integ/test_session.py @@ -11,12 +11,7 @@ from strands.session.file_session_manager import FileSessionManager from strands.session.s3_session_manager import S3SessionManager - -@pytest.fixture -def yellow_img(pytestconfig): - path = pytestconfig.rootdir / "tests_integ/yellow.png" - with open(path, "rb") as fp: - return fp.read() +# yellow_img imported from conftest @pytest.fixture