Skip to content

Commit 80064f7

Browse files
add tests for FieldsExtension impact on validation (#708)
Co-authored-by: Jonathan Healy <[email protected]>
1 parent 68dfbd5 commit 80064f7

File tree

1 file changed

+106
-2
lines changed

1 file changed

+106
-2
lines changed

stac_fastapi/api/tests/test_app.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
from stac_fastapi.api import app
1010
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
11-
from stac_fastapi.extensions.core.filter.filter import FilterExtension
11+
from stac_fastapi.extensions.core import FieldsExtension, FilterExtension
1212
from stac_fastapi.types import stac
1313
from stac_fastapi.types.config import ApiSettings
14-
from stac_fastapi.types.core import NumType
14+
from stac_fastapi.types.core import BaseCoreClient, NumType
1515
from stac_fastapi.types.search import BaseSearchPostRequest
1616

1717

@@ -190,3 +190,107 @@ def get_search(
190190
assert landing.status_code == 200, landing.text
191191
assert get_search.status_code == 200, get_search.text
192192
assert post_search.status_code == 200, post_search.text
193+
194+
195+
@pytest.mark.parametrize("validate", [True, False])
196+
def test_fields_extension(validate, TestCoreClient, item_dict):
197+
"""Test if fields Parameters are passed correctly."""
198+
199+
class BadCoreClient(BaseCoreClient):
200+
def post_search(
201+
self, search_request: BaseSearchPostRequest, **kwargs
202+
) -> stac.ItemCollection:
203+
return {"not": "a proper stac item"}
204+
205+
def get_search(
206+
self,
207+
collections: Optional[List[str]] = None,
208+
ids: Optional[List[str]] = None,
209+
bbox: Optional[List[NumType]] = None,
210+
intersects: Optional[str] = None,
211+
datetime: Optional[Union[str, datetime]] = None,
212+
limit: Optional[int] = 10,
213+
**kwargs,
214+
) -> stac.ItemCollection:
215+
return {"not": "a proper stac item"}
216+
217+
def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item:
218+
raise NotImplementedError
219+
220+
def all_collections(self, **kwargs) -> stac.Collections:
221+
raise NotImplementedError
222+
223+
def get_collection(self, collection_id: str, **kwargs) -> stac.Collection:
224+
raise NotImplementedError
225+
226+
def item_collection(
227+
self,
228+
collection_id: str,
229+
bbox: Optional[List[Union[float, int]]] = None,
230+
datetime: Optional[Union[str, datetime]] = None,
231+
limit: int = 10,
232+
token: str = None,
233+
**kwargs,
234+
) -> stac.ItemCollection:
235+
raise NotImplementedError
236+
237+
test_app = app.StacApi(
238+
settings=ApiSettings(enable_response_models=validate),
239+
client=BadCoreClient(),
240+
search_get_request_model=create_get_request_model([FieldsExtension()]),
241+
search_post_request_model=create_post_request_model([FieldsExtension()]),
242+
extensions=[FieldsExtension()],
243+
)
244+
245+
with TestClient(test_app.app) as client:
246+
get_search = client.get(
247+
"/search",
248+
params={"fields": "properties.datetime"},
249+
)
250+
post_search = client.post(
251+
"/search",
252+
json={
253+
"collections": ["test"],
254+
"fields": {
255+
"include": ["properties.datetime"],
256+
"exclude": [],
257+
},
258+
},
259+
)
260+
261+
assert get_search.status_code == 200, get_search.text
262+
assert post_search.status_code == 200, post_search.text
263+
264+
test_app = app.StacApi(
265+
settings=ApiSettings(enable_response_models=validate),
266+
client=BadCoreClient(),
267+
search_get_request_model=create_get_request_model([FieldsExtension()]),
268+
search_post_request_model=create_post_request_model([FieldsExtension()]),
269+
extensions=[],
270+
)
271+
272+
with TestClient(test_app.app) as client:
273+
get_search = client.get(
274+
"/search",
275+
params={"fields": "properties.datetime"},
276+
)
277+
post_search = client.post(
278+
"/search",
279+
json={
280+
"collections": ["test"],
281+
"fields": {
282+
"include": ["properties.datetime"],
283+
"exclude": [],
284+
},
285+
},
286+
)
287+
if validate:
288+
assert get_search.status_code == 500, (
289+
get_search.json()["code"] == "ResponseValidationError"
290+
)
291+
assert post_search.status_code == 500, (
292+
post_search.json()["code"] == "ResponseValidationError"
293+
)
294+
else:
295+
assert get_search.status_code == 200, get_search.text
296+
assert post_search.status_code == 200, post_search.text

0 commit comments

Comments
 (0)