From 2fd8c8f4168906b16da37cfd173fe283f19dccc5 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Thu, 10 Jul 2025 12:04:45 +0200 Subject: [PATCH 1/5] fix type for advanced freetext --- CHANGES.md | 4 ++++ stac_fastapi/pgstac/core.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index b838437b..4bffb411 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -52,6 +52,10 @@ - `writer_connection_string` in `PostgresSettings` class - `testing_connection_string` in `PostgresSettings` class +### Fixed + +- Allow `q` parameter to be a `str` not a `list[str]` for Advanced Free-Text extension + ## [5.0.2] - 2025-04-07 ### Fixed diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 7854ad0d..aaf5db56 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -54,7 +54,7 @@ async def all_collections( # noqa: C901 sortby: Optional[str] = None, filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, - q: Optional[List[str]] = None, + q: Optional[Union[str, List[str]]] = None, **kwargs, ) -> Collections: """Cross catalog search (GET). @@ -550,7 +550,7 @@ def _clean_search_args( # noqa: C901 sortby: Optional[str] = None, filter_query: Optional[str] = None, filter_lang: Optional[str] = None, - q: Optional[List[str]] = None, + q: Optional[Union[str, List[str]]] = None, ) -> Dict[str, Any]: """Clean up search arguments to match format expected by pgstac""" if filter_query: @@ -596,7 +596,7 @@ def _clean_search_args( # noqa: C901 base_args["fields"] = {"include": includes, "exclude": excludes} if q: - base_args["q"] = " OR ".join(q) + base_args["q"] = " OR ".join(q) if isinstance(q, list) else q # Remove None values from dict clean = {} From c0767f3f16a743f395dfb5d86fee292ee4d86452 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Thu, 24 Jul 2025 16:32:25 +0200 Subject: [PATCH 2/5] add tests and remove free-text from method annotations --- stac_fastapi/pgstac/core.py | 32 +++++++++++++++++++++--------- tests/resources/test_collection.py | 26 ++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index aaf5db56..948dca29 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -54,8 +54,7 @@ async def all_collections( # noqa: C901 sortby: Optional[str] = None, filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, - q: Optional[Union[str, List[str]]] = None, - **kwargs, + **kwargs: Any, ) -> Collections: """Cross catalog search (GET). @@ -86,7 +85,7 @@ async def all_collections( # noqa: C901 sortby=sortby, filter_query=filter_expr, filter_lang=filter_lang, - q=q, + **kwargs, ) async with request.app.state.get_connection(request, "r") as conn: @@ -157,7 +156,10 @@ async def all_collections( # noqa: C901 ) async def get_collection( - self, collection_id: str, request: Request, **kwargs + self, + collection_id: str, + request: Request, + **kwargs: Any, ) -> Collection: """Get collection by id. @@ -202,7 +204,9 @@ async def get_collection( return Collection(**collection) async def _get_base_item( - self, collection_id: str, request: Request + self, + collection_id: str, + request: Request, ) -> Dict[str, Any]: """Get the base item of a collection for use in rehydrating full item collection properties. @@ -359,7 +363,7 @@ async def item_collection( filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, token: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> ItemCollection: """Get all items from a specific collection. @@ -391,6 +395,7 @@ async def item_collection( filter_lang=filter_lang, fields=fields, sortby=sortby, + **kwargs, ) try: @@ -417,7 +422,11 @@ async def item_collection( return ItemCollection(**item_collection) async def get_item( - self, item_id: str, collection_id: str, request: Request, **kwargs + self, + item_id: str, + collection_id: str, + request: Request, + **kwargs: Any, ) -> Item: """Get item by id. @@ -445,7 +454,10 @@ async def get_item( return Item(**item_collection["features"][0]) async def post_search( - self, search_request: PgstacSearch, request: Request, **kwargs + self, + search_request: PgstacSearch, + request: Request, + **kwargs: Any, ) -> ItemCollection: """Cross catalog search (POST). @@ -489,7 +501,7 @@ async def get_search( filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, token: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> ItemCollection: """Cross catalog search (GET). @@ -516,6 +528,7 @@ async def get_search( sortby=sortby, filter_query=filter_expr, filter_lang=filter_lang, + **kwargs, ) try: @@ -551,6 +564,7 @@ def _clean_search_args( # noqa: C901 filter_query: Optional[str] = None, filter_lang: Optional[str] = None, q: Optional[Union[str, List[str]]] = None, + **kwargs: Any, ) -> Dict[str, Any]: """Clean up search arguments to match format expected by pgstac""" if filter_query: diff --git a/tests/resources/test_collection.py b/tests/resources/test_collection.py index 745d4230..62405725 100644 --- a/tests/resources/test_collection.py +++ b/tests/resources/test_collection.py @@ -364,6 +364,32 @@ async def test_collection_search_freetext( assert len(resp.json()["collections"]) == 1 assert resp.json()["collections"][0]["id"] == load_test2_collection.id + resp = await app_client.get( + "/collections", + params={"q": "temperature,calibrated"}, + ) + assert resp.json()["numberReturned"] == 2 + assert resp.json()["numberMatched"] == 2 + assert len(resp.json()["collections"]) == 2 + + resp = await app_client.get( + "/collections", + params={"q": "temperature,yo"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["numberMatched"] == 1 + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + resp = await app_client.get( + "/collections", + params={"q": "temperature OR yo"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["numberMatched"] == 1 + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + resp = await app_client.get( "/collections", params={"q": "nosuchthing"}, From bbf0cb59f3533ee06825620085d735ac2e58192b Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Thu, 24 Jul 2025 16:54:34 +0200 Subject: [PATCH 3/5] add advanced tests --- tests/conftest.py | 56 ++++++++++++++++++++++++++++++ tests/resources/test_collection.py | 41 +++++++++++++++++++++- 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 05846bec..73da7eec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,7 @@ CollectionSearchExtension, CollectionSearchFilterExtension, FieldsExtension, + FreeTextAdvancedExtension, FreeTextExtension, ItemCollectionFilterExtension, OffsetPaginationExtension, @@ -402,3 +403,58 @@ async def default_client(default_app): transport=ASGITransport(app=default_app), base_url="http://test" ) as c: yield c + + +@pytest.fixture(scope="function") +async def app_advanced_freetext(database): + """Default stac-fastapi-pgstac application without only the transaction extensions.""" + api_settings = Settings(testing=True) + + application_extensions = [ + TransactionExtension(client=TransactionsClient(), settings=api_settings) + ] + + collection_extensions = [ + FreeTextAdvancedExtension(), + OffsetPaginationExtension(), + ] + collection_search_extension = CollectionSearchExtension.from_extensions( + collection_extensions + ) + application_extensions.append(collection_search_extension) + + app = StacApi( + settings=api_settings, + extensions=application_extensions, + client=CoreCrudClient(), + health_check=health_check, + collections_get_request_model=collection_search_extension.GET, + ) + + postgres_settings = PostgresSettings( + pguser=database.user, + pgpassword=database.password, + pghost=database.host, + pgport=database.port, + pgdatabase=database.dbname, + ) + logger.info("Creating app Fixture") + time.time() + await connect_to_db( + app.app, + postgres_settings=postgres_settings, + add_write_connection_pool=True, + ) + yield app.app + await close_db_connection(app.app) + + logger.info("Closed Pools.") + + +@pytest.fixture(scope="function") +async def app_client_advanced_freetext(app_advanced_freetext): + logger.info("creating app_client") + async with AsyncClient( + transport=ASGITransport(app=app_advanced_freetext), base_url="http://test" + ) as c: + yield c diff --git a/tests/resources/test_collection.py b/tests/resources/test_collection.py index 62405725..013f9baa 100644 --- a/tests/resources/test_collection.py +++ b/tests/resources/test_collection.py @@ -382,6 +382,45 @@ async def test_collection_search_freetext( assert resp.json()["collections"][0]["id"] == load_test2_collection.id resp = await app_client.get( + "/collections", + params={"q": "nosuchthing"}, + ) + assert len(resp.json()["collections"]) == 0 + + +@requires_pgstac_0_9_2 +@pytest.mark.asyncio +async def test_collection_search_freetext_advanced( + app_client_advanced_freetext, load_test_collection, load_test2_collection +): + # free-text + resp = await app_client_advanced_freetext.get( + "/collections", + params={"q": "temperature"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["numberMatched"] == 1 + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + resp = await app_client_advanced_freetext.get( + "/collections", + params={"q": "temperature,calibrated"}, + ) + assert resp.json()["numberReturned"] == 2 + assert resp.json()["numberMatched"] == 2 + assert len(resp.json()["collections"]) == 2 + + resp = await app_client_advanced_freetext.get( + "/collections", + params={"q": "temperature,yo"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["numberMatched"] == 1 + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + resp = await app_client_advanced_freetext.get( "/collections", params={"q": "temperature OR yo"}, ) @@ -390,7 +429,7 @@ async def test_collection_search_freetext( assert len(resp.json()["collections"]) == 1 assert resp.json()["collections"][0]["id"] == load_test2_collection.id - resp = await app_client.get( + resp = await app_client_advanced_freetext.get( "/collections", params={"q": "nosuchthing"}, ) From b5fdaba1da9e162d5e4fb655186f44e5b4ae39a3 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Fri, 25 Jul 2025 10:06:02 +0200 Subject: [PATCH 4/5] add failing tests --- tests/conftest.py | 7 ++--- tests/resources/test_item.py | 55 ++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 73da7eec..f9afd2b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,6 @@ import json import logging import os -import time from typing import Callable, Dict from urllib.parse import quote_plus as quote from urllib.parse import urljoin @@ -140,6 +139,7 @@ def api_client(request): FieldsExtension(), SearchFilterExtension(client=FiltersClient()), TokenPaginationExtension(), + FreeTextExtension(), # not recommended by PgSTAC ] application_extensions.extend(search_extensions) @@ -168,6 +168,7 @@ def api_client(request): FieldsExtension(conformance_classes=[FieldsConformanceClasses.ITEMS]), ItemCollectionFilterExtension(client=FiltersClient()), TokenPaginationExtension(), + FreeTextExtension(), # not recommended by PgSTAC ] application_extensions.extend(item_collection_extensions) @@ -208,7 +209,6 @@ async def app(api_client, database): pgdatabase=database.dbname, ) logger.info("Creating app Fixture") - time.time() app = api_client.app await connect_to_db( app, @@ -315,7 +315,6 @@ async def app_no_ext(database): pgdatabase=database.dbname, ) logger.info("Creating app Fixture") - time.time() await connect_to_db( api_client_no_ext.app, postgres_settings=postgres_settings, @@ -355,7 +354,6 @@ async def app_no_transaction(database): pgdatabase=database.dbname, ) logger.info("Creating app Fixture") - time.time() await connect_to_db( api.app, postgres_settings=postgres_settings, @@ -439,7 +437,6 @@ async def app_advanced_freetext(database): pgdatabase=database.dbname, ) logger.info("Creating app Fixture") - time.time() await connect_to_db( app.app, postgres_settings=postgres_settings, diff --git a/tests/resources/test_item.py b/tests/resources/test_item.py index 4ea70193..65112ed5 100644 --- a/tests/resources/test_item.py +++ b/tests/resources/test_item.py @@ -18,6 +18,8 @@ from stac_fastapi.pgstac.models.links import CollectionLinks +from ..conftest import requires_pgstac_0_9_2 + async def test_create_collection(app_client, load_test_data: Callable): in_json = load_test_data("test_collection.json") @@ -1689,3 +1691,56 @@ async def test_get_search_link_media(app_client): assert len(links) == 2 get_self_link = next((link for link in links if link["rel"] == "self"), None) assert get_self_link["type"] == "application/geo+json" + + +@requires_pgstac_0_9_2 +@pytest.mark.asyncio +async def test_item_search_freetext(app_client, load_test_data, load_test_collection): + test_item = load_test_data("test_item.json") + resp = await app_client.post( + f"/collections/{test_item['collection']}/items", json=test_item + ) + assert resp.status_code == 201 + + # free-text + resp = await app_client.get( + "/search", + params={"q": "temperature"}, + ) + print(resp.json()) + # assert resp.json()["numberReturned"] == 1 + # assert resp.json()["numberMatched"] == 1 + # assert len(resp.json()["collections"]) == 1 + # assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # resp = await app_client_advanced_freetext.get( + # "/collections", + # params={"q": "temperature,calibrated"}, + # ) + # assert resp.json()["numberReturned"] == 2 + # assert resp.json()["numberMatched"] == 2 + # assert len(resp.json()["collections"]) == 2 + + # resp = await app_client_advanced_freetext.get( + # "/collections", + # params={"q": "temperature,yo"}, + # ) + # assert resp.json()["numberReturned"] == 1 + # assert resp.json()["numberMatched"] == 1 + # assert len(resp.json()["collections"]) == 1 + # assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # resp = await app_client_advanced_freetext.get( + # "/collections", + # params={"q": "temperature OR yo"}, + # ) + # assert resp.json()["numberReturned"] == 1 + # assert resp.json()["numberMatched"] == 1 + # assert len(resp.json()["collections"]) == 1 + # assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # resp = await app_client_advanced_freetext.get( + # "/collections", + # params={"q": "nosuchthing"}, + # ) + # assert len(resp.json()["collections"]) == 0 From 947a9770b9ad96b18ebd1082341c1c8acb6c6df1 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Fri, 1 Aug 2025 10:49:13 +0200 Subject: [PATCH 5/5] fix and enable free-text for items --- stac_fastapi/pgstac/app.py | 3 +- stac_fastapi/pgstac/core.py | 8 ++- stac_fastapi/pgstac/extensions/__init__.py | 3 +- stac_fastapi/pgstac/extensions/free_text.py | 31 ++++++++++++ tests/conftest.py | 3 +- tests/data/test_item.json | 1 + tests/resources/test_item.py | 56 +++++++-------------- 7 files changed, 60 insertions(+), 45 deletions(-) create mode 100644 stac_fastapi/pgstac/extensions/free_text.py diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 5d42f769..7d93bfc0 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -24,7 +24,6 @@ CollectionSearchExtension, CollectionSearchFilterExtension, FieldsExtension, - FreeTextExtension, ItemCollectionFilterExtension, OffsetPaginationExtension, SearchFilterExtension, @@ -42,7 +41,7 @@ from stac_fastapi.pgstac.config import Settings from stac_fastapi.pgstac.core import CoreCrudClient, health_check from stac_fastapi.pgstac.db import close_db_connection, connect_to_db -from stac_fastapi.pgstac.extensions import QueryExtension +from stac_fastapi.pgstac.extensions import FreeTextExtension, QueryExtension from stac_fastapi.pgstac.extensions.filter import FiltersClient from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 948dca29..d159ba67 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -88,6 +88,12 @@ async def all_collections( # noqa: C901 **kwargs, ) + # NOTE: `FreeTextExtension` - pgstac will only accept `str` so we need to + # join the list[str] with ` OR ` + # ref: https://github.com/stac-utils/stac-fastapi-pgstac/pull/263 + if q := clean_args.pop("q", None): + clean_args["q"] = " OR ".join(q) if isinstance(q, list) else q + async with request.app.state.get_connection(request, "r") as conn: q, p = render( """ @@ -610,7 +616,7 @@ def _clean_search_args( # noqa: C901 base_args["fields"] = {"include": includes, "exclude": excludes} if q: - base_args["q"] = " OR ".join(q) if isinstance(q, list) else q + base_args["q"] = q # Remove None values from dict clean = {} diff --git a/stac_fastapi/pgstac/extensions/__init__.py b/stac_fastapi/pgstac/extensions/__init__.py index 00544179..6c2812b6 100644 --- a/stac_fastapi/pgstac/extensions/__init__.py +++ b/stac_fastapi/pgstac/extensions/__init__.py @@ -1,6 +1,7 @@ """pgstac extension customisations.""" from .filter import FiltersClient +from .free_text import FreeTextExtension from .query import QueryExtension -__all__ = ["QueryExtension", "FiltersClient"] +__all__ = ["QueryExtension", "FiltersClient", "FreeTextExtension"] diff --git a/stac_fastapi/pgstac/extensions/free_text.py b/stac_fastapi/pgstac/extensions/free_text.py new file mode 100644 index 00000000..cadab7fe --- /dev/null +++ b/stac_fastapi/pgstac/extensions/free_text.py @@ -0,0 +1,31 @@ +"""Free-Text model for PgSTAC.""" + +from typing import List, Optional + +from pydantic import BaseModel, Field +from pydantic.functional_serializers import PlainSerializer +from stac_fastapi.extensions.core.free_text import ( + FreeTextExtension as FreeTextExtensionBase, +) +from typing_extensions import Annotated + + +class FreeTextExtensionPostRequest(BaseModel): + """Free-text Extension POST request model.""" + + q: Annotated[ + Optional[List[str]], + PlainSerializer(lambda x: " OR ".join(x), return_type=str, when_used="json"), + ] = Field( + None, + description="Parameter to perform free-text queries against STAC metadata", + ) + + +class FreeTextExtension(FreeTextExtensionBase): + """FreeText Extension. + + Override the POST request model to add custom serialization + """ + + POST = FreeTextExtensionPostRequest diff --git a/tests/conftest.py b/tests/conftest.py index f9afd2b8..d3495936 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,7 +26,6 @@ CollectionSearchFilterExtension, FieldsExtension, FreeTextAdvancedExtension, - FreeTextExtension, ItemCollectionFilterExtension, OffsetPaginationExtension, SearchFilterExtension, @@ -44,7 +43,7 @@ from stac_fastapi.pgstac.config import PostgresSettings, Settings from stac_fastapi.pgstac.core import CoreCrudClient, health_check from stac_fastapi.pgstac.db import close_db_connection, connect_to_db -from stac_fastapi.pgstac.extensions import QueryExtension +from stac_fastapi.pgstac.extensions import FreeTextExtension, QueryExtension from stac_fastapi.pgstac.extensions.filter import FiltersClient from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch diff --git a/tests/data/test_item.json b/tests/data/test_item.json index 1c68b959..cac06d66 100644 --- a/tests/data/test_item.json +++ b/tests/data/test_item.json @@ -34,6 +34,7 @@ "type": "Polygon" }, "properties": { + "description": "Landat 8 imagery radiometrically calibrated and orthorectified using gound points and Digital Elevation Model (DEM) data to correct relief displacement.", "datetime": "2020-02-12T12:30:22Z", "landsat:scene_id": "LC82081612020043LGN00", "landsat:row": "161", diff --git a/tests/resources/test_item.py b/tests/resources/test_item.py index 65112ed5..490d652a 100644 --- a/tests/resources/test_item.py +++ b/tests/resources/test_item.py @@ -1705,42 +1705,20 @@ async def test_item_search_freetext(app_client, load_test_data, load_test_collec # free-text resp = await app_client.get( "/search", - params={"q": "temperature"}, - ) - print(resp.json()) - # assert resp.json()["numberReturned"] == 1 - # assert resp.json()["numberMatched"] == 1 - # assert len(resp.json()["collections"]) == 1 - # assert resp.json()["collections"][0]["id"] == load_test2_collection.id - - # resp = await app_client_advanced_freetext.get( - # "/collections", - # params={"q": "temperature,calibrated"}, - # ) - # assert resp.json()["numberReturned"] == 2 - # assert resp.json()["numberMatched"] == 2 - # assert len(resp.json()["collections"]) == 2 - - # resp = await app_client_advanced_freetext.get( - # "/collections", - # params={"q": "temperature,yo"}, - # ) - # assert resp.json()["numberReturned"] == 1 - # assert resp.json()["numberMatched"] == 1 - # assert len(resp.json()["collections"]) == 1 - # assert resp.json()["collections"][0]["id"] == load_test2_collection.id - - # resp = await app_client_advanced_freetext.get( - # "/collections", - # params={"q": "temperature OR yo"}, - # ) - # assert resp.json()["numberReturned"] == 1 - # assert resp.json()["numberMatched"] == 1 - # assert len(resp.json()["collections"]) == 1 - # assert resp.json()["collections"][0]["id"] == load_test2_collection.id - - # resp = await app_client_advanced_freetext.get( - # "/collections", - # params={"q": "nosuchthing"}, - # ) - # assert len(resp.json()["collections"]) == 0 + params={"q": "orthorectified"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["features"][0]["id"] == "test-item" + + resp = await app_client.get( + "/search", + params={"q": "orthorectified,yo"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["features"][0]["id"] == "test-item" + + resp = await app_client.get( + "/search", + params={"q": "yo"}, + ) + assert resp.json()["numberReturned"] == 0