diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 91c22fd21e..8865bd68c6 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -22,6 +22,7 @@ import pickle from typing import Any from typing import Optional +from typing import overload import uuid from google.genai import types @@ -413,36 +414,86 @@ def set_sqlite_pragma(dbapi_connection, connection_record): class DatabaseSessionService(BaseSessionService): """A session service that uses a database for storage.""" - def __init__(self, db_url: str, **kwargs: Any): - """Initializes the database session service with a database URL.""" - # 1. Create DB engine for db connection - # 2. Create all tables based on schema - # 3. Initialize all properties - try: - db_engine = create_async_engine(db_url, **kwargs) + @overload + def __init__( + self, + db_url: str, + **kwargs: Any, + ) -> None: + """Initializes the database session service with a database URL. - if db_engine.dialect.name == "sqlite": - # Set sqlite pragma to enable foreign keys constraints - event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma) + Args: + db_url: Database URL string for creating a new engine. + **kwargs: Additional keyword arguments passed to create_async_engine. + """ - except Exception as e: - if isinstance(e, ArgumentError): - raise ValueError( - f"Invalid database URL format or argument '{db_url}'." - ) from e - if isinstance(e, ImportError): + @overload + def __init__( + self, + *, + db_engine: AsyncEngine, + ) -> None: + """Initializes the database session service with an existing SQLAlchemy AsyncEngine. + + Args: + db_engine: Existing SQLAlchemy AsyncEngine instance to use. + """ + + def __init__( + self, + db_url: Optional[str] = None, + db_engine: Optional[AsyncEngine] = None, + **kwargs: Any, + ) -> None: + """Initializes the database session service. + + Args: + db_url: Database URL string for creating a new engine. Mutually exclusive + with db_engine. + db_engine: Existing AsyncEngine instance. Mutually exclusive with db_url. + **kwargs: Additional keyword arguments passed to create_async_engine when + db_url is provided. Ignored when db_engine is provided. + + Raises: + ValueError: If neither or both db_url and db_engine are provided, or if + engine creation fails. + """ + if (db_url is None) == (db_engine is None): + raise ValueError( + "Exactly one of 'db_url' or 'db_engine' must be provided." + ) + + # 1. Create or use provided DB engine for db connection + # 2. Create all tables based on schema + # 3. Initialize all properties + if db_engine is not None: + engine = db_engine + else: + try: + engine = create_async_engine(db_url, **kwargs) + + if engine.dialect.name == "sqlite": + # Set sqlite pragma to enable foreign keys constraints + event.listen(engine.sync_engine, "connect", set_sqlite_pragma) + + except Exception as e: + if isinstance(e, ArgumentError): + raise ValueError( + f"Invalid database URL format or argument '{db_url}'." + ) from e + if isinstance(e, ImportError): + raise ValueError( + f"Database related module not found for URL '{db_url}'." + ) from e raise ValueError( - f"Database related module not found for URL '{db_url}'." + f"Failed to create database engine for URL '{db_url}'" ) from e - raise ValueError( - f"Failed to create database engine for URL '{db_url}'" - ) from e # Get the local timezone local_timezone = get_localzone() logger.info("Local timezone: %s", local_timezone) - self.db_engine: AsyncEngine = db_engine + self.db_engine: AsyncEngine = engine self.metadata: MetaData = MetaData() # DB session factory method diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 7fb91c9db6..facda562bc 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -25,12 +25,13 @@ from google.adk.sessions.sqlite_session_service import SqliteSessionService from google.genai import types import pytest +from sqlalchemy.ext.asyncio import create_async_engine class SessionServiceType(enum.Enum): - IN_MEMORY = 'IN_MEMORY' - DATABASE = 'DATABASE' - SQLITE = 'SQLITE' + IN_MEMORY = "IN_MEMORY" + DATABASE = "DATABASE" + SQLITE = "SQLITE" def get_session_service( @@ -39,15 +40,16 @@ def get_session_service( ): """Creates a session service for testing.""" if service_type == SessionServiceType.DATABASE: - return DatabaseSessionService('sqlite+aiosqlite:///:memory:') + # Using positional argument to test backward compatibility + return DatabaseSessionService("sqlite+aiosqlite:///:memory:") if service_type == SessionServiceType.SQLITE: - return SqliteSessionService(str(tmp_path / 'sqlite.db')) + return SqliteSessionService(str(tmp_path / "sqlite.db")) return InMemorySessionService() @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -57,13 +59,13 @@ def get_session_service( async def test_get_empty_session(service_type, tmp_path): session_service = get_session_service(service_type, tmp_path) assert not await session_service.get_session( - app_name='my_app', user_id='test_user', session_id='123' + app_name="my_app", user_id="test_user", session_id="123" ) @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -72,9 +74,9 @@ async def test_get_empty_session(service_type, tmp_path): ) async def test_create_get_session(service_type, tmp_path): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' - user_id = 'test_user' - state = {'key': 'value'} + app_name = "my_app" + user_id = "test_user" + state = {"key": "value"} session = await session_service.create_session( app_name=app_name, user_id=user_id, state=state @@ -112,7 +114,7 @@ async def test_create_get_session(service_type, tmp_path): @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -121,16 +123,16 @@ async def test_create_get_session(service_type, tmp_path): ) async def test_create_and_list_sessions(service_type, tmp_path): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' - user_id = 'test_user' + app_name = "my_app" + user_id = "test_user" - session_ids = ['session' + str(i) for i in range(5)] + session_ids = ["session" + str(i) for i in range(5)] for session_id in session_ids: await session_service.create_session( app_name=app_name, user_id=user_id, session_id=session_id, - state={'key': 'value' + session_id}, + state={"key": "value" + session_id}, ) list_sessions_response = await session_service.list_sessions( @@ -140,12 +142,12 @@ async def test_create_and_list_sessions(service_type, tmp_path): assert len(sessions) == len(session_ids) assert {s.id for s in sessions} == set(session_ids) for session in sessions: - assert session.state == {'key': 'value' + session.id} + assert session.state == {"key": "value" + session.id} @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -154,27 +156,27 @@ async def test_create_and_list_sessions(service_type, tmp_path): ) async def test_list_sessions_all_users(service_type, tmp_path): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' - user_id_1 = 'user1' - user_id_2 = 'user2' + app_name = "my_app" + user_id_1 = "user1" + user_id_2 = "user2" await session_service.create_session( app_name=app_name, user_id=user_id_1, - session_id='session1a', - state={'key': 'value1a'}, + session_id="session1a", + state={"key": "value1a"}, ) await session_service.create_session( app_name=app_name, user_id=user_id_1, - session_id='session1b', - state={'key': 'value1b'}, + session_id="session1b", + state={"key": "value1b"}, ) await session_service.create_session( app_name=app_name, user_id=user_id_2, - session_id='session2a', - state={'key': 'value2a'}, + session_id="session2a", + state={"key": "value2a"}, ) # List sessions for user1 - should contain merged state @@ -184,8 +186,8 @@ async def test_list_sessions_all_users(service_type, tmp_path): sessions_1 = list_sessions_response_1.sessions assert len(sessions_1) == 2 sessions_1_map = {s.id: s for s in sessions_1} - assert sessions_1_map['session1a'].state == {'key': 'value1a'} - assert sessions_1_map['session1b'].state == {'key': 'value1b'} + assert sessions_1_map["session1a"].state == {"key": "value1a"} + assert sessions_1_map["session1b"].state == {"key": "value1b"} # List sessions for user2 - should contain merged state list_sessions_response_2 = await session_service.list_sessions( @@ -193,8 +195,8 @@ async def test_list_sessions_all_users(service_type, tmp_path): ) sessions_2 = list_sessions_response_2.sessions assert len(sessions_2) == 1 - assert sessions_2[0].id == 'session2a' - assert sessions_2[0].state == {'key': 'value2a'} + assert sessions_2[0].id == "session2a" + assert sessions_2[0].state == {"key": "value2a"} # List sessions for all users - should contain merged state list_sessions_response_all = await session_service.list_sessions( @@ -203,14 +205,14 @@ async def test_list_sessions_all_users(service_type, tmp_path): sessions_all = list_sessions_response_all.sessions assert len(sessions_all) == 3 sessions_all_map = {s.id: s for s in sessions_all} - assert sessions_all_map['session1a'].state == {'key': 'value1a'} - assert sessions_all_map['session1b'].state == {'key': 'value1b'} - assert sessions_all_map['session2a'].state == {'key': 'value2a'} + assert sessions_all_map["session1a"].state == {"key": "value1a"} + assert sessions_all_map["session1b"].state == {"key": "value1b"} + assert sessions_all_map["session2a"].state == {"key": "value2a"} @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -219,36 +221,36 @@ async def test_list_sessions_all_users(service_type, tmp_path): ) async def test_app_state_is_shared_by_all_users_of_app(service_type, tmp_path): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' + app_name = "my_app" # User 1 creates a session, establishing app:k1 session1 = await session_service.create_session( - app_name=app_name, user_id='u1', session_id='s1', state={'app:k1': 'v1'} + app_name=app_name, user_id="u1", session_id="s1", state={"app:k1": "v1"} ) # User 1 appends an event to session1, establishing app:k2 event = Event( - invocation_id='inv1', - author='user', - actions=EventActions(state_delta={'app:k2': 'v2'}), + invocation_id="inv1", + author="user", + actions=EventActions(state_delta={"app:k2": "v2"}), ) await session_service.append_event(session=session1, event=event) # User 2 creates a new session session2, it should see app:k1 and app:k2 session2 = await session_service.create_session( - app_name=app_name, user_id='u2', session_id='s2' + app_name=app_name, user_id="u2", session_id="s2" ) - assert session2.state == {'app:k1': 'v1', 'app:k2': 'v2'} + assert session2.state == {"app:k1": "v1", "app:k2": "v2"} # If we get session session1 again, it should also see both session1_got = await session_service.get_session( - app_name=app_name, user_id='u1', session_id='s1' + app_name=app_name, user_id="u1", session_id="s1" ) - assert session1_got.state.get('app:k1') == 'v1' - assert session1_got.state.get('app:k2') == 'v2' + assert session1_got.state.get("app:k1") == "v1" + assert session1_got.state.get("app:k2") == "v2" @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -259,35 +261,35 @@ async def test_user_state_is_shared_only_by_user_sessions( service_type, tmp_path ): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' + app_name = "my_app" # User 1 creates a session, establishing user:k1 for user 1 session1 = await session_service.create_session( - app_name=app_name, user_id='u1', session_id='s1', state={'user:k1': 'v1'} + app_name=app_name, user_id="u1", session_id="s1", state={"user:k1": "v1"} ) # User 1 appends an event to session1, establishing user:k2 for user 1 event = Event( - invocation_id='inv1', - author='user', - actions=EventActions(state_delta={'user:k2': 'v2'}), + invocation_id="inv1", + author="user", + actions=EventActions(state_delta={"user:k2": "v2"}), ) await session_service.append_event(session=session1, event=event) # Another session for User 1 should see user:k1 and user:k2 session1b = await session_service.create_session( - app_name=app_name, user_id='u1', session_id='s1b' + app_name=app_name, user_id="u1", session_id="s1b" ) - assert session1b.state == {'user:k1': 'v1', 'user:k2': 'v2'} + assert session1b.state == {"user:k1": "v1", "user:k2": "v2"} # A session for User 2 should NOT see user:k1 or user:k2 session2 = await session_service.create_session( - app_name=app_name, user_id='u2', session_id='s2' + app_name=app_name, user_id="u2", session_id="s2" ) assert session2.state == {} @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -296,36 +298,36 @@ async def test_user_state_is_shared_only_by_user_sessions( ) async def test_session_state_is_not_shared(service_type, tmp_path): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' + app_name = "my_app" # User 1 creates a session session1, establishing sk1 only for session1 session1 = await session_service.create_session( - app_name=app_name, user_id='u1', session_id='s1', state={'sk1': 'v1'} + app_name=app_name, user_id="u1", session_id="s1", state={"sk1": "v1"} ) # User 1 appends an event to session1, establishing sk2 only for session1 event = Event( - invocation_id='inv1', - author='user', - actions=EventActions(state_delta={'sk2': 'v2'}), + invocation_id="inv1", + author="user", + actions=EventActions(state_delta={"sk2": "v2"}), ) await session_service.append_event(session=session1, event=event) # Getting session1 should show sk1 and sk2 session1_got = await session_service.get_session( - app_name=app_name, user_id='u1', session_id='s1' + app_name=app_name, user_id="u1", session_id="s1" ) - assert session1_got.state.get('sk1') == 'v1' - assert session1_got.state.get('sk2') == 'v2' + assert session1_got.state.get("sk1") == "v1" + assert session1_got.state.get("sk2") == "v2" # Creating another session session1b for User 1 should NOT see sk1 or sk2 session1b = await session_service.create_session( - app_name=app_name, user_id='u1', session_id='s1b' + app_name=app_name, user_id="u1", session_id="s1b" ) assert session1b.state == {} @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -336,33 +338,33 @@ async def test_temp_state_is_not_persisted_in_state_or_events( service_type, tmp_path ): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' - user_id = 'u1' + app_name = "my_app" + user_id = "u1" session = await session_service.create_session( - app_name=app_name, user_id=user_id, session_id='s1' + app_name=app_name, user_id=user_id, session_id="s1" ) event = Event( - invocation_id='inv1', - author='user', - actions=EventActions(state_delta={'temp:k1': 'v1', 'sk': 'v2'}), + invocation_id="inv1", + author="user", + actions=EventActions(state_delta={"temp:k1": "v1", "sk": "v2"}), ) await session_service.append_event(session=session, event=event) # Refetch session and check state and event session_got = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id='s1' + app_name=app_name, user_id=user_id, session_id="s1" ) # Check session state does not contain temp keys - assert session_got.state.get('sk') == 'v2' - assert 'temp:k1' not in session_got.state + assert session_got.state.get("sk") == "v2" + assert "temp:k1" not in session_got.state # Check event as stored in session does not contain temp keys in state_delta - assert 'temp:k1' not in session_got.events[0].actions.state_delta - assert session_got.events[0].actions.state_delta.get('sk') == 'v2' + assert "temp:k1" not in session_got.events[0].actions.state_delta + assert session_got.events[0].actions.state_delta.get("sk") == "v2" @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -371,29 +373,29 @@ async def test_temp_state_is_not_persisted_in_state_or_events( ) async def test_get_session_respects_user_id(service_type, tmp_path): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' + app_name = "my_app" # u1 creates session 's1' and adds an event session1 = await session_service.create_session( - app_name=app_name, user_id='u1', session_id='s1' + app_name=app_name, user_id="u1", session_id="s1" ) - event = Event(invocation_id='inv1', author='user') + event = Event(invocation_id="inv1", author="user") await session_service.append_event(session1, event) # u2 creates a session with the same session_id 's1' await session_service.create_session( - app_name=app_name, user_id='u2', session_id='s1' + app_name=app_name, user_id="u2", session_id="s1" ) # Check that getting s1 for u2 returns u2's session (with no events) # not u1's session. session2_got = await session_service.get_session( - app_name=app_name, user_id='u2', session_id='s1' + app_name=app_name, user_id="u2", session_id="s1" ) - assert session2_got.user_id == 'u2' + assert session2_got.user_id == "u2" assert len(session2_got.events) == 0 @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -404,9 +406,9 @@ async def test_create_session_with_existing_id_raises_error( service_type, tmp_path ): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' - user_id = 'test_user' - session_id = 'existing_session' + app_name = "my_app" + user_id = "test_user" + session_id = "existing_session" # Create the first session await session_service.create_session( @@ -426,7 +428,7 @@ async def test_create_session_with_existing_id_raises_error( @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -435,25 +437,25 @@ async def test_create_session_with_existing_id_raises_error( ) async def test_append_event_bytes(service_type, tmp_path): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' - user_id = 'user' + app_name = "my_app" + user_id = "user" session = await session_service.create_session( app_name=app_name, user_id=user_id ) test_content = types.Content( - role='user', + role="user", parts=[ - types.Part.from_bytes(data=b'test_image_data', mime_type='image/png'), + types.Part.from_bytes(data=b"test_image_data", mime_type="image/png"), ], ) test_grounding_metadata = types.GroundingMetadata( - search_entry_point=types.SearchEntryPoint(sdk_blob=b'test_sdk_blob') + search_entry_point=types.SearchEntryPoint(sdk_blob=b"test_sdk_blob") ) event = Event( - invocation_id='invocation', - author='user', + invocation_id="invocation", + author="user", content=test_content, grounding_metadata=test_grounding_metadata, ) @@ -472,7 +474,7 @@ async def test_append_event_bytes(service_type, tmp_path): @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -481,37 +483,37 @@ async def test_append_event_bytes(service_type, tmp_path): ) async def test_append_event_complete(service_type, tmp_path): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' - user_id = 'user' + app_name = "my_app" + user_id = "user" session = await session_service.create_session( app_name=app_name, user_id=user_id ) event = Event( - invocation_id='invocation', - author='user', - content=types.Content(role='user', parts=[types.Part(text='test_text')]), + invocation_id="invocation", + author="user", + content=types.Content(role="user", parts=[types.Part(text="test_text")]), turn_complete=True, partial=False, actions=EventActions( artifact_delta={ - 'file': 0, + "file": 0, }, - transfer_to_agent='agent', + transfer_to_agent="agent", escalate=True, ), - long_running_tool_ids={'tool1'}, - error_code='error_code', - error_message='error_message', + long_running_tool_ids={"tool1"}, + error_code="error_code", + error_message="error_message", interrupted=True, grounding_metadata=types.GroundingMetadata( - web_search_queries=['query1'], + web_search_queries=["query1"], ), usage_metadata=types.GenerateContentResponseUsageMetadata( prompt_token_count=1, candidates_token_count=1, total_token_count=2 ), citation_metadata=types.CitationMetadata(), - custom_metadata={'custom_key': 'custom_value'}, + custom_metadata={"custom_key": "custom_value"}, ) await session_service.append_event(session=session, event=event) @@ -525,7 +527,7 @@ async def test_append_event_complete(service_type, tmp_path): @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -534,15 +536,15 @@ async def test_append_event_complete(service_type, tmp_path): ) async def test_get_session_with_config(service_type, tmp_path): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' - user_id = 'user' + app_name = "my_app" + user_id = "user" num_test_events = 5 session = await session_service.create_session( app_name=app_name, user_id=user_id ) for i in range(1, num_test_events + 1): - event = Event(author='user', timestamp=i) + event = Event(author="user", timestamp=i) await session_service.append_event(session, event) # No config, expect all events to be returned. @@ -594,7 +596,7 @@ async def test_get_session_with_config(service_type, tmp_path): @pytest.mark.asyncio @pytest.mark.parametrize( - 'service_type', + "service_type", [ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, @@ -603,12 +605,12 @@ async def test_get_session_with_config(service_type, tmp_path): ) async def test_partial_events_are_not_persisted(service_type, tmp_path): session_service = get_session_service(service_type, tmp_path) - app_name = 'my_app' - user_id = 'user' + app_name = "my_app" + user_id = "user" session = await session_service.create_session( app_name=app_name, user_id=user_id ) - event = Event(author='user', partial=True) + event = Event(author="user", partial=True) await session_service.append_event(session, event) # Check in-memory session @@ -618,3 +620,80 @@ async def test_partial_events_are_not_persisted(service_type, tmp_path): app_name=app_name, user_id=user_id, session_id=session.id ) assert len(session_got.events) == 0 + + +@pytest.mark.asyncio +async def test_database_session_service_with_db_url(): + """Test DatabaseSessionService initialization with db_url.""" + # Test db_url as positional argument + service = DatabaseSessionService("sqlite+aiosqlite:///:memory:") + app_name = "test_app" + user_id = "test_user" + + # Create and retrieve a session + session = await service.create_session( + app_name=app_name, user_id=user_id, state={"key": "value"} + ) + assert session.app_name == app_name + assert session.user_id == user_id + assert session.state == {"key": "value"} + + # Let's check that we can retrieve it + retrieved = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert retrieved == session + + # test db_url as keyword argument + service2 = DatabaseSessionService(db_url="sqlite+aiosqlite:///:memory:") + session2 = await service2.create_session( + app_name=app_name, user_id=user_id, state={"key": "value2"} + ) + assert session2.state == {"key": "value2"} + + +@pytest.mark.asyncio +async def test_database_session_service_with_db_engine(): + """Test DatabaseSessionService initialization with db_engine.""" + # Create an engine manually + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + + # Create service with db_engine + service = DatabaseSessionService(db_engine=engine) + app_name = "test_app" + user_id = "test_user" + + # Create and retrieve a session + session = await service.create_session( + app_name=app_name, user_id=user_id, state={"key": "value"} + ) + assert session.app_name == app_name + assert session.user_id == user_id + assert session.state == {"key": "value"} + + # Let's check that we can retrieve it + retrieved = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert retrieved == session + + +@pytest.mark.asyncio +async def test_database_session_service_requires_one_argument(): + """Test that DatabaseSessionService requires exactly one of db_url or db_engine.""" + # Neither argument provided + with pytest.raises( + ValueError, + match="Exactly one of 'db_url' or 'db_engine' must be provided", + ): + DatabaseSessionService() + + # Both arguments provided + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + with pytest.raises( + ValueError, + match="Exactly one of 'db_url' or 'db_engine' must be provided", + ): + DatabaseSessionService( + db_url="sqlite+aiosqlite:///:memory:", db_engine=engine + )