Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

### Fixed

* Content-type response headers for the /search endpoint now reflect the geojson response expected in the STAC api spec ([#220](https://github.com/stac-utils/stac-fastapi/issues/220)
* The minimum `limit` value for searches is now 1 ([#296](https://github.com/stac-utils/stac-fastapi/pull/296))
* Links stored with Collections and Items (e.g. license links) are now returned with those STAC objects ([#282](https://github.com/stac-utils/stac-fastapi/pull/282))
* Content-type response headers for the /api endpoint now reflect those expected in the STAC api spec ([#287](https://github.com/stac-utils/stac-fastapi/pull/287))
Expand Down
46 changes: 29 additions & 17 deletions stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
APIRequest,
CollectionUri,
EmptyRequest,
GeoJSONResponse,
ItemCollectionUri,
ItemUri,
SearchGetRequest,
Expand Down Expand Up @@ -96,17 +97,16 @@ def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]
return None

def _create_endpoint(
self, func: Callable, request_type: Union[Type[APIRequest], Type[BaseModel]]
self,
func: Callable,
request_type: Union[Type[APIRequest], Type[BaseModel]],
resp_class: Type[Response],
) -> Callable:
"""Create a FastAPI endpoint."""
if isinstance(self.client, AsyncBaseCoreClient):
return create_async_endpoint(
func, request_type, response_class=self.response_class
)
return create_async_endpoint(func, request_type, response_class=resp_class)
elif isinstance(self.client, BaseCoreClient):
return create_sync_endpoint(
func, request_type, response_class=self.response_class
)
return create_sync_endpoint(func, request_type, response_class=resp_class)
raise NotImplementedError

def register_landing_page(self):
Expand All @@ -125,7 +125,9 @@ def register_landing_page(self):
response_model_exclude_unset=False,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(self.client.landing_page, EmptyRequest),
endpoint=self._create_endpoint(
self.client.landing_page, EmptyRequest, self.response_class
),
)

def register_conformance_classes(self):
Expand All @@ -144,7 +146,9 @@ def register_conformance_classes(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(self.client.conformance, EmptyRequest),
endpoint=self._create_endpoint(
self.client.conformance, EmptyRequest, self.response_class
),
)

def register_get_item(self):
Expand All @@ -161,7 +165,9 @@ def register_get_item(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(self.client.get_item, ItemUri),
endpoint=self._create_endpoint(
self.client.get_item, ItemUri, self.response_class
),
)

def register_post_search(self):
Expand All @@ -178,12 +184,12 @@ def register_post_search(self):
response_model=(ItemCollection if not fields_ext else None)
if self.settings.enable_response_models
else None,
response_class=self.response_class,
response_class=GeoJSONResponse,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause the GeoJSONResponse to not use ORJSONResponse if the self.response_class is ORJSONResponse, leading to slowdowns in what's the largest API responses.

Is there a way to just set the media_type at response class creation time instead of setting the media type through this mechanism?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To resolve this, I've used importlib to conditionally use ORJSONResponse when orjson is available and JSONResponse otherwise.

response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["POST"],
endpoint=self._create_endpoint(
self.client.post_search, search_request_model
self.client.post_search, search_request_model, GeoJSONResponse
),
)

Expand All @@ -200,12 +206,12 @@ def register_get_search(self):
response_model=(ItemCollection if not fields_ext else None)
if self.settings.enable_response_models
else None,
response_class=self.response_class,
response_class=GeoJSONResponse,
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
self.client.get_search, self.search_get_request
self.client.get_search, self.search_get_request, GeoJSONResponse
),
)

Expand All @@ -225,7 +231,9 @@ def register_get_collections(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(self.client.all_collections, EmptyRequest),
endpoint=self._create_endpoint(
self.client.all_collections, EmptyRequest, self.response_class
),
)

def register_get_collection(self):
Expand All @@ -242,7 +250,9 @@ def register_get_collection(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(self.client.get_collection, CollectionUri),
endpoint=self._create_endpoint(
self.client.get_collection, CollectionUri, self.response_class
),
)

def register_get_item_collection(self):
Expand All @@ -262,7 +272,9 @@ def register_get_item_collection(self):
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
self.client.item_collection, self.item_collection_uri
self.client.item_collection,
self.item_collection_uri,
self.response_class,
),
)

Expand Down
20 changes: 20 additions & 0 deletions stac_fastapi/api/stac_fastapi/api/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""api request/response models."""

import abc
import importlib
from typing import Dict, Optional, Type, Union

import attr
Expand Down Expand Up @@ -127,3 +128,22 @@ def kwargs(self) -> Dict:
"fields": self.fields.split(",") if self.fields else self.fields,
"sortby": self.sortby.split(",") if self.sortby else self.sortby,
}


# Test for ORJSON and use it rather than stdlib JSON where supported
if importlib.util.find_spec("orjson") is not None:
from fastapi.responses import ORJSONResponse

class GeoJSONResponse(ORJSONResponse):
"""JSON with custom, vendor content-type."""

media_type = "application/geo+json"


else:
from starlette.responses import JSONResponse

class GeoJSONResponse(JSONResponse):
"""JSON with custom, vendor content-type."""

media_type = "application/geo+json"
13 changes: 13 additions & 0 deletions stac_fastapi/pgstac/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@
]


@pytest.mark.asyncio
async def test_post_search_content_type(app_client):
params = {"limit": 1}
resp = await app_client.post("search", json=params)
assert resp.headers["content-type"] == "application/geo+json"


@pytest.mark.asyncio
async def test_get_search_content_type(app_client):
resp = await app_client.get("search")
assert resp.headers["content-type"] == "application/geo+json"


@pytest.mark.asyncio
async def test_api_headers(app_client):
resp = await app_client.get("/api")
Expand Down
11 changes: 11 additions & 0 deletions stac_fastapi/sqlalchemy/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@
]


def test_post_search_content_type(app_client):
params = {"limit": 1}
resp = app_client.post("search", json=params)
assert resp.headers["content-type"] == "application/geo+json"


def test_get_search_content_type(app_client):
resp = app_client.get("search")
assert resp.headers["content-type"] == "application/geo+json"


def test_api_headers(app_client):
resp = app_client.get("/api")
assert (
Expand Down