diff --git a/.github/workflows/cicd.yaml b/.github/workflows/cicd.yaml index 95d85cfde..af5b2f819 100644 --- a/.github/workflows/cicd.yaml +++ b/.github/workflows/cicd.yaml @@ -155,7 +155,7 @@ jobs: cache: pip cache-dependency-path: stac_fastapi/pgstac/setup.cfg - name: Install stac-fastapi and stac-api-validator - run: pip install ./stac_fastapi/api ./stac_fastapi/types ./stac_fastapi/${{ matrix.backend }}[server] stac-api-validator==0.4.1 + run: pip install ./stac_fastapi/api ./stac_fastapi/types ./stac_fastapi/extensions ./stac_fastapi/${{ matrix.backend }}[server] stac-api-validator==0.4.1 - name: Run migration if: ${{ matrix.backend == 'sqlalchemy' }} run: cd stac_fastapi/sqlalchemy && alembic upgrade head diff --git a/CHANGES.md b/CHANGES.md index b29958eaa..ffe73dfcf 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -12,6 +12,7 @@ * Updated CI to test against [pgstac v0.6.12](https://github.com/stac-utils/pgstac/releases/tag/v0.6.12) ([#511](https://github.com/stac-utils/stac-fastapi/pull/511)) * Reworked `update_openapi` and added a test for it ([#523](https://github.com/stac-utils/stac-fastapi/pull/523)) * Limit values above 10,000 are now replaced with 10,000 instead of returning a 400 error ([#526](https://github.com/stac-utils/stac-fastapi/pull/526)) +* Default field include and exclude behavior ([#527](https://github.com/stac-utils/stac-fastapi/pull/527)) ### Removed diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py index 93a69a2bc..e781bf3ea 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py @@ -43,6 +43,7 @@ class FieldsExtension(ApiExtension): "assets", "properties.datetime", "collection", + "stac_extensions", } ) schema_href: Optional[str] = attr.ib(default=None) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py index da967a576..eecf406fd 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py @@ -1,5 +1,8 @@ """Request models for the fields extension.""" +from __future__ import annotations + +import copy from typing import Dict, Optional, Set import attr @@ -48,16 +51,30 @@ def filter_fields(self) -> Dict: to the API Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude """ - # Always include default_includes, even if they - # exist in the exclude list. - include = (self.include or set()) - (self.exclude or set()) - include |= Settings.get().default_includes or set() + recommended = self.into_recommended() return { - "include": self._get_field_dict(include), - "exclude": self._get_field_dict(self.exclude), + "include": self._get_field_dict(recommended.include), + "exclude": self._get_field_dict(recommended.exclude), } + def into_recommended(self) -> PostFieldsExtension: + """Convert this fields extension into the recommended sets. + + Based on https://github.com/stac-api-extensions/fields#includeexclude-semantics + """ + include = self.include or set() + exclude = (self.exclude or set()) - include + # We can't do a simple set intersection, because we may include subkeys + # while excluding everything else. E.g. we may want to include an + # attribute of a specific asset, but exclude the rest of the asset + # dictionary. + default_include = copy.deepcopy(Settings.get().default_includes) + if any(incl.startswith("assets.") for incl in include): + default_include.remove("assets") + include = (include | default_include) - exclude + return PostFieldsExtension(include=include, exclude=exclude) + @attr.s class FieldsExtensionGetRequest(APIRequest): @@ -69,4 +86,4 @@ class FieldsExtensionGetRequest(APIRequest): class FieldsExtensionPostRequest(BaseModel): """Additional fields and schema for the POST request.""" - fields: Optional[PostFieldsExtension] = Field(PostFieldsExtension()) + fields: Optional[PostFieldsExtension] = Field(None) diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py index a8c73d9f8..3cb82186f 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py @@ -157,6 +157,14 @@ async def _search_base( """ items: Dict[str, Any] + if search_request.fields is None: + exclude = None + include = None + else: + search_request.fields = search_request.fields.into_recommended() + exclude = search_request.fields.exclude + include = search_request.fields.include + request: Request = kwargs["request"] settings: Settings = request.app.state.settings pool = request.app.state.readpool @@ -183,13 +191,6 @@ async def _search_base( prev: Optional[str] = items.pop("prev", None) collection = ItemCollection(**items) - exclude = search_request.fields.exclude - if exclude and len(exclude) == 0: - exclude = None - include = search_request.fields.include - if include and len(include) == 0: - include = None - async def _add_item_links( feature: Item, collection_id: Optional[str] = None, @@ -204,8 +205,8 @@ async def _add_item_links( item_id = feature.get("id") or item_id if ( - search_request.fields.exclude is None - or "links" not in search_request.fields.exclude + exclude is None + or "links" not in exclude and all([collection_id, item_id]) ): feature["links"] = await ItemLinks( @@ -233,12 +234,15 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]: collection_id = feature.get("collection") item_id = feature.get("id") - feature = filter_fields(feature, include, exclude) + if include or exclude: + feature = filter_fields(feature, include, exclude) await _add_item_links(feature, collection_id, item_id) cleaned_features.append(feature) else: for feature in collection.get("features") or []: + if include or exclude: + feature = filter_fields(feature, include, exclude) await _add_item_links(feature) cleaned_features.append(feature) diff --git a/stac_fastapi/pgstac/tests/resources/test_item.py b/stac_fastapi/pgstac/tests/resources/test_item.py index 43e1f22ef..4d32fdaac 100644 --- a/stac_fastapi/pgstac/tests/resources/test_item.py +++ b/stac_fastapi/pgstac/tests/resources/test_item.py @@ -1083,28 +1083,6 @@ async def test_field_extension_post(app_client, load_test_data, load_test_collec } -async def test_field_extension_exclude_and_include( - app_client, load_test_data, load_test_collection -): - """Test POST search including/excluding same field (fields extension)""" - test_item = load_test_data("test_item.json") - resp = await app_client.post( - f"/collections/{test_item['collection']}/items", json=test_item - ) - assert resp.status_code == 200 - - body = { - "fields": { - "exclude": ["properties.eo:cloud_cover"], - "include": ["properties.eo:cloud_cover", "collection"], - } - } - - resp = await app_client.post("/search", json=body) - resp_json = resp.json() - assert "properties" not in resp_json["features"][0] - - async def test_field_extension_exclude_default_includes( app_client, load_test_data, load_test_collection ): @@ -1133,14 +1111,18 @@ async def test_field_extension_include_multiple_subkeys( resp_json = resp.json() resp_prop_keys = resp_json["features"][0]["properties"].keys() - assert set(resp_prop_keys) == set(["width", "height"]) + assert set(["width", "height"]) <= set(resp_prop_keys) async def test_field_extension_include_multiple_deeply_nested_subkeys( app_client, load_test_item, load_test_collection ): """Test that multiple deeply nested subkeys of an object field are included""" - body = {"fields": {"include": ["assets.ANG.type", "assets.ANG.href"]}} + body = { + "fields": { + "include": ["assets.ANG.type", "assets.ANG.href"], + } + } resp = await app_client.post("/search", json=body) assert resp.status_code == 200 @@ -1184,7 +1166,7 @@ async def test_field_extension_exclude_deeply_nested_included_subkeys( resp_assets = resp_json["features"][0]["assets"] assert "type" in resp_assets["ANG"] - assert "href" not in resp_assets["ANG"] + assert "href" in resp_assets["ANG"] async def test_field_extension_exclude_links( @@ -1209,8 +1191,127 @@ async def test_field_extension_include_only_non_existant_field( resp = await app_client.post("/search", json=body) assert resp.status_code == 200 resp_json = resp.json() + feature = resp_json["features"][0] + + assert set(feature.keys()) == { + "id", + "stac_version", + "type", + "geometry", + "bbox", + "assets", + "properties", + "collection", + "stac_extensions", + "links", + } + assert set(feature["properties"].keys()) == {"datetime"} + + +@pytest.mark.parametrize( + "fields", ({}, {"include": None, "exclude": None}, {"include": [], "exclude": []}) +) +async def test_field_extension_default_includes( + app_client, load_test_item, load_test_collection, fields +): + """Per https://github.com/stac-api-extensions/fields#includeexclude-semantics""" + body = {"fields": fields} + + resp = await app_client.post("/search", json=body) + assert resp.status_code == 200 + resp_json = resp.json() + feature = resp_json["features"][0] + + assert set(feature.keys()) == { + "id", + "stac_version", + "type", + "geometry", + "bbox", + "assets", + "properties", + "collection", + "stac_extensions", + "links", + } + assert set(feature["properties"].keys()) == {"datetime"} + - assert list(resp_json["features"][0].keys()) == ["id", "collection", "links"] +async def test_field_extension_single_include( + app_client, load_test_item, load_test_collection +): + """Per https://github.com/stac-api-extensions/fields#includeexclude-semantics""" + body = {"fields": {"include": ["properties.gsd"]}} + + resp = await app_client.post("/search", json=body) + assert resp.status_code == 200 + resp_json = resp.json() + feature = resp_json["features"][0] + + assert set(feature.keys()) == { + "id", + "stac_version", + "type", + "geometry", + "bbox", + "assets", + "properties", + "collection", + "stac_extensions", + "links", + } + assert set(feature["properties"].keys()) == {"datetime", "gsd"} + + +async def test_field_extension_single_exclude( + app_client, load_test_item, load_test_collection +): + """Per https://github.com/stac-api-extensions/fields#includeexclude-semantics""" + body = {"fields": {"exclude": ["assets"]}} + + resp = await app_client.post("/search", json=body) + assert resp.status_code == 200 + resp_json = resp.json() + feature = resp_json["features"][0] + + assert set(feature.keys()) == { + "id", + "stac_version", + "type", + "geometry", + "bbox", + "properties", + "collection", + "stac_extensions", + "links", + } + assert set(feature["properties"].keys()) == {"datetime"} + + +async def test_field_extension_include_and_exclude( + app_client, load_test_item, load_test_collection +): + """Per https://github.com/stac-api-extensions/fields#includeexclude-semantics""" + body = {"fields": {"include": ["assets"], "exclude": ["assets"]}} + + resp = await app_client.post("/search", json=body) + assert resp.status_code == 200 + resp_json = resp.json() + feature = resp_json["features"][0] + + assert set(feature.keys()) == { + "id", + "stac_version", + "type", + "geometry", + "assets", + "bbox", + "properties", + "collection", + "stac_extensions", + "links", + } + assert set(feature["properties"].keys()) == {"datetime"} async def test_search_intersects_and_bbox(app_client): diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py index 68995d209..1a4e2624d 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py @@ -20,6 +20,7 @@ from stac_pydantic.links import Relations from stac_pydantic.shared import MimeTypes +from stac_fastapi.extensions.core.fields.request import PostFieldsExtension from stac_fastapi.sqlalchemy import serializers from stac_fastapi.sqlalchemy.extensions.query import Operator from stac_fastapi.sqlalchemy.models import database @@ -281,7 +282,7 @@ def get_search( ) base_args["sortby"] = sort_param - if fields: + if fields is not None: includes = set() excludes = set() for field in fields: @@ -291,7 +292,13 @@ def get_search( includes.add(field[1:]) else: includes.add(field) - base_args["fields"] = {"include": includes, "exclude": excludes} + post_fields = PostFieldsExtension( + include=includes, exclude=excludes + ).into_recommended() + base_args["fields"] = { + "include": post_fields.include, + "exclude": post_fields.exclude, + } # Do the request try: @@ -481,7 +488,10 @@ def post_search( ) # Use pydantic includes/excludes syntax to implement fields extension - if self.extension_is_enabled("FieldsExtension"): + if ( + self.extension_is_enabled("FieldsExtension") + and search_request.fields is not None + ): if search_request.query is not None: query_include: Set[str] = set( [ diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index 6fdbb6ed8..f78c15e16 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -136,18 +136,6 @@ def test_app_context_extension(load_test_data, app_client, postgres_transactions assert resp_json["context"]["returned"] == resp_json["context"]["matched"] == 1 -def test_app_fields_extension(load_test_data, app_client, postgres_transactions): - item = load_test_data("test_item.json") - postgres_transactions.create_item( - item["collection"], item, request=MockStarletteRequest - ) - - resp = app_client.get("/search", params={"collections": ["test-collection"]}) - assert resp.status_code == 200 - resp_json = resp.json() - assert list(resp_json["features"][0]["properties"]) == ["datetime"] - - def test_app_query_extension_gt(load_test_data, app_client, postgres_transactions): test_item = load_test_data("test_item.json") postgres_transactions.create_item( diff --git a/stac_fastapi/sqlalchemy/tests/resources/test_item.py b/stac_fastapi/sqlalchemy/tests/resources/test_item.py index b44aa9afc..d074c5f20 100644 --- a/stac_fastapi/sqlalchemy/tests/resources/test_item.py +++ b/stac_fastapi/sqlalchemy/tests/resources/test_item.py @@ -8,6 +8,7 @@ from urllib.parse import parse_qs, urlparse, urlsplit import pystac +import pytest from pydantic.datetime_parse import parse_datetime from pystac.utils import datetime_to_str from shapely.geometry import Polygon @@ -844,39 +845,171 @@ def test_field_extension_post(app_client, load_test_data): } -def test_field_extension_exclude_and_include(app_client, load_test_data): - """Test POST search including/excluding same field (fields extension)""" +def test_field_extension_exclude_default_includes(app_client, load_test_data): + """Test POST search excluding a forbidden field (fields extension)""" test_item = load_test_data("test_item.json") resp = app_client.post( f"/collections/{test_item['collection']}/items", json=test_item ) assert resp.status_code == 200 - body = { - "fields": { - "exclude": ["properties.eo:cloud_cover"], - "include": ["properties.eo:cloud_cover"], - } + body = {"fields": {"exclude": ["geometry"]}} + + resp = app_client.post("/search", json=body) + resp_json = resp.json() + assert "geometry" not in resp_json["features"][0] + + +@pytest.mark.parametrize( + "fields", ({}, {"include": None, "exclude": None}, {"include": [], "exclude": []}) +) +def test_field_extension_default_includes(app_client, load_test_data, fields): + """Per https://github.com/stac-api-extensions/fields#includeexclude-semantics""" + test_item = load_test_data("test_item.json") + resp = app_client.post( + f"/collections/{test_item['collection']}/items", json=test_item + ) + assert resp.status_code == 200 + + body = {"fields": fields} + + resp = app_client.post("/search", json=body) + assert resp.status_code == 200 + resp_json = resp.json() + feature = resp_json["features"][0] + + assert set(feature.keys()) == { + "id", + "stac_version", + "type", + "geometry", + "bbox", + "assets", + "properties", + "collection", + "stac_extensions", + "links", } + assert set(feature["properties"].keys()) == {"datetime"} + + +def test_field_extension_get_default_includes(app_client, load_test_data): + """Per https://github.com/stac-api-extensions/fields#includeexclude-semantics""" + test_item = load_test_data("test_item.json") + resp = app_client.post( + f"/collections/{test_item['collection']}/items", json=test_item + ) + assert resp.status_code == 200 + + resp = app_client.get( + "/search?fields=-gsd" + ) # tough to get an empty string in to fields, so we just exclude one + assert resp.status_code == 200 + resp_json = resp.json() + feature = resp_json["features"][0] + + assert set(feature.keys()) == { + "id", + "stac_version", + "type", + "geometry", + "bbox", + "assets", + "properties", + "collection", + "stac_extensions", + "links", + } + assert set(feature["properties"].keys()) == {"datetime"} + + +def test_field_extension_single_include(app_client, load_test_data): + """Per https://github.com/stac-api-extensions/fields#includeexclude-semantics""" + test_item = load_test_data("test_item.json") + resp = app_client.post( + f"/collections/{test_item['collection']}/items", json=test_item + ) + assert resp.status_code == 200 + + body = {"fields": {"include": ["properties.gsd"]}} resp = app_client.post("/search", json=body) + assert resp.status_code == 200 resp_json = resp.json() - assert "eo:cloud_cover" not in resp_json["features"][0]["properties"] + feature = resp_json["features"][0] + + assert set(feature.keys()) == { + "id", + "stac_version", + "type", + "geometry", + "bbox", + "assets", + "properties", + "collection", + "stac_extensions", + "links", + } + assert set(feature["properties"].keys()) == {"datetime", "gsd"} -def test_field_extension_exclude_default_includes(app_client, load_test_data): - """Test POST search excluding a forbidden field (fields extension)""" +def test_field_extension_single_exclude(app_client, load_test_data): + """Per https://github.com/stac-api-extensions/fields#includeexclude-semantics""" test_item = load_test_data("test_item.json") resp = app_client.post( f"/collections/{test_item['collection']}/items", json=test_item ) assert resp.status_code == 200 - body = {"fields": {"exclude": ["geometry"]}} + body = {"fields": {"exclude": ["assets"]}} resp = app_client.post("/search", json=body) + assert resp.status_code == 200 resp_json = resp.json() - assert "geometry" not in resp_json["features"][0] + feature = resp_json["features"][0] + + assert set(feature.keys()) == { + "id", + "stac_version", + "type", + "geometry", + "bbox", + "properties", + "collection", + "stac_extensions", + "links", + } + assert set(feature["properties"].keys()) == {"datetime"} + + +def test_field_extension_include_and_exclude(app_client, load_test_data): + """Per https://github.com/stac-api-extensions/fields#includeexclude-semantics""" + test_item = load_test_data("test_item.json") + resp = app_client.post( + f"/collections/{test_item['collection']}/items", json=test_item + ) + assert resp.status_code == 200 + + body = {"fields": {"include": ["assets"], "exclude": ["assets"]}} + + resp = app_client.post("/search", json=body) + assert resp.status_code == 200 + resp_json = resp.json() + feature = resp_json["features"][0] + + assert set(feature.keys()) == { + "id", + "stac_version", + "type", + "geometry", + "assets", + "bbox", + "properties", + "collection", + "stac_extensions", + "links", + } + assert set(feature["properties"].keys()) == {"datetime"} def test_search_intersects_and_bbox(app_client):