Skip to content

Commit 3741427

Browse files
committed
feat: remove sessionType, combine session creation branches
1 parent 2157597 commit 3741427

File tree

8 files changed

+75
-154
lines changed

8 files changed

+75
-154
lines changed

src/strands/session/file_session_manager.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .. import _identifier
1111
from ..types.exceptions import SessionException
12-
from ..types.session import Session, SessionAgent, SessionMessage, SessionType
12+
from ..types.session import Session, SessionAgent, SessionMessage
1313
from .repository_session_manager import RepositorySessionManager
1414
from .session_repository import SessionRepository
1515

@@ -45,8 +45,6 @@ def __init__(
4545
self,
4646
session_id: str,
4747
storage_dir: Optional[str] = None,
48-
*,
49-
session_type: SessionType = SessionType.AGENT,
5048
**kwargs: Any,
5149
):
5250
"""Initialize FileSession with filesystem storage.
@@ -55,13 +53,12 @@ def __init__(
5553
session_id: ID for the session.
5654
ID is not allowed to contain path separators (e.g., a/b).
5755
storage_dir: Directory for local filesystem storage (defaults to temp dir).
58-
session_type: single agent or multiagent.
5956
**kwargs: Additional keyword arguments for future extensibility.
6057
"""
6158
self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions")
6259
os.makedirs(self.storage_dir, exist_ok=True)
6360

64-
super().__init__(session_id=session_id, session_repository=self, session_type=session_type)
61+
super().__init__(session_id=session_id, session_repository=self)
6562

6663
def _get_session_path(self, session_id: str) -> str:
6764
"""Get session directory path.
@@ -133,10 +130,9 @@ def create_session(self, session: Session, **kwargs: Any) -> Session:
133130

134131
# Create directory structure
135132
os.makedirs(session_dir, exist_ok=True)
136-
if self.session_type == SessionType.AGENT:
137-
os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True)
138-
else:
139-
os.makedirs(os.path.join(session_dir, "multi_agents"), exist_ok=True)
133+
os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True)
134+
os.makedirs(os.path.join(session_dir, "multi_agents"), exist_ok=True)
135+
140136
# Write session file
141137
session_file = os.path.join(session_dir, "session.json")
142138
session_dict = session.to_dict()

src/strands/session/repository_session_manager.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ def __init__(
2929
self,
3030
session_id: str,
3131
session_repository: SessionRepository,
32-
*,
33-
session_type: SessionType = SessionType.AGENT,
3432
**kwargs: Any,
3533
):
3634
"""Initialize the RepositorySessionManager.
@@ -42,27 +40,22 @@ def __init__(
4240
session_id: ID to use for the session. A new session with this id will be created if it does
4341
not exist in the repository yet
4442
session_repository: Underlying session repository to use to store the sessions state.
45-
session_type: single agent or multiagent.
4643
**kwargs: Additional keyword arguments for future extensibility.
4744
4845
"""
49-
super().__init__(session_type=session_type)
50-
5146
self.session_repository = session_repository
5247
self.session_id = session_id
5348
session = session_repository.read_session(session_id)
5449
# Create a session if it does not exist yet
5550
if session is None:
5651
logger.debug("session_id=<%s> | session not found, creating new session", self.session_id)
57-
session = Session(session_id=session_id, session_type=session_type)
52+
session = Session(session_id=session_id, session_type=SessionType.AGENT)
5853
session_repository.create_session(session)
5954

6055
self.session = session
61-
self.session_type = session.session_type
6256

6357
# Keep track of the latest message of each agent in case we need to redact it.
64-
if self.session_type == SessionType.AGENT:
65-
self._latest_agent_message: dict[str, Optional[SessionMessage]] = {}
58+
self._latest_agent_message: dict[str, Optional[SessionMessage]] = {}
6659

6760
def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None:
6861
"""Append a message to the agent's session.

src/strands/session/s3_session_manager.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .. import _identifier
1212
from ..types.exceptions import SessionException
13-
from ..types.session import Session, SessionAgent, SessionMessage, SessionType
13+
from ..types.session import Session, SessionAgent, SessionMessage
1414
from .repository_session_manager import RepositorySessionManager
1515
from .session_repository import SessionRepository
1616

@@ -50,8 +50,6 @@ def __init__(
5050
boto_session: Optional[boto3.Session] = None,
5151
boto_client_config: Optional[BotocoreConfig] = None,
5252
region_name: Optional[str] = None,
53-
*,
54-
session_type: SessionType = SessionType.AGENT,
5553
**kwargs: Any,
5654
):
5755
"""Initialize S3SessionManager with S3 storage.
@@ -64,7 +62,6 @@ def __init__(
6462
boto_session: Optional boto3 session
6563
boto_client_config: Optional boto3 client configuration
6664
region_name: AWS region for S3 storage
67-
session_type: single agent or multiagent.
6865
**kwargs: Additional keyword arguments for future extensibility.
6966
"""
7067
self.bucket = bucket
@@ -85,7 +82,7 @@ def __init__(
8582
client_config = BotocoreConfig(user_agent_extra="strands-agents")
8683

8784
self.client = session.client(service_name="s3", config=client_config)
88-
super().__init__(session_id=session_id, session_type=session_type, session_repository=self)
85+
super().__init__(session_id=session_id, session_repository=self)
8986

9087
def _get_session_path(self, session_id: str) -> str:
9188
"""Get session S3 prefix.

src/strands/session/session_manager.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent
1313
from ..hooks.registry import HookProvider, HookRegistry
1414
from ..types.content import Message
15-
from ..types.session import SessionType
1615

1716
if TYPE_CHECKING:
1817
from ..agent.agent import Agent
@@ -30,37 +29,23 @@ class SessionManager(HookProvider, ABC):
3029
for an agent, and should be persisted in the session.
3130
"""
3231

33-
def __init__(self, session_type: SessionType = SessionType.AGENT) -> None:
34-
"""Initialize SessionManager with session type.
35-
36-
Args:
37-
session_type: Type of session (AGENT or MULTI_AGENT)
38-
"""
39-
self.session_type: SessionType = session_type
40-
4132
def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
4233
"""Register hooks for persisting the agent to the session."""
43-
if not hasattr(self, "session_type"):
44-
self.session_type = SessionType.AGENT
45-
logger.debug("Session type not set, defaulting to AGENT")
46-
47-
if self.session_type == SessionType.MULTI_AGENT:
48-
registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source))
49-
registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source))
50-
registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source))
34+
# After the normal Agent initialization behavior, call the session initialize function to restore the agent
35+
registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent))
5136

52-
else:
53-
# After the normal Agent initialization behavior, call the session initialize function to restore the agent
54-
registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent))
37+
# For each message appended to the Agents messages, store that message in the session
38+
registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent))
5539

56-
# For each message appended to the Agents messages, store that message in the session
57-
registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent))
40+
# Sync the agent into the session for each message in case the agent state was updated
41+
registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent))
5842

59-
# Sync the agent into the session for each message in case the agent state was updated
60-
registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent))
43+
# After an agent was invoked, sync it with the session to capture any conversation manager state updates
44+
registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent))
6145

62-
# After an agent was invoked, sync it with the session to capture any conversation manager state updates
63-
registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent))
46+
registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source))
47+
registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source))
48+
registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source))
6449

6550
@abstractmethod
6651
def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None:

src/strands/types/session.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717
class SessionType(str, Enum):
1818
"""Enumeration of session types.
1919
20-
As sessions are expanded to support new usecases like multi-agent patterns,
20+
As sessions are expanded to support new use cases like multi-agent patterns,
2121
new types will be added here.
2222
"""
2323

2424
AGENT = "AGENT"
25-
MULTI_AGENT = "MULTI_AGENT"
2625

2726

2827
def encode_bytes_values(obj: Any) -> Any:

tests/strands/session/test_file_session_manager.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,10 @@ def mock_multi_agent():
6363
return mock
6464

6565

66-
@pytest.fixture
67-
def multi_agent_session():
68-
"""Create sample multi-agent session for testing."""
69-
return Session(session_id="test-session", session_type=SessionType.MULTI_AGENT)
70-
71-
7266
@pytest.fixture
7367
def multi_agent_manager(temp_dir):
74-
"""Create FileSessionManager with multi-agent session type."""
75-
return FileSessionManager(session_id="test", storage_dir=temp_dir, session_type=SessionType.MULTI_AGENT)
68+
"""Create FileSessionManager."""
69+
return FileSessionManager(session_id="test", storage_dir=temp_dir)
7670

7771

7872
def test_create_session(file_manager, sample_session):
@@ -432,14 +426,14 @@ def test__get_message_path_invalid_message_id(message_id, file_manager):
432426
file_manager._get_message_path("session1", "agent1", message_id)
433427

434428

435-
def test_create_multi_agent(multi_agent_manager, multi_agent_session, mock_multi_agent):
429+
def test_create_multi_agent(multi_agent_manager, sample_session, mock_multi_agent):
436430
"""Test creating multi-agent state."""
437-
multi_agent_manager.create_session(multi_agent_session)
438-
multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent)
431+
multi_agent_manager.create_session(sample_session)
432+
multi_agent_manager.create_multi_agent(sample_session.session_id, mock_multi_agent)
439433

440434
# Verify file created
441435
multi_agent_file = os.path.join(
442-
multi_agent_manager._get_multi_agent_path(multi_agent_session.session_id, mock_multi_agent.id),
436+
multi_agent_manager._get_multi_agent_path(sample_session.session_id, mock_multi_agent.id),
443437
"multi_agent.json",
444438
)
445439
assert os.path.exists(multi_agent_file)
@@ -451,58 +445,58 @@ def test_create_multi_agent(multi_agent_manager, multi_agent_session, mock_multi
451445
assert data["state"] == mock_multi_agent.state
452446

453447

454-
def test_read_multi_agent(multi_agent_manager, multi_agent_session, mock_multi_agent):
448+
def test_read_multi_agent(multi_agent_manager, sample_session, mock_multi_agent):
455449
"""Test reading multi-agent state."""
456450
# Create session and multi-agent
457-
multi_agent_manager.create_session(multi_agent_session)
458-
multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent)
451+
multi_agent_manager.create_session(sample_session)
452+
multi_agent_manager.create_multi_agent(sample_session.session_id, mock_multi_agent)
459453

460454
# Read multi-agent
461-
result = multi_agent_manager.read_multi_agent(multi_agent_session.session_id, mock_multi_agent.id)
455+
result = multi_agent_manager.read_multi_agent(sample_session.session_id, mock_multi_agent.id)
462456

463457
assert result["id"] == mock_multi_agent.id
464458
assert result["state"] == mock_multi_agent.state
465459

466460

467-
def test_read_nonexistent_multi_agent(multi_agent_manager, multi_agent_session):
461+
def test_read_nonexistent_multi_agent(multi_agent_manager, sample_session):
468462
"""Test reading multi-agent state that doesn't exist."""
469-
result = multi_agent_manager.read_multi_agent(multi_agent_session.session_id, "nonexistent")
463+
result = multi_agent_manager.read_multi_agent(sample_session.session_id, "nonexistent")
470464
assert result is None
471465

472466

473-
def test_update_multi_agent(multi_agent_manager, multi_agent_session, mock_multi_agent):
467+
def test_update_multi_agent(multi_agent_manager, sample_session, mock_multi_agent):
474468
"""Test updating multi-agent state."""
475469
# Create session and multi-agent
476-
multi_agent_manager.create_session(multi_agent_session)
477-
multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent)
470+
multi_agent_manager.create_session(sample_session)
471+
multi_agent_manager.create_multi_agent(sample_session.session_id, mock_multi_agent)
478472

479473
updated_mock = Mock()
480474
updated_mock.id = mock_multi_agent.id
481475
updated_mock.serialize_state.return_value = {"id": mock_multi_agent.id, "state": {"updated": "value"}}
482-
multi_agent_manager.update_multi_agent(multi_agent_session.session_id, updated_mock)
476+
multi_agent_manager.update_multi_agent(sample_session.session_id, updated_mock)
483477

484478
# Verify update
485-
result = multi_agent_manager.read_multi_agent(multi_agent_session.session_id, mock_multi_agent.id)
479+
result = multi_agent_manager.read_multi_agent(sample_session.session_id, mock_multi_agent.id)
486480
assert result["state"] == {"updated": "value"}
487481

488482

489-
def test_update_nonexistent_multi_agent(multi_agent_manager, multi_agent_session):
483+
def test_update_nonexistent_multi_agent(multi_agent_manager, sample_session):
490484
"""Test updating multi-agent state that doesn't exist."""
491485
# Create session
492-
multi_agent_manager.create_session(multi_agent_session)
486+
multi_agent_manager.create_session(sample_session)
493487

494488
nonexistent_mock = Mock()
495489
nonexistent_mock.id = "nonexistent"
496490
with pytest.raises(SessionException):
497-
multi_agent_manager.update_multi_agent(multi_agent_session.session_id, nonexistent_mock)
491+
multi_agent_manager.update_multi_agent(sample_session.session_id, nonexistent_mock)
498492

499493

500-
def test_create_session_multi_agent_directory_structure(multi_agent_manager, multi_agent_session):
494+
def test_create_session_multi_agent_directory_structure(multi_agent_manager, sample_session):
501495
"""Test multi-agent session creates correct directory structure."""
502-
multi_agent_manager.create_session(multi_agent_session)
496+
multi_agent_manager.create_session(sample_session)
503497

504498
# Verify directory structure
505-
session_dir = multi_agent_manager._get_session_path(multi_agent_session.session_id)
499+
session_dir = multi_agent_manager._get_session_path(sample_session.session_id)
506500
multi_agents_dir = os.path.join(session_dir, "multi_agents")
507501

508502
assert os.path.exists(session_dir)

tests/strands/session/test_repository_session_manager.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,6 @@ def mock_multi_agent():
4444
return mock
4545

4646

47-
@pytest.fixture
48-
def multi_agent_session_manager(mock_repository):
49-
"""Create a multi-agent session manager."""
50-
return RepositorySessionManager(
51-
session_id="test-multi-session", session_repository=mock_repository, session_type=SessionType.MULTI_AGENT
52-
)
53-
54-
5547
def test_init_creates_session_if_not_exists(mock_repository):
5648
"""Test that init creates a session if it doesn't exist."""
5749
# Session doesn't exist yet
@@ -200,55 +192,44 @@ def test_append_message(session_manager):
200192
assert messages[0].message["content"][0]["text"] == "Hello"
201193

202194

203-
def test_init_multi_agent_session_type(mock_repository):
204-
"""Test creating session manager with multi-agent type."""
205-
manager = RepositorySessionManager(
206-
session_id="multi-session", session_repository=mock_repository, session_type=SessionType.MULTI_AGENT
207-
)
208-
209-
assert manager.session_type == SessionType.MULTI_AGENT
210-
session = mock_repository.read_session("multi-session")
211-
assert session.session_type == SessionType.MULTI_AGENT
212-
213-
214-
def test_sync_multi_agent(multi_agent_session_manager, mock_multi_agent):
195+
def test_sync_multi_agent(session_manager, mock_multi_agent):
215196
"""Test syncing multi-agent state."""
216197
# Create multi-agent first
217-
multi_agent_session_manager.session_repository.create_multi_agent("test-multi-session", mock_multi_agent)
198+
session_manager.session_repository.create_multi_agent("test-session", mock_multi_agent)
218199

219200
# Sync multi-agent
220-
multi_agent_session_manager.sync_multi_agent(mock_multi_agent)
201+
session_manager.sync_multi_agent(mock_multi_agent)
221202

222203
# Verify repository update_multi_agent was called
223-
state = multi_agent_session_manager.session_repository.read_multi_agent("test-multi-session", mock_multi_agent.id)
204+
state = session_manager.session_repository.read_multi_agent("test-session", mock_multi_agent.id)
224205
assert state["id"] == "test-multi-agent"
225206
assert state["state"] == {"key": "value"}
226207

227208

228-
def test_initialize_multi_agent_new(multi_agent_session_manager, mock_multi_agent):
209+
def test_initialize_multi_agent_new(session_manager, mock_multi_agent):
229210
"""Test initializing new multi-agent state."""
230-
multi_agent_session_manager.initialize_multi_agent(mock_multi_agent)
211+
session_manager.initialize_multi_agent(mock_multi_agent)
231212

232213
# Verify multi-agent was created
233-
state = multi_agent_session_manager.session_repository.read_multi_agent("test-multi-session", mock_multi_agent.id)
214+
state = session_manager.session_repository.read_multi_agent("test-session", mock_multi_agent.id)
234215
assert state["id"] == "test-multi-agent"
235216
assert state["state"] == {"key": "value"}
236217

237218

238-
def test_initialize_multi_agent_existing(multi_agent_session_manager, mock_multi_agent):
219+
def test_initialize_multi_agent_existing(session_manager, mock_multi_agent):
239220
"""Test initializing existing multi-agent state."""
240221
# Create existing state first
241-
multi_agent_session_manager.session_repository.create_multi_agent("test-multi-session", mock_multi_agent)
222+
session_manager.session_repository.create_multi_agent("test-session", mock_multi_agent)
242223

243224
# Create a mock with updated state for the update call
244225
updated_mock = Mock()
245226
updated_mock.id = "test-multi-agent"
246227
existing_state = {"id": "test-multi-agent", "state": {"restored": "data"}}
247228
updated_mock.serialize_state.return_value = existing_state
248-
multi_agent_session_manager.session_repository.update_multi_agent("test-multi-session", updated_mock)
229+
session_manager.session_repository.update_multi_agent("test-session", updated_mock)
249230

250231
# Initialize multi-agent
251-
multi_agent_session_manager.initialize_multi_agent(mock_multi_agent)
232+
session_manager.initialize_multi_agent(mock_multi_agent)
252233

253234
# Verify deserialize_state was called with existing state
254235
mock_multi_agent.deserialize_state.assert_called_once_with(existing_state)

0 commit comments

Comments
 (0)