Skip to content

Commit 767c4d1

Browse files
Merge pull request #24 from developmentseed/feature/simplify-settings-and-extensions
merge with stac-fastapi-pgstac settings and defaults to all extensions
2 parents 58f2985 + 4d8b0b0 commit 767c4d1

File tree

2 files changed

+87
-147
lines changed

2 files changed

+87
-147
lines changed

runtimes/eoapi/stac/eoapi/stac/app.py

Lines changed: 61 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from fastapi.responses import ORJSONResponse
1010
from stac_fastapi.api.app import StacApi
1111
from stac_fastapi.api.models import (
12-
EmptyRequest,
1312
ItemCollectionUri,
1413
create_get_request_model,
1514
create_post_request_model,
@@ -25,17 +24,14 @@
2524
SearchFilterExtension,
2625
SortExtension,
2726
TokenPaginationExtension,
28-
TransactionExtension,
2927
)
3028
from stac_fastapi.extensions.core.fields import FieldsConformanceClasses
3129
from stac_fastapi.extensions.core.free_text import FreeTextConformanceClasses
3230
from stac_fastapi.extensions.core.query import QueryConformanceClasses
3331
from stac_fastapi.extensions.core.sort import SortConformanceClasses
34-
from stac_fastapi.extensions.third_party import BulkTransactionExtension
3532
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
3633
from stac_fastapi.pgstac.extensions import QueryExtension
3734
from stac_fastapi.pgstac.extensions.filter import FiltersClient
38-
from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient
3935
from stac_fastapi.pgstac.types.search import PgstacSearch
4036
from starlette.middleware import Middleware
4137
from starlette.middleware.cors import CORSMiddleware
@@ -44,8 +40,9 @@
4440
from starlette.templating import Jinja2Templates
4541
from starlette_cramjam.middleware import CompressionMiddleware
4642

43+
from . import __version__ as eoapi_devseed_version
4744
from .client import PgSTACClient
48-
from .config import ApiSettings
45+
from .config import Settings
4946
from .extension import TiTilerExtension
5047
from .logs import init_logging
5148

@@ -58,115 +55,78 @@
5855
)
5956
templates = Jinja2Templates(env=jinja2_env)
6057

61-
api_settings = ApiSettings()
58+
settings = Settings()
6259
auth_settings = OpenIdConnectSettings()
63-
settings = api_settings.load_postgres_settings()
6460

65-
enabled_extensions = api_settings.extensions or []
6661

6762
# Logs
68-
init_logging(debug=api_settings.debug)
63+
init_logging(debug=settings.debug)
6964
logger = logging.getLogger(__name__)
7065

7166
# Extensions
7267
# application extensions
73-
application_extensions_map = {
74-
"transaction": TransactionExtension(
75-
client=TransactionsClient(),
76-
settings=settings,
77-
response_class=ORJSONResponse,
78-
),
79-
"bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()),
80-
}
81-
if "titiler" in enabled_extensions and api_settings.titiler_endpoint:
82-
application_extensions_map["titiler"] = TiTilerExtension(
83-
titiler_endpoint=api_settings.titiler_endpoint
68+
application_extensions = []
69+
70+
if settings.titiler_endpoint:
71+
application_extensions.append(
72+
TiTilerExtension(titiler_endpoint=settings.titiler_endpoint)
8473
)
8574

8675
# search extensions
87-
search_extensions_map = {
88-
"query": QueryExtension(),
89-
"sort": SortExtension(),
90-
"fields": FieldsExtension(),
91-
"filter": SearchFilterExtension(client=FiltersClient()),
92-
"pagination": TokenPaginationExtension(),
93-
}
76+
search_extensions = [
77+
QueryExtension(),
78+
SortExtension(),
79+
FieldsExtension(),
80+
SearchFilterExtension(client=FiltersClient()),
81+
TokenPaginationExtension(),
82+
]
9483

9584
# collection_search extensions
96-
cs_extensions_map = {
97-
"query": QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
98-
"sort": SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]),
99-
"fields": FieldsExtension(
100-
conformance_classes=[FieldsConformanceClasses.COLLECTIONS]
101-
),
102-
"filter": CollectionSearchFilterExtension(client=FiltersClient()),
103-
"free_text": FreeTextExtension(
85+
cs_extensions = [
86+
QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
87+
SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]),
88+
FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]),
89+
CollectionSearchFilterExtension(client=FiltersClient()),
90+
FreeTextExtension(
10491
conformance_classes=[FreeTextConformanceClasses.COLLECTIONS],
10592
),
106-
"pagination": OffsetPaginationExtension(),
107-
}
93+
OffsetPaginationExtension(),
94+
]
10895

10996
# item_collection extensions
110-
itm_col_extensions_map = {
111-
"query": QueryExtension(
97+
itm_col_extensions = [
98+
QueryExtension(
11299
conformance_classes=[QueryConformanceClasses.ITEMS],
113100
),
114-
"sort": SortExtension(
101+
SortExtension(
115102
conformance_classes=[SortConformanceClasses.ITEMS],
116103
),
117-
"fields": FieldsExtension(conformance_classes=[FieldsConformanceClasses.ITEMS]),
118-
"filter": ItemCollectionFilterExtension(client=FiltersClient()),
119-
"pagination": TokenPaginationExtension(),
120-
}
121-
122-
application_extensions = [
123-
extension
124-
for key, extension in application_extensions_map.items()
125-
if key in enabled_extensions
104+
FieldsExtension(conformance_classes=[FieldsConformanceClasses.ITEMS]),
105+
ItemCollectionFilterExtension(client=FiltersClient()),
106+
TokenPaginationExtension(),
126107
]
127108

128109
# Request Models
129110
# /search models
130-
search_extensions = [
131-
extension
132-
for key, extension in search_extensions_map.items()
133-
if key in enabled_extensions
134-
]
135-
post_request_model = create_post_request_model(
111+
search_post_model = create_post_request_model(
136112
search_extensions, base_model=PgstacSearch
137113
)
138-
get_request_model = create_get_request_model(search_extensions)
114+
search_get_model = create_get_request_model(search_extensions)
139115
application_extensions.extend(search_extensions)
140116

141117
# /collections/{collectionId}/items model
142-
items_get_request_model = ItemCollectionUri
143-
itm_col_extensions = [
144-
extension
145-
for key, extension in itm_col_extensions_map.items()
146-
if key in enabled_extensions
147-
]
148-
if itm_col_extensions:
149-
items_get_request_model = create_request_model(
150-
model_name="ItemCollectionUri",
151-
base_model=ItemCollectionUri,
152-
extensions=itm_col_extensions,
153-
request_type="GET",
154-
)
155-
application_extensions.extend(itm_col_extensions)
118+
items_get_model = create_request_model(
119+
model_name="ItemCollectionUri",
120+
base_model=ItemCollectionUri,
121+
extensions=itm_col_extensions,
122+
request_type="GET",
123+
)
124+
application_extensions.extend(itm_col_extensions)
156125

157126
# /collections model
158-
collections_get_request_model = EmptyRequest
159-
if "collection_search" in enabled_extensions:
160-
cs_extensions = [
161-
extension
162-
for key, extension in cs_extensions_map.items()
163-
if key in enabled_extensions
164-
]
165-
collection_search_extension = CollectionSearchExtension.from_extensions(
166-
cs_extensions
167-
)
168-
collections_get_request_model = collection_search_extension.GET
169-
application_extensions.append(collection_search_extension)
127+
collection_search_extension = CollectionSearchExtension.from_extensions(cs_extensions)
128+
collections_get_model = collection_search_extension.GET
129+
application_extensions.append(collection_search_extension)
170130

171131

172132
@asynccontextmanager
@@ -179,38 +139,44 @@ async def lifespan(app: FastAPI):
179139

180140
# Middlewares
181141
middlewares = [Middleware(CompressionMiddleware)]
182-
if api_settings.cors_origins:
142+
if settings.cors_origins:
183143
middlewares.append(
184144
Middleware(
185145
CORSMiddleware,
186-
allow_origins=api_settings.cors_origins,
146+
allow_origins=settings.cors_origins,
187147
allow_credentials=True,
188-
allow_methods=api_settings.cors_methods,
148+
allow_methods=settings.cors_methods,
189149
allow_headers=["*"],
190150
)
191151
)
192152

193153
api = StacApi(
194154
app=FastAPI(
195-
title=api_settings.name,
155+
title=settings.stac_fastapi_title,
156+
description=settings.stac_fastapi_description,
157+
version=eoapi_devseed_version,
196158
lifespan=lifespan,
197-
openapi_url="/api",
198-
docs_url="/api.html",
159+
openapi_url=settings.openapi_url,
160+
docs_url=settings.docs_url,
199161
redoc_url=None,
200162
swagger_ui_init_oauth={
201163
"clientId": auth_settings.client_id,
202164
"usePkceWithAuthorizationCodeGrant": auth_settings.use_pkce,
203165
},
204166
),
205-
title=api_settings.name,
206-
description=api_settings.name,
167+
api_version=eoapi_devseed_version,
207168
settings=settings,
208169
extensions=application_extensions,
209-
client=PgSTACClient(pgstac_search_model=post_request_model),
210-
items_get_request_model=items_get_request_model,
211-
search_get_request_model=get_request_model,
212-
search_post_request_model=post_request_model,
213-
collections_get_request_model=collections_get_request_model,
170+
client=PgSTACClient( # type: ignore
171+
landing_page_id=settings.stac_fastapi_landing_id,
172+
title=settings.stac_fastapi_title,
173+
description=settings.stac_fastapi_description,
174+
pgstac_search_model=search_post_model,
175+
),
176+
items_get_request_model=items_get_model,
177+
search_get_request_model=search_get_model,
178+
search_post_request_model=search_post_model,
179+
collections_get_request_model=collections_get_model,
214180
response_class=ORJSONResponse,
215181
middlewares=middlewares,
216182
)

runtimes/eoapi/stac/eoapi/stac/config.py

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
import base64
44
import json
5-
from typing import List, Optional
5+
from typing import Any, Optional
66

77
import boto3
8-
from pydantic import field_validator
9-
from pydantic_settings import BaseSettings
10-
from stac_fastapi.pgstac.config import Settings
8+
from pydantic import model_validator
9+
from stac_fastapi.pgstac import config
1110

1211

1312
def get_secret_dict(secret_name: str):
@@ -33,59 +32,34 @@ def get_secret_dict(secret_name: str):
3332
return json.loads(base64.b64decode(get_secret_value_response["SecretBinary"]))
3433

3534

36-
class ApiSettings(BaseSettings):
37-
"""API settings"""
35+
class Settings(config.Settings):
36+
"""Extent stac-fastapi-pgstac settings"""
37+
38+
stac_fastapi_title: str = "eoAPI-stac"
39+
stac_fastapi_description: str = "Custom stac-fastapi application for eoAPI-Devseed"
40+
stac_fastapi_landing_id: str = "eoapi-devseed-stac"
3841

39-
name: str = "eoAPI-stac"
40-
cors_origins: str = "*"
41-
cors_methods: str = "GET,POST,OPTIONS"
4242
cachecontrol: str = "public, max-age=3600"
43-
debug: bool = False
4443

4544
pgstac_secret_arn: Optional[str] = None
45+
4646
titiler_endpoint: Optional[str] = None
4747

48-
extensions: List[str] = [
49-
"filter",
50-
"query",
51-
"sort",
52-
"fields",
53-
"pagination",
54-
"titiler",
55-
"free_text",
56-
"transaction",
57-
# "bulk_transactions",
58-
"collection_search",
59-
]
60-
61-
@field_validator("cors_origins")
62-
def parse_cors_origin(cls, v):
63-
"""Parse CORS origins."""
64-
return [origin.strip() for origin in v.split(",")]
65-
66-
@field_validator("cors_methods")
67-
def parse_cors_methods(cls, v):
68-
"""Parse CORS methods."""
69-
return [method.strip() for method in v.split(",")]
70-
71-
def load_postgres_settings(self) -> "Settings":
72-
"""Load postgres connection params from AWS secret"""
73-
74-
if self.pgstac_secret_arn:
75-
secret = get_secret_dict(self.pgstac_secret_arn)
76-
77-
return Settings(
78-
postgres_host_reader=secret["host"],
79-
postgres_host_writer=secret["host"],
80-
postgres_dbname=secret["dbname"],
81-
postgres_user=secret["username"],
82-
postgres_pass=secret["password"],
83-
postgres_port=secret["port"],
48+
debug: bool = False
49+
50+
@model_validator(mode="before")
51+
def get_postgres_setting(cls, data: Any) -> Any:
52+
if arn := data.get("pgstac_secret_arn"):
53+
secret = get_secret_dict(arn)
54+
data.update(
55+
{
56+
"postgres_host_reader": secret["host"],
57+
"postgres_host_writer": secret["host"],
58+
"postgres_dbname": secret["dbname"],
59+
"postgres_user": secret["username"],
60+
"postgres_pass": secret["password"],
61+
"postgres_port": secret["port"],
62+
}
8463
)
85-
else:
86-
return Settings()
8764

88-
model_config = {
89-
"env_file": ".env",
90-
"extra": "allow",
91-
}
65+
return data

0 commit comments

Comments
 (0)