Skip to content

Commit 82e2553

Browse files
authored
feat: enable SnapStart on lambda functions (#56)
1 parent 616b890 commit 82e2553

File tree

6 files changed

+216
-13
lines changed

6 files changed

+216
-13
lines changed

infrastructure/app.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def __init__(
181181
"handler": "handler.handler",
182182
"runtime": aws_lambda.Runtime.PYTHON_3_12,
183183
},
184+
enable_snap_start=True,
184185
)
185186

186187
#######################################################################
@@ -228,6 +229,7 @@ def __init__(
228229
"handler": "handler.handler",
229230
"runtime": aws_lambda.Runtime.PYTHON_3_12,
230231
},
232+
enable_snap_start=True,
231233
)
232234

233235
#######################################################################
@@ -274,6 +276,7 @@ def __init__(
274276
"handler": "handler.handler",
275277
"runtime": aws_lambda.Runtime.PYTHON_3_12,
276278
},
279+
enable_snap_start=True,
277280
)
278281

279282
if app_config.stac_ingestor:

infrastructure/handlers/raster_handler.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,70 @@
77
from eoapi.raster.app import app
88
from eoapi.raster.config import PostgresSettings
99
from mangum import Mangum
10+
from snapshot_restore_py import register_after_restore, register_before_snapshot
1011
from titiler.pgstac.db import connect_to_db
1112

1213
logging.getLogger("mangum.lifespan").setLevel(logging.ERROR)
1314
logging.getLogger("mangum.http").setLevel(logging.ERROR)
1415

16+
postgres_settings = PostgresSettings()
17+
_connection_initialized = False
18+
19+
20+
@register_before_snapshot
21+
def on_snapshot():
22+
"""
23+
Runtime hook called by Lambda before taking a snapshot.
24+
We close database connections that shouldn't be in the snapshot.
25+
"""
26+
27+
if hasattr(app, "state") and hasattr(app.state, "dbpool") and app.state.dbpool:
28+
try:
29+
app.state.dbpool.close()
30+
app.state.dbpool = None
31+
except Exception as e:
32+
print(f"SnapStart: Error closing database pool: {e}")
33+
34+
return {"statusCode": 200}
35+
36+
37+
@register_after_restore
38+
def on_snap_restore():
39+
"""
40+
Runtime hook called by Lambda after restoring from a snapshot.
41+
We recreate database connections that were closed before the snapshot.
42+
"""
43+
global _connection_initialized
44+
45+
try:
46+
try:
47+
loop = asyncio.get_running_loop()
48+
except RuntimeError:
49+
loop = asyncio.new_event_loop()
50+
asyncio.set_event_loop(loop)
51+
52+
if hasattr(app.state, "dbpool") and app.state.dbpool:
53+
try:
54+
app.state.dbpool.close()
55+
except Exception as e:
56+
print(f"SnapStart: Error closing stale pool: {e}")
57+
app.state.dbpool = None
58+
59+
loop.run_until_complete(connect_to_db(app, settings=postgres_settings))
60+
61+
_connection_initialized = True
62+
63+
except Exception as e:
64+
print(f"SnapStart: Failed to initialize database connection: {e}")
65+
raise
66+
67+
return {"statusCode": 200}
68+
1569

1670
@app.on_event("startup")
1771
async def startup_event() -> None:
1872
"""Connect to database on startup."""
19-
await connect_to_db(app, settings=PostgresSettings())
73+
await connect_to_db(app, settings=postgres_settings)
2074

2175

2276
handler = Mangum(app, lifespan="off")

infrastructure/handlers/stac_handler.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,98 @@
77
from eoapi.stac.app import app
88
from eoapi.stac.config import PostgresSettings, Settings
99
from mangum import Mangum
10+
from snapshot_restore_py import register_after_restore, register_before_snapshot
1011
from stac_fastapi.pgstac.db import connect_to_db
1112

1213
logging.getLogger("mangum.lifespan").setLevel(logging.ERROR)
1314
logging.getLogger("mangum.http").setLevel(logging.ERROR)
1415

1516
settings = Settings()
17+
postgres_settings = PostgresSettings()
18+
19+
_connection_initialized = False
20+
21+
22+
@register_before_snapshot
23+
def on_snapshot():
24+
"""
25+
Runtime hook called by Lambda before taking a snapshot.
26+
We close database connections that shouldn't be in the snapshot.
27+
"""
28+
29+
if hasattr(app, "state") and hasattr(app.state, "readpool") and app.state.readpool:
30+
try:
31+
app.state.readpool.close()
32+
app.state.readpool = None
33+
except Exception as e:
34+
print(f"SnapStart: Error closing database readpool: {e}")
35+
36+
if (
37+
hasattr(app, "state")
38+
and hasattr(app.state, "writepool")
39+
and app.state.writepool
40+
):
41+
try:
42+
app.state.writepool.close()
43+
app.state.writepool = None
44+
except Exception as e:
45+
print(f"SnapStart: Error closing database writepool: {e}")
46+
47+
return {"statusCode": 200}
48+
49+
50+
@register_after_restore
51+
def on_snap_restore():
52+
"""
53+
Runtime hook called by Lambda after restoring from a snapshot.
54+
We recreate database connections that were closed before the snapshot.
55+
"""
56+
global _connection_initialized
57+
58+
try:
59+
try:
60+
loop = asyncio.get_running_loop()
61+
except RuntimeError:
62+
loop = asyncio.new_event_loop()
63+
asyncio.set_event_loop(loop)
64+
65+
if hasattr(app.state, "readpool") and app.state.readpool:
66+
try:
67+
app.state.readpool.close()
68+
except Exception as e:
69+
print(f"SnapStart: Error closing stale readpool: {e}")
70+
app.state.readpool = None
71+
72+
if hasattr(app.state, "writepool") and app.state.writepool:
73+
try:
74+
app.state.writepool.close()
75+
except Exception as e:
76+
print(f"SnapStart: Error closing stale writepool: {e}")
77+
app.state.writepool = None
78+
79+
loop.run_until_complete(
80+
connect_to_db(
81+
app,
82+
postgres_settings=postgres_settings,
83+
add_write_connection_pool=settings.enable_transaction,
84+
)
85+
)
86+
87+
_connection_initialized = True
88+
89+
except Exception as e:
90+
print(f"SnapStart: Failed to initialize database connection: {e}")
91+
raise
92+
93+
return {"statusCode": 200}
1694

1795

1896
@app.on_event("startup")
1997
async def startup_event() -> None:
2098
"""Connect to database on startup."""
2199
await connect_to_db(
22100
app,
23-
postgres_settings=PostgresSettings(),
101+
postgres_settings=postgres_settings,
24102
add_write_connection_pool=settings.enable_transaction,
25103
)
26104

infrastructure/handlers/vector_handler.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from eoapi.vector.app import app
99
from eoapi.vector.config import PostgresSettings
1010
from mangum import Mangum
11+
from snapshot_restore_py import register_after_restore, register_before_snapshot
1112
from tipg.collections import register_collection_catalog
1213
from tipg.database import connect_to_db
1314
from tipg.settings import DatabaseSettings
@@ -26,6 +27,73 @@
2627
# We allow non-spatial tables
2728
spatial=False,
2829
)
30+
postgres_settings = PostgresSettings()
31+
32+
_connection_initialized = False
33+
34+
35+
@register_before_snapshot
36+
def on_snapshot():
37+
"""
38+
Runtime hook called by Lambda before taking a snapshot.
39+
We close database connections that shouldn't be in the snapshot.
40+
"""
41+
42+
if hasattr(app, "state") and hasattr(app.state, "pool") and app.state.pool:
43+
try:
44+
app.state.pool.close()
45+
app.state.pool = None
46+
except Exception as e:
47+
print(f"SnapStart: Error closing database pool: {e}")
48+
49+
return {"statusCode": 200}
50+
51+
52+
@register_after_restore
53+
def on_snap_restore():
54+
"""
55+
Runtime hook called by Lambda after restoring from a snapshot.
56+
We recreate database connections that were closed before the snapshot.
57+
"""
58+
global _connection_initialized
59+
60+
try:
61+
try:
62+
loop = asyncio.get_running_loop()
63+
except RuntimeError:
64+
loop = asyncio.new_event_loop()
65+
asyncio.set_event_loop(loop)
66+
67+
if hasattr(app.state, "pool") and app.state.pool:
68+
try:
69+
app.state.pool.close()
70+
except Exception as e:
71+
print(f"SnapStart: Error closing stale pool: {e}")
72+
app.state.pool = None
73+
74+
loop.run_until_complete(
75+
connect_to_db(
76+
app,
77+
schemas=["pgstac", "public"],
78+
user_sql_files=list(CUSTOM_SQL_DIRECTORY.glob("*.sql")), # type: ignore
79+
settings=postgres_settings,
80+
)
81+
)
82+
83+
loop.run_until_complete(
84+
register_collection_catalog(
85+
app,
86+
db_settings=db_settings,
87+
)
88+
)
89+
90+
_connection_initialized = True
91+
92+
except Exception as e:
93+
print(f"SnapStart: Failed to initialize database connection: {e}")
94+
raise
95+
96+
return {"statusCode": 200}
2997

3098

3199
@app.on_event("startup")
@@ -36,7 +104,7 @@ async def startup_event() -> None:
36104
# We enable both pgstac and public schemas (pgstac will be used by custom functions)
37105
schemas=["pgstac", "public"],
38106
user_sql_files=list(CUSTOM_SQL_DIRECTORY.glob("*.sql")), # type: ignore
39-
settings=PostgresSettings(),
107+
settings=postgres_settings,
40108
)
41109
await register_collection_catalog(app, db_settings=db_settings)
42110

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ dependencies = []
99
[dependency-groups]
1010
deploy = [
1111
"boto3==1.24.15",
12-
"eoapi-cdk==10.2.5",
12+
"eoapi-cdk==10.3.0",
1313
"pydantic-settings[yaml]==2.2.1",
1414
"pydantic==2.7",
1515
"typing-extensions>=4.12.2",

uv.lock

Lines changed: 9 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)