diff --git a/stac_api/app.py b/stac_api/app.py index ebd940a5f..7a821c005 100644 --- a/stac_api/app.py +++ b/stac_api/app.py @@ -4,7 +4,6 @@ from . import settings from .resources import mgmt, collection, conformance, item -from .utils import dependencies app = FastAPI() @@ -18,18 +17,18 @@ @app.on_event("startup") async def on_startup(): """Create database engines and sessions on startup""" - dependencies.ENGINE_READER = create_engine(settings.SQLALCHEMY_DATABASE_READER) - dependencies.ENGINE_WRITER = create_engine(settings.SQLALCHEMY_DATABASE_WRITER) - dependencies.DB_READER = sessionmaker( - autocommit=False, autoflush=False, bind=dependencies.ENGINE_READER + app.state.ENGINE_READER = create_engine(settings.SQLALCHEMY_DATABASE_READER) + app.state.ENGINE_WRITER = create_engine(settings.SQLALCHEMY_DATABASE_WRITER) + app.state.DB_READER = sessionmaker( + autocommit=False, autoflush=False, bind=app.state.ENGINE_READER ) - dependencies.DB_WRITER = sessionmaker( - autocommit=False, autoflush=False, bind=dependencies.ENGINE_WRITER + app.state.DB_WRITER = sessionmaker( + autocommit=False, autoflush=False, bind=app.state.ENGINE_WRITER ) @app.on_event("shutdown") async def on_shutdown(): """Dispose of database engines and sessions on app shutdown""" - dependencies.ENGINE_READER.dispose() - dependencies.ENGINE_WRITER.dispose() + app.state.ENGINE_READER.dispose() + app.state.ENGINE_WRITER.dispose() diff --git a/stac_api/utils/dependencies.py b/stac_api/utils/dependencies.py index 892bff856..bb7e0823f 100644 --- a/stac_api/utils/dependencies.py +++ b/stac_api/utils/dependencies.py @@ -1,17 +1,10 @@ from dataclasses import dataclass -from typing import Callable, List, Optional +from typing import Callable, List -from sqlalchemy.engine import Engine from sqlalchemy.orm import Session from starlette.requests import Request -ENGINE_READER: Optional[Engine] = None -ENGINE_WRITER: Optional[Engine] = None -DB_READER: Optional[Session] = None -DB_WRITER: Optional[Session] = None - - @dataclass class DatabaseConnectionError(Exception): message: str @@ -32,27 +25,19 @@ def _parse(request: Request): return _parse -def database_reader_factory() -> Session: +def database_reader_factory(request: Request) -> Session: """Instantiate the database reader session""" try: - if not DB_READER: - raise DatabaseConnectionError( - message="Database engine has not been created" - ) - db = DB_READER() + db = request.app.state.DB_READER() yield db finally: db.close() -def database_writer_factory() -> Session: +def database_writer_factory(request: Request) -> Session: """Instantiate the database writer session""" try: - if not DB_WRITER: - raise DatabaseConnectionError( - message="Database engine has not been created" - ) - db = DB_WRITER() + db = request.app.state.DB_WRITER() yield db finally: db.close()