diff --git a/CHANGELOG.md b/CHANGELOG.md index 31b96da49..69538718e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,15 +8,16 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added - +- Environment variable `ENABLE_COLLECTIONS_SEARCH_ROUTE` to turn on/off the `/collections-search` endpoint. [#478](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/478) +- POST and GET `/collections-search` endpoint for collections search queries, needed because POST /collections search will not work when the Transactions Extension is enabled. Defaults to `False` [#478](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/478) - GET `/collections` collection search structured filter extension with support for both cql2-json and cql2-text formats. [#475](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/475) - GET `/collections` collection search query extension. [#477](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/477) - GET `/collections` collections search datetime filtering support. [#476](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/476) ### Changed +- Refactored `/collections` endpoint implementation to support both GET and POST methods. [#478](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/478) ### Fixed - - support of disabled nested attributes in the properties dictionary. [#474](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/474) ## [v6.4.0] - 2025-09-24 diff --git a/README.md b/README.md index d6e74912f..a8e2a2973 100644 --- a/README.md +++ b/README.md @@ -66,13 +66,13 @@ This project is built on the following technologies: STAC, stac-fastapi, FastAPI ## Table of Contents - [stac-fastapi-elasticsearch-opensearch](#stac-fastapi-elasticsearch-opensearch) - - [Sponsors \& Supporters](#sponsors--supporters) + - [Sponsors & Supporters](#sponsors--supporters) - [Project Introduction - What is SFEOS?](#project-introduction---what-is-sfeos) - [Common Deployment Patterns](#common-deployment-patterns) - [Technologies](#technologies) - [Table of Contents](#table-of-contents) - [Collection Search Extensions](#collection-search-extensions) - - [Documentation \& Resources](#documentation--resources) + - [Documentation & Resources](#documentation--resources) - [Package Structure](#package-structure) - [Examples](#examples) - [Performance](#performance) @@ -115,7 +115,11 @@ This project is built on the following technologies: STAC, stac-fastapi, FastAPI ## Collection Search Extensions -SFEOS implements extended capabilities for the `/collections` endpoint, allowing for more powerful collection discovery: +SFEOS provides enhanced collection search capabilities through two primary routes: +- **GET/POST `/collections`**: The standard STAC endpoint with extended query parameters +- **GET/POST `/collections-search`**: A custom endpoint that supports the same parameters, created to avoid conflicts with the STAC Transactions extension if enabled (which uses POST `/collections` for collection creation) + +These endpoints support advanced collection discovery features including: - **Sorting**: Sort collections by sortable fields using the `sortby` parameter - Example: `/collections?sortby=+id` (ascending sort by ID) @@ -146,11 +150,11 @@ SFEOS implements extended capabilities for the `/collections` endpoint, allowing - Collections are matched if their temporal extent overlaps with the provided datetime parameter - This allows for efficient discovery of collections based on time periods -> **Note on HTTP Methods**: All collection search extensions (sorting, field selection, free text search, structured filtering, and datetime filtering) currently only support GET requests. POST requests with these parameters in the request body are not yet supported. - These extensions make it easier to build user interfaces that display and navigate through collections efficiently. -> **Configuration**: Collection search extensions (sorting, field selection, free text search, structured filtering, and datetime filtering) can be disabled by setting the `ENABLE_COLLECTIONS_SEARCH` environment variable to `false`. By default, these extensions are enabled. +> **Configuration**: Collection search extensions (sorting, field selection, free text search, structured filtering, and datetime filtering) for the `/collections` endpoint can be disabled by setting the `ENABLE_COLLECTIONS_SEARCH` environment variable to `false`. By default, these extensions are enabled. +> +> **Configuration**: The custom `/collections-search` endpoint can be enabled by setting the `ENABLE_COLLECTIONS_SEARCH_ROUTE` environment variable to `true`. By default, this endpoint is **disabled**. > **Note**: Sorting is only available on fields that are indexed for sorting in Elasticsearch/OpenSearch. With the default mappings, you can sort on: > - `id` (keyword field) @@ -161,6 +165,7 @@ These extensions make it easier to build user interfaces that display and naviga > > **Important**: Adding keyword fields to make text fields sortable can significantly increase the index size, especially for large text fields. Consider the storage implications when deciding which fields to make sortable. + ## Package Structure This project is organized into several packages, each with a specific purpose: @@ -291,8 +296,9 @@ You can customize additional settings in your `.env` file: | `ENABLE_DIRECT_RESPONSE` | Enable direct response for maximum performance (disables all FastAPI dependencies, including authentication, custom status codes, and validation) | `false` | Optional | | `RAISE_ON_BULK_ERROR` | Controls whether bulk insert operations raise exceptions on errors. If set to `true`, the operation will stop and raise an exception when an error occurs. If set to `false`, errors will be logged, and the operation will continue. **Note:** STAC Item and ItemCollection validation errors will always raise, regardless of this flag. | `false` | Optional | | `DATABASE_REFRESH` | Controls whether database operations refresh the index immediately after changes. If set to `true`, changes will be immediately searchable. If set to `false`, changes may not be immediately visible but can improve performance for bulk operations. If set to `wait_for`, changes will wait for the next refresh cycle to become visible. | `false` | Optional | -| `ENABLE_COLLECTIONS_SEARCH` | Enable collection search extensions (sort, fields, free text search, structured filtering, and datetime filtering). | `true` | Optional | -| `ENABLE_TRANSACTIONS_EXTENSIONS` | Enables or disables the Transactions and Bulk Transactions API extensions. If set to `false`, the POST `/collections` route and related transaction endpoints (including bulk transaction operations) will be unavailable in the API. This is useful for deployments where mutating the catalog via the API should be prevented. | `true` | Optional | +| `ENABLE_COLLECTIONS_SEARCH` | Enable collection search extensions (sort, fields, free text search, structured filtering, and datetime filtering) on the core `/collections` endpoint. | `true` | Optional | +| `ENABLE_COLLECTIONS_SEARCH_ROUTE` | Enable the custom `/collections-search` endpoint (both GET and POST methods). When disabled, the custom endpoint will not be available, but collection search extensions will still be available on the core `/collections` endpoint if `ENABLE_COLLECTIONS_SEARCH` is true. | `false` | Optional | +| `ENABLE_TRANSACTIONS_EXTENSIONS` | Enables or disables the Transactions and Bulk Transactions API extensions. This is useful for deployments where mutating the catalog via the API should be prevented. If set to `true`, the POST `/collections` route for search will be unavailable in the API. | `true` | Optional | | `STAC_ITEM_LIMIT` | Sets the environment variable for result limiting to SFEOS for the number of returned items and STAC collections. | `10` | Optional | | `STAC_INDEX_ASSETS` | Controls if Assets are indexed when added to Elasticsearch/Opensearch. This allows asset fields to be included in search queries. | `false` | Optional | | `ENV_MAX_LIMIT` | Configures the environment variable in SFEOS to override the default `MAX_LIMIT`, which controls the limit parameter for returned items and STAC collections. | `10,000` | Optional | @@ -442,7 +448,6 @@ The system uses a precise naming convention: - `ENABLE_COLLECTIONS_SEARCH`: Set to `true` (default) to enable collection search extensions (sort, fields). Set to `false` to disable. - `ENABLE_TRANSACTIONS_EXTENSIONS`: Set to `true` (default) to enable transaction extensions. Set to `false` to disable. - ## Collection Pagination - **Overview**: The collections route supports pagination through optional query parameters. diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index a6862cf25..ac2f228d2 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -136,6 +136,20 @@ def _landing_page( "href": urljoin(base_url, "search"), "method": "POST", }, + { + "rel": "collections-search", + "type": "application/json", + "title": "Collections Search", + "href": urljoin(base_url, "collections-search"), + "method": "GET", + }, + { + "rel": "collections-search", + "type": "application/json", + "title": "Collections Search", + "href": urljoin(base_url, "collections-search"), + "method": "POST", + }, ], stac_extensions=extension_schemas, ) @@ -227,8 +241,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: async def all_collections( self, datetime: Optional[str] = None, + limit: Optional[int] = None, fields: Optional[List[str]] = None, - sortby: Optional[str] = None, + sortby: Optional[Union[str, List[str]]] = None, filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, q: Optional[Union[str, List[str]]] = None, @@ -239,6 +254,7 @@ async def all_collections( Args: datetime (Optional[str]): Filter collections by datetime range. + limit (Optional[int]): Maximum number of collections to return. fields (Optional[List[str]]): Fields to include or exclude from the results. sortby (Optional[str]): Sorting options for the results. filter_expr (Optional[str]): Structured filter expression in CQL2 JSON or CQL2-text format. @@ -252,7 +268,36 @@ async def all_collections( """ request = kwargs["request"] base_url = str(request.base_url) - limit = int(request.query_params.get("limit", os.getenv("STAC_ITEM_LIMIT", 10))) + + # Get the global limit from environment variable + global_limit = None + env_limit = os.getenv("STAC_ITEM_LIMIT") + if env_limit: + try: + global_limit = int(env_limit) + except ValueError: + # Handle invalid integer in environment variable + pass + + # Apply global limit if it exists + if global_limit is not None: + # If a limit was provided, use the smaller of the two + if limit is not None: + limit = min(limit, global_limit) + else: + limit = global_limit + else: + # No global limit, use provided limit or default + if limit is None: + query_limit = request.query_params.get("limit") + if query_limit: + try: + limit = int(query_limit) + except ValueError: + limit = 10 + else: + limit = 10 + token = request.query_params.get("token") # Process fields parameter for filtering collection properties @@ -262,7 +307,8 @@ async def all_collections( if field[0] == "-": excludes.add(field[1:]) else: - includes.add(field[1:] if field[0] in "+ " else field) + include_field = field[1:] if field[0] in "+ " else field + includes.add(include_field) sort = None if sortby: @@ -337,6 +383,7 @@ async def all_collections( raise HTTPException( status_code=400, detail=f"Error parsing filter: {e}" ) + except Exception as e: raise HTTPException( status_code=400, detail=f"Invalid filter parameter: {e}" @@ -346,7 +393,7 @@ async def all_collections( if datetime: parsed_datetime = format_datetime_range(date_str=datetime) - collections, next_token = await self.database.get_all_collections( + collections, next_token, maybe_count = await self.database.get_all_collections( token=token, limit=limit, request=request, @@ -380,7 +427,91 @@ async def all_collections( next_link = PagingLinks(next=next_token, request=request).link_next() links.append(next_link) - return stac_types.Collections(collections=filtered_collections, links=links) + return stac_types.Collections( + collections=filtered_collections, + links=links, + numberMatched=maybe_count, + numberReturned=len(filtered_collections), + ) + + async def post_all_collections( + self, search_request: BaseSearchPostRequest, request: Request, **kwargs + ) -> stac_types.Collections: + """Search collections with POST request. + + Args: + search_request (BaseSearchPostRequest): The search request. + request (Request): The request. + + Returns: + A Collections object containing all the collections in the database and links to various resources. + """ + request.postbody = search_request.model_dump(exclude_unset=True) + + fields = None + + # Check for field attribute (ExtendedSearch format) + if hasattr(search_request, "field") and search_request.field: + fields = [] + + # Handle include fields + if ( + hasattr(search_request.field, "includes") + and search_request.field.includes + ): + for field in search_request.field.includes: + fields.append(f"+{field}") + + # Handle exclude fields + if ( + hasattr(search_request.field, "excludes") + and search_request.field.excludes + ): + for field in search_request.field.excludes: + fields.append(f"-{field}") + + # Convert sortby parameter from POST format to all_collections format + sortby = None + # Check for sortby attribute + if hasattr(search_request, "sortby") and search_request.sortby: + # Create a list of sort strings in the format expected by all_collections + sortby = [] + for sort_item in search_request.sortby: + # Handle different types of sort items + if hasattr(sort_item, "field") and hasattr(sort_item, "direction"): + # This is a Pydantic model with field and direction attributes + field = sort_item.field + direction = sort_item.direction + elif isinstance(sort_item, dict): + # This is a dictionary with field and direction keys + field = sort_item.get("field") + direction = sort_item.get("direction", "asc") + else: + # Skip this item if we can't extract field and direction + continue + + if field: + # Create a sort string in the format expected by all_collections + # e.g., "-id" for descending sort on id field + prefix = "-" if direction.lower() == "desc" else "" + sortby.append(f"{prefix}{field}") + + # Pass all parameters from search_request to all_collections + return await self.all_collections( + limit=search_request.limit if hasattr(search_request, "limit") else None, + fields=fields, + sortby=sortby, + filter_expr=search_request.filter + if hasattr(search_request, "filter") + else None, + filter_lang=search_request.filter_lang + if hasattr(search_request, "filter_lang") + else None, + query=search_request.query if hasattr(search_request, "query") else None, + q=search_request.q if hasattr(search_request, "q") else None, + request=request, + **kwargs, + ) async def get_collection( self, collection_id: str, **kwargs diff --git a/stac_fastapi/core/stac_fastapi/core/extensions/__init__.py b/stac_fastapi/core/stac_fastapi/core/extensions/__init__.py index 7ee6eea5c..9216e8ec0 100644 --- a/stac_fastapi/core/stac_fastapi/core/extensions/__init__.py +++ b/stac_fastapi/core/stac_fastapi/core/extensions/__init__.py @@ -1,5 +1,11 @@ """elasticsearch extensions modifications.""" +from .collections_search import CollectionsSearchEndpointExtension from .query import Operator, QueryableTypes, QueryExtension -__all__ = ["Operator", "QueryableTypes", "QueryExtension"] +__all__ = [ + "Operator", + "QueryableTypes", + "QueryExtension", + "CollectionsSearchEndpointExtension", +] diff --git a/stac_fastapi/core/stac_fastapi/core/extensions/collections_search.py b/stac_fastapi/core/stac_fastapi/core/extensions/collections_search.py new file mode 100644 index 000000000..0ddbefeda --- /dev/null +++ b/stac_fastapi/core/stac_fastapi/core/extensions/collections_search.py @@ -0,0 +1,190 @@ +"""Collections search extension.""" + +from typing import List, Optional, Type, Union + +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from stac_pydantic.api.search import ExtendedSearch +from starlette.responses import Response + +from stac_fastapi.api.models import APIRequest +from stac_fastapi.types.core import BaseCoreClient +from stac_fastapi.types.extension import ApiExtension +from stac_fastapi.types.stac import Collections + + +class CollectionsSearchRequest(ExtendedSearch): + """Extended search model for collections with free text search support.""" + + q: Optional[Union[str, List[str]]] = None + + +class CollectionsSearchEndpointExtension(ApiExtension): + """Collections search endpoint extension. + + This extension adds a dedicated /collections-search endpoint for collection search operations. + """ + + def __init__( + self, + client: Optional[BaseCoreClient] = None, + settings: dict = None, + GET: Optional[Type[Union[BaseModel, APIRequest]]] = None, + POST: Optional[Type[Union[BaseModel, APIRequest]]] = None, + conformance_classes: Optional[List[str]] = None, + ): + """Initialize the extension. + + Args: + client: Optional BaseCoreClient instance to use for this extension. + settings: Dictionary of settings to pass to the extension. + GET: Optional GET request model. + POST: Optional POST request model. + conformance_classes: Optional list of conformance classes to add to the API. + """ + super().__init__() + self.client = client + self.settings = settings or {} + self.GET = GET + self.POST = POST + self.conformance_classes = conformance_classes or [] + self.router = APIRouter() + self.create_endpoints() + + def register(self, app: FastAPI) -> None: + """Register the extension with a FastAPI application. + + Args: + app: target FastAPI application. + + Returns: + None + """ + app.include_router(self.router) + + def create_endpoints(self) -> None: + """Create endpoints for the extension.""" + if self.GET: + self.router.add_api_route( + name="Get Collections Search", + path="/collections-search", + response_model=None, + response_class=JSONResponse, + methods=["GET"], + endpoint=self.collections_search_get_endpoint, + **(self.settings if isinstance(self.settings, dict) else {}), + ) + + if self.POST: + self.router.add_api_route( + name="Post Collections Search", + path="/collections-search", + response_model=None, + response_class=JSONResponse, + methods=["POST"], + endpoint=self.collections_search_post_endpoint, + **(self.settings if isinstance(self.settings, dict) else {}), + ) + + async def collections_search_get_endpoint( + self, request: Request + ) -> Union[Collections, Response]: + """GET /collections-search endpoint. + + Args: + request: Request object. + + Returns: + Collections: Collections object. + """ + # Extract query parameters from the request + params = dict(request.query_params) + + # Convert query parameters to appropriate types + if "limit" in params: + try: + params["limit"] = int(params["limit"]) + except ValueError: + pass + + # Handle fields parameter + if "fields" in params: + fields_str = params.pop("fields") + fields = fields_str.split(",") + params["fields"] = fields + + # Handle sortby parameter + if "sortby" in params: + sortby_str = params.pop("sortby") + sortby = sortby_str.split(",") + params["sortby"] = sortby + + collections = await self.client.all_collections(request=request, **params) + return collections + + async def collections_search_post_endpoint( + self, request: Request, body: dict + ) -> Union[Collections, Response]: + """POST /collections-search endpoint. + + Args: + request: Request object. + body: Search request body. + + Returns: + Collections: Collections object. + """ + # Convert the dict to an ExtendedSearch model + search_request = CollectionsSearchRequest.model_validate(body) + + # Check if fields are present in the body + if "fields" in body: + # Extract fields from body and add them to search_request + if hasattr(search_request, "field"): + from stac_pydantic.api.extensions.fields import FieldsExtension + + fields_data = body["fields"] + search_request.field = FieldsExtension( + includes=fields_data.get("include"), + excludes=fields_data.get("exclude"), + ) + + # Set the postbody on the request for pagination links + request.postbody = body + + collections = await self.client.post_all_collections( + search_request=search_request, request=request + ) + + return collections + + @classmethod + def from_extensions( + cls, extensions: List[ApiExtension] + ) -> "CollectionsSearchEndpointExtension": + """Create a CollectionsSearchEndpointExtension from a list of extensions. + + Args: + extensions: List of extensions to include in the CollectionsSearchEndpointExtension. + + Returns: + CollectionsSearchEndpointExtension: A new CollectionsSearchEndpointExtension instance. + """ + from stac_fastapi.api.models import ( + create_get_request_model, + create_post_request_model, + ) + + get_model = create_get_request_model(extensions) + post_model = create_post_request_model(extensions) + + return cls( + GET=get_model, + POST=post_model, + conformance_classes=[ + ext.conformance_classes + for ext in extensions + if hasattr(ext, "conformance_classes") + ], + ) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py index 8cc32088c..6012c1906 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py @@ -23,6 +23,9 @@ EsAggregationExtensionGetRequest, EsAggregationExtensionPostRequest, ) +from stac_fastapi.core.extensions.collections_search import ( + CollectionsSearchEndpointExtension, +) from stac_fastapi.core.extensions.fields import FieldsExtension from stac_fastapi.core.rate_limit import setup_rate_limit from stac_fastapi.core.route_dependencies import get_route_dependencies @@ -38,6 +41,7 @@ AggregationExtension, CollectionSearchExtension, CollectionSearchFilterExtension, + CollectionSearchPostExtension, FilterExtension, FreeTextExtension, SortExtension, @@ -58,8 +62,14 @@ TRANSACTIONS_EXTENSIONS = get_bool_env("ENABLE_TRANSACTIONS_EXTENSIONS", default=True) ENABLE_COLLECTIONS_SEARCH = get_bool_env("ENABLE_COLLECTIONS_SEARCH", default=True) +ENABLE_COLLECTIONS_SEARCH_ROUTE = get_bool_env( + "ENABLE_COLLECTIONS_SEARCH_ROUTE", default=False +) logger.info("TRANSACTIONS_EXTENSIONS is set to %s", TRANSACTIONS_EXTENSIONS) logger.info("ENABLE_COLLECTIONS_SEARCH is set to %s", ENABLE_COLLECTIONS_SEARCH) +logger.info( + "ENABLE_COLLECTIONS_SEARCH_ROUTE is set to %s", ENABLE_COLLECTIONS_SEARCH_ROUTE +) settings = ElasticsearchSettings() session = Session.create_from_settings(settings) @@ -117,8 +127,10 @@ extensions = [aggregation_extension] + search_extensions -# Create collection search extensions if enabled -if ENABLE_COLLECTIONS_SEARCH: +# Collection search related variables +collections_get_request_model = None + +if ENABLE_COLLECTIONS_SEARCH or ENABLE_COLLECTIONS_SEARCH_ROUTE: # Create collection search extensions collection_search_extensions = [ QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]), @@ -136,7 +148,58 @@ ) collections_get_request_model = collection_search_ext.GET + # Create a post request model for collection search + collection_search_post_request_model = create_post_request_model( + collection_search_extensions + ) + +# Create collection search extensions if enabled +if ENABLE_COLLECTIONS_SEARCH: + # Initialize collection search POST extension + collection_search_post_ext = CollectionSearchPostExtension( + client=CoreClient( + database=database_logic, + session=session, + post_request_model=collection_search_post_request_model, + landing_page_id=os.getenv("STAC_FASTAPI_LANDING_PAGE_ID", "stac-fastapi"), + ), + settings=settings, + POST=collection_search_post_request_model, + conformance_classes=[ + "https://api.stacspec.org/v1.0.0-rc.1/collection-search", + QueryConformanceClasses.COLLECTIONS, + FilterConformanceClasses.COLLECTIONS, + FreeTextConformanceClasses.COLLECTIONS, + SortConformanceClasses.COLLECTIONS, + FieldsConformanceClasses.COLLECTIONS, + ], + ) extensions.append(collection_search_ext) + extensions.append(collection_search_post_ext) + +if ENABLE_COLLECTIONS_SEARCH_ROUTE: + # Initialize collections-search endpoint extension + collections_search_endpoint_ext = CollectionsSearchEndpointExtension( + client=CoreClient( + database=database_logic, + session=session, + post_request_model=collection_search_post_request_model, + landing_page_id=os.getenv("STAC_FASTAPI_LANDING_PAGE_ID", "stac-fastapi"), + ), + settings=settings, + GET=collections_get_request_model, + POST=collection_search_post_request_model, + conformance_classes=[ + "https://api.stacspec.org/v1.0.0-rc.1/collection-search", + QueryConformanceClasses.COLLECTIONS, + FilterConformanceClasses.COLLECTIONS, + FreeTextConformanceClasses.COLLECTIONS, + SortConformanceClasses.COLLECTIONS, + FieldsConformanceClasses.COLLECTIONS, + ], + ) + extensions.append(collections_search_endpoint_ext) + database_logic.extensions = [type(ext).__name__ for ext in extensions] @@ -176,8 +239,8 @@ "route_dependencies": get_route_dependencies(), } -# Add collections_get_request_model if collection search is enabled -if ENABLE_COLLECTIONS_SEARCH: +# Add collections_get_request_model if it was created +if collections_get_request_model: app_config["collections_get_request_model"] = collections_get_request_model api = StacApi(**app_config) @@ -194,6 +257,7 @@ async def lifespan(app: FastAPI): app = api.app app.router.lifespan_context = lifespan app.root_path = os.getenv("STAC_FASTAPI_ROOT_PATH", "") + # Add rate limit setup_rate_limit(app, rate_limit=os.getenv("STAC_FASTAPI_RATE_LIMIT")) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index a7893dc8b..f4f33cb97 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -179,7 +179,7 @@ async def get_all_collections( filter: Optional[Dict[str, Any]] = None, query: Optional[Dict[str, Dict[str, Any]]] = None, datetime: Optional[str] = None, - ) -> Tuple[List[Dict[str, Any]], Optional[str]]: + ) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[int]]: """Retrieve a list of collections from Elasticsearch, supporting pagination. Args: @@ -228,8 +228,21 @@ async def get_all_collections( "size": limit, } + # Handle search_after token - split by '|' to get all sort values + search_after = None if token: - body["search_after"] = [token] + try: + # The token should be a pipe-separated string of sort values + # e.g., "2023-01-01T00:00:00Z|collection-1" + search_after = token.split("|") + # If the number of sort fields doesn't match token parts, ignore the token + if len(search_after) != len(formatted_sort): + search_after = None + except Exception: + search_after = None + + if search_after is not None: + body["search_after"] = search_after # Build the query part of the body query_parts = [] @@ -317,12 +330,30 @@ async def get_all_collections( else {"bool": {"must": query_parts}} ) - # Execute the search - response = await self.client.search( - index=COLLECTIONS_INDEX, - body=body, + # Create a copy of the body for count query (without pagination and sorting) + count_body = body.copy() + if "search_after" in count_body: + del count_body["search_after"] + count_body["size"] = 0 + + # Create async tasks for both search and count + search_task = asyncio.create_task( + self.client.search( + index=COLLECTIONS_INDEX, + body=body, + ) + ) + + count_task = asyncio.create_task( + self.client.count( + index=COLLECTIONS_INDEX, + body={"query": body.get("query", {"match_all": {}})}, + ) ) + # Wait for search task to complete + response = await search_task + hits = response["hits"]["hits"] collections = [ self.collection_serializer.db_to_stac( @@ -335,9 +366,25 @@ async def get_all_collections( if len(hits) == limit: next_token_values = hits[-1].get("sort") if next_token_values: - next_token = next_token_values[0] + # Join all sort values with '|' to create the token + next_token = "|".join(str(val) for val in next_token_values) + + # Get the total count of collections + matched = ( + response["hits"]["total"]["value"] + if response["hits"]["total"]["relation"] == "eq" + else None + ) + + # If count task is done, use its result + if count_task.done(): + try: + matched = count_task.result().get("count") + except Exception as e: + logger = logging.getLogger(__name__) + logger.error(f"Count task failed: {e}") - return collections, next_token + return collections, next_token, matched @staticmethod def _apply_collection_datetime_filter( diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py index 56f717a34..181d8a7aa 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py @@ -23,6 +23,9 @@ EsAggregationExtensionGetRequest, EsAggregationExtensionPostRequest, ) +from stac_fastapi.core.extensions.collections_search import ( + CollectionsSearchEndpointExtension, +) from stac_fastapi.core.extensions.fields import FieldsExtension from stac_fastapi.core.rate_limit import setup_rate_limit from stac_fastapi.core.route_dependencies import get_route_dependencies @@ -32,6 +35,7 @@ AggregationExtension, CollectionSearchExtension, CollectionSearchFilterExtension, + CollectionSearchPostExtension, FilterExtension, FreeTextExtension, SortExtension, @@ -58,8 +62,14 @@ TRANSACTIONS_EXTENSIONS = get_bool_env("ENABLE_TRANSACTIONS_EXTENSIONS", default=True) ENABLE_COLLECTIONS_SEARCH = get_bool_env("ENABLE_COLLECTIONS_SEARCH", default=True) +ENABLE_COLLECTIONS_SEARCH_ROUTE = get_bool_env( + "ENABLE_COLLECTIONS_SEARCH_ROUTE", default=False +) logger.info("TRANSACTIONS_EXTENSIONS is set to %s", TRANSACTIONS_EXTENSIONS) logger.info("ENABLE_COLLECTIONS_SEARCH is set to %s", ENABLE_COLLECTIONS_SEARCH) +logger.info( + "ENABLE_COLLECTIONS_SEARCH_ROUTE is set to %s", ENABLE_COLLECTIONS_SEARCH_ROUTE +) settings = OpensearchSettings() session = Session.create_from_settings(settings) @@ -117,8 +127,10 @@ extensions = [aggregation_extension] + search_extensions -# Create collection search extensions if enabled -if ENABLE_COLLECTIONS_SEARCH: +# Collection search related variables +collections_get_request_model = None + +if ENABLE_COLLECTIONS_SEARCH or ENABLE_COLLECTIONS_SEARCH_ROUTE: # Create collection search extensions collection_search_extensions = [ QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]), @@ -136,7 +148,58 @@ ) collections_get_request_model = collection_search_ext.GET + # Create a post request model for collection search + collection_search_post_request_model = create_post_request_model( + collection_search_extensions + ) + +# Create collection search extensions if enabled +if ENABLE_COLLECTIONS_SEARCH: + # Initialize collection search POST extension + collection_search_post_ext = CollectionSearchPostExtension( + client=CoreClient( + database=database_logic, + session=session, + post_request_model=collection_search_post_request_model, + landing_page_id=os.getenv("STAC_FASTAPI_LANDING_PAGE_ID", "stac-fastapi"), + ), + settings=settings, + POST=collection_search_post_request_model, + conformance_classes=[ + "https://api.stacspec.org/v1.0.0-rc.1/collection-search", + QueryConformanceClasses.COLLECTIONS, + FilterConformanceClasses.COLLECTIONS, + FreeTextConformanceClasses.COLLECTIONS, + SortConformanceClasses.COLLECTIONS, + FieldsConformanceClasses.COLLECTIONS, + ], + ) extensions.append(collection_search_ext) + extensions.append(collection_search_post_ext) + +if ENABLE_COLLECTIONS_SEARCH_ROUTE: + # Initialize collections-search endpoint extension + collections_search_endpoint_ext = CollectionsSearchEndpointExtension( + client=CoreClient( + database=database_logic, + session=session, + post_request_model=collection_search_post_request_model, + landing_page_id=os.getenv("STAC_FASTAPI_LANDING_PAGE_ID", "stac-fastapi"), + ), + settings=settings, + GET=collections_get_request_model, + POST=collection_search_post_request_model, + conformance_classes=[ + "https://api.stacspec.org/v1.0.0-rc.1/collection-search", + QueryConformanceClasses.COLLECTIONS, + FilterConformanceClasses.COLLECTIONS, + FreeTextConformanceClasses.COLLECTIONS, + SortConformanceClasses.COLLECTIONS, + FieldsConformanceClasses.COLLECTIONS, + ], + ) + extensions.append(collections_search_endpoint_ext) + database_logic.extensions = [type(ext).__name__ for ext in extensions] @@ -177,8 +240,8 @@ "route_dependencies": get_route_dependencies(), } -# Add collections_get_request_model if collection search is enabled -if ENABLE_COLLECTIONS_SEARCH: +# Add collections_get_request_model if it was created +if collections_get_request_model: app_config["collections_get_request_model"] = collections_get_request_model api = StacApi(**app_config) diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index 694d6cfae..8791390bb 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -163,7 +163,7 @@ async def get_all_collections( filter: Optional[Dict[str, Any]] = None, query: Optional[Dict[str, Dict[str, Any]]] = None, datetime: Optional[str] = None, - ) -> Tuple[List[Dict[str, Any]], Optional[str]]: + ) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[int]]: """Retrieve a list of collections from OpenSearch, supporting pagination. Args: @@ -172,8 +172,8 @@ async def get_all_collections( request (Request): The FastAPI request object. sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request. q (Optional[List[str]]): Free text search terms. - filter (Optional[Dict[str, Any]]): Structured filter in CQL2 format. query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters. + filter (Optional[Dict[str, Any]]): Structured query in CQL2 format. datetime (Optional[str]): Temporal filter. Returns: @@ -212,8 +212,21 @@ async def get_all_collections( "size": limit, } + # Handle search_after token - split by '|' to get all sort values + search_after = None if token: - body["search_after"] = [token] + try: + # The token should be a pipe-separated string of sort values + # e.g., "2023-01-01T00:00:00Z|collection-1" + search_after = token.split("|") + # If the number of sort fields doesn't match token parts, ignore the token + if len(search_after) != len(formatted_sort): + search_after = None + except Exception: + search_after = None + + if search_after is not None: + body["search_after"] = search_after # Build the query part of the body query_parts = [] @@ -278,6 +291,7 @@ async def get_all_collections( search_dict = search.to_dict() if "query" in search_dict: query_parts.append(search_dict["query"]) + except Exception as e: logger = logging.getLogger(__name__) logger.error(f"Error converting query to OpenSearch: {e}") @@ -285,6 +299,7 @@ async def get_all_collections( query_parts.append({"bool": {"must_not": {"match_all": {}}}}) raise + # Combine all query parts with AND logic if there are multiple datetime_filter = None if datetime: datetime_filter = self._apply_collection_datetime_filter(datetime) @@ -299,12 +314,30 @@ async def get_all_collections( else {"bool": {"must": query_parts}} ) - # Execute the search - response = await self.client.search( - index=COLLECTIONS_INDEX, - body=body, + # Create a copy of the body for count query (without pagination and sorting) + count_body = body.copy() + if "search_after" in count_body: + del count_body["search_after"] + count_body["size"] = 0 + + # Create async tasks for both search and count + search_task = asyncio.create_task( + self.client.search( + index=COLLECTIONS_INDEX, + body=body, + ) ) + count_task = asyncio.create_task( + self.client.count( + index=COLLECTIONS_INDEX, + body={"query": body.get("query", {"match_all": {}})}, + ) + ) + + # Wait for search task to complete + response = await search_task + hits = response["hits"]["hits"] collections = [ self.collection_serializer.db_to_stac( @@ -317,9 +350,25 @@ async def get_all_collections( if len(hits) == limit: next_token_values = hits[-1].get("sort") if next_token_values: - next_token = next_token_values[0] + # Join all sort values with '|' to create the token + next_token = "|".join(str(val) for val in next_token_values) + + # Get the total count of collections + matched = ( + response["hits"]["total"]["value"] + if response["hits"]["total"]["relation"] == "eq" + else None + ) + + # If count task is done, use its result + if count_task.done(): + try: + matched = count_task.result().get("count") + except Exception as e: + logger = logging.getLogger(__name__) + logger.error(f"Count task failed: {e}") - return collections, next_token + return collections, next_token, matched async def get_one_item(self, collection_id: str, item_id: str) -> Dict: """Retrieve a single item from the database. diff --git a/stac_fastapi/tests/api/test_api.py b/stac_fastapi/tests/api/test_api.py index e74ab5600..6fdc2fb60 100644 --- a/stac_fastapi/tests/api/test_api.py +++ b/stac_fastapi/tests/api/test_api.py @@ -48,6 +48,8 @@ "GET /collections/{collection_id}/aggregate", "POST /collections/{collection_id}/aggregations", "POST /collections/{collection_id}/aggregate", + "GET /collections-search", + "POST /collections-search", } diff --git a/stac_fastapi/tests/api/test_api_search_collections.py b/stac_fastapi/tests/api/test_api_search_collections.py index 668ba0603..8f5bed73b 100644 --- a/stac_fastapi/tests/api/test_api_search_collections.py +++ b/stac_fastapi/tests/api/test_api_search_collections.py @@ -85,8 +85,8 @@ async def test_collections_sort_id_desc(app_client, txn_client, ctx): @pytest.mark.asyncio -async def test_collections_fields(app_client, txn_client, ctx): - """Verify GET /collections honors the fields parameter.""" +async def test_collections_fields_all_endpoints(app_client, txn_client, ctx): + """Verify GET /collections, GET /collections-search, and POST /collections-search honor the fields parameter.""" # Create multiple collections with different ids base_collection = ctx.collection @@ -104,136 +104,143 @@ async def test_collections_fields(app_client, txn_client, ctx): await refresh_indices(txn_client) - # Test include fields parameter - resp = await app_client.get( - "/collections", - params=[("fields", "id"), ("fields", "title")], - ) - assert resp.status_code == 200 - resp_json = resp.json() + # Define endpoints to test + endpoints = [ + {"method": "GET", "path": "/collections", "params": [("fields", "id,title")]}, + { + "method": "GET", + "path": "/collections-search", + "params": [("fields", "id,title")], + }, + { + "method": "POST", + "path": "/collections-search", + "body": {"fields": {"include": ["id", "title"]}}, + }, + ] - # Check if collections exist in the response - assert "collections" in resp_json, "No collections in response" + for endpoint in endpoints: + if endpoint["method"] == "GET": + resp = await app_client.get(endpoint["path"], params=endpoint["params"]) + else: # POST + resp = await app_client.post(endpoint["path"], json=endpoint["body"]) - # Filter collections to only include the ones we created for this test - test_collections = [] - for c in resp_json["collections"]: - if "id" in c and c["id"].startswith(test_prefix): - test_collections.append(c) + assert resp.status_code == 200 + resp_json = resp.json() - # Filter collections to only include the ones we created for this test - test_collections = [] - for c in resp_json["collections"]: - if "id" in c and c["id"].startswith(test_prefix): - test_collections.append(c) + collections_list = resp_json["collections"] - # Collections should only have id and title fields - for collection in test_collections: - assert "id" in collection - assert "title" in collection - assert "description" not in collection - assert "links" in collection # links are always included + # Filter collections to only include the ones we created for this test + test_collections = [ + c for c in collections_list if c["id"].startswith(test_prefix) + ] + + # Collections should only have id and title fields + for collection in test_collections: + assert "id" in collection + assert "title" in collection + assert "description" not in collection # Test exclude fields parameter - resp = await app_client.get( - "/collections", - params=[("fields", "-description")], - ) - assert resp.status_code == 200 - resp_json = resp.json() + endpoints = [ + { + "method": "GET", + "path": "/collections", + "params": [("fields", "-description")], + }, + { + "method": "GET", + "path": "/collections-search", + "params": [("fields", "-description")], + }, + { + "method": "POST", + "path": "/collections-search", + "body": {"fields": {"exclude": ["description"]}}, + }, + ] - # Check if collections exist in the response - assert ( - "collections" in resp_json - ), "No collections in response for exclude fields test" + for endpoint in endpoints: + if endpoint["method"] == "GET": + resp = await app_client.get(endpoint["path"], params=endpoint["params"]) + else: # POST + resp = await app_client.post(endpoint["path"], json=endpoint["body"]) - # Filter collections to only include the ones we created for this test - test_collections = [] - for c in resp_json["collections"]: - if "id" in c and c["id"].startswith(test_prefix): - test_collections.append(c) + assert resp.status_code == 200 + resp_json = resp.json() - # Collections should have all fields except description - for collection in test_collections: - assert "id" in collection - assert "title" in collection - assert "description" not in collection - assert "links" in collection + collections_list = resp_json["collections"] + + # Filter collections to only include the ones we created for this test + test_collections = [ + c for c in collections_list if c["id"].startswith(test_prefix) + ] + + # Collections should have all fields except description + for collection in test_collections: + assert "id" in collection + assert "title" in collection + assert "description" not in collection + assert "links" in collection # links are always included @pytest.mark.asyncio -async def test_collections_free_text_search_get(app_client, txn_client, ctx): - """Verify GET /collections honors the q parameter for free text search.""" - # Create multiple collections with different content +async def test_collections_free_text_all_endpoints( + app_client, txn_client, ctx, monkeypatch +): + """Test free text search across all collection endpoints.""" + # Create test data + test_prefix = f"free-text-{uuid.uuid4().hex[:8]}" base_collection = ctx.collection + search_term = "SEARCHABLETERM" - # Use unique prefixes to avoid conflicts between tests - test_prefix = f"q-get-{uuid.uuid4().hex[:8]}" + monkeypatch.setenv("ENABLE_COLLECTIONS_SEARCH_ROUTE", "true") - test_collections = [ - { - "id": f"{test_prefix}-sentinel", - "title": "Sentinel-2 Collection", - "description": "Collection of Sentinel-2 data", - "summaries": {"platform": ["sentinel-2a", "sentinel-2b"]}, - }, - { - "id": f"{test_prefix}-landsat", - "title": "Landsat Collection", - "description": "Collection of Landsat data", - "summaries": {"platform": ["landsat-8", "landsat-9"]}, - }, - { - "id": f"{test_prefix}-modis", - "title": "MODIS Collection", - "description": "Collection of MODIS data", - "summaries": {"platform": ["terra", "aqua"]}, - }, - ] + # Create test collections + target_collection = base_collection.copy() + target_collection["id"] = f"{test_prefix}-target" + target_collection["title"] = f"Collection with {search_term} in title" + await create_collection(txn_client, target_collection) - for i, coll in enumerate(test_collections): - test_collection = base_collection.copy() - test_collection["id"] = coll["id"] - test_collection["title"] = coll["title"] - test_collection["description"] = coll["description"] - test_collection["summaries"] = coll["summaries"] - await create_collection(txn_client, test_collection) + decoy_collection = base_collection.copy() + decoy_collection["id"] = f"{test_prefix}-decoy" + decoy_collection["title"] = "Collection without the term" + await create_collection(txn_client, decoy_collection) await refresh_indices(txn_client) - # Test free text search for "sentinel" - resp = await app_client.get( - "/collections", - params=[("q", "sentinel")], - ) - assert resp.status_code == 200 - resp_json = resp.json() - - # Filter collections to only include the ones we created for this test - found_collections = [ - c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + # Define endpoints to test + endpoints = [ + {"method": "GET", "path": "/collections", "param": "q"}, + {"method": "GET", "path": "/collections-search", "param": "q"}, + {"method": "POST", "path": "/collections-search", "body_key": "q"}, ] - # Should only find the sentinel collection - assert len(found_collections) == 1 - assert found_collections[0]["id"] == f"{test_prefix}-sentinel" + for endpoint in endpoints: + print(f"Testing free text search on {endpoint['method']} {endpoint['path']}") - # Test free text search for "landsat" - resp = await app_client.get( - "/collections", - params=[("q", "modis")], - ) - assert resp.status_code == 200 - resp_json = resp.json() + if endpoint["method"] == "GET": + params = [(endpoint["param"], search_term)] + resp = await app_client.get(endpoint["path"], params=params) + else: # POST + body = {endpoint["body_key"]: search_term} + resp = await app_client.post(endpoint["path"], json=body) - # Filter collections to only include the ones we created for this test - found_collections = [ - c for c in resp_json["collections"] if c["id"].startswith(test_prefix) - ] + assert ( + resp.status_code == 200 + ), f"Failed for {endpoint['method']} {endpoint['path']} with status {resp.status_code}" + resp_json = resp.json() - # Should only find the landsat collection - assert len(found_collections) == 1 - assert found_collections[0]["id"] == f"{test_prefix}-modis" + collections = resp_json["collections"] + + # Filter to our test collections + found = [c for c in collections if c["id"].startswith(test_prefix)] + assert ( + len(found) == 1 + ), f"Expected 1 collection, found {len(found)} for {endpoint['method']} {endpoint['path']}" + assert ( + found[0]["id"] == target_collection["id"] + ), f"Expected {target_collection['id']}, found {found[0]['id']} for {endpoint['method']} {endpoint['path']}" @pytest.mark.asyncio @@ -409,22 +416,12 @@ async def test_collections_query_extension(app_client, txn_client, ctx): # Test query extension with not-equal operator on ID query = {"id": {"neq": f"{test_prefix}-sentinel"}} - print(f"\nTesting neq query: {query}") - print(f"JSON query: {json.dumps(query)}") - resp = await app_client.get( "/collections", params=[("query", json.dumps(query))], ) - print(f"Response status: {resp.status_code}") assert resp.status_code == 200 resp_json = resp.json() - print(f"Response JSON keys: {resp_json.keys()}") - print(f"Number of collections in response: {len(resp_json.get('collections', []))}") - - # Print all collection IDs in the response - all_ids = [c["id"] for c in resp_json.get("collections", [])] - print(f"All collection IDs in response: {all_ids}") # Filter collections to only include the ones we created for this test found_collections = [ @@ -439,6 +436,7 @@ async def test_collections_query_extension(app_client, txn_client, ctx): assert f"{test_prefix}-modis" in found_ids +@pytest.mark.asyncio async def test_collections_datetime_filter(app_client, load_test_data, txn_client): """Test filtering collections by datetime.""" # Create a test collection with a specific temporal extent @@ -526,3 +524,336 @@ async def test_collections_datetime_filter(app_client, load_test_data, txn_clien found_collections = [c for c in resp_json["collections"] if c["id"] == test_collection_id] assert len(found_collections) == 1, f"Expected to find collection {test_collection_id} with open-ended past range to a date within its range" """ + + +@pytest.mark.asyncio +async def test_collections_number_matched_returned(app_client, txn_client, ctx): + """Verify GET /collections returns correct numberMatched and numberReturned values.""" + # Create multiple collections with different ids + base_collection = ctx.collection + + # Create collections with ids in a specific order to test pagination + # Use unique prefixes to avoid conflicts between tests + test_prefix = f"count-{uuid.uuid4().hex[:8]}" + collection_ids = [f"{test_prefix}-{i}" for i in range(10)] + + for i, coll_id in enumerate(collection_ids): + test_collection = base_collection.copy() + test_collection["id"] = coll_id + test_collection["title"] = f"Test Collection {i}" + await create_collection(txn_client, test_collection) + + await refresh_indices(txn_client) + + # Test with limit=5 + resp = await app_client.get( + "/collections", + params=[("limit", "5")], + ) + assert resp.status_code == 200 + resp_json = resp.json() + + # Filter collections to only include the ones we created for this test + test_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + + # Should return 5 collections + assert len(test_collections) == 5 + + # Check that numberReturned matches the number of collections returned + assert resp_json["numberReturned"] == len(resp_json["collections"]) + + # Check that numberMatched is greater than or equal to numberReturned + # (since there might be other collections in the database) + assert resp_json["numberMatched"] >= resp_json["numberReturned"] + + # Check that numberMatched includes at least all our test collections + assert resp_json["numberMatched"] >= len(collection_ids) + + # Now test with a query that should match only some collections + query = {"id": {"eq": f"{test_prefix}-1"}} + resp = await app_client.get( + "/collections", + params=[("query", json.dumps(query))], + ) + assert resp.status_code == 200 + resp_json = resp.json() + + # Filter collections to only include the ones we created for this test + test_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + + # Should return only 1 collection + assert len(test_collections) == 1 + assert test_collections[0]["id"] == f"{test_prefix}-1" + + # Check that numberReturned matches the number of collections returned + assert resp_json["numberReturned"] == len(resp_json["collections"]) + + # Check that numberMatched matches the number of collections that match the query + # (should be 1 in this case) + assert resp_json["numberMatched"] >= 1 + + +@pytest.mark.asyncio +async def test_collections_post(app_client, txn_client, ctx): + """Verify POST /collections-search endpoint works.""" + + # Create multiple collections with different ids + base_collection = ctx.collection + + # Create collections with ids in a specific order to test search + # Use unique prefixes to avoid conflicts between tests + test_prefix = f"post-{uuid.uuid4().hex[:8]}" + collection_ids = [f"{test_prefix}-{i}" for i in range(10)] + + for i, coll_id in enumerate(collection_ids): + test_collection = base_collection.copy() + test_collection["id"] = coll_id + test_collection["title"] = f"Test Collection {i}" + await create_collection(txn_client, test_collection) + + await refresh_indices(txn_client) + + # Test basic POST search + resp = await app_client.post( + "/collections-search", + json={"limit": 5}, + ) + assert resp.status_code == 200 + resp_json = resp.json() + + # Filter collections to only include the ones we created for this test + test_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + + # Should return 5 collections + assert len(test_collections) == 5 + + # Check that numberReturned matches the number of collections returned + assert resp_json["numberReturned"] == len(resp_json["collections"]) + + # Check that numberMatched is greater than or equal to numberReturned + assert resp_json["numberMatched"] >= resp_json["numberReturned"] + + # Test POST search with sortby + resp = await app_client.post( + "/collections-search", + json={"sortby": [{"field": "id", "direction": "desc"}]}, + ) + assert resp.status_code == 200 + resp_json = resp.json() + + # Filter collections to only include the ones we created for this test + test_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + + # Check that collections are sorted by id in descending order + if len(test_collections) >= 2: + assert test_collections[0]["id"] > test_collections[1]["id"] + + # Check that numberReturned matches the number of collections returned + assert resp_json["numberReturned"] == len(resp_json["collections"]) + + # Test POST search with fields + resp = await app_client.post( + "/collections-search", + json={"fields": {"exclude": ["stac_version"]}}, + ) + assert resp.status_code == 200 + resp_json = resp.json() + + # Filter collections to only include the ones we created for this test + test_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + + # Check that stac_version is excluded from the collections + for collection in test_collections: + assert "stac_version" not in collection + + +@pytest.mark.asyncio +async def test_collections_search_cql2_text(app_client, txn_client, ctx): + """Test collections search with CQL2-text filter.""" + # Create a unique prefix for test collections + test_prefix = f"test-{uuid.uuid4()}" + + # Create test collections + collection_data = ctx.collection.copy() + collection_data["id"] = f"{test_prefix}-collection" + await create_collection(txn_client, collection_data) + await refresh_indices(txn_client) + + # Test GET search with CQL2-text filter + collection_id = collection_data["id"] + resp = await app_client.get( + f"/collections-search?filter-lang=cql2-text&filter=id='{collection_id}'" + ) + assert resp.status_code == 200 + resp_json = resp.json() + + # Filter collections to only include the ones with our test prefix + filtered_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + + # Check that only the filtered collection is returned + assert len(filtered_collections) == 1 + assert filtered_collections[0]["id"] == collection_id + + # Test GET search with more complex CQL2-text filter (LIKE operator) + test_prefix_escaped = test_prefix.replace("-", "\\-") + resp = await app_client.get( + f"/collections-search?filter-lang=cql2-text&filter=id LIKE '{test_prefix_escaped}%'" + ) + assert resp.status_code == 200 + resp_json = resp.json() + + # Filter collections to only include the ones with our test prefix + filtered_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + + # Check that all test collections are returned + assert ( + len(filtered_collections) == 1 + ) # We only created one collection with this prefix + assert filtered_collections[0]["id"] == collection_id + + +@pytest.mark.asyncio +async def test_collections_pagination_all_endpoints(app_client, txn_client, ctx): + """Test pagination works correctly across all collection endpoints.""" + # Create test data + test_prefix = f"pagination-{uuid.uuid4().hex[:8]}" + base_collection = ctx.collection + + # Create 10 test collections with predictable IDs for sorting + test_collections = [] + for i in range(10): + test_coll = base_collection.copy() + test_coll["id"] = f"{test_prefix}-{i:02d}" + test_coll["title"] = f"Test Collection {i}" + test_collections.append(test_coll) + await create_collection(txn_client, test_coll) + + await refresh_indices(txn_client) + + # Define endpoints to test + endpoints = [ + {"method": "GET", "path": "/collections", "param": "limit"}, + {"method": "GET", "path": "/collections-search", "param": "limit"}, + {"method": "POST", "path": "/collections-search", "body_key": "limit"}, + ] + + # Test pagination for each endpoint + for endpoint in endpoints: + # Test first page with limit=3 + limit = 3 + + # Make the request + if endpoint["method"] == "GET": + params = [(endpoint["param"], str(limit))] + resp = await app_client.get(endpoint["path"], params=params) + else: # POST + body = {endpoint["body_key"]: limit} + resp = await app_client.post(endpoint["path"], json=body) + + assert ( + resp.status_code == 200 + ), f"Failed for {endpoint['method']} {endpoint['path']}" + resp_json = resp.json() + + # # Filter to our test collections + # if endpoint["path"] == "/collections": + # found_collections = resp_json + # else: # For collection-search endpoints + found_collections = resp_json["collections"] + + test_found = [c for c in found_collections if c["id"].startswith(test_prefix)] + + # Should return exactly limit collections + assert ( + len(test_found) == limit + ), f"Expected {limit} collections, got {len(test_found)}" + + # Verify collections are in correct order (ascending by ID) + expected_ids = [f"{test_prefix}-{i:02d}" for i in range(limit)] + for i, expected_id in enumerate(expected_ids): + assert test_found[i]["id"] == expected_id + + # Test second page using the token from the first page + if "token" in resp_json and resp_json["token"]: + token = resp_json["token"] + + # Make the request with token + if endpoint["method"] == "GET": + params = [(endpoint["param"], str(limit)), ("token", token)] + resp = await app_client.get(endpoint["path"], params=params) + else: # POST + body = {endpoint["body_key"]: limit, "token": token} + resp = await app_client.post(endpoint["path"], json=body) + + assert ( + resp.status_code == 200 + ), f"Failed for {endpoint['method']} {endpoint['path']} with token" + resp_json = resp.json() + + # Filter to our test collections + if endpoint["path"] == "/collections": + found_collections = resp_json + else: # For collection-search endpoints + found_collections = resp_json["collections"] + + test_found = [ + c for c in found_collections if c["id"].startswith(test_prefix) + ] + + # Should return next set of collections + expected_ids = [f"{test_prefix}-{i:02d}" for i in range(limit, limit * 2)] + assert len(test_found) == min( + limit, len(expected_ids) + ), f"Expected {min(limit, len(expected_ids))} collections, got {len(test_found)}" + + # Verify collections are in correct order + for i, expected_id in enumerate(expected_ids[: len(test_found)]): + assert test_found[i]["id"] == expected_id + + # Test with sortby parameter to ensure token works with sorting + if endpoint["method"] == "GET": + params = [("sortby", "-id"), (endpoint["param"], str(limit))] + resp = await app_client.get(endpoint["path"], params=params) + else: # POST + body = { + "sortby": [{"field": "id", "direction": "desc"}], + endpoint["body_key"]: limit, + } + resp = await app_client.post(endpoint["path"], json=body) + + assert ( + resp.status_code == 200 + ), f"Failed for {endpoint['method']} {endpoint['path']} with sortby" + resp_json = resp.json() + + found_collections = resp_json["collections"] + + test_found = [c for c in found_collections if c["id"].startswith(test_prefix)] + + # Verify collections are sorted in descending order + # We expect the highest IDs first (09, 08, 07, etc.) + expected_ids = sorted( + [f"{test_prefix}-{i:02d}" for i in range(10)], reverse=True + )[:limit] + + # Filter expected_ids to only include collections that actually exist in the response + expected_ids = [ + id for id in expected_ids if any(c["id"] == id for c in found_collections) + ] + + for i, expected_id in enumerate(expected_ids): + assert test_found[i]["id"] == expected_id diff --git a/stac_fastapi/tests/conftest.py b/stac_fastapi/tests/conftest.py index 08e3277dc..b461e7221 100644 --- a/stac_fastapi/tests/conftest.py +++ b/stac_fastapi/tests/conftest.py @@ -25,8 +25,20 @@ ) from stac_fastapi.core.rate_limit import setup_rate_limit from stac_fastapi.core.utilities import get_bool_env +from stac_fastapi.extensions.core import ( + AggregationExtension, + FieldsExtension, + FilterExtension, + FreeTextExtension, + SortExtension, + TokenPaginationExtension, + TransactionExtension, +) from stac_fastapi.sfeos_helpers.aggregation import EsAsyncBaseAggregationClient from stac_fastapi.sfeos_helpers.mappings import ITEMS_INDEX_PREFIX +from stac_fastapi.types.config import Settings + +os.environ.setdefault("ENABLE_COLLECTIONS_SEARCH_ROUTE", "true") if os.getenv("BACKEND", "elasticsearch").lower() == "opensearch": from stac_fastapi.opensearch.app import app_config @@ -51,17 +63,6 @@ create_index_templates, ) -from stac_fastapi.extensions.core import ( - AggregationExtension, - FieldsExtension, - FilterExtension, - FreeTextExtension, - SortExtension, - TokenPaginationExtension, - TransactionExtension, -) -from stac_fastapi.types.config import Settings - DATA_DIR = os.path.join(os.path.dirname(__file__), "data") @@ -315,7 +316,6 @@ def must_be_bob( @pytest_asyncio.fixture(scope="session") async def route_dependencies_app(): """Fixture to get the FastAPI app with custom route dependencies.""" - # Create a copy of the app config test_config = app_config.copy()