Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/ref/extensions/memory/sqlalchemy_session.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `SQLAlchemySession`

::: agents.extensions.memory.sqlalchemy_session.SQLAlchemySession
4 changes: 3 additions & 1 deletion docs/sessions.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ Use meaningful session IDs that help you organize conversations:

- Use in-memory SQLite (`SQLiteSession("session_id")`) for temporary conversations
- Use file-based SQLite (`SQLiteSession("session_id", "path/to/db.sqlite")`) for persistent conversations
- Consider implementing custom session backends for production systems (Redis, PostgreSQL, etc.)
- Use SQLAlchemy-powered sessions (`SQLAlchemySession("session_id", engine=engine, create_tables=True)`) for production systems with existing databases supported by SQLAlchemy
- Consider implementing custom session backends for other production systems (Redis, Django, etc.) for more advanced use cases

### Session management

Expand Down Expand Up @@ -376,3 +377,4 @@ For detailed API documentation, see:

- [`Session`][agents.memory.Session] - Protocol interface
- [`SQLiteSession`][agents.memory.SQLiteSession] - SQLite implementation
- [`SQLAlchemySession`][agents.extensions.memory.sqlalchemy_session.SQLAlchemySession] - SQLAlchemy-powered implementation
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ plugins:
- ref/extensions/handoff_filters.md
- ref/extensions/handoff_prompt.md
- ref/extensions/litellm.md
- ref/extensions/memory/sqlalchemy_session.md

- locale: ja
name: 日本語
Expand Down
76 changes: 45 additions & 31 deletions src/agents/extensions/memory/sqlalchemy_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,19 @@ def __init__(
create_tables: bool = False,
sessions_table: str = "agent_sessions",
messages_table: str = "agent_messages",
): # noqa: D401 – short description on the class-level docstring
"""Create a new session.

Parameters
----------
session_id
Unique identifier for the conversation.
engine
A pre-configured SQLAlchemy *async* engine. The engine **must** be
created with an async driver (``postgresql+asyncpg://``,
``mysql+aiomysql://`` or ``sqlite+aiosqlite://``).
create_tables
Whether to automatically create the required tables & indexes.
Defaults to *False* for production use. Set to *True* for development
and testing when migrations aren't used.
sessions_table, messages_table
Override default table names if needed.
):
"""Initializes a new SQLAlchemySession.

Args:
session_id (str): Unique identifier for the conversation.
engine (AsyncEngine): A pre-configured SQLAlchemy async engine. The engine
must be created with an async driver (e.g., 'postgresql+asyncpg://',
'mysql+aiomysql://', or 'sqlite+aiosqlite://').
create_tables (bool, optional): Whether to automatically create the required
tables and indexes. Defaults to False for production use. Set to True for
development and testing when migrations aren't used.
sessions_table (str, optional): Override the default table name for sessions if needed.
messages_table (str, optional): Override the default table name for messages if needed.
"""
self.session_id = session_id
self._engine = engine
Expand Down Expand Up @@ -132,9 +128,7 @@ def __init__(
)

# Async session factory
self._session_factory = async_sessionmaker(
self._engine, expire_on_commit=False
)
self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)

self._create_tables = create_tables

Expand All @@ -152,16 +146,16 @@ def from_url(
) -> SQLAlchemySession:
"""Create a session from a database URL string.

Parameters
----------
session_id
Conversation ID.
url
Any SQLAlchemy async URL – e.g. ``"postgresql+asyncpg://user:pass@host/db"``.
engine_kwargs
Additional kwargs forwarded to :pyfunc:`sqlalchemy.ext.asyncio.create_async_engine`.
kwargs
Forwarded to the main constructor (``create_tables``, custom table names, …).
Args:
session_id (str): Conversation ID.
url (str): Any SQLAlchemy async URL, e.g. "postgresql+asyncpg://user:pass@host/db".
engine_kwargs (dict[str, Any] | None): Additional keyword arguments forwarded to
sqlalchemy.ext.asyncio.create_async_engine.
**kwargs: Additional keyword arguments forwarded to the main constructor
(e.g., create_tables, custom table names, etc.).

Returns:
SQLAlchemySession: An instance of SQLAlchemySession connected to the specified database.
"""
engine_kwargs = engine_kwargs or {}
engine = create_async_engine(url, **engine_kwargs)
Expand All @@ -186,6 +180,15 @@ async def _ensure_tables(self) -> None:
self._create_tables = False # Only create once

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
limit: Maximum number of items to retrieve. If None, retrieves all items.
When specified, returns the latest N items in chronological order.

Returns:
List of input items representing the conversation history
"""
await self._ensure_tables()
async with self._session_factory() as sess:
if limit is None:
Expand Down Expand Up @@ -220,6 +223,11 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
"""Add new items to the conversation history.

Args:
items: List of input items to add to the history
"""
if not items:
return

Expand Down Expand Up @@ -258,6 +266,11 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
)

async def pop_item(self) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Returns:
The most recent item if it exists, None if the session is empty
"""
await self._ensure_tables()
async with self._session_factory() as sess:
async with sess.begin():
Expand Down Expand Up @@ -286,7 +299,8 @@ async def pop_item(self) -> TResponseInputItem | None:
except json.JSONDecodeError:
return None

async def clear_session(self) -> None: # noqa: D401 – imperative mood is fine
async def clear_session(self) -> None:
"""Clear all items for this session."""
await self._ensure_tables()
async with self._session_factory() as sess:
async with sess.begin():
Expand Down