Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
metrics = _parse_metrics(data.get("accumulated_metrics", {}))

multiagent_result = cls(
status=Status(data.get("status", Status.PENDING.value)),
status=Status(data["status"]),
results=results,
accumulated_usage=usage,
accumulated_metrics=metrics,
Expand All @@ -164,8 +164,13 @@ class MultiAgentBase(ABC):

This class integrates with existing Strands Agent instances and provides
multi-agent orchestration capabilities.

Attributes:
id: Unique MultiAgent id for session management,etc.
"""

id: str

@abstractmethod
async def invoke_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
Expand Down
52 changes: 49 additions & 3 deletions src/strands/session/file_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@
import os
import shutil
import tempfile
from typing import Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast

from .. import _identifier
from ..types.exceptions import SessionException
from ..types.session import Session, SessionAgent, SessionMessage
from .repository_session_manager import RepositorySessionManager
from .session_repository import SessionRepository

if TYPE_CHECKING:
from ..multiagent.base import MultiAgentBase

logger = logging.getLogger(__name__)

SESSION_PREFIX = "session_"
AGENT_PREFIX = "agent_"
MESSAGE_PREFIX = "message_"
MULTI_AGENT_PREFIX = "multi_agent_"


class FileSessionManager(RepositorySessionManager, SessionRepository):
Expand All @@ -37,7 +41,12 @@ class FileSessionManager(RepositorySessionManager, SessionRepository):
```
"""

def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any):
def __init__(
self,
session_id: str,
storage_dir: Optional[str] = None,
**kwargs: Any,
):
"""Initialize FileSession with filesystem storage.

Args:
Expand Down Expand Up @@ -107,8 +116,11 @@ def _read_file(self, path: str) -> dict[str, Any]:
def _write_file(self, path: str, data: dict[str, Any]) -> None:
"""Write JSON file."""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
# This automic write ensure the completeness of session files in both single agent/ multi agents
tmp = f"{path}.tmp"
with open(tmp, "w", encoding="utf-8", newline="\n") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
os.replace(tmp, path)

def create_session(self, session: Session, **kwargs: Any) -> Session:
"""Create a new session."""
Expand All @@ -119,6 +131,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session:
# Create directory structure
os.makedirs(session_dir, exist_ok=True)
os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True)
os.makedirs(os.path.join(session_dir, "multi_agents"), exist_ok=True)

# Write session file
session_file = os.path.join(session_dir, "session.json")
Expand Down Expand Up @@ -239,3 +252,36 @@ def list_messages(
messages.append(SessionMessage.from_dict(message_data))

return messages

def _get_multi_agent_path(self, session_id: str, multi_agent_id: str) -> str:
"""Get multi-agent state file path."""
session_path = self._get_session_path(session_id)
multi_agent_id = _identifier.validate(multi_agent_id, _identifier.Identifier.AGENT)
return os.path.join(session_path, "multi_agents", f"{MULTI_AGENT_PREFIX}{multi_agent_id}")

def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None:
"""Create a new multiagent state in the session."""
multi_agent_id = multi_agent.id
multi_agent_dir = self._get_multi_agent_path(session_id, multi_agent_id)
os.makedirs(multi_agent_dir, exist_ok=True)

multi_agent_file = os.path.join(multi_agent_dir, "multi_agent.json")
session_data = multi_agent.serialize_state()
self._write_file(multi_agent_file, session_data)

def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]:
"""Read multi-agent state from filesystem."""
multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent_id), "multi_agent.json")
if not os.path.exists(multi_agent_file):
return None
return self._read_file(multi_agent_file)

def update_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None:
"""Update multi-agent state from filesystem."""
multi_agent_state = multi_agent.serialize_state()
previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent.id)
if previous_multi_agent_state is None:
raise SessionException(f"MultiAgent state {multi_agent.id} in session {session_id} does not exist")

multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent.id), "multi_agent.json")
self._write_file(multi_agent_file, multi_agent_state)
31 changes: 30 additions & 1 deletion src/strands/session/repository_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@

if TYPE_CHECKING:
from ..agent.agent import Agent
from ..multiagent.base import MultiAgentBase

logger = logging.getLogger(__name__)


class RepositorySessionManager(SessionManager):
"""Session manager for persisting agents in a SessionRepository."""

def __init__(self, session_id: str, session_repository: SessionRepository, **kwargs: Any):
def __init__(
self,
session_id: str,
session_repository: SessionRepository,
**kwargs: Any,
):
"""Initialize the RepositorySessionManager.

If no session with the specified session_id exists yet, it will be created
Expand Down Expand Up @@ -152,3 +158,26 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:

# Restore the agents messages array including the optional prepend messages
agent.messages = prepend_messages + [session_message.to_message() for session_message in session_messages]

def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None:
"""Serialize and update the multi-agent state into the session repository.

Args:
source: Multi-agent source object to sync to the session.
**kwargs: Additional keyword arguments for future extensibility.
"""
self.session_repository.update_multi_agent(self.session_id, source)

def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None:
"""Initialize multi-agent state from the session repository.

Args:
source: Multi-agent source object to restore state into
**kwargs: Additional keyword arguments for future extensibility.
"""
state = self.session_repository.read_multi_agent(self.session_id, source.id, **kwargs)
if state is None:
self.session_repository.create_multi_agent(self.session_id, source, **kwargs)
else:
logger.debug("session_id=<%s> | restoring multi-agent state", self.session_id)
source.deserialize_state(state)
34 changes: 33 additions & 1 deletion src/strands/session/s3_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import logging
from typing import Any, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast

import boto3
from botocore.config import Config as BotocoreConfig
Expand All @@ -14,11 +14,15 @@
from .repository_session_manager import RepositorySessionManager
from .session_repository import SessionRepository

if TYPE_CHECKING:
from ..multiagent.base import MultiAgentBase

logger = logging.getLogger(__name__)

SESSION_PREFIX = "session_"
AGENT_PREFIX = "agent_"
MESSAGE_PREFIX = "message_"
MULTI_AGENT_PREFIX = "multi_agent_"


class S3SessionManager(RepositorySessionManager, SessionRepository):
Expand Down Expand Up @@ -294,3 +298,31 @@ def list_messages(

except ClientError as e:
raise SessionException(f"S3 error reading messages: {e}") from e

def _get_multi_agent_path(self, session_id: str, multi_agent_id: str) -> str:
"""Get multi-agent S3 prefix."""
session_path = self._get_session_path(session_id)
multi_agent_id = _identifier.validate(multi_agent_id, _identifier.Identifier.AGENT)
return f"{session_path}multi_agents/{MULTI_AGENT_PREFIX}{multi_agent_id}/"

def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None:
"""Create a new multiagent state in S3."""
multi_agent_id = multi_agent.id
multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json"
session_data = multi_agent.serialize_state()
self._write_s3_object(multi_agent_key, session_data)

def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]:
"""Read multi-agent state from S3."""
multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json"
return self._read_s3_object(multi_agent_key)

def update_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None:
"""Update multi-agent state in S3."""
multi_agent_state = multi_agent.serialize_state()
previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent.id)
if previous_multi_agent_state is None:
raise SessionException(f"MultiAgent state {multi_agent.id} in session {session_id} does not exist")

multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent.id)}multi_agent.json"
self._write_s3_object(multi_agent_key, multi_agent_state)
43 changes: 43 additions & 0 deletions src/strands/session/session_manager.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
"""Session manager interface for agent session management."""

import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

from ..experimental.hooks.multiagent.events import (
AfterMultiAgentInvocationEvent,
AfterNodeCallEvent,
MultiAgentInitializedEvent,
)
from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent
from ..hooks.registry import HookProvider, HookRegistry
from ..types.content import Message

if TYPE_CHECKING:
from ..agent.agent import Agent
from ..multiagent.base import MultiAgentBase

logger = logging.getLogger(__name__)


class SessionManager(HookProvider, ABC):
Expand All @@ -34,6 +43,10 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
# After an agent was invoked, sync it with the session to capture any conversation manager state updates
registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent))

registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source))
registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source))
registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source))

@abstractmethod
def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None:
"""Redact the message most recently appended to the agent in the session.
Expand Down Expand Up @@ -71,3 +84,33 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
agent: Agent to initialize
**kwargs: Additional keyword arguments for future extensibility.
"""

def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None:
"""Serialize and sync multi-agent with the session storage.

Args:
source: Multi-agent source object to persist
**kwargs: Additional keyword arguments for future extensibility.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support multi-agent persistence "
"(sync_multi_agent). Provide an implementation or use a "
"SessionManager with session_type=SessionType.MULTI_AGENT."
)

def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None:
"""Read multi-agent state from persistent storage.

Args:
**kwargs: Additional keyword arguments for future extensibility.
source: Multi-agent state to initialize.

Returns:
Multi-agent state dictionary or empty dict if not found.

"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support multi-agent persistence "
"(initialize_multi_agent). Provide an implementation or use a "
"SessionManager with session_type=SessionType.MULTI_AGENT."
)
17 changes: 16 additions & 1 deletion src/strands/session/session_repository.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Session repository interface for agent session management."""

from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

from ..types.session import Session, SessionAgent, SessionMessage

if TYPE_CHECKING:
from ..multiagent import MultiAgentBase


class SessionRepository(ABC):
"""Abstract repository for creating, reading, and updating Sessions, AgentSessions, and AgentMessages."""
Expand Down Expand Up @@ -49,3 +52,15 @@ def list_messages(
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any
) -> list[SessionMessage]:
"""List Messages from an Agent with pagination."""

def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None:
"""Create a new MultiAgent state for the Session."""
raise NotImplementedError("MultiAgent is not implemented for this repository")

def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]:
"""Read the MultiAgent state for the Session."""
raise NotImplementedError("MultiAgent is not implemented for this repository")

def update_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None:
"""Update the MultiAgent state for the Session."""
raise NotImplementedError("MultiAgent is not implemented for this repository")
2 changes: 1 addition & 1 deletion src/strands/types/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class SessionType(str, Enum):
"""Enumeration of session types.

As sessions are expanded to support new usecases like multi-agent patterns,
As sessions are expanded to support new use cases like multi-agent patterns,
new types will be added here.
"""

Expand Down
26 changes: 26 additions & 0 deletions tests/fixtures/mock_session_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self):
self.sessions = {}
self.agents = {}
self.messages = {}
self.multi_agents = {}

def create_session(self, session) -> None:
"""Create a session."""
Expand All @@ -20,6 +21,7 @@ def create_session(self, session) -> None:
self.sessions[session_id] = session
self.agents[session_id] = {}
self.messages[session_id] = {}
self.multi_agents[session_id] = {}

def read_session(self, session_id) -> SessionAgent:
"""Read a session."""
Expand Down Expand Up @@ -95,3 +97,27 @@ def list_messages(self, session_id, agent_id, limit=None, offset=0) -> list[Sess
if limit is not None:
return sorted_messages[offset : offset + limit]
return sorted_messages[offset:]

def create_multi_agent(self, session_id, multi_agent, **kwargs) -> None:
"""Create multi-agent state."""
multi_agent_id = multi_agent.id
if session_id not in self.sessions:
raise SessionException(f"Session {session_id} does not exist")
state = multi_agent.serialize_state()
self.multi_agents.setdefault(session_id, {})[multi_agent_id] = state

def read_multi_agent(self, session_id, multi_agent_id, **kwargs):
"""Read multi-agent state."""
if session_id not in self.sessions:
return None
return self.multi_agents.get(session_id, {}).get(multi_agent_id)

def update_multi_agent(self, session_id, multi_agent, **kwargs) -> None:
"""Update multi-agent state."""
multi_agent_id = multi_agent.id
if session_id not in self.sessions:
raise SessionException(f"Session {session_id} does not exist")
if multi_agent_id not in self.multi_agents.get(session_id, {}):
raise SessionException(f"MultiAgent {multi_agent} does not exist in session {session_id}")
state = multi_agent.serialize_state()
self.multi_agents[session_id][multi_agent_id] = state
Loading
Loading