diff --git a/fastapi_async_sqlalchemy/__init__.py b/fastapi_async_sqlalchemy/__init__.py index 2821653..8afd39f 100644 --- a/fastapi_async_sqlalchemy/__init__.py +++ b/fastapi_async_sqlalchemy/__init__.py @@ -2,4 +2,4 @@ __all__ = ["db", "SQLAlchemyMiddleware"] -__version__ = "0.7.0.dev2" +__version__ = "0.7.0.dev3" diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index 13cb359..1171ede 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -21,9 +21,6 @@ def create_middleware_and_session_proxy(): _Session: Optional[async_sessionmaker] = None _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) - _task_session_ctx: ContextVar[Optional[AsyncSession]] = ContextVar( - "_task_session_ctx", default=None - ) _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) # Usage of context vars inside closures is not recommended, since they are not properly # garbage collected, but in our use case context var is created on program startup and @@ -92,25 +89,22 @@ async def execute_query(query): ``` """ commit_on_exit = _commit_on_exit_ctx.get() - session = _task_session_ctx.get() - if session is None: - session = _Session() - _task_session_ctx.set(session) - - async def cleanup(): - try: - if commit_on_exit: - await session.commit() - except Exception: - await session.rollback() - raise - finally: - await session.close() - _task_session_ctx.set(None) - - task = asyncio.current_task() - if task is not None: - task.add_done_callback(lambda t: asyncio.create_task(cleanup())) + # Always create a new session for each access when multi_sessions=True + session = _Session() + + async def cleanup(): + try: + if commit_on_exit: + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() + + task = asyncio.current_task() + if task is not None: + task.add_done_callback(lambda t: asyncio.create_task(cleanup())) return session else: session = _session.get() @@ -126,7 +120,6 @@ def __init__( multi_sessions: bool = False, ): self.token = None - self.multi_sessions_token = None self.commit_on_exit_token = None self.session_args = session_args or {} self.commit_on_exit = commit_on_exit