From 4da7a6d0bffdfd9f16eb83fd96a6e6fa0d7d9265 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Thu, 13 Jun 2024 19:58:30 +0200 Subject: [PATCH] add tests for FieldsExtension impact on validation --- stac_fastapi/api/tests/test_app.py | 108 ++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 2 deletions(-) diff --git a/stac_fastapi/api/tests/test_app.py b/stac_fastapi/api/tests/test_app.py index 062575915..829982b51 100644 --- a/stac_fastapi/api/tests/test_app.py +++ b/stac_fastapi/api/tests/test_app.py @@ -8,10 +8,10 @@ from stac_fastapi.api import app from stac_fastapi.api.models import create_get_request_model, create_post_request_model -from stac_fastapi.extensions.core.filter.filter import FilterExtension +from stac_fastapi.extensions.core import FieldsExtension, FilterExtension from stac_fastapi.types import stac from stac_fastapi.types.config import ApiSettings -from stac_fastapi.types.core import NumType +from stac_fastapi.types.core import BaseCoreClient, NumType from stac_fastapi.types.search import BaseSearchPostRequest @@ -190,3 +190,107 @@ def get_search( assert landing.status_code == 200, landing.text assert get_search.status_code == 200, get_search.text assert post_search.status_code == 200, post_search.text + + +@pytest.mark.parametrize("validate", [True, False]) +def test_fields_extension(validate, TestCoreClient, item_dict): + """Test if fields Parameters are passed correctly.""" + + class BadCoreClient(BaseCoreClient): + def post_search( + self, search_request: BaseSearchPostRequest, **kwargs + ) -> stac.ItemCollection: + return {"not": "a proper stac item"} + + def get_search( + self, + collections: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + bbox: Optional[List[NumType]] = None, + intersects: Optional[str] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: Optional[int] = 10, + **kwargs, + ) -> stac.ItemCollection: + return {"not": "a proper stac item"} + + def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: + raise NotImplementedError + + def all_collections(self, **kwargs) -> stac.Collections: + raise NotImplementedError + + def get_collection(self, collection_id: str, **kwargs) -> stac.Collection: + raise NotImplementedError + + def item_collection( + self, + collection_id: str, + bbox: Optional[List[Union[float, int]]] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: int = 10, + token: str = None, + **kwargs, + ) -> stac.ItemCollection: + raise NotImplementedError + + test_app = app.StacApi( + settings=ApiSettings(enable_response_models=validate), + client=BadCoreClient(), + search_get_request_model=create_get_request_model([FieldsExtension()]), + search_post_request_model=create_post_request_model([FieldsExtension()]), + extensions=[FieldsExtension()], + ) + + with TestClient(test_app.app) as client: + get_search = client.get( + "/search", + params={"fields": "properties.datetime"}, + ) + post_search = client.post( + "/search", + json={ + "collections": ["test"], + "fields": { + "include": ["properties.datetime"], + "exclude": [], + }, + }, + ) + + assert get_search.status_code == 200, get_search.text + assert post_search.status_code == 200, post_search.text + + test_app = app.StacApi( + settings=ApiSettings(enable_response_models=validate), + client=BadCoreClient(), + search_get_request_model=create_get_request_model([FieldsExtension()]), + search_post_request_model=create_post_request_model([FieldsExtension()]), + extensions=[], + ) + + with TestClient(test_app.app) as client: + get_search = client.get( + "/search", + params={"fields": "properties.datetime"}, + ) + post_search = client.post( + "/search", + json={ + "collections": ["test"], + "fields": { + "include": ["properties.datetime"], + "exclude": [], + }, + }, + ) + if validate: + assert get_search.status_code == 500, ( + get_search.json()["code"] == "ResponseValidationError" + ) + assert post_search.status_code == 500, ( + post_search.json()["code"] == "ResponseValidationError" + ) + else: + assert get_search.status_code == 200, get_search.text + assert post_search.status_code == 200, post_search.text