diff --git a/infrastructure/app.py b/infrastructure/app.py index 0e15a9a..f9e0890 100644 --- a/infrastructure/app.py +++ b/infrastructure/app.py @@ -181,6 +181,7 @@ def __init__( "handler": "handler.handler", "runtime": aws_lambda.Runtime.PYTHON_3_12, }, + enable_snap_start=True, ) ####################################################################### @@ -228,6 +229,7 @@ def __init__( "handler": "handler.handler", "runtime": aws_lambda.Runtime.PYTHON_3_12, }, + enable_snap_start=True, ) ####################################################################### @@ -274,6 +276,7 @@ def __init__( "handler": "handler.handler", "runtime": aws_lambda.Runtime.PYTHON_3_12, }, + enable_snap_start=True, ) if app_config.stac_ingestor: diff --git a/infrastructure/handlers/raster_handler.py b/infrastructure/handlers/raster_handler.py index 439a2de..2e46e51 100644 --- a/infrastructure/handlers/raster_handler.py +++ b/infrastructure/handlers/raster_handler.py @@ -7,16 +7,70 @@ from eoapi.raster.app import app from eoapi.raster.config import PostgresSettings from mangum import Mangum +from snapshot_restore_py import register_after_restore, register_before_snapshot from titiler.pgstac.db import connect_to_db logging.getLogger("mangum.lifespan").setLevel(logging.ERROR) logging.getLogger("mangum.http").setLevel(logging.ERROR) +postgres_settings = PostgresSettings() +_connection_initialized = False + + +@register_before_snapshot +def on_snapshot(): + """ + Runtime hook called by Lambda before taking a snapshot. + We close database connections that shouldn't be in the snapshot. + """ + + if hasattr(app, "state") and hasattr(app.state, "dbpool") and app.state.dbpool: + try: + app.state.dbpool.close() + app.state.dbpool = None + except Exception as e: + print(f"SnapStart: Error closing database pool: {e}") + + return {"statusCode": 200} + + +@register_after_restore +def on_snap_restore(): + """ + Runtime hook called by Lambda after restoring from a snapshot. + We recreate database connections that were closed before the snapshot. + """ + global _connection_initialized + + try: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if hasattr(app.state, "dbpool") and app.state.dbpool: + try: + app.state.dbpool.close() + except Exception as e: + print(f"SnapStart: Error closing stale pool: {e}") + app.state.dbpool = None + + loop.run_until_complete(connect_to_db(app, settings=postgres_settings)) + + _connection_initialized = True + + except Exception as e: + print(f"SnapStart: Failed to initialize database connection: {e}") + raise + + return {"statusCode": 200} + @app.on_event("startup") async def startup_event() -> None: """Connect to database on startup.""" - await connect_to_db(app, settings=PostgresSettings()) + await connect_to_db(app, settings=postgres_settings) handler = Mangum(app, lifespan="off") diff --git a/infrastructure/handlers/stac_handler.py b/infrastructure/handlers/stac_handler.py index b7f763e..fee7706 100644 --- a/infrastructure/handlers/stac_handler.py +++ b/infrastructure/handlers/stac_handler.py @@ -7,12 +7,90 @@ from eoapi.stac.app import app from eoapi.stac.config import PostgresSettings, Settings from mangum import Mangum +from snapshot_restore_py import register_after_restore, register_before_snapshot from stac_fastapi.pgstac.db import connect_to_db logging.getLogger("mangum.lifespan").setLevel(logging.ERROR) logging.getLogger("mangum.http").setLevel(logging.ERROR) settings = Settings() +postgres_settings = PostgresSettings() + +_connection_initialized = False + + +@register_before_snapshot +def on_snapshot(): + """ + Runtime hook called by Lambda before taking a snapshot. + We close database connections that shouldn't be in the snapshot. + """ + + if hasattr(app, "state") and hasattr(app.state, "readpool") and app.state.readpool: + try: + app.state.readpool.close() + app.state.readpool = None + except Exception as e: + print(f"SnapStart: Error closing database readpool: {e}") + + if ( + hasattr(app, "state") + and hasattr(app.state, "writepool") + and app.state.writepool + ): + try: + app.state.writepool.close() + app.state.writepool = None + except Exception as e: + print(f"SnapStart: Error closing database writepool: {e}") + + return {"statusCode": 200} + + +@register_after_restore +def on_snap_restore(): + """ + Runtime hook called by Lambda after restoring from a snapshot. + We recreate database connections that were closed before the snapshot. + """ + global _connection_initialized + + try: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if hasattr(app.state, "readpool") and app.state.readpool: + try: + app.state.readpool.close() + except Exception as e: + print(f"SnapStart: Error closing stale readpool: {e}") + app.state.readpool = None + + if hasattr(app.state, "writepool") and app.state.writepool: + try: + app.state.writepool.close() + except Exception as e: + print(f"SnapStart: Error closing stale writepool: {e}") + app.state.writepool = None + + loop.run_until_complete( + connect_to_db( + app, + postgres_settings=postgres_settings, + add_write_connection_pool=settings.enable_transaction, + ) + ) + + _connection_initialized = True + + except Exception as e: + print(f"SnapStart: Failed to initialize database connection: {e}") + raise + + return {"statusCode": 200} @app.on_event("startup") @@ -20,7 +98,7 @@ async def startup_event() -> None: """Connect to database on startup.""" await connect_to_db( app, - postgres_settings=PostgresSettings(), + postgres_settings=postgres_settings, add_write_connection_pool=settings.enable_transaction, ) diff --git a/infrastructure/handlers/vector_handler.py b/infrastructure/handlers/vector_handler.py index c93a5dd..2a202a7 100644 --- a/infrastructure/handlers/vector_handler.py +++ b/infrastructure/handlers/vector_handler.py @@ -8,6 +8,7 @@ from eoapi.vector.app import app from eoapi.vector.config import PostgresSettings from mangum import Mangum +from snapshot_restore_py import register_after_restore, register_before_snapshot from tipg.collections import register_collection_catalog from tipg.database import connect_to_db from tipg.settings import DatabaseSettings @@ -26,6 +27,73 @@ # We allow non-spatial tables spatial=False, ) +postgres_settings = PostgresSettings() + +_connection_initialized = False + + +@register_before_snapshot +def on_snapshot(): + """ + Runtime hook called by Lambda before taking a snapshot. + We close database connections that shouldn't be in the snapshot. + """ + + if hasattr(app, "state") and hasattr(app.state, "pool") and app.state.pool: + try: + app.state.pool.close() + app.state.pool = None + except Exception as e: + print(f"SnapStart: Error closing database pool: {e}") + + return {"statusCode": 200} + + +@register_after_restore +def on_snap_restore(): + """ + Runtime hook called by Lambda after restoring from a snapshot. + We recreate database connections that were closed before the snapshot. + """ + global _connection_initialized + + try: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if hasattr(app.state, "pool") and app.state.pool: + try: + app.state.pool.close() + except Exception as e: + print(f"SnapStart: Error closing stale pool: {e}") + app.state.pool = None + + loop.run_until_complete( + connect_to_db( + app, + schemas=["pgstac", "public"], + user_sql_files=list(CUSTOM_SQL_DIRECTORY.glob("*.sql")), # type: ignore + settings=postgres_settings, + ) + ) + + loop.run_until_complete( + register_collection_catalog( + app, + db_settings=db_settings, + ) + ) + + _connection_initialized = True + + except Exception as e: + print(f"SnapStart: Failed to initialize database connection: {e}") + raise + + return {"statusCode": 200} @app.on_event("startup") @@ -36,7 +104,7 @@ async def startup_event() -> None: # We enable both pgstac and public schemas (pgstac will be used by custom functions) schemas=["pgstac", "public"], user_sql_files=list(CUSTOM_SQL_DIRECTORY.glob("*.sql")), # type: ignore - settings=PostgresSettings(), + settings=postgres_settings, ) await register_collection_catalog(app, db_settings=db_settings) diff --git a/pyproject.toml b/pyproject.toml index b7352cb..e5b1ddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [] [dependency-groups] deploy = [ "boto3==1.24.15", - "eoapi-cdk==10.2.5", + "eoapi-cdk==10.3.0", "pydantic-settings[yaml]==2.2.1", "pydantic==2.7", "typing-extensions>=4.12.2", diff --git a/uv.lock b/uv.lock index 7953820..27752bd 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12" [[package]] @@ -213,16 +213,16 @@ wheels = [ [[package]] name = "constructs" -version = "10.3.0" +version = "10.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jsii" }, { name = "publication" }, { name = "typeguard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/49/24/62b6b537a7fa0348086b5942bd054cd919153ec392dc5594f4b2c5f19218/constructs-10.3.0.tar.gz", hash = "sha256:518551135ec236f9cc6b86500f4fbbe83b803ccdc6c2cb7684e0b7c4d234e7b1", size = 59978, upload-time = "2023-10-07T12:44:42.876Z" } +sdist = { url = "https://files.pythonhosted.org/packages/46/84/f608a0a71a05a476b2f1761ab8f3f776677d39f7996ecf1092a1ce741a7c/constructs-10.4.2.tar.gz", hash = "sha256:ce54724360fffe10bab27d8a081844eb81f5ace7d7c62c84b719c49f164d5307", size = 65434, upload-time = "2024-10-14T12:58:02.822Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/82/5b1407b9747a8c0133d56433dd9b67ec622558b4b47dad6b1751c9d5aeeb/constructs-10.3.0-py3-none-any.whl", hash = "sha256:2972f514837565ff5b09171cfba50c0159dfa75ee86a42921ea8c86f2941b3d2", size = 58188, upload-time = "2023-10-07T12:44:39.598Z" }, + { url = "https://files.pythonhosted.org/packages/f2/d9/c5e7458f323bf063a9a54200742f2494e2ce3c7c6873e0ff80f88033c75f/constructs-10.4.2-py3-none-any.whl", hash = "sha256:1f0f59b004edebfde0f826340698b8c34611f57848139b7954904c61645f13c1", size = 63509, upload-time = "2024-10-14T12:57:59.828Z" }, ] [[package]] @@ -236,7 +236,7 @@ wheels = [ [[package]] name = "eoapi-cdk" -version = "10.2.4" +version = "10.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aws-cdk-lib" }, @@ -245,14 +245,14 @@ dependencies = [ { name = "publication" }, { name = "typeguard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bb/dc/e505eb7917f156f23dbd0aa6ea155dfe6512adeeeb3c974791b2236ab49d/eoapi_cdk-10.2.4.tar.gz", hash = "sha256:0473581c4c94877e3926f3eed77ca9d7dafa59d60511e80d2dc7704f34a07f03", size = 250134, upload-time = "2025-09-16T10:35:41.352Z" } +sdist = { url = "https://files.pythonhosted.org/packages/39/cb/a9b83c9871392b93d39b21c306e9dfa27a98a95ac929ac90feb5a183d90f/eoapi_cdk-10.3.0.tar.gz", hash = "sha256:1505b5fcc4c465d5b4e542faf335499a8d4407d27b074f95b4146bcb020a934a", size = 258033, upload-time = "2025-09-30T16:28:33.063Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/5d/ba6ab37c783a69b6a6ff44a1c24e97a679ea760e02420672de9698a3d9ce/eoapi_cdk-10.2.4-py3-none-any.whl", hash = "sha256:62b6166a074a0372eab2c30265c9a463b5f8caca7abc74941054381aa58046c1", size = 248064, upload-time = "2025-09-16T10:35:37.579Z" }, + { url = "https://files.pythonhosted.org/packages/55/c0/e5aa84595d7f16cb7bc0870648dcf7ac6af75cebaca51bab0a56dfd0288f/eoapi_cdk-10.3.0-py3-none-any.whl", hash = "sha256:2f382a45aaa1f13a6e1f519bad80e4201fcb2aa81ff3cc6dcd1a1f0cd314d451", size = 256393, upload-time = "2025-09-30T16:28:31.486Z" }, ] [[package]] name = "eoapi-devseed" -version = "0.2.0" +version = "0.3.1" source = { virtual = "." } [package.dev-dependencies] @@ -281,7 +281,7 @@ load = [ [package.metadata.requires-dev] deploy = [ { name = "boto3", specifier = "==1.24.15" }, - { name = "eoapi-cdk", specifier = "==10.2.4" }, + { name = "eoapi-cdk", specifier = "==10.3.0" }, { name = "pydantic", specifier = "==2.7" }, { name = "pydantic-settings", extras = ["yaml"], specifier = "==2.2.1" }, { name = "typing-extensions", specifier = ">=4.12.2" },