Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
93 changes: 72 additions & 21 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pickle
from typing import Any
from typing import Optional
from typing import overload
import uuid

from google.genai import types
Expand Down Expand Up @@ -413,36 +414,86 @@ def set_sqlite_pragma(dbapi_connection, connection_record):
class DatabaseSessionService(BaseSessionService):
"""A session service that uses a database for storage."""

def __init__(self, db_url: str, **kwargs: Any):
"""Initializes the database session service with a database URL."""
# 1. Create DB engine for db connection
# 2. Create all tables based on schema
# 3. Initialize all properties
try:
db_engine = create_async_engine(db_url, **kwargs)
@overload
def __init__(
self,
db_url: str,
**kwargs: Any,
) -> None:
"""Initializes the database session service with a database URL.

if db_engine.dialect.name == "sqlite":
# Set sqlite pragma to enable foreign keys constraints
event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma)
Args:
db_url: Database URL string for creating a new engine.
**kwargs: Additional keyword arguments passed to create_async_engine.
"""

except Exception as e:
if isinstance(e, ArgumentError):
raise ValueError(
f"Invalid database URL format or argument '{db_url}'."
) from e
if isinstance(e, ImportError):
@overload
def __init__(
self,
*,
db_engine: AsyncEngine,
) -> None:
"""Initializes the database session service with an existing SQLAlchemy AsyncEngine.

Args:
db_engine: Existing SQLAlchemy AsyncEngine instance to use.
"""

def __init__(
self,
db_url: Optional[str] = None,
db_engine: Optional[AsyncEngine] = None,
**kwargs: Any,
) -> None:
"""Initializes the database session service.

Args:
db_url: Database URL string for creating a new engine. Mutually exclusive
with db_engine.
db_engine: Existing AsyncEngine instance. Mutually exclusive with db_url.
**kwargs: Additional keyword arguments passed to create_async_engine when
db_url is provided. Ignored when db_engine is provided.

Raises:
ValueError: If neither or both db_url and db_engine are provided, or if
engine creation fails.
"""
if (db_url is None) == (db_engine is None):
raise ValueError(
"Exactly one of 'db_url' or 'db_engine' must be provided."
)

# 1. Create or use provided DB engine for db connection
# 2. Create all tables based on schema
# 3. Initialize all properties
if db_engine is not None:
engine = db_engine
else:
try:
engine = create_async_engine(db_url, **kwargs)

if engine.dialect.name == "sqlite":
# Set sqlite pragma to enable foreign keys constraints
event.listen(engine.sync_engine, "connect", set_sqlite_pragma)

except Exception as e:
if isinstance(e, ArgumentError):
raise ValueError(
f"Invalid database URL format or argument '{db_url}'."
) from e
if isinstance(e, ImportError):
raise ValueError(
f"Database related module not found for URL '{db_url}'."
) from e
raise ValueError(
f"Database related module not found for URL '{db_url}'."
f"Failed to create database engine for URL '{db_url}'"
) from e
raise ValueError(
f"Failed to create database engine for URL '{db_url}'"
) from e

# Get the local timezone
local_timezone = get_localzone()
logger.info("Local timezone: %s", local_timezone)

self.db_engine: AsyncEngine = db_engine
self.db_engine: AsyncEngine = engine
self.metadata: MetaData = MetaData()

# DB session factory method
Expand Down
Loading