Skip to content

Commit 0cf131a

Browse files
committed
merge with stac-fastapi-pgstac settings and defaults to all extensions
1 parent 58f2985 commit 0cf131a

File tree

2 files changed

+82
-147
lines changed

2 files changed

+82
-147
lines changed

runtimes/eoapi/stac/eoapi/stac/app.py

Lines changed: 55 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from fastapi.responses import ORJSONResponse
1010
from stac_fastapi.api.app import StacApi
1111
from stac_fastapi.api.models import (
12-
EmptyRequest,
1312
ItemCollectionUri,
1413
create_get_request_model,
1514
create_post_request_model,
@@ -25,17 +24,14 @@
2524
SearchFilterExtension,
2625
SortExtension,
2726
TokenPaginationExtension,
28-
TransactionExtension,
2927
)
3028
from stac_fastapi.extensions.core.fields import FieldsConformanceClasses
3129
from stac_fastapi.extensions.core.free_text import FreeTextConformanceClasses
3230
from stac_fastapi.extensions.core.query import QueryConformanceClasses
3331
from stac_fastapi.extensions.core.sort import SortConformanceClasses
34-
from stac_fastapi.extensions.third_party import BulkTransactionExtension
3532
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
3633
from stac_fastapi.pgstac.extensions import QueryExtension
3734
from stac_fastapi.pgstac.extensions.filter import FiltersClient
38-
from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient
3935
from stac_fastapi.pgstac.types.search import PgstacSearch
4036
from starlette.middleware import Middleware
4137
from starlette.middleware.cors import CORSMiddleware
@@ -45,7 +41,7 @@
4541
from starlette_cramjam.middleware import CompressionMiddleware
4642

4743
from .client import PgSTACClient
48-
from .config import ApiSettings
44+
from .config import Settings
4945
from .extension import TiTilerExtension
5046
from .logs import init_logging
5147

@@ -58,115 +54,78 @@
5854
)
5955
templates = Jinja2Templates(env=jinja2_env)
6056

61-
api_settings = ApiSettings()
57+
settings = Settings()
6258
auth_settings = OpenIdConnectSettings()
63-
settings = api_settings.load_postgres_settings()
6459

65-
enabled_extensions = api_settings.extensions or []
6660

6761
# Logs
68-
init_logging(debug=api_settings.debug)
62+
init_logging(debug=settings.debug)
6963
logger = logging.getLogger(__name__)
7064

7165
# Extensions
7266
# application extensions
73-
application_extensions_map = {
74-
"transaction": TransactionExtension(
75-
client=TransactionsClient(),
76-
settings=settings,
77-
response_class=ORJSONResponse,
78-
),
79-
"bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()),
80-
}
81-
if "titiler" in enabled_extensions and api_settings.titiler_endpoint:
82-
application_extensions_map["titiler"] = TiTilerExtension(
83-
titiler_endpoint=api_settings.titiler_endpoint
67+
application_extensions = []
68+
69+
if settings.titiler_endpoint:
70+
application_extensions.append(
71+
TiTilerExtension(titiler_endpoint=settings.titiler_endpoint)
8472
)
8573

8674
# search extensions
87-
search_extensions_map = {
88-
"query": QueryExtension(),
89-
"sort": SortExtension(),
90-
"fields": FieldsExtension(),
91-
"filter": SearchFilterExtension(client=FiltersClient()),
92-
"pagination": TokenPaginationExtension(),
93-
}
75+
search_extensions = [
76+
QueryExtension(),
77+
SortExtension(),
78+
FieldsExtension(),
79+
SearchFilterExtension(client=FiltersClient()),
80+
TokenPaginationExtension(),
81+
]
9482

9583
# collection_search extensions
96-
cs_extensions_map = {
97-
"query": QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
98-
"sort": SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]),
99-
"fields": FieldsExtension(
100-
conformance_classes=[FieldsConformanceClasses.COLLECTIONS]
101-
),
102-
"filter": CollectionSearchFilterExtension(client=FiltersClient()),
103-
"free_text": FreeTextExtension(
84+
cs_extensions = [
85+
QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
86+
SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]),
87+
FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]),
88+
CollectionSearchFilterExtension(client=FiltersClient()),
89+
FreeTextExtension(
10490
conformance_classes=[FreeTextConformanceClasses.COLLECTIONS],
10591
),
106-
"pagination": OffsetPaginationExtension(),
107-
}
92+
OffsetPaginationExtension(),
93+
]
10894

10995
# item_collection extensions
110-
itm_col_extensions_map = {
111-
"query": QueryExtension(
96+
itm_col_extensions = [
97+
QueryExtension(
11298
conformance_classes=[QueryConformanceClasses.ITEMS],
11399
),
114-
"sort": SortExtension(
100+
SortExtension(
115101
conformance_classes=[SortConformanceClasses.ITEMS],
116102
),
117-
"fields": FieldsExtension(conformance_classes=[FieldsConformanceClasses.ITEMS]),
118-
"filter": ItemCollectionFilterExtension(client=FiltersClient()),
119-
"pagination": TokenPaginationExtension(),
120-
}
121-
122-
application_extensions = [
123-
extension
124-
for key, extension in application_extensions_map.items()
125-
if key in enabled_extensions
103+
FieldsExtension(conformance_classes=[FieldsConformanceClasses.ITEMS]),
104+
ItemCollectionFilterExtension(client=FiltersClient()),
105+
TokenPaginationExtension(),
126106
]
127107

128108
# Request Models
129109
# /search models
130-
search_extensions = [
131-
extension
132-
for key, extension in search_extensions_map.items()
133-
if key in enabled_extensions
134-
]
135-
post_request_model = create_post_request_model(
110+
search_post_model = create_post_request_model(
136111
search_extensions, base_model=PgstacSearch
137112
)
138-
get_request_model = create_get_request_model(search_extensions)
113+
search_get_model = create_get_request_model(search_extensions)
139114
application_extensions.extend(search_extensions)
140115

141116
# /collections/{collectionId}/items model
142-
items_get_request_model = ItemCollectionUri
143-
itm_col_extensions = [
144-
extension
145-
for key, extension in itm_col_extensions_map.items()
146-
if key in enabled_extensions
147-
]
148-
if itm_col_extensions:
149-
items_get_request_model = create_request_model(
150-
model_name="ItemCollectionUri",
151-
base_model=ItemCollectionUri,
152-
extensions=itm_col_extensions,
153-
request_type="GET",
154-
)
155-
application_extensions.extend(itm_col_extensions)
117+
items_get_model = create_request_model(
118+
model_name="ItemCollectionUri",
119+
base_model=ItemCollectionUri,
120+
extensions=itm_col_extensions,
121+
request_type="GET",
122+
)
123+
application_extensions.extend(itm_col_extensions)
156124

157125
# /collections model
158-
collections_get_request_model = EmptyRequest
159-
if "collection_search" in enabled_extensions:
160-
cs_extensions = [
161-
extension
162-
for key, extension in cs_extensions_map.items()
163-
if key in enabled_extensions
164-
]
165-
collection_search_extension = CollectionSearchExtension.from_extensions(
166-
cs_extensions
167-
)
168-
collections_get_request_model = collection_search_extension.GET
169-
application_extensions.append(collection_search_extension)
126+
collection_search_extension = CollectionSearchExtension.from_extensions(cs_extensions)
127+
collections_get_model = collection_search_extension.GET
128+
application_extensions.append(collection_search_extension)
170129

171130

172131
@asynccontextmanager
@@ -179,38 +138,39 @@ async def lifespan(app: FastAPI):
179138

180139
# Middlewares
181140
middlewares = [Middleware(CompressionMiddleware)]
182-
if api_settings.cors_origins:
141+
if settings.cors_origins:
183142
middlewares.append(
184143
Middleware(
185144
CORSMiddleware,
186-
allow_origins=api_settings.cors_origins,
145+
allow_origins=settings.cors_origins,
187146
allow_credentials=True,
188-
allow_methods=api_settings.cors_methods,
147+
allow_methods=settings.cors_methods,
189148
allow_headers=["*"],
190149
)
191150
)
192151

193152
api = StacApi(
194153
app=FastAPI(
195-
title=api_settings.name,
154+
title=settings.stac_fastapi_title,
155+
description=settings.stac_fastapi_description,
196156
lifespan=lifespan,
197-
openapi_url="/api",
198-
docs_url="/api.html",
157+
openapi_url=settings.openapi_url,
158+
docs_url=settings.docs_url,
199159
redoc_url=None,
200160
swagger_ui_init_oauth={
201161
"clientId": auth_settings.client_id,
202162
"usePkceWithAuthorizationCodeGrant": auth_settings.use_pkce,
203163
},
204164
),
205-
title=api_settings.name,
206-
description=api_settings.name,
207165
settings=settings,
208166
extensions=application_extensions,
209-
client=PgSTACClient(pgstac_search_model=post_request_model),
210-
items_get_request_model=items_get_request_model,
211-
search_get_request_model=get_request_model,
212-
search_post_request_model=post_request_model,
213-
collections_get_request_model=collections_get_request_model,
167+
client=PgSTACClient(
168+
pgstac_search_model=search_post_model,
169+
),
170+
items_get_request_model=items_get_model,
171+
search_get_request_model=search_get_model,
172+
search_post_request_model=search_post_model,
173+
collections_get_request_model=collections_get_model,
214174
response_class=ORJSONResponse,
215175
middlewares=middlewares,
216176
)

runtimes/eoapi/stac/eoapi/stac/config.py

Lines changed: 27 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
import base64
44
import json
5-
from typing import List, Optional
5+
from typing import Any, Optional
66

77
import boto3
8-
from pydantic import field_validator
9-
from pydantic_settings import BaseSettings
10-
from stac_fastapi.pgstac.config import Settings
8+
from pydantic import model_validator
9+
from stac_fastapi.pgstac import config
1110

1211

1312
def get_secret_dict(secret_name: str):
@@ -33,59 +32,35 @@ def get_secret_dict(secret_name: str):
3332
return json.loads(base64.b64decode(get_secret_value_response["SecretBinary"]))
3433

3534

36-
class ApiSettings(BaseSettings):
37-
"""API settings"""
35+
class Settings(config.Settings):
36+
"""Extent stac-fastapi-pgstac settings"""
37+
38+
stac_fastapi_title: str = "eoAPI-stac"
39+
stac_fastapi_description: str = "Custom stac-fastapi application for eoAPI Devseed"
40+
stac_fastapi_version: str = "0.1.0"
41+
stac_fastapi_landing_id: str = "eoapi-devseed-stac"
3842

39-
name: str = "eoAPI-stac"
40-
cors_origins: str = "*"
41-
cors_methods: str = "GET,POST,OPTIONS"
4243
cachecontrol: str = "public, max-age=3600"
43-
debug: bool = False
4444

4545
pgstac_secret_arn: Optional[str] = None
46+
4647
titiler_endpoint: Optional[str] = None
4748

48-
extensions: List[str] = [
49-
"filter",
50-
"query",
51-
"sort",
52-
"fields",
53-
"pagination",
54-
"titiler",
55-
"free_text",
56-
"transaction",
57-
# "bulk_transactions",
58-
"collection_search",
59-
]
60-
61-
@field_validator("cors_origins")
62-
def parse_cors_origin(cls, v):
63-
"""Parse CORS origins."""
64-
return [origin.strip() for origin in v.split(",")]
65-
66-
@field_validator("cors_methods")
67-
def parse_cors_methods(cls, v):
68-
"""Parse CORS methods."""
69-
return [method.strip() for method in v.split(",")]
70-
71-
def load_postgres_settings(self) -> "Settings":
72-
"""Load postgres connection params from AWS secret"""
73-
74-
if self.pgstac_secret_arn:
75-
secret = get_secret_dict(self.pgstac_secret_arn)
76-
77-
return Settings(
78-
postgres_host_reader=secret["host"],
79-
postgres_host_writer=secret["host"],
80-
postgres_dbname=secret["dbname"],
81-
postgres_user=secret["username"],
82-
postgres_pass=secret["password"],
83-
postgres_port=secret["port"],
49+
debug: bool = False
50+
51+
@model_validator(mode="before")
52+
def get_postgres_setting(cls, data: Any) -> Any:
53+
if arn := data.get("pgstac_secret_arn"):
54+
secret = get_secret_dict(arn)
55+
data.update(
56+
{
57+
"postgres_host_reader": secret["host"],
58+
"postgres_host_writer": secret["host"],
59+
"postgres_dbname": secret["dbname"],
60+
"postgres_user": secret["username"],
61+
"postgres_pass": secret["password"],
62+
"postgres_port": secret["port"],
63+
}
8464
)
85-
else:
86-
return Settings()
8765

88-
model_config = {
89-
"env_file": ".env",
90-
"extra": "allow",
91-
}
66+
return data

0 commit comments

Comments
 (0)