From 88561c26e7a71aee0687ce6ab40f55c3808fa112 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Fri, 22 Aug 2025 16:15:04 -0400 Subject: [PATCH 1/3] fix: prevent path traversal for message_id in file_session_manager --- src/strands/_identifier.py | 1 + src/strands/session/file_session_manager.py | 9 ++++++++- .../strands/session/test_file_session_manager.py | 15 +++++++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/strands/_identifier.py b/src/strands/_identifier.py index e8b12635c..e02d83473 100644 --- a/src/strands/_identifier.py +++ b/src/strands/_identifier.py @@ -9,6 +9,7 @@ class Identifier(enum.Enum): AGENT = "agent" SESSION = "session" + MESSAGE = "message" def validate(id_: str, type_: Identifier) -> str: diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 9df86e17a..20d9f74e0 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -86,9 +86,16 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> message_id: Index of the message Returns: The filename for the message + + Raises: + ValueError: If message id contains path separators. """ + # Validate message_id to prevent path traversal + message_id_str = str(message_id) + message_id_str = _identifier.validate(message_id_str, _identifier.Identifier.MESSAGE) + agent_path = self._get_agent_path(session_id, agent_id) - return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json") + return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id_str}.json") def _read_file(self, path: str) -> dict[str, Any]: """Read JSON file.""" diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index a89222b7e..a0236c420 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -390,3 +390,18 @@ def test__get_session_path_invalid_session_id(session_id, file_manager): def test__get_agent_path_invalid_agent_id(agent_id, file_manager): with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): file_manager._get_agent_path("session1", agent_id) + + +@pytest.mark.parametrize( + "message_id", + [ + "../../../secret", + "../../attack", + "../escape", + "path/traversal", + ], +) +def test__get_message_path_invalid_message_id(message_id, file_manager): + """Test that message_id with path traversal sequences raises ValueError.""" + with pytest.raises(ValueError, match=f"message_id={message_id} | id cannot contain path separators"): + file_manager._get_message_path("session1", "agent1", message_id) From 22b2af86024d01fba56ad939b068dd7e0f311450 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 25 Aug 2025 09:55:22 -0400 Subject: [PATCH 2/3] fix: prevent path traversal for message_id in session managers --- src/strands/_identifier.py | 1 - src/strands/session/file_session_manager.py | 11 +++++------ src/strands/session/s3_session_manager.py | 7 ++++++- .../session/test_file_session_manager.py | 9 ++++++--- .../strands/session/test_s3_session_manager.py | 18 ++++++++++++++++++ 5 files changed, 35 insertions(+), 11 deletions(-) diff --git a/src/strands/_identifier.py b/src/strands/_identifier.py index e02d83473..e8b12635c 100644 --- a/src/strands/_identifier.py +++ b/src/strands/_identifier.py @@ -9,7 +9,6 @@ class Identifier(enum.Enum): AGENT = "agent" SESSION = "session" - MESSAGE = "message" def validate(id_: str, type_: Identifier) -> str: diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 20d9f74e0..14e71d07c 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -86,16 +86,15 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> message_id: Index of the message Returns: The filename for the message - + Raises: - ValueError: If message id contains path separators. + ValueError: If message_id is not an integer. """ - # Validate message_id to prevent path traversal - message_id_str = str(message_id) - message_id_str = _identifier.validate(message_id_str, _identifier.Identifier.MESSAGE) + if not isinstance(message_id, int): + raise ValueError(f"message_id=<{message_id}> | message id must be an integer") agent_path = self._get_agent_path(session_id, agent_id) - return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id_str}.json") + return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json") def _read_file(self, path: str) -> dict[str, Any]: """Read JSON file.""" diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index d15e6e3bd..da1735e35 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -113,11 +113,16 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> session_id: ID of the session agent_id: ID of the agent message_id: Index of the message - **kwargs: Additional keyword arguments for future extensibility. Returns: The key for the message + + Raises: + ValueError: If message_id is not an integer. """ + if not isinstance(message_id, int): + raise ValueError(f"message_id=<{message_id}> | message id must be an integer") + agent_path = self._get_agent_path(session_id, agent_id) return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json" diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index a0236c420..6937806ee 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -396,12 +396,15 @@ def test__get_agent_path_invalid_agent_id(agent_id, file_manager): "message_id", [ "../../../secret", - "../../attack", + "../../attack", "../escape", "path/traversal", + "not_an_int", + None, + [], ], ) def test__get_message_path_invalid_message_id(message_id, file_manager): - """Test that message_id with path traversal sequences raises ValueError.""" - with pytest.raises(ValueError, match=f"message_id={message_id} | id cannot contain path separators"): + """Test that message_id that is not an integer raises ValueError.""" + with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): file_manager._get_message_path("session1", "agent1", message_id) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 71bff3050..b4384d9fc 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -356,3 +356,21 @@ def test__get_session_path_invalid_session_id(session_id, s3_manager): def test__get_agent_path_invalid_agent_id(agent_id, s3_manager): with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): s3_manager._get_agent_path("session1", agent_id) + + +@pytest.mark.parametrize( + "message_id", + [ + "../../../secret", + "../../attack", + "../escape", + "path/traversal", + "not_an_int", + None, + [], + ], +) +def test__get_message_path_invalid_message_id(message_id, s3_manager): + """Test that message_id that is not an integer raises ValueError.""" + with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): + s3_manager._get_message_path("session1", "agent1", message_id) From d93ea7f0d3ed186e03c4a98a898503a9fc1d7b6d Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 25 Aug 2025 09:59:31 -0400 Subject: [PATCH 3/3] fix: prevent path traversal for message_id in session managers --- tests/strands/session/test_file_session_manager.py | 4 ++-- tests/strands/session/test_s3_session_manager.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index 6937806ee..036591924 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -224,14 +224,14 @@ def test_read_messages_with_new_agent(file_manager, sample_session, sample_agent file_manager.create_session(sample_session) file_manager.create_agent(sample_session.session_id, sample_agent) - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999) assert result is None def test_read_nonexistent_message(file_manager, sample_session, sample_agent): """Test reading a message that doesnt exist.""" - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999) assert result is None diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index b4384d9fc..50fb303f7 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -251,7 +251,7 @@ def test_read_nonexistent_message(s3_manager, sample_session, sample_agent, samp s3_manager.create_agent(sample_session.session_id, sample_agent) # Read message - result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999) assert result is None