Skip to content

Commit bd13682

Browse files
fallback to all_collections when CollectionSearchExtension is not enabled (#179)
* fallback to all_collections when `CollectionSearchExtension` is not enabled * test all_collection fallback
1 parent 9293dd8 commit bd13682

File tree

3 files changed

+136
-33
lines changed

3 files changed

+136
-33
lines changed

stac_fastapi/pgstac/core.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -65,42 +65,51 @@ async def all_collections( # noqa: C901
6565
"""
6666
base_url = get_base_url(request)
6767

68-
# Parse request parameters
69-
base_args = {
70-
"bbox": bbox,
71-
"limit": limit,
72-
"offset": offset,
73-
"query": orjson.loads(unquote_plus(query)) if query else query,
74-
}
68+
next_link: Optional[Dict[str, Any]] = None
69+
prev_link: Optional[Dict[str, Any]] = None
70+
collections_result: Collections
71+
72+
if self.extension_is_enabled("CollectionSearchExtension"):
73+
base_args = {
74+
"bbox": bbox,
75+
"limit": limit,
76+
"offset": offset,
77+
"query": orjson.loads(unquote_plus(query)) if query else query,
78+
}
79+
80+
clean_args = clean_search_args(
81+
base_args=base_args,
82+
datetime=datetime,
83+
fields=fields,
84+
sortby=sortby,
85+
filter_query=filter,
86+
filter_lang=filter_lang,
87+
)
7588

76-
clean_args = clean_search_args(
77-
base_args=base_args,
78-
datetime=datetime,
79-
fields=fields,
80-
sortby=sortby,
81-
filter_query=filter,
82-
filter_lang=filter_lang,
83-
)
89+
async with request.app.state.get_connection(request, "r") as conn:
90+
q, p = render(
91+
"""
92+
SELECT * FROM collection_search(:req::text::jsonb);
93+
""",
94+
req=json.dumps(clean_args),
95+
)
96+
collections_result = await conn.fetchval(q, *p)
8497

85-
async with request.app.state.get_connection(request, "r") as conn:
86-
q, p = render(
87-
"""
88-
SELECT * FROM collection_search(:req::text::jsonb);
89-
""",
90-
req=json.dumps(clean_args),
91-
)
92-
collections_result: Collections = await conn.fetchval(q, *p)
98+
if links := collections_result.get("links"):
99+
for link in links:
100+
if link["rel"] == "next":
101+
next_link = link
102+
elif link["rel"] == "prev":
103+
prev_link = link
93104

94-
next_link: Optional[Dict[str, Any]] = None
95-
prev_link: Optional[Dict[str, Any]] = None
96-
if links := collections_result.get("links"):
97-
next_link = None
98-
prev_link = None
99-
for link in links:
100-
if link["rel"] == "next":
101-
next_link = link
102-
elif link["rel"] == "prev":
103-
prev_link = link
105+
else:
106+
async with request.app.state.get_connection(request, "r") as conn:
107+
cols = await conn.fetchval(
108+
"""
109+
SELECT * FROM all_collections();
110+
"""
111+
)
112+
collections_result = {"collections": cols, "links": []}
104113

105114
linked_collections: List[Collection] = []
106115
collections = collections_result["collections"]

tests/conftest.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,48 @@ async def load_test2_item(app_client, load_test_data, load_test2_collection):
268268
)
269269
assert resp.status_code == 201
270270
return Item.model_validate(resp.json())
271+
272+
273+
@pytest.fixture(
274+
scope="session",
275+
)
276+
def api_client_no_ext(database):
277+
api_settings = Settings(
278+
postgres_user=database.user,
279+
postgres_pass=database.password,
280+
postgres_host_reader=database.host,
281+
postgres_host_writer=database.host,
282+
postgres_port=database.port,
283+
postgres_dbname=database.dbname,
284+
testing=True,
285+
)
286+
return StacApi(
287+
settings=api_settings,
288+
extensions=[
289+
TransactionExtension(client=TransactionsClient(), settings=api_settings)
290+
],
291+
client=CoreCrudClient(),
292+
)
293+
294+
295+
@pytest.fixture(scope="function")
296+
async def app_no_ext(api_client_no_ext):
297+
logger.info("Creating app Fixture")
298+
time.time()
299+
app = api_client_no_ext.app
300+
await connect_to_db(app)
301+
302+
yield app
303+
304+
await close_db_connection(app)
305+
306+
logger.info("Closed Pools.")
307+
308+
309+
@pytest.fixture(scope="function")
310+
async def app_client_no_ext(app_no_ext):
311+
logger.info("creating app_client")
312+
async with AsyncClient(
313+
transport=ASGITransport(app=app_no_ext), base_url="http://test"
314+
) as c:
315+
yield c

tests/resources/test_collection.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,55 @@ async def test_get_collections_search(
307307
assert len(resp.json()["collections"]) == 2
308308

309309

310+
@requires_pgstac_0_9_2
311+
@pytest.mark.asyncio
312+
async def test_all_collections_with_pagination(app_client, load_test_data):
313+
data = load_test_data("test_collection.json")
314+
collection_id = data["id"]
315+
for ii in range(0, 12):
316+
data["id"] = collection_id + f"_{ii}"
317+
resp = await app_client.post(
318+
"/collections",
319+
json=data,
320+
)
321+
assert resp.status_code == 201
322+
323+
resp = await app_client.get("/collections")
324+
cols = resp.json()["collections"]
325+
assert len(cols) == 10
326+
links = resp.json()["links"]
327+
assert len(links) == 3
328+
assert {"root", "self", "next"} == {link["rel"] for link in links}
329+
330+
resp = await app_client.get("/collections", params={"limit": 12})
331+
cols = resp.json()["collections"]
332+
assert len(cols) == 12
333+
links = resp.json()["links"]
334+
assert len(links) == 2
335+
assert {"root", "self"} == {link["rel"] for link in links}
336+
337+
338+
@requires_pgstac_0_9_2
339+
@pytest.mark.asyncio
340+
async def test_all_collections_without_pagination(app_client_no_ext, load_test_data):
341+
data = load_test_data("test_collection.json")
342+
collection_id = data["id"]
343+
for ii in range(0, 12):
344+
data["id"] = collection_id + f"_{ii}"
345+
resp = await app_client_no_ext.post(
346+
"/collections",
347+
json=data,
348+
)
349+
assert resp.status_code == 201
350+
351+
resp = await app_client_no_ext.get("/collections")
352+
cols = resp.json()["collections"]
353+
assert len(cols) == 12
354+
links = resp.json()["links"]
355+
assert len(links) == 2
356+
assert {"root", "self"} == {link["rel"] for link in links}
357+
358+
310359
@requires_pgstac_0_9_2
311360
@pytest.mark.asyncio
312361
async def test_get_collections_search_pagination(

0 commit comments

Comments
 (0)