diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index ab3c6d143..54e7a58ec 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -15,7 +15,6 @@ import random from concurrent.futures import ThreadPoolExecutor from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast -from uuid import uuid4 from opentelemetry import trace from pydantic import BaseModel @@ -32,6 +31,7 @@ ) from ..models.bedrock import BedrockModel from ..models.model import Model +from ..session.session_manager import SessionManager from ..telemetry.metrics import EventLoopMetrics from ..telemetry.tracer import get_tracer from ..tools.registry import ToolRegistry @@ -62,6 +62,7 @@ class _DefaultCallbackHandlerSentinel: _DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() _DEFAULT_AGENT_NAME = "Strands Agents" +_DEFAULT_AGENT_ID = "default" class Agent: @@ -207,6 +208,7 @@ def __init__( description: Optional[str] = None, state: Optional[Union[AgentState, dict]] = None, hooks: Optional[list[HookProvider]] = None, + session_manager: Optional[SessionManager] = None, ): """Initialize the Agent with the specified configuration. @@ -237,22 +239,24 @@ def __init__( load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. Defaults to False. trace_attributes: Custom trace attributes to apply to the agent's trace span. - agent_id: Optional ID for the agent, useful for multi-agent scenarios. - If None, a UUID is generated. + agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios. + Defaults to "default". name: name of the Agent - Defaults to None. + Defaults to "Strands Agents". description: description of what the Agent does Defaults to None. state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. Defaults to an empty AgentState object. hooks: hooks to be added to the agent hook registry Defaults to None. + session_manager: Manager for handling agent sessions including conversation history and state. + If provided, enables session-based persistence and state management. """ self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] self.system_prompt = system_prompt - self.agent_id = agent_id or str(uuid4()) + self.agent_id = agent_id or _DEFAULT_AGENT_ID self.name = name or _DEFAULT_AGENT_NAME self.description = description @@ -312,6 +316,12 @@ def __init__( self.tool_caller = Agent.ToolCaller(self) self.hooks = HookRegistry() + + # Initialize session management functionality + self._session_manager = session_manager + if self._session_manager: + self.hooks.add_hook(self._session_manager) + if hooks: for hook in hooks: self.hooks.add_hook(hook) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py new file mode 100644 index 000000000..2748f2e20 --- /dev/null +++ b/src/strands/session/file_session_manager.py @@ -0,0 +1,226 @@ +"""File-based session manager for local filesystem storage.""" + +import json +import logging +import os +import shutil +import tempfile +from dataclasses import asdict +from typing import Any, Optional, cast + +from ..types.exceptions import SessionException +from ..types.session import Session, SessionAgent, SessionMessage +from .repository_session_manager import RepositorySessionManager +from .session_repository import SessionRepository + +logger = logging.getLogger(__name__) + +SESSION_PREFIX = "session_" +AGENT_PREFIX = "agent_" +MESSAGE_PREFIX = "message_" + + +class FileSessionManager(RepositorySessionManager, SessionRepository): + """File-based session manager for local filesystem storage. + + Creates the following filesystem structure for the session storage: + // + └── session_/ + ├── session.json # Session metadata + └── agents/ + └── agent_/ + ├── agent.json # Agent metadata + └── messages/ + ├── message__.json + └── message__.json + + """ + + def __init__(self, session_id: str, storage_dir: Optional[str] = None): + """Initialize FileSession with filesystem storage. + + Args: + session_id: ID for the session + storage_dir: Directory for local filesystem storage (defaults to temp dir) + """ + self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions") + os.makedirs(self.storage_dir, exist_ok=True) + + super().__init__(session_id=session_id, session_repository=self) + + def _get_session_path(self, session_id: str) -> str: + """Get session directory path.""" + return os.path.join(self.storage_dir, f"{SESSION_PREFIX}{session_id}") + + def _get_agent_path(self, session_id: str, agent_id: str) -> str: + """Get agent directory path.""" + session_path = self._get_session_path(session_id) + return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}") + + def _get_message_path(self, session_id: str, agent_id: str, message_id: str, timestamp: str) -> str: + """Get message file path. + + Args: + session_id: ID of the session + agent_id: ID of the agent + message_id: ID of the message + timestamp: ISO format timestamp to include in filename for sorting + Returns: + The filename for the message + """ + agent_path = self._get_agent_path(session_id, agent_id) + # Use timestamp for sortable filenames + # Replace colons and periods in ISO format with underscores for filesystem compatibility + filename_timestamp = timestamp.replace(":", "_").replace(".", "_") + return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{filename_timestamp}_{message_id}.json") + + def _read_file(self, path: str) -> dict[str, Any]: + """Read JSON file.""" + try: + with open(path, "r", encoding="utf-8") as f: + return cast(dict[str, Any], json.load(f)) + except json.JSONDecodeError as e: + raise SessionException(f"Invalid JSON in file {path}: {str(e)}") from e + + def _write_file(self, path: str, data: dict[str, Any]) -> None: + """Write JSON file.""" + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + def create_session(self, session: Session) -> Session: + """Create a new session.""" + session_dir = self._get_session_path(session.session_id) + if os.path.exists(session_dir): + raise SessionException(f"Session {session.session_id} already exists") + + # Create directory structure + os.makedirs(session_dir, exist_ok=True) + os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) + + # Write session file + session_file = os.path.join(session_dir, "session.json") + session_dict = asdict(session) + self._write_file(session_file, session_dict) + + return session + + def read_session(self, session_id: str) -> Optional[Session]: + """Read session data.""" + session_file = os.path.join(self._get_session_path(session_id), "session.json") + if not os.path.exists(session_file): + return None + + session_data = self._read_file(session_file) + return Session.from_dict(session_data) + + def create_agent(self, session_id: str, session_agent: SessionAgent) -> None: + """Create a new agent in the session.""" + agent_id = session_agent.agent_id + + agent_dir = self._get_agent_path(session_id, agent_id) + os.makedirs(agent_dir, exist_ok=True) + os.makedirs(os.path.join(agent_dir, "messages"), exist_ok=True) + + agent_file = os.path.join(agent_dir, "agent.json") + session_data = asdict(session_agent) + self._write_file(agent_file, session_data) + + def delete_session(self, session_id: str) -> None: + """Delete session and all associated data.""" + session_dir = self._get_session_path(session_id) + if not os.path.exists(session_dir): + raise SessionException(f"Session {session_id} does not exist") + + shutil.rmtree(session_dir) + + def read_agent(self, session_id: str, agent_id: str) -> Optional[SessionAgent]: + """Read agent data.""" + agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json") + if not os.path.exists(agent_file): + return None + + agent_data = self._read_file(agent_file) + return SessionAgent.from_dict(agent_data) + + def update_agent(self, session_id: str, session_agent: SessionAgent) -> None: + """Update agent data.""" + agent_id = session_agent.agent_id + previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id) + if previous_agent is None: + raise SessionException(f"Agent {agent_id} in session {session_id} does not exist") + + session_agent.created_at = previous_agent.created_at + agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json") + self._write_file(agent_file, asdict(session_agent)) + + def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: + """Create a new message for the agent.""" + message_file = self._get_message_path( + session_id, + agent_id, + session_message.message_id, + session_message.created_at, + ) + session_dict = asdict(session_message) + self._write_file(message_file, session_dict) + + def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]: + """Read message data.""" + # Get the messages directory + messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages") + if not os.path.exists(messages_dir): + return None + + # List files in messages directory, and check if the filename ends with the message id + for filename in os.listdir(messages_dir): + if filename.endswith(f"{message_id}.json"): + file_path = os.path.join(messages_dir, filename) + message_data = self._read_file(file_path) + return SessionMessage.from_dict(message_data) + + return None + + def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: + """Update message data.""" + message_id = session_message.message_id + previous_message = self.read_message(session_id=session_id, agent_id=agent_id, message_id=message_id) + if previous_message is None: + raise SessionException(f"Message {message_id} does not exist") + + # Preserve the original created_at timestamp + session_message.created_at = previous_message.created_at + message_file = self._get_message_path(session_id, agent_id, message_id, session_message.created_at) + self._write_file(message_file, asdict(session_message)) + + def list_messages( + self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0 + ) -> list[SessionMessage]: + """List messages for an agent with pagination.""" + messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages") + if not os.path.exists(messages_dir): + raise SessionException(f"Messages directory missing from agent: {agent_id} in session {session_id}") + + # Read all message files + message_files: list[str] = [] + for filename in os.listdir(messages_dir): + if filename.startswith(MESSAGE_PREFIX) and filename.endswith(".json"): + message_files.append(filename) + + # Sort filenames - the timestamp in the file's name will sort chronologically + message_files.sort() + + # Apply pagination to filenames + if limit is not None: + message_files = message_files[offset : offset + limit] + else: + message_files = message_files[offset:] + + # Load only the message files + messages: list[SessionMessage] = [] + for filename in message_files: + file_path = os.path.join(messages_dir, filename) + message_data = self._read_file(file_path) + messages.append(SessionMessage.from_dict(message_data)) + + return messages diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py new file mode 100644 index 000000000..fd31d9671 --- /dev/null +++ b/src/strands/session/repository_session_manager.py @@ -0,0 +1,108 @@ +"""Repository session manager implementation.""" + +import logging + +from ..agent.agent import Agent +from ..agent.state import AgentState +from ..types.content import Message +from ..types.exceptions import SessionException +from ..types.session import ( + Session, + SessionAgent, + SessionMessage, + SessionType, +) +from .session_manager import SessionManager +from .session_repository import SessionRepository + +logger = logging.getLogger(__name__) + + +class RepositorySessionManager(SessionManager): + """Session manager for persisting agents in a SessionRepository.""" + + def __init__( + self, + session_id: str, + session_repository: SessionRepository, + ): + """Initialize the RepositorySessionManager. + + If no session with the specified session_id exists yet, it will be created + in the session_repository. + + Args: + session_id: ID to use for the session. A new session with this id will be created if it does + not exist in the reposiory yet + session_repository: Underlying session repository to use to store the sessions state. + """ + self.session_repository = session_repository + self.session_id = session_id + session = session_repository.read_session(session_id) + # Create a session if it does not exist yet + if session is None: + logger.debug("session_id=<%s> | session not found, creating new session", self.session_id) + session = Session(session_id=session_id, session_type=SessionType.AGENT) + session_repository.create_session(session) + + self.session = session + + # Keep track of the initialized agent id's so that two agents in a session cannot share an id + self._initialized_agent_ids: set[str] = set() + + def append_message(self, message: Message, agent: Agent) -> None: + """Append a message to the agent's session. + + Args: + message: Message to add to the agent in the session + agent: Agent to append the message to + """ + session_message = SessionMessage.from_message(message) + self.session_repository.create_message(self.session_id, agent.agent_id, session_message) + + def sync_agent(self, agent: Agent) -> None: + """Serialize and update the agent into the session repository. + + Args: + agent: Agent to sync to the session. + """ + self.session_repository.update_agent( + self.session_id, + SessionAgent.from_agent(agent), + ) + + def initialize(self, agent: Agent) -> None: + """Initialize an agent with a session. + + Args: + agent: Agent to initialize from the session + """ + if agent.agent_id in self._initialized_agent_ids: + raise SessionException("The `agent_id` of an agent must be unique in a session.") + self._initialized_agent_ids.add(agent.agent_id) + + session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) + + if session_agent is None: + logger.debug( + "agent_id=<%s> | session_id=<%s> | creating agent", + agent.agent_id, + self.session_id, + ) + + session_agent = SessionAgent.from_agent(agent) + self.session_repository.create_agent(self.session_id, session_agent) + for message in agent.messages: + session_message = SessionMessage.from_message(message) + self.session_repository.create_message(self.session_id, agent.agent_id, session_message) + else: + logger.debug( + "agent_id=<%s> | session_id=<%s> | restoring agent", + agent.agent_id, + self.session_id, + ) + agent.messages = [ + session_message.to_message() + for session_message in self.session_repository.list_messages(self.session_id, agent.agent_id) + ] + agent.state = AgentState(session_agent.state) diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py new file mode 100644 index 000000000..af14c5386 --- /dev/null +++ b/src/strands/session/s3_session_manager.py @@ -0,0 +1,280 @@ +"""S3-based session manager for cloud storage.""" + +import json +import logging +from dataclasses import asdict +from typing import Any, Dict, List, Optional, cast + +import boto3 +from botocore.config import Config as BotocoreConfig +from botocore.exceptions import ClientError + +from ..types.exceptions import SessionException +from ..types.session import Session, SessionAgent, SessionMessage +from .repository_session_manager import RepositorySessionManager +from .session_repository import SessionRepository + +logger = logging.getLogger(__name__) + +SESSION_PREFIX = "session_" +AGENT_PREFIX = "agent_" +MESSAGE_PREFIX = "message_" + + +class S3SessionManager(RepositorySessionManager, SessionRepository): + """S3-based session manager for cloud storage. + + Creates the following filesystem structure for the session storage: + // + └── session_/ + ├── session.json # Session metadata + └── agents/ + └── agent_/ + ├── agent.json # Agent metadata + └── messages/ + ├── message__.json + └── message__.json + + """ + + def __init__( + self, + session_id: str, + bucket: str, + prefix: str = "", + boto_session: Optional[boto3.Session] = None, + boto_client_config: Optional[BotocoreConfig] = None, + region_name: Optional[str] = None, + ): + """Initialize S3SessionManager with S3 storage. + + Args: + session_id: ID for the session + bucket: S3 bucket name (required) + prefix: S3 key prefix for storage organization + boto_session: Optional boto3 session + boto_client_config: Optional boto3 client configuration + region_name: AWS region for S3 storage + """ + self.bucket = bucket + self.prefix = prefix + + session = boto_session or boto3.Session(region_name=region_name) + + # Add strands-agents to the request user agent + if boto_client_config: + existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) + # Append 'strands-agents' to existing user_agent_extra or set it if not present + if existing_user_agent: + new_user_agent = f"{existing_user_agent} strands-agents" + else: + new_user_agent = "strands-agents" + client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + else: + client_config = BotocoreConfig(user_agent_extra="strands-agents") + + self.client = session.client(service_name="s3", config=client_config) + super().__init__(session_id=session_id, session_repository=self) + + def _get_session_path(self, session_id: str) -> str: + """Get session S3 prefix.""" + return f"{self.prefix}{SESSION_PREFIX}{session_id}/" + + def _get_agent_path(self, session_id: str, agent_id: str) -> str: + """Get agent S3 prefix.""" + session_path = self._get_session_path(session_id) + return f"{session_path}agents/{AGENT_PREFIX}{agent_id}/" + + def _get_message_path(self, session_id: str, agent_id: str, message_id: str, timestamp: str) -> str: + """Get message S3 key. + + Args: + session_id: ID of the session + agent_id: ID of the agent + message_id: ID of the message + timestamp: ISO format timestamp to include in key for sorting + Returns: + The key for the message + """ + agent_path = self._get_agent_path(session_id, agent_id) + # Use timestamp for sortable keys + # Replace colons and periods in ISO format with underscores for filesystem compatibility + filename_timestamp = timestamp.replace(":", "_").replace(".", "_") + return f"{agent_path}messages/{MESSAGE_PREFIX}{filename_timestamp}_{message_id}.json" + + def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]: + """Read JSON object from S3.""" + try: + response = self.client.get_object(Bucket=self.bucket, Key=key) + content = response["Body"].read().decode("utf-8") + return cast(dict[str, Any], json.loads(content)) + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + return None + else: + raise SessionException(f"S3 error reading {key}: {e}") from e + except json.JSONDecodeError as e: + raise SessionException(f"Invalid JSON in S3 object {key}: {e}") from e + + def _write_s3_object(self, key: str, data: Dict[str, Any]) -> None: + """Write JSON object to S3.""" + try: + content = json.dumps(data, indent=2, ensure_ascii=False) + self.client.put_object( + Bucket=self.bucket, Key=key, Body=content.encode("utf-8"), ContentType="application/json" + ) + except ClientError as e: + raise SessionException(f"Failed to write S3 object {key}: {e}") from e + + def create_session(self, session: Session) -> Session: + """Create a new session in S3.""" + session_key = f"{self._get_session_path(session.session_id)}session.json" + + # Check if session already exists + try: + self.client.head_object(Bucket=self.bucket, Key=session_key) + raise SessionException(f"Session {session.session_id} already exists") + except ClientError as e: + if e.response["Error"]["Code"] != "404": + raise SessionException(f"S3 error checking session existence: {e}") from e + + # Write session object + session_dict = asdict(session) + self._write_s3_object(session_key, session_dict) + return session + + def read_session(self, session_id: str) -> Optional[Session]: + """Read session data from S3.""" + session_key = f"{self._get_session_path(session_id)}session.json" + session_data = self._read_s3_object(session_key) + if session_data is None: + return None + return Session.from_dict(session_data) + + def delete_session(self, session_id: str) -> None: + """Delete session and all associated data from S3.""" + session_prefix = self._get_session_path(session_id) + try: + paginator = self.client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=self.bucket, Prefix=session_prefix) + + objects_to_delete = [] + for page in pages: + if "Contents" in page: + objects_to_delete.extend([{"Key": obj["Key"]} for obj in page["Contents"]]) + + if not objects_to_delete: + raise SessionException(f"Session {session_id} does not exist") + + # Delete objects in batches + for i in range(0, len(objects_to_delete), 1000): + batch = objects_to_delete[i : i + 1000] + self.client.delete_objects(Bucket=self.bucket, Delete={"Objects": batch}) + + except ClientError as e: + raise SessionException(f"S3 error deleting session {session_id}: {e}") from e + + def create_agent(self, session_id: str, session_agent: SessionAgent) -> None: + """Create a new agent in S3.""" + agent_id = session_agent.agent_id + agent_dict = asdict(session_agent) + agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" + self._write_s3_object(agent_key, agent_dict) + + def read_agent(self, session_id: str, agent_id: str) -> Optional[SessionAgent]: + """Read agent data from S3.""" + agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" + agent_data = self._read_s3_object(agent_key) + if agent_data is None: + return None + return SessionAgent.from_dict(agent_data) + + def update_agent(self, session_id: str, session_agent: SessionAgent) -> None: + """Update agent data in S3.""" + agent_id = session_agent.agent_id + previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id) + if previous_agent is None: + raise SessionException(f"Agent {agent_id} in session {session_id} does not exist") + + # Preserve creation timestamp + session_agent.created_at = previous_agent.created_at + agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" + self._write_s3_object(agent_key, asdict(session_agent)) + + def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: + """Create a new message in S3.""" + message_id = session_message.message_id + message_dict = asdict(session_message) + message_key = self._get_message_path(session_id, agent_id, message_id, session_message.created_at) + self._write_s3_object(message_key, message_dict) + + def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]: + """Read message data from S3.""" + # Get the messages prefix + messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" + try: + paginator = self.client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=self.bucket, Prefix=messages_prefix) + + for page in pages: + if "Contents" in page: + for obj in page["Contents"]: + if obj["Key"].endswith(f"{message_id}.json"): + message_data = self._read_s3_object(obj["Key"]) + if message_data: + return SessionMessage.from_dict(message_data) + + return None + + except ClientError as e: + raise SessionException(f"S3 error reading message: {e}") from e + + def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: + """Update message data in S3.""" + message_id = session_message.message_id + previous_message = self.read_message(session_id=session_id, agent_id=agent_id, message_id=message_id) + if previous_message is None: + raise SessionException(f"Message {message_id} does not exist") + + # Preserve creation timestamp + session_message.created_at = previous_message.created_at + message_key = self._get_message_path(session_id, agent_id, message_id, session_message.created_at) + self._write_s3_object(message_key, asdict(session_message)) + + def list_messages( + self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0 + ) -> List[SessionMessage]: + """List messages for an agent with pagination from S3.""" + messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" + try: + paginator = self.client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=self.bucket, Prefix=messages_prefix) + + # Collect all message keys first + message_keys = [] + for page in pages: + if "Contents" in page: + for obj in page["Contents"]: + if obj["Key"].endswith(".json") and MESSAGE_PREFIX in obj["Key"]: + message_keys.append(obj["Key"]) + + # Sort keys - timestamp prefixed keys will sort chronologically + message_keys.sort() + + # Apply pagination to keys before loading content + if limit is not None: + message_keys = message_keys[offset : offset + limit] + else: + message_keys = message_keys[offset:] + + # Load only the required message objects + messages: List[SessionMessage] = [] + for key in message_keys: + message_data = self._read_s3_object(key) + if message_data: + messages.append(SessionMessage.from_dict(message_data)) + + return messages + + except ClientError as e: + raise SessionException(f"S3 error reading messages: {e}") from e diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py new file mode 100644 index 000000000..6f071f929 --- /dev/null +++ b/src/strands/session/session_manager.py @@ -0,0 +1,52 @@ +"""Session manager interface for agent session management.""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from ..hooks.events import AgentInitializedEvent, MessageAddedEvent +from ..hooks.registry import HookProvider, HookRegistry +from ..types.content import Message + +if TYPE_CHECKING: + from ..agent.agent import Agent + + +class SessionManager(HookProvider, ABC): + """Abstract interface for managing sessions. + + A session manager is in charge of persisting the conversation and state of an agent across its interaction. + Changes made to the agents conversation, state, or other attributes should be persisted immediately after + they are changed. The different methods introduced in this class are called at important lifecycle events + for an agent, and should be persisted in the session. + """ + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register hooks for persisting the agent to the session.""" + registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) + registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent)) + registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent)) + + @abstractmethod + def append_message(self, message: Message, agent: "Agent") -> None: + """Append a message to the agent's session. + + Args: + message: Message to add to the agent in the session + agent: Agent to append the message to + """ + + @abstractmethod + def sync_agent(self, agent: "Agent") -> None: + """Serialize and sync the agent with the session storage. + + Args: + agent: Agent who should be synchronized with the session storage + """ + + @abstractmethod + def initialize(self, agent: "Agent") -> None: + """Initialize an agent with a session. + + Args: + agent: Agent to initialize + """ diff --git a/src/strands/session/session_repository.py b/src/strands/session/session_repository.py new file mode 100644 index 000000000..b9735e05f --- /dev/null +++ b/src/strands/session/session_repository.py @@ -0,0 +1,51 @@ +"""Session repository interface for agent session management.""" + +from abc import ABC, abstractmethod +from typing import Optional + +from ..types.session import Session, SessionAgent, SessionMessage + + +class SessionRepository(ABC): + """Abstract repository for creating, reading, and updating Sessions, AgentSessions, and AgentMessages.""" + + @abstractmethod + def create_session(self, session: Session) -> Session: + """Create a new Session.""" + + @abstractmethod + def read_session(self, session_id: str) -> Optional[Session]: + """Read a Session.""" + + @abstractmethod + def create_agent(self, session_id: str, session_agent: SessionAgent) -> None: + """Create a new Agent in a Session.""" + + @abstractmethod + def read_agent(self, session_id: str, agent_id: str) -> Optional[SessionAgent]: + """Read an Agent.""" + + @abstractmethod + def update_agent(self, session_id: str, session_agent: SessionAgent) -> None: + """Update an Agent.""" + + @abstractmethod + def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: + """Create a new Message for the Agent.""" + + @abstractmethod + def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]: + """Read a Message.""" + + @abstractmethod + def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: + """Update a Message. + + A message is usually only updated when some content is redacted due to a guardrail. + """ + + @abstractmethod + def list_messages( + self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0 + ) -> list[SessionMessage]: + """List Messages from an Agent with pagination.""" diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 1ffeba4ec..4bd3fd88e 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -52,3 +52,9 @@ def __init__(self, message: str) -> None: super().__init__(message) pass + + +class SessionException(Exception): + """Exception raised when session operations fail.""" + + pass diff --git a/src/strands/types/session.py b/src/strands/types/session.py new file mode 100644 index 000000000..50d82b368 --- /dev/null +++ b/src/strands/types/session.py @@ -0,0 +1,122 @@ +"""Data models for session management.""" + +import base64 +import inspect +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, cast +from uuid import uuid4 + +from ..agent.agent import Agent +from .content import Message + + +class SessionType(str, Enum): + """Enumeration of session types. + + As sessions are expanded to support new usecases like multi-agent patterns, + new types will be added here. + """ + + AGENT = "AGENT" + + +def encode_bytes_values(obj: Any) -> Any: + """Recursively encode any bytes values in an object to base64. + + Handles dictionaries, lists, and nested structures. + """ + if isinstance(obj, bytes): + return {"__bytes_encoded__": True, "data": base64.b64encode(obj).decode()} + elif isinstance(obj, dict): + return {k: encode_bytes_values(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [encode_bytes_values(item) for item in obj] + else: + return obj + + +def decode_bytes_values(obj: Any) -> Any: + """Recursively decode any base64-encoded bytes values in an object. + + Handles dictionaries, lists, and nested structures. + """ + if isinstance(obj, dict): + if obj.get("__bytes_encoded__") is True and "data" in obj: + return base64.b64decode(obj["data"]) + return {k: decode_bytes_values(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [decode_bytes_values(item) for item in obj] + else: + return obj + + +@dataclass +class SessionMessage: + """Message within a SessionAgent.""" + + message: Message + message_id: str = field(default_factory=lambda: str(uuid4())) + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + @classmethod + def from_message(cls, message: Message) -> "SessionMessage": + """Convert from a Message, base64 encoding bytes values.""" + bytes_encoded_dict = encode_bytes_values(message) + return cls( + message=bytes_encoded_dict, + message_id=str(uuid4()), + created_at=datetime.now(timezone.utc).isoformat(), + updated_at=datetime.now(timezone.utc).isoformat(), + ) + + def to_message(self) -> Message: + """Convert SessionMessage back to a Message, decoding any bytes values.""" + return cast(Message, decode_bytes_values(self.message)) + + @classmethod + def from_dict(cls, env: dict[str, Any]) -> "SessionMessage": + """Initialize a SessionMessage from a dictionary, ignoring keys that are not class parameters.""" + return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) + + +@dataclass +class SessionAgent: + """Agent that belongs to a Session.""" + + agent_id: str + state: Dict[str, Any] + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + @classmethod + def from_agent(cls, agent: Agent) -> "SessionAgent": + """Convert an Agent to a SessionAgent.""" + if agent.agent_id is None: + raise ValueError("agent_id needs to be defined.") + return cls( + agent_id=agent.agent_id, + state=agent.state.get(), + ) + + @classmethod + def from_dict(cls, env: dict[str, Any]) -> "SessionAgent": + """Initialize a SessionAgent from a dictionary, ignoring keys that are not calss parameters.""" + return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) + + +@dataclass +class Session: + """Session data model.""" + + session_id: str + session_type: SessionType + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + @classmethod + def from_dict(cls, env: dict[str, Any]) -> "Session": + """Initialize a Session from a dictionary, ignoring keys that are not calss parameters.""" + return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) diff --git a/tests/fixtures/mock_session_repository.py b/tests/fixtures/mock_session_repository.py new file mode 100644 index 000000000..8e25691d0 --- /dev/null +++ b/tests/fixtures/mock_session_repository.py @@ -0,0 +1,98 @@ +from strands.session.session_repository import SessionRepository +from strands.types.exceptions import SessionException + + +class MockedSessionRepository(SessionRepository): + """Mock repository for testing.""" + + def __init__(self): + """Initialize with empty storage.""" + self.sessions = {} + self.agents = {} + self.messages = {} + + def create_session(self, session): + """Create a session.""" + session_id = session.session_id + if session_id in self.sessions: + raise SessionException(f"Session {session_id} already exists") + self.sessions[session_id] = session + self.agents[session_id] = {} + self.messages[session_id] = {} + return session + + def read_session(self, session_id): + """Read a session.""" + return self.sessions.get(session_id) + + def create_agent(self, session_id, session_agent): + """Create an agent.""" + agent_id = session_agent.agent_id + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + if agent_id in self.agents.get(session_id, {}): + raise SessionException(f"Agent {agent_id} already exists in session {session_id}") + self.agents.setdefault(session_id, {})[agent_id] = session_agent + self.messages.setdefault(session_id, {}).setdefault(agent_id, []) + return session_agent + + def read_agent(self, session_id, agent_id): + """Read an agent.""" + if session_id not in self.sessions: + return None + return self.agents.get(session_id, {}).get(agent_id) + + def update_agent(self, session_id, session_agent): + """Update an agent.""" + agent_id = session_agent.agent_id + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + if agent_id not in self.agents.get(session_id, {}): + raise SessionException(f"Agent {agent_id} does not exist in session {session_id}") + self.agents[session_id][agent_id] = session_agent + + def create_message(self, session_id, agent_id, session_message): + """Create a message.""" + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + if agent_id not in self.agents.get(session_id, {}): + raise SessionException(f"Agent {agent_id} does not exist in session {session_id}") + self.messages.setdefault(session_id, {}).setdefault(agent_id, []).append(session_message) + + def read_message(self, session_id, agent_id, message_id): + """Read a message.""" + if session_id not in self.sessions: + return None + if agent_id not in self.agents.get(session_id, {}): + return None + for message in self.messages.get(session_id, {}).get(agent_id, []): + if message.message_id == message_id: + return message + return None + + def update_message(self, session_id, agent_id, session_message): + """Update a message.""" + message_id = session_message.message_id + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + if agent_id not in self.agents.get(session_id, {}): + raise SessionException(f"Agent {agent_id} does not exist in session {session_id}") + + for i, message in enumerate(self.messages.get(session_id, {}).get(agent_id, [])): + if message.message_id == message_id: + self.messages[session_id][agent_id][i] = session_message + return + + raise SessionException(f"Message {message_id} does not exist") + + def list_messages(self, session_id, agent_id, limit=None, offset=0): + """List messages.""" + if session_id not in self.sessions: + return [] + if agent_id not in self.agents.get(session_id, {}): + return [] + + messages = self.messages.get(session_id, {}).get(agent_id, []) + if limit is not None: + return messages[offset : offset + limit] + return messages[offset:] diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 988e08919..c5453c5fe 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -13,10 +13,15 @@ from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager +from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel +from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException, EventLoopException +from strands.types.session import Session, SessionAgent, SessionType +from tests.fixtures.mock_session_repository import MockedSessionRepository +from tests.fixtures.mocked_model_provider import MockedModelProvider @pytest.fixture @@ -636,7 +641,6 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): ) agent("test") - callback_handler.assert_has_calls( [ unittest.mock.call(init_event_loop=True), @@ -1338,6 +1342,11 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) +def test_agent_init_with_state_object(): + agent = Agent(state=AgentState({"foo": "bar"})) + assert agent.state.get("foo") == "bar" + + def test_non_dict_throws_error(): with pytest.raises(ValueError, match="state must be an AgentState object or a dict"): agent = Agent(state={"object", object()}) @@ -1391,3 +1400,28 @@ def test_agent_state_get_breaks_deep_dict_reference(): # This will fail if AgentState reflects the updated reference json.dumps(agent.state.get()) + + +def test_agent_session_management(): + mock_session_repository = MockedSessionRepository() + session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) + agent = Agent(session_manager=session_manager, model=model) + agent("Hello!") + + +def test_agent_restored_from_session_management(): + mock_session_repository = MockedSessionRepository() + mock_session_repository.create_session(Session(session_id="123", session_type=SessionType.AGENT)) + mock_session_repository.create_agent( + "123", + SessionAgent( + agent_id="default", + state={"foo": "bar"}, + ), + ) + session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) + + agent = Agent(session_manager=session_manager) + + assert agent.state.get("foo") == "bar" diff --git a/tests/strands/session/__init__.py b/tests/strands/session/__init__.py new file mode 100644 index 000000000..601ac7006 --- /dev/null +++ b/tests/strands/session/__init__.py @@ -0,0 +1 @@ +"""Tests for session management.""" diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py new file mode 100644 index 000000000..3153c611d --- /dev/null +++ b/tests/strands/session/test_file_session_manager.py @@ -0,0 +1,358 @@ +"""Tests for FileSessionManager.""" + +import json +import os +import tempfile +from unittest.mock import patch + +import pytest + +from strands.session.file_session_manager import FileSessionManager +from strands.types.content import ContentBlock +from strands.types.exceptions import SessionException +from strands.types.session import Session, SessionAgent, SessionMessage, SessionType + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + +@pytest.fixture +def file_manager(temp_dir): + """Create FileSessionManager for testing.""" + return FileSessionManager(session_id="test", storage_dir=temp_dir) + + +@pytest.fixture +def sample_session(): + """Create sample session for testing.""" + return Session(session_id="test-session", session_type=SessionType.AGENT) + + +@pytest.fixture +def sample_agent(): + """Create sample agent for testing.""" + return SessionAgent( + agent_id="test-agent", + state={"key": "value"}, + ) + + +@pytest.fixture +def sample_message(): + """Create sample message for testing.""" + return SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text="Hello world")], + } + ) + + +class TestFileSessionManagerSessionOperations: + """Tests for session operations.""" + + def test_create_session(self, file_manager, sample_session): + """Test creating a session.""" + file_manager.create_session(sample_session) + + # Verify directory structure created + session_path = file_manager._get_session_path(sample_session.session_id) + assert os.path.exists(session_path) + + # Verify session file created + session_file = os.path.join(session_path, "session.json") + assert os.path.exists(session_file) + + # Verify content + with open(session_file, "r") as f: + data = json.load(f) + assert data["session_id"] == sample_session.session_id + assert data["session_type"] == sample_session.session_type + + def test_read_session(self, file_manager, sample_session): + """Test reading an existing session.""" + # Create session first + file_manager.create_session(sample_session) + + # Read it back + result = file_manager.read_session(sample_session.session_id) + + assert result.session_id == sample_session.session_id + assert result.session_type == sample_session.session_type + + def test_read_nonexistent_session(self, file_manager): + """Test reading a session that doesn't exist.""" + result = file_manager.read_session("nonexistent-session") + assert result is None + + def test_delete_session(self, file_manager, sample_session): + """Test deleting a session.""" + # Create session first + file_manager.create_session(sample_session) + session_path = file_manager._get_session_path(sample_session.session_id) + assert os.path.exists(session_path) + + # Delete session + file_manager.delete_session(sample_session.session_id) + + # Verify deletion + assert not os.path.exists(session_path) + + def test_delete_nonexistent_session(self, file_manager): + """Test deleting a session that doesn't exist.""" + # Should raise an error according to the implementation + with pytest.raises(SessionException, match="does not exist"): + file_manager.delete_session("nonexistent-session") + + +class TestFileSessionManagerAgentOperations: + """Tests for agent operations.""" + + def test_create_agent(self, file_manager, sample_session, sample_agent): + """Test creating an agent in a session.""" + # Create session first + file_manager.create_session(sample_session) + + # Create agent + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Verify directory structure + agent_path = file_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id) + assert os.path.exists(agent_path) + + # Verify agent file + agent_file = os.path.join(agent_path, "agent.json") + assert os.path.exists(agent_file) + + # Verify content + with open(agent_file, "r") as f: + data = json.load(f) + assert data["agent_id"] == sample_agent.agent_id + assert data["state"] == sample_agent.state + + def test_read_agent(self, file_manager, sample_session, sample_agent): + """Test reading an agent from a session.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Read agent + result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + + assert result.agent_id == sample_agent.agent_id + assert result.state == sample_agent.state + + def test_read_nonexistent_agent(self, file_manager, sample_session): + """Test reading an agent that doesn't exist.""" + result = file_manager.read_agent(sample_session.session_id, "nonexistent_agent") + assert result is None + + def test_update_agent(self, file_manager, sample_session, sample_agent): + """Test updating an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Update agent + sample_agent.state = {"updated": "value"} + file_manager.update_agent(sample_session.session_id, sample_agent) + + # Verify update + result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + assert result.state == {"updated": "value"} + + def test_update_nonexistent_agent(self, file_manager, sample_session, sample_agent): + """Test updating an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + + # Update agent + with pytest.raises(SessionException): + file_manager.update_agent(sample_session.session_id, sample_agent) + + +class TestFileSessionManagerMessageOperations: + """Tests for message operations.""" + + def test_create_message(self, file_manager, sample_session, sample_agent, sample_message): + """Test creating a message for an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Create message + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify message file + message_path = file_manager._get_message_path( + sample_session.session_id, sample_agent.agent_id, sample_message.message_id, sample_message.created_at + ) + assert os.path.exists(message_path) + + # Verify content + with open(message_path, "r") as f: + data = json.load(f) + assert data["message_id"] == sample_message.message_id + + def test_read_message(self, file_manager, sample_session, sample_agent, sample_message): + """Test reading a message.""" + # Create session, agent, and message + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Create multiple messages when reading + sample_message.message_id = sample_message.message_id + "_2" + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Read message + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + + assert result.message_id == sample_message.message_id + assert result.message["role"] == sample_message.message["role"] + assert result.message["content"] == sample_message.message["content"] + + def test_read_messages_with_new_agent(self, file_manager, sample_session, sample_agent): + """Test reading a message with with a new agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + + assert result is None + + def test_read_nonexistent_message(self, file_manager, sample_session, sample_agent): + """Test reading a message that doesnt exist.""" + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + assert result is None + + def test_list_messages_all(self, file_manager, sample_session, sample_agent): + """Test listing all messages for an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + messages = [] + for i in range(5): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + } + ) + messages.append(message) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List all messages + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 5 + + def test_list_messages_with_limit(self, file_manager, sample_session, sample_agent): + """Test listing messages with limit.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for i in range(10): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + } + ) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with limit + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3) + + assert len(result) == 3 + + def test_list_messages_with_offset(self, file_manager, sample_session, sample_agent): + """Test listing messages with offset.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for i in range(10): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + } + ) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with offset + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5) + + assert len(result) == 5 + + def test_list_messages_with_new_agent(self, file_manager, sample_session, sample_agent): + """Test listing messages with new agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 0 + + def test_update_message(self, file_manager, sample_session, sample_agent, sample_message): + """Test updating a message.""" + # Create session, agent, and message + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Update message + sample_message.message["content"] = [ContentBlock(text="Updated content")] + file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify update + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + assert result.message["content"][0]["text"] == "Updated content" + + def test_update_nonexistent_message(self, file_manager, sample_session, sample_agent, sample_message): + """Test updating a message.""" + # Create session, agent, and message + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Update nonexistent message + with pytest.raises(SessionException): + file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + +class TestFileSessionManagerErrorHandling: + """Tests for error handling scenarios.""" + + def test_corrupted_json_file(self, file_manager, temp_dir): + """Test handling of corrupted JSON files.""" + # Create a corrupted session file + session_path = os.path.join(temp_dir, "session_test") + os.makedirs(session_path, exist_ok=True) + session_file = os.path.join(session_path, "session.json") + + with open(session_file, "w") as f: + f.write("invalid json content") + + # Should raise SessionException + with pytest.raises(SessionException, match="Invalid JSON"): + file_manager._read_file(session_file) + + def test_permission_error_handling(self, file_manager): + """Test handling of permission errors.""" + with patch("builtins.open", side_effect=PermissionError("Access denied")): + session = Session(session_id="test", session_type=SessionType.AGENT) + + with pytest.raises(SessionException): + file_manager.create_session(session) diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py new file mode 100644 index 000000000..10901d30b --- /dev/null +++ b/tests/strands/session/test_repository_session_manager.py @@ -0,0 +1,136 @@ +"""Tests for AgentSessionManager.""" + +import pytest + +from strands.agent.agent import Agent +from strands.session.repository_session_manager import RepositorySessionManager +from strands.types.content import ContentBlock +from strands.types.exceptions import SessionException +from strands.types.session import Session, SessionAgent, SessionMessage, SessionType +from tests.fixtures.mock_session_repository import MockedSessionRepository + + +@pytest.fixture +def mock_repository(): + """Create a mock repository.""" + return MockedSessionRepository() + + +@pytest.fixture +def session_manager(mock_repository): + """Create a session manager with mock repository.""" + return RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + +@pytest.fixture +def agent(): + """Create a mock agent.""" + return Agent(messages=[{"role": "user", "content": [{"text": "Hello!"}]}]) + + +def test_init_creates_session_if_not_exists(mock_repository): + """Test that init creates a session if it doesn't exist.""" + # Session doesn't exist yet + assert mock_repository.read_session("test-session") is None + + # Creating manager should create session + RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Verify session created + session = mock_repository.read_session("test-session") + assert session is not None + assert session.session_id == "test-session" + assert session.session_type == SessionType.AGENT + + +def test_init_uses_existing_session(mock_repository): + """Test that init uses existing session if it exists.""" + # Create session first + session = Session(session_id="test-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Creating manager should use existing session + manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Verify session used + assert manager.session == session + + +def test_initialize_with_existing_agent_id(session_manager, agent): + """Test initializing an agent with existing agent_id.""" + # Set agent ID + agent.agent_id = "custom-agent" + + # Initialize agent + session_manager.initialize(agent) + + # Verify agent created in repository + agent_data = session_manager.session_repository.read_agent("test-session", "custom-agent") + assert agent_data is not None + assert agent_data.agent_id == "custom-agent" + + +def test_initialize_multiple_agents_without_id(session_manager, agent): + """Test initializing multiple agents with same ID.""" + # First agent initialization works + agent.agent_id = "custom-agent" + session_manager.initialize(agent) + + # Second agent with no set agent_id should fail + agent2 = Agent(agent_id="custom-agent") + + with pytest.raises(SessionException, match="The `agent_id` of an agent must be unique in a session."): + session_manager.initialize(agent2) + + +def test_initialize_restores_existing_agent(session_manager, agent): + """Test that initializing an existing agent restores its state.""" + # Set agent ID + agent.agent_id = "existing-agent" + + # Create agent in repository first + session_agent = SessionAgent(agent_id="existing-agent", state={"key": "value"}) + session_manager.session_repository.create_agent("test-session", session_agent) + + # Create some messages + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text="Hello")], + } + ) + session_manager.session_repository.create_message("test-session", "existing-agent", message) + + # Initialize agent + session_manager.initialize(agent) + + # Verify agent state restored + assert agent.state.get("key") == "value" + assert len(agent.messages) == 1 + assert agent.messages[0]["role"] == "user" + assert agent.messages[0]["content"][0]["text"] == "Hello" + + +def test_append_message(session_manager, agent): + """Test appending a message to an agent's session.""" + # Set agent ID + agent.agent_id = "test-agent" + + # Create agent in repository + session_agent = SessionAgent( + agent_id="test-agent", + state={}, + ) + session_manager.session_repository.create_agent("test-session", session_agent) + + # Create message + message = {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + + # Append message + session_manager.append_message(message, agent) + + # Verify message created in repository + messages = session_manager.session_repository.list_messages("test-session", "test-agent") + assert len(messages) == 1 + assert messages[0].message["role"] == "user" + assert messages[0].message["content"][0]["text"] == "Hello" diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py new file mode 100644 index 000000000..ffc05e53e --- /dev/null +++ b/tests/strands/session/test_s3_session_manager.py @@ -0,0 +1,331 @@ +"""Tests for S3SessionManager.""" + +import json + +import boto3 +import pytest +from botocore.config import Config as BotocoreConfig +from botocore.exceptions import ClientError +from moto import mock_aws + +from strands.session.s3_session_manager import S3SessionManager +from strands.types.content import ContentBlock +from strands.types.exceptions import SessionException +from strands.types.session import Session, SessionAgent, SessionMessage, SessionType + + +@pytest.fixture +def mocked_aws(): + """ + Mock all AWS interactions + Requires you to create your own boto3 clients + """ + with mock_aws(): + yield + + +@pytest.fixture(scope="function") +def s3_bucket(mocked_aws): + """S3 bucket name for testing.""" + # Create the bucket + s3_client = boto3.client("s3", region_name="us-west-2") + s3_client.create_bucket(Bucket="test-session-bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-2"}) + return "test-session-bucket" + + +@pytest.fixture +def s3_manager(mocked_aws, s3_bucket): + """Create S3SessionManager with mocked S3.""" + yield S3SessionManager(session_id="test", bucket=s3_bucket, prefix="sessions/", region_name="us-west-2") + + +@pytest.fixture +def sample_session(): + """Create sample session for testing.""" + return Session( + session_id="test-session-123", + session_type=SessionType.AGENT, + ) + + +@pytest.fixture +def sample_agent(): + """Create sample agent for testing.""" + return SessionAgent( + agent_id="test-agent-456", + state={"key": "value"}, + ) + + +@pytest.fixture +def sample_message(): + """Create sample message for testing.""" + return SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text="test_message")], + } + ) + + +def test_init_s3_session_manager(mocked_aws, s3_bucket): + session_manager = S3SessionManager(session_id="test", bucket=s3_bucket) + assert "strands-agents" in session_manager.client.meta.config.user_agent_extra + + +def test_init_s3_session_manager_with_config(mocked_aws, s3_bucket): + session_manager = S3SessionManager(session_id="test", bucket=s3_bucket, boto_client_config=BotocoreConfig()) + assert "strands-agents" in session_manager.client.meta.config.user_agent_extra + + +def test_init_s3_session_manager_with_existing_user_agent(mocked_aws, s3_bucket): + session_manager = S3SessionManager( + session_id="test", bucket=s3_bucket, boto_client_config=BotocoreConfig(user_agent_extra="test") + ) + assert "strands-agents" in session_manager.client.meta.config.user_agent_extra + + +def test_create_session(s3_manager, sample_session): + """Test creating a session in S3.""" + result = s3_manager.create_session(sample_session) + + assert result == sample_session + + # Verify S3 object created + key = f"{s3_manager._get_session_path(sample_session.session_id)}session.json" + response = s3_manager.client.get_object(Bucket=s3_manager.bucket, Key=key) + data = json.loads(response["Body"].read().decode("utf-8")) + + assert data["session_id"] == sample_session.session_id + assert data["session_type"] == sample_session.session_type + + +def test_create_session_already_exists(s3_manager, sample_session): + """Test creating a session in S3.""" + s3_manager.create_session(sample_session) + + with pytest.raises(SessionException): + s3_manager.create_session(sample_session) + + +def test_read_session(s3_manager, sample_session): + """Test reading a session from S3.""" + # Create session first + s3_manager.create_session(sample_session) + + # Read it back + result = s3_manager.read_session(sample_session.session_id) + + assert result.session_id == sample_session.session_id + assert result.session_type == sample_session.session_type + + +def test_read_nonexistent_session(s3_manager): + """Test reading a session that doesn't exist in S3.""" + with mock_aws(): + result = s3_manager.read_session("nonexistent-session") + assert result is None + + +def test_delete_session(s3_manager, sample_session): + """Test deleting a session from S3.""" + # Create session first + s3_manager.create_session(sample_session) + + # Verify session exists + key = f"{s3_manager._get_session_path(sample_session.session_id)}session.json" + s3_manager.client.head_object(Bucket=s3_manager.bucket, Key=key) + + # Delete session + s3_manager.delete_session(sample_session.session_id) + + # Verify deletion + with pytest.raises(ClientError) as excinfo: + s3_manager.client.head_object(Bucket=s3_manager.bucket, Key=key) + assert excinfo.value.response["Error"]["Code"] == "404" + + +def test_create_agent(s3_manager, sample_session, sample_agent): + """Test creating an agent in S3.""" + # Create session first + s3_manager.create_session(sample_session) + + # Create agent + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Verify S3 object created + key = f"{s3_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id)}agent.json" + response = s3_manager.client.get_object(Bucket=s3_manager.bucket, Key=key) + data = json.loads(response["Body"].read().decode("utf-8")) + + assert data["agent_id"] == sample_agent.agent_id + assert data["state"] == sample_agent.state + + +def test_read_agent(s3_manager, sample_session, sample_agent): + """Test reading an agent from S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Read agent + result = s3_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + + assert result.agent_id == sample_agent.agent_id + assert result.state == sample_agent.state + + +def test_read_nonexistent_agent(s3_manager, sample_session, sample_agent): + """Test reading an agent from S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + # Read agent + result = s3_manager.read_agent(sample_session.session_id, "nonexistent_agent") + + assert result is None + + +def test_update_agent(s3_manager, sample_session, sample_agent): + """Test updating an agent in S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Update agent + sample_agent.state = {"updated": "value"} + s3_manager.update_agent(sample_session.session_id, sample_agent) + + # Verify update + result = s3_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + assert result.state == {"updated": "value"} + + +def test_update_nonexistent_agent(s3_manager, sample_session, sample_agent): + """Test updating an agent in S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + + with pytest.raises(SessionException): + s3_manager.update_agent(sample_session.session_id, sample_agent) + + +def test_create_message(s3_manager, sample_session, sample_agent, sample_message): + """Test creating a message in S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Create message + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify S3 object created + key = s3_manager._get_message_path( + sample_session.session_id, sample_agent.agent_id, sample_message.message_id, sample_message.created_at + ) + response = s3_manager.client.get_object(Bucket=s3_manager.bucket, Key=key) + data = json.loads(response["Body"].read().decode("utf-8")) + + assert data["message_id"] == sample_message.message_id + + +def test_read_message(s3_manager, sample_session, sample_agent, sample_message): + """Test reading a message from S3.""" + # Create session, agent, and message + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Read message + result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + + assert result.message_id == sample_message.message_id + assert result.message["role"] == sample_message.message["role"] + assert result.message["content"] == sample_message.message["content"] + + +def test_read_nonexistent_message(s3_manager, sample_session, sample_agent, sample_message): + """Test reading a message from S3.""" + # Create session, agent, and message + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Read message + result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + + assert result is None + + +def test_list_messages_all(s3_manager, sample_session, sample_agent): + """Test listing all messages from S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + messages = [] + for i in range(5): + message = SessionMessage( + { + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + } + ) + messages.append(message) + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List all messages + result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 5 + + +def test_list_messages_with_pagination(s3_manager, sample_session, sample_agent): + """Test listing messages with pagination in S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for _ in range(10): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text="test_message")], + } + ) + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with limit + result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3) + assert len(result) == 3 + + # List with offset + result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5) + assert len(result) == 5 + + +def test_update_message(s3_manager, sample_session, sample_agent, sample_message): + """Test updating a message in S3.""" + # Create session, agent, and message + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Update message + sample_message.message["content"] = [ContentBlock(text="Updated content")] + s3_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify update + result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + assert result.message["content"][0]["text"] == "Updated content" + + +def test_update_nonexistent_message(s3_manager, sample_session, sample_agent, sample_message): + """Test updating a message in S3.""" + # Create session, agent, and message + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Update message + with pytest.raises(SessionException): + s3_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) diff --git a/tests/strands/types/test_session.py b/tests/strands/types/test_session.py new file mode 100644 index 000000000..596fe337c --- /dev/null +++ b/tests/strands/types/test_session.py @@ -0,0 +1,91 @@ +import json +from dataclasses import asdict +from uuid import uuid4 + +from strands.types.session import ( + Session, + SessionAgent, + SessionMessage, + SessionType, + decode_bytes_values, + encode_bytes_values, +) + + +def test_session_json_serializable(): + session = Session(session_id=str(uuid4()), session_type=SessionType.AGENT) + # json dumps will fail if its not json serializable + session_json_string = json.dumps(asdict(session)) + loaded_session = Session.from_dict(json.loads(session_json_string)) + assert loaded_session is not None + + +def test_agent_json_serializable(): + agent = SessionAgent(agent_id=str(uuid4()), state={"foo": "bar"}) + # json dumps will fail if its not json serializable + agent_json_string = json.dumps(asdict(agent)) + loaded_agent = SessionAgent.from_dict(json.loads(agent_json_string)) + assert loaded_agent is not None + + +def test_message_json_serializable(): + message = SessionMessage(message={"role": "user", "content": [{"text": "Hello!"}]}) + # json dumps will fail if its not json serializable + message_json_string = json.dumps(asdict(message)) + loaded_message = SessionMessage.from_dict(json.loads(message_json_string)) + assert loaded_message is not None + + +def test_bytes_encoding_decoding(): + # Test simple bytes + test_bytes = b"Hello, world!" + encoded = encode_bytes_values(test_bytes) + assert isinstance(encoded, dict) + assert encoded["__bytes_encoded__"] is True + decoded = decode_bytes_values(encoded) + assert decoded == test_bytes + + # Test nested structure with bytes + test_data = { + "text": "Hello", + "binary": b"Binary data", + "nested": {"more_binary": b"More binary data", "list_with_binary": [b"Item 1", "Text item", b"Item 3"]}, + } + + encoded = encode_bytes_values(test_data) + # Verify it's JSON serializable + json_str = json.dumps(encoded) + # Deserialize and decode + decoded = decode_bytes_values(json.loads(json_str)) + + # Verify the decoded data matches the original + assert decoded["text"] == test_data["text"] + assert decoded["binary"] == test_data["binary"] + assert decoded["nested"]["more_binary"] == test_data["nested"]["more_binary"] + assert decoded["nested"]["list_with_binary"][0] == test_data["nested"]["list_with_binary"][0] + assert decoded["nested"]["list_with_binary"][1] == test_data["nested"]["list_with_binary"][1] + assert decoded["nested"]["list_with_binary"][2] == test_data["nested"]["list_with_binary"][2] + + +def test_session_message_with_bytes(): + # Create a message with bytes content + message = { + "role": "user", + "content": [{"text": "Here is some binary data"}, {"binary_data": b"This is binary data"}], + } + + # Create a SessionMessage + session_message = SessionMessage.from_message(message) + + # Verify it's JSON serializable + message_json_string = json.dumps(asdict(session_message)) + + # Load it back + loaded_message = SessionMessage.from_dict(json.loads(message_json_string)) + + # Convert back to original message and verify + original_message = loaded_message.to_message() + + assert original_message["role"] == message["role"] + assert original_message["content"][0]["text"] == message["content"][0]["text"] + assert original_message["content"][1]["binary_data"] == message["content"][1]["binary_data"] diff --git a/tests_integ/test_session.py b/tests_integ/test_session.py new file mode 100644 index 000000000..fbfd54384 --- /dev/null +++ b/tests_integ/test_session.py @@ -0,0 +1,123 @@ +"""Integration tests for session management.""" + +import tempfile +from uuid import uuid4 + +import boto3 +import pytest +from botocore.client import ClientError + +from strands import Agent +from strands.session.file_session_manager import FileSessionManager +from strands.session.s3_session_manager import S3SessionManager + + +@pytest.fixture +def yellow_img(pytestconfig): + path = pytestconfig.rootdir / "tests_integ/yellow.png" + with open(path, "rb") as fp: + return fp.read() + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + +@pytest.fixture +def bucket_name(): + bucket_name = f"test-strands-session-bucket-{boto3.client('sts').get_caller_identity()['Account']}" + s3_client = boto3.resource("s3", region_name="us-west-2") + try: + s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": "us-west-2"}) + except ClientError as e: + if "BucketAlreadyOwnedByYou" not in str(e): + raise e + yield bucket_name + + +def test_agent_with_file_session(temp_dir): + # Set up the session manager and add an agent + test_session_id = str(uuid4()) + # Create a session + session_manager = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) + try: + agent = Agent(session_manager=session_manager) + agent("Hello!") + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 + + # After agent is persisted and run, restore the agent and run it again + session_manager_2 = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) + agent_2 = Agent(session_manager=session_manager_2) + assert len(agent_2.messages) == 2 + agent_2("Hello!") + assert len(agent_2.messages) == 4 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + finally: + # Delete the session + session_manager.delete_session(test_session_id) + assert session_manager.read_session(test_session_id) is None + + +def test_agent_with_file_session_with_image(temp_dir, yellow_img): + test_session_id = str(uuid4()) + # Create a session + session_manager = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) + try: + agent = Agent(session_manager=session_manager) + agent([{"image": {"format": "png", "source": {"bytes": yellow_img}}}]) + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 + + # After agent is persisted and run, restore the agent and run it again + session_manager_2 = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) + agent_2 = Agent(session_manager=session_manager_2) + assert len(agent_2.messages) == 2 + agent_2("Hello!") + assert len(agent_2.messages) == 4 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + finally: + # Delete the session + session_manager.delete_session(test_session_id) + assert session_manager.read_session(test_session_id) is None + + +def test_agent_with_s3_session(bucket_name): + test_session_id = str(uuid4()) + session_manager = S3SessionManager(session_id=test_session_id, bucket=bucket_name, region_name="us-west-2") + try: + agent = Agent(session_manager=session_manager) + agent("Hello!") + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 + + # After agent is persisted and run, restore the agent and run it again + session_manager_2 = S3SessionManager(session_id=test_session_id, bucket=bucket_name, region_name="us-west-2") + agent_2 = Agent(session_manager=session_manager_2) + assert len(agent_2.messages) == 2 + agent_2("Hello!") + assert len(agent_2.messages) == 4 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + finally: + session_manager.delete_session(test_session_id) + assert session_manager.read_session(test_session_id) is None + + +def test_agent_with_s3_session_with_image(yellow_img, bucket_name): + test_session_id = str(uuid4()) + session_manager = S3SessionManager(session_id=test_session_id, bucket=bucket_name, region_name="us-west-2") + try: + agent = Agent(session_manager=session_manager) + agent([{"image": {"format": "png", "source": {"bytes": yellow_img}}}]) + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 + + # After agent is persisted and run, restore the agent and run it again + session_manager_2 = S3SessionManager(session_id=test_session_id, bucket=bucket_name, region_name="us-west-2") + agent_2 = Agent(session_manager=session_manager_2) + assert len(agent_2.messages) == 2 + agent_2("Hello!") + assert len(agent_2.messages) == 4 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + finally: + session_manager.delete_session(test_session_id) + assert session_manager.read_session(test_session_id) is None