diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index e852ed774..85cf846a7 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -17,7 +17,7 @@ from stac_fastapi.core.base_database_logic import BaseDatabaseLogic from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer -from stac_fastapi.core.utilities import bbox2polygon, get_max_limit +from stac_fastapi.core.utilities import bbox2polygon, get_bool_env, get_max_limit from stac_fastapi.elasticsearch.config import AsyncElasticsearchSettings from stac_fastapi.elasticsearch.config import ( ElasticsearchSettings as SyncElasticsearchSettings, @@ -289,26 +289,99 @@ def apply_datetime_filter( Returns: The filtered search object. """ + # USE_DATETIME env var + # True: Search by datetime, if null search by start/end datetime + # False: Always search only by start/end datetime + USE_DATETIME = get_bool_env("USE_DATETIME", default=True) + datetime_search = return_date(datetime) if not datetime_search: return search, datetime_search - if "eq" in datetime_search: - # For exact matches, include: - # 1. Items with matching exact datetime - # 2. Items with datetime:null where the time falls within their range - should = [ - Q( - "bool", - filter=[ - Q("exists", field="properties.datetime"), - Q("term", **{"properties__datetime": datetime_search["eq"]}), - ], - ), - Q( + if USE_DATETIME: + if "eq" in datetime_search: + # For exact matches, include: + # 1. Items with matching exact datetime + # 2. Items with datetime:null where the time falls within their range + should = [ + Q( + "bool", + filter=[ + Q("exists", field="properties.datetime"), + Q( + "term", + **{"properties__datetime": datetime_search["eq"]}, + ), + ], + ), + Q( + "bool", + must_not=[Q("exists", field="properties.datetime")], + filter=[ + Q("exists", field="properties.start_datetime"), + Q("exists", field="properties.end_datetime"), + Q( + "range", + properties__start_datetime={ + "lte": datetime_search["eq"] + }, + ), + Q( + "range", + properties__end_datetime={"gte": datetime_search["eq"]}, + ), + ], + ), + ] + else: + # For date ranges, include: + # 1. Items with datetime in the range + # 2. Items with datetime:null that overlap the search range + should = [ + Q( + "bool", + filter=[ + Q("exists", field="properties.datetime"), + Q( + "range", + properties__datetime={ + "gte": datetime_search["gte"], + "lte": datetime_search["lte"], + }, + ), + ], + ), + Q( + "bool", + must_not=[Q("exists", field="properties.datetime")], + filter=[ + Q("exists", field="properties.start_datetime"), + Q("exists", field="properties.end_datetime"), + Q( + "range", + properties__start_datetime={ + "lte": datetime_search["lte"] + }, + ), + Q( + "range", + properties__end_datetime={ + "gte": datetime_search["gte"] + }, + ), + ], + ), + ] + + return ( + search.query(Q("bool", should=should, minimum_should_match=1)), + datetime_search, + ) + else: + if "eq" in datetime_search: + filter_query = Q( "bool", - must_not=[Q("exists", field="properties.datetime")], filter=[ Q("exists", field="properties.start_datetime"), Q("exists", field="properties.end_datetime"), @@ -321,29 +394,10 @@ def apply_datetime_filter( properties__end_datetime={"gte": datetime_search["eq"]}, ), ], - ), - ] - else: - # For date ranges, include: - # 1. Items with datetime in the range - # 2. Items with datetime:null that overlap the search range - should = [ - Q( - "bool", - filter=[ - Q("exists", field="properties.datetime"), - Q( - "range", - properties__datetime={ - "gte": datetime_search["gte"], - "lte": datetime_search["lte"], - }, - ), - ], - ), - Q( + ) + else: + filter_query = Q( "bool", - must_not=[Q("exists", field="properties.datetime")], filter=[ Q("exists", field="properties.start_datetime"), Q("exists", field="properties.end_datetime"), @@ -356,13 +410,8 @@ def apply_datetime_filter( properties__end_datetime={"gte": datetime_search["gte"]}, ), ], - ), - ] - - return ( - search.query(Q("bool", should=should, minimum_should_match=1)), - datetime_search, - ) + ) + return search.query(filter_query), datetime_search @staticmethod def apply_bbox_filter(search: Search, bbox: List): diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index e54397bab..dca1ea077 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -7,25 +7,26 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Type import attr +import elasticsearch.helpers as helpers import orjson +from elasticsearch.dsl import Q, Search +from elasticsearch.exceptions import BadRequestError +from elasticsearch.exceptions import NotFoundError as ESNotFoundError from fastapi import HTTPException -from opensearchpy import exceptions, helpers -from opensearchpy.helpers.query import Q -from opensearchpy.helpers.search import Search from starlette.requests import Request from stac_fastapi.core.base_database_logic import BaseDatabaseLogic from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer -from stac_fastapi.core.utilities import bbox2polygon, get_max_limit +from stac_fastapi.core.utilities import bbox2polygon, get_bool_env, get_max_limit +from stac_fastapi.elasticsearch.config import AsyncElasticsearchSettings +from stac_fastapi.elasticsearch.config import ( + ElasticsearchSettings as SyncElasticsearchSettings, +) from stac_fastapi.extensions.core.transaction.request import ( PartialCollection, PartialItem, PatchOperation, ) -from stac_fastapi.opensearch.config import ( - AsyncOpensearchSettings as AsyncSearchSettings, -) -from stac_fastapi.opensearch.config import OpensearchSettings as SyncSearchSettings from stac_fastapi.sfeos_helpers import filter as filter_module from stac_fastapi.sfeos_helpers.database import ( apply_free_text_filter_shared, @@ -34,6 +35,7 @@ delete_item_index_shared, get_queryables_mapping_shared, index_alias_by_collection_id, + index_by_collection_id, mk_actions, mk_item_id, populate_sort_shared, @@ -52,7 +54,6 @@ AGGREGATION_MAPPING, COLLECTIONS_INDEX, DEFAULT_SORT, - ES_COLLECTIONS_MAPPINGS, ITEM_INDICES, ITEMS_INDEX_PREFIX, Geometry, @@ -78,7 +79,7 @@ async def create_index_templates() -> None: None """ - await create_index_templates_shared(settings=AsyncSearchSettings()) + await create_index_templates_shared(settings=AsyncElasticsearchSettings()) async def create_collection_index() -> None: @@ -89,23 +90,36 @@ async def create_collection_index() -> None: None """ - client = AsyncSearchSettings().create_client + client = AsyncElasticsearchSettings().create_client - index = f"{COLLECTIONS_INDEX}-000001" + await client.options(ignore_status=400).indices.create( + index=f"{COLLECTIONS_INDEX}-000001", + body={"aliases": {COLLECTIONS_INDEX: {}}}, + ) + await client.close() - exists = await client.indices.exists(index=index) - if not exists: - await client.indices.create( - index=index, - body={ - "aliases": {COLLECTIONS_INDEX: {}}, - "mappings": ES_COLLECTIONS_MAPPINGS, - }, - ) + +async def create_item_index(collection_id: str): + """ + Create the index for Items. The settings of the index template will be used implicitly. + + Args: + collection_id (str): Collection identifier. + + Returns: + None + + """ + client = AsyncElasticsearchSettings().create_client + + await client.options(ignore_status=400).indices.create( + index=f"{index_by_collection_id(collection_id)}-000001", + body={"aliases": {index_alias_by_collection_id(collection_id): {}}}, + ) await client.close() -async def delete_item_index(collection_id: str) -> None: +async def delete_item_index(collection_id: str): """Delete the index for items in a collection. Args: @@ -115,7 +129,7 @@ async def delete_item_index(collection_id: str) -> None: This function delegates to the shared implementation in delete_item_index_shared. """ await delete_item_index_shared( - settings=AsyncSearchSettings(), collection_id=collection_id + settings=AsyncElasticsearchSettings(), collection_id=collection_id ) @@ -123,9 +137,12 @@ async def delete_item_index(collection_id: str) -> None: class DatabaseLogic(BaseDatabaseLogic): """Database logic.""" - async_settings: AsyncSearchSettings = attr.ib(factory=AsyncSearchSettings) - sync_settings: SyncSearchSettings = attr.ib(factory=SyncSearchSettings) - + async_settings: AsyncElasticsearchSettings = attr.ib( + factory=AsyncElasticsearchSettings + ) + sync_settings: SyncElasticsearchSettings = attr.ib( + factory=SyncElasticsearchSettings + ) async_index_selector: BaseIndexSelector = attr.ib(init=False) async_index_inserter: BaseIndexInserter = attr.ib(init=False) @@ -155,8 +172,7 @@ def __attrs_post_init__(self): async def get_all_collections( self, token: Optional[str], limit: int, request: Request ) -> Tuple[List[Dict[str, Any]], Optional[str]]: - """ - Retrieve a list of all collections from Opensearch, supporting pagination. + """Retrieve a list of all collections from Elasticsearch, supporting pagination. Args: token (Optional[str]): The pagination token. @@ -165,19 +181,17 @@ async def get_all_collections( Returns: A tuple of (collections, next pagination token if any). """ - search_body = { - "sort": [{"id": {"order": "asc"}}], - "size": limit, - } - - # Only add search_after to the query if token is not None and not empty + search_after = None if token: search_after = [token] - search_body["search_after"] = search_after response = await self.client.search( index=COLLECTIONS_INDEX, - body=search_body, + body={ + "sort": [{"id": {"order": "asc"}}], + "size": limit, + **({"search_after": search_after} if search_after is not None else {}), + }, ) hits = response["hits"]["hits"] @@ -190,10 +204,7 @@ async def get_all_collections( next_token = None if len(hits) == limit: - # Ensure we have a valid sort value for next_token - next_token_values = hits[-1].get("sort") - if next_token_values: - next_token = next_token_values[0] + next_token = hits[-1]["sort"][0] return collections, next_token @@ -228,7 +239,7 @@ async def get_one_item(self, collection_id: str, item_id: str) -> Dict: ) return response["hits"]["hits"][0]["_source"] - except exceptions.NotFoundError: + except ESNotFoundError: raise NotFoundError( f"Item {item_id} does not exist inside Collection {collection_id}" ) @@ -265,24 +276,6 @@ def apply_collections_filter(search: Search, collection_ids: List[str]): """Database logic to search a list of STAC collection ids.""" return search.filter("terms", collection=collection_ids) - @staticmethod - def apply_free_text_filter(search: Search, free_text_queries: Optional[List[str]]): - """Create a free text query for OpenSearch queries. - - This method delegates to the shared implementation in apply_free_text_filter_shared. - - Args: - search (Search): The search object to apply the query to. - free_text_queries (Optional[List[str]]): A list of text strings to search for in the properties. - - Returns: - Search: The search object with the free text query applied, or the original search - object if no free_text_queries were provided. - """ - return apply_free_text_filter_shared( - search=search, free_text_queries=free_text_queries - ) - @staticmethod def apply_datetime_filter( search: Search, datetime: Optional[str] @@ -296,26 +289,96 @@ def apply_datetime_filter( Returns: The filtered search object. """ + USE_DATETIME = get_bool_env("USE_DATETIME", default=True) + datetime_search = return_date(datetime) if not datetime_search: return search, datetime_search - if "eq" in datetime_search: - # For exact matches, include: - # 1. Items with matching exact datetime - # 2. Items with datetime:null where the time falls within their range - should = [ - Q( - "bool", - filter=[ - Q("exists", field="properties.datetime"), - Q("term", **{"properties__datetime": datetime_search["eq"]}), - ], - ), - Q( + if USE_DATETIME: + if "eq" in datetime_search: + # For exact matches, include: + # 1. Items with matching exact datetime + # 2. Items with datetime:null where the time falls within their range + should = [ + Q( + "bool", + filter=[ + Q("exists", field="properties.datetime"), + Q( + "term", + **{"properties__datetime": datetime_search["eq"]}, + ), + ], + ), + Q( + "bool", + must_not=[Q("exists", field="properties.datetime")], + filter=[ + Q("exists", field="properties.start_datetime"), + Q("exists", field="properties.end_datetime"), + Q( + "range", + properties__start_datetime={ + "lte": datetime_search["eq"] + }, + ), + Q( + "range", + properties__end_datetime={"gte": datetime_search["eq"]}, + ), + ], + ), + ] + else: + # For date ranges, include: + # 1. Items with datetime in the range + # 2. Items with datetime:null that overlap the search range + should = [ + Q( + "bool", + filter=[ + Q("exists", field="properties.datetime"), + Q( + "range", + properties__datetime={ + "gte": datetime_search["gte"], + "lte": datetime_search["lte"], + }, + ), + ], + ), + Q( + "bool", + must_not=[Q("exists", field="properties.datetime")], + filter=[ + Q("exists", field="properties.start_datetime"), + Q("exists", field="properties.end_datetime"), + Q( + "range", + properties__start_datetime={ + "lte": datetime_search["lte"] + }, + ), + Q( + "range", + properties__end_datetime={ + "gte": datetime_search["gte"] + }, + ), + ], + ), + ] + + return ( + search.query(Q("bool", should=should, minimum_should_match=1)), + datetime_search, + ) + else: + if "eq" in datetime_search: + filter_query = Q( "bool", - must_not=[Q("exists", field="properties.datetime")], filter=[ Q("exists", field="properties.start_datetime"), Q("exists", field="properties.end_datetime"), @@ -328,29 +391,10 @@ def apply_datetime_filter( properties__end_datetime={"gte": datetime_search["eq"]}, ), ], - ), - ] - else: - # For date ranges, include: - # 1. Items with datetime in the range - # 2. Items with datetime:null that overlap the search range - should = [ - Q( - "bool", - filter=[ - Q("exists", field="properties.datetime"), - Q( - "range", - properties__datetime={ - "gte": datetime_search["gte"], - "lte": datetime_search["lte"], - }, - ), - ], - ), - Q( + ) + else: + filter_query = Q( "bool", - must_not=[Q("exists", field="properties.datetime")], filter=[ Q("exists", field="properties.start_datetime"), Q("exists", field="properties.end_datetime"), @@ -363,13 +407,8 @@ def apply_datetime_filter( properties__end_datetime={"gte": datetime_search["gte"]}, ), ], - ), - ] - - return ( - search.query(Q("bool", should=should, minimum_should_match=1)), - datetime_search, - ) + ) + return search.query(filter_query), datetime_search @staticmethod def apply_bbox_filter(search: Search, bbox: List): @@ -437,29 +476,47 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float): search (Search): The search object with the specified filter applied. """ if op != "eq": - key_filter = {field: {f"{op}": value}} + key_filter = {field: {op: value}} search = search.filter(Q("range", **key_filter)) else: search = search.filter("term", **{field: value}) return search + @staticmethod + def apply_free_text_filter(search: Search, free_text_queries: Optional[List[str]]): + """Create a free text query for Elasticsearch queries. + + This method delegates to the shared implementation in apply_free_text_filter_shared. + + Args: + search (Search): The search object to apply the query to. + free_text_queries (Optional[List[str]]): A list of text strings to search for in the properties. + + Returns: + Search: The search object with the free text query applied, or the original search + object if no free_text_queries were provided. + """ + return apply_free_text_filter_shared( + search=search, free_text_queries=free_text_queries + ) + async def apply_cql2_filter( self, search: Search, _filter: Optional[Dict[str, Any]] ): """ - Apply a CQL2 filter to an Opensearch Search object. + Apply a CQL2 filter to an Elasticsearch Search object. - This method transforms a dictionary representing a CQL2 filter into an Opensearch query + This method transforms a dictionary representing a CQL2 filter into an Elasticsearch query and applies it to the provided Search object. If the filter is None, the original Search object is returned unmodified. Args: - search (Search): The Opensearch Search object to which the filter will be applied. + search (Search): The Elasticsearch Search object to which the filter will be applied. _filter (Optional[Dict[str, Any]]): The filter in dictionary form that needs to be applied to the search. The dictionary should follow the structure required by the `to_es` function which converts it - to an Opensearch query. + to an Elasticsearch query. Returns: Search: The modified Search object with the filter applied if a filter is provided, @@ -467,13 +524,13 @@ async def apply_cql2_filter( """ if _filter is not None: es_query = filter_module.to_es(await self.get_queryables_mapping(), _filter) - search = search.filter(es_query) + search = search.query(es_query) return search @staticmethod def populate_sort(sortby: List) -> Optional[Dict[str, Dict[str, str]]]: - """Create a sort configuration for OpenSearch queries. + """Create a sort configuration for Elasticsearch queries. This method delegates to the shared implementation in populate_sort_shared. @@ -518,7 +575,11 @@ async def execute_search( Raises: NotFoundError: If the collections specified in `collection_ids` do not exist. """ - search_body: Dict[str, Any] = {} + search_after = None + + if token: + search_after = orjson.loads(urlsafe_b64decode(token)) + query = search.query.to_dict() if search.query else None index_param = await self.async_index_selector.select_indexes( @@ -528,18 +589,6 @@ async def execute_search( index_param = ITEM_INDICES query = add_collections_to_body(collection_ids, query) - if query: - search_body["query"] = query - - search_after = None - - if token: - search_after = orjson.loads(urlsafe_b64decode(token)) - if search_after: - search_body["search_after"] = search_after - - search_body["sort"] = sort if sort else DEFAULT_SORT - max_result_window = get_max_limit() size_limit = min(limit + 1, max_result_window) @@ -548,7 +597,9 @@ async def execute_search( self.client.search( index=index_param, ignore_unavailable=ignore_unavailable, - body=search_body, + query=query, + sort=sort or DEFAULT_SORT, + **({"search_after": search_after} if search_after is not None else {}), size=size_limit, ) ) @@ -563,7 +614,7 @@ async def execute_search( try: es_response = await search_task - except exceptions.NotFoundError: + except ESNotFoundError: raise NotFoundError(f"Collections '{collection_ids}' do not exist") hits = es_response["hits"]["hits"] @@ -609,6 +660,8 @@ async def aggregate( if query: search_body["query"] = query + logger.debug("Aggregations: %s", aggregations) + def _fill_aggregation_parameters(name: str, agg: dict) -> dict: [key] = agg.keys() agg_precision = { @@ -648,7 +701,7 @@ def _fill_aggregation_parameters(name: str, agg: dict) -> dict: try: db_response = await search_task - except exceptions.NotFoundError: + except ESNotFoundError: raise NotFoundError(f"Collections '{collection_ids}' do not exist") return db_response @@ -679,14 +732,21 @@ async def async_prep_create_item( """ await self.check_collection_exists(collection_id=item["collection"]) + alias = index_alias_by_collection_id(item["collection"]) + doc_id = mk_item_id(item["id"], item["collection"]) - if not exist_ok and await self.client.exists( - index=index_alias_by_collection_id(item["collection"]), - id=mk_item_id(item["id"], item["collection"]), - ): - raise ConflictError( - f"Item {item['id']} in collection {item['collection']} already exists" - ) + if not exist_ok: + alias_exists = await self.client.indices.exists_alias(name=alias) + + if alias_exists: + alias_info = await self.client.indices.get_alias(name=alias) + indices = list(alias_info.keys()) + + for index in indices: + if await self.client.exists(index=index, id=doc_id): + raise ConflictError( + f"Item {item['id']} in collection {item['collection']} already exists" + ) return self.item_serializer.stac_to_db(item, base_url) @@ -734,6 +794,7 @@ async def bulk_async_prep_create_item( logger.warning( f"{error_message} Continuing as `RAISE_ON_BULK_ERROR` is set to false." ) + # Serialize the item into a database-compatible format prepped_item = self.item_serializer.stac_to_db(item, base_url) logger.debug(f"Item {item['id']} prepared successfully.") @@ -803,7 +864,9 @@ async def create_item( item (Item): The item to be created. base_url (str, optional): The base URL for the item. Defaults to an empty string. exist_ok (bool, optional): Whether to allow the item to exist already. Defaults to False. - **kwargs: Additional keyword arguments like refresh. + **kwargs: Additional keyword arguments. + - refresh (str): Whether to refresh the index after the operation. Can be "true", "false", or "wait_for". + - refresh (bool): Whether to refresh the index after the operation. Defaults to the value in `self.async_settings.database_refresh`. Raises: ConflictError: If the item already exists in the database. @@ -811,10 +874,9 @@ async def create_item( Returns: None """ - # todo: check if collection exists, but cache + # Extract item and collection IDs item_id = item["id"] collection_id = item["collection"] - # Ensure kwargs is a dictionary kwargs = kwargs or {} @@ -827,6 +889,7 @@ async def create_item( f"Creating item {item_id} in collection {collection_id} with refresh={refresh}" ) + # Prepare the item for insertion item = await self.async_prep_create_item( item=item, base_url=base_url, exist_ok=exist_ok ) @@ -834,11 +897,11 @@ async def create_item( target_index = await self.async_index_inserter.get_target_index( collection_id, item ) - + # Index the item in the database await self.client.index( index=target_index, id=mk_item_id(item_id, collection_id), - body=item, + document=item, refresh=refresh, ) @@ -932,14 +995,14 @@ async def json_patch_item( await self.client.update( index=document_index, id=mk_item_id(item_id, collection_id), - body={"script": script}, + script=script, refresh=True, ) - except exceptions.NotFoundError: + except ESNotFoundError: raise NotFoundError( f"Item {item_id} does not exist inside Collection {collection_id}" ) - except exceptions.RequestError as exc: + except BadRequestError as exc: raise HTTPException( status_code=400, detail=exc.info["error"]["caused_by"] ) from exc @@ -949,7 +1012,9 @@ async def json_patch_item( if new_collection_id: await self.client.reindex( body={ - "dest": {"index": f"{ITEMS_INDEX_PREFIX}{new_collection_id}"}, + "dest": { + "index": f"{ITEMS_INDEX_PREFIX}{new_collection_id}" + }, # # noqa "source": { "index": f"{ITEMS_INDEX_PREFIX}{collection_id}", "query": {"term": {"id": {"value": item_id}}}, @@ -957,8 +1022,8 @@ async def json_patch_item( "script": { "lang": "painless", "source": ( - f"""ctx._id = ctx._id.replace('{collection_id}', '{new_collection_id}');""" # noqa: E702 - f"""ctx._source.collection = '{new_collection_id}';""" # noqa: E702 + f"""ctx._id = ctx._id.replace('{collection_id}', '{new_collection_id}');""" # noqa + f"""ctx._source.collection = '{new_collection_id}';""" # noqa ), }, }, @@ -994,10 +1059,15 @@ async def delete_item(self, item_id: str, collection_id: str, **kwargs: Any): Args: item_id (str): The id of the Item to be deleted. collection_id (str): The id of the Collection that the Item belongs to. - **kwargs: Additional keyword arguments like refresh. + **kwargs: Additional keyword arguments. + - refresh (str): Whether to refresh the index after the operation. Can be "true", "false", or "wait_for". + - refresh (bool): Whether to refresh the index after the operation. Defaults to the value in `self.async_settings.database_refresh`. Raises: NotFoundError: If the Item does not exist in the database. + + Returns: + None """ # Ensure kwargs is a dictionary kwargs = kwargs or {} @@ -1012,12 +1082,14 @@ async def delete_item(self, item_id: str, collection_id: str, **kwargs: Any): ) try: + # Perform the delete operation await self.client.delete_by_query( index=index_alias_by_collection_id(collection_id), body={"query": {"term": {"_id": mk_item_id(item_id, collection_id)}}}, refresh=refresh, ) - except exceptions.NotFoundError: + except ESNotFoundError: + # Raise a custom NotFoundError if the item does not exist raise NotFoundError( f"Item {item_id} in collection {collection_id} not found" ) @@ -1034,10 +1106,10 @@ async def get_items_mapping(self, collection_id: str) -> Dict[str, Any]: index_name = index_alias_by_collection_id(collection_id) try: mapping = await self.client.indices.get_mapping( - index=index_name, params={"allow_no_indices": "false"} + index=index_name, allow_no_indices=False ) - return mapping - except exceptions.NotFoundError: + return mapping.body + except ESNotFoundError: raise NotFoundError(f"Mapping for index {index_name} not found") async def get_items_unique_values( @@ -1076,11 +1148,16 @@ async def create_collection(self, collection: Collection, **kwargs: Any): Args: collection (Collection): The Collection object to be created. - **kwargs: Additional keyword arguments like refresh. + **kwargs: Additional keyword arguments. + - refresh (str): Whether to refresh the index after the operation. Can be "true", "false", or "wait_for". + - refresh (bool): Whether to refresh the index after the operation. Defaults to the value in `self.async_settings.database_refresh`. Raises: ConflictError: If a Collection with the same id already exists in the database. + Returns: + None + Notes: A new index is created for the items in the Collection using the `create_item_index` function. """ @@ -1096,15 +1173,18 @@ async def create_collection(self, collection: Collection, **kwargs: Any): # Log the creation attempt logger.info(f"Creating collection {collection_id} with refresh={refresh}") + # Check if the collection already exists if await self.client.exists(index=COLLECTIONS_INDEX, id=collection_id): raise ConflictError(f"Collection {collection_id} already exists") + # Index the collection in the database await self.client.index( index=COLLECTIONS_INDEX, id=collection_id, - body=collection, + document=collection, refresh=refresh, ) + if self.async_index_inserter.should_create_collection_index(): await self.async_index_inserter.create_simple_index( self.client, collection_id @@ -1131,7 +1211,7 @@ async def find_collection(self, collection_id: str) -> Collection: collection = await self.client.get( index=COLLECTIONS_INDEX, id=collection_id ) - except exceptions.NotFoundError: + except ESNotFoundError: raise NotFoundError(f"Collection {collection_id} not found") return collection["_source"] @@ -1139,21 +1219,26 @@ async def find_collection(self, collection_id: str) -> Collection: async def update_collection( self, collection_id: str, collection: Collection, **kwargs: Any ): - """Update a collection from the database. + """Update a collection in the database. Args: collection_id (str): The ID of the collection to be updated. collection (Collection): The Collection object to be used for the update. - **kwargs: Additional keyword arguments like refresh. + **kwargs: Additional keyword arguments. + - refresh (str): Whether to refresh the index after the operation. Can be "true", "false", or "wait_for". + - refresh (bool): Whether to refresh the index after the operation. Defaults to the value in `self.async_settings.database_refresh`. + Returns: + None Raises: - NotFoundError: If the collection with the given `collection_id` is not - found in the database. + NotFoundError: If the collection with the given `collection_id` is not found in the database. + ConflictError: If a conflict occurs during the update. Notes: This function updates the collection in the database using the specified - `collection_id` and with the collection specified in the `Collection` object. - If the collection is not found, a `NotFoundError` is raised. + `collection_id` and the provided `Collection` object. If the collection ID + changes, the function creates a new collection, reindexes the items, and deletes + the old collection. """ # Ensure kwargs is a dictionary kwargs = kwargs or {} @@ -1165,15 +1250,19 @@ async def update_collection( # Log the update attempt logger.info(f"Updating collection {collection_id} with refresh={refresh}") + # Ensure the collection exists await self.find_collection(collection_id=collection_id) + # Handle collection ID change if collection_id != collection["id"]: logger.info( f"Collection ID change detected: {collection_id} -> {collection['id']}" ) + # Create the new collection await self.create_collection(collection, refresh=refresh) + # Reindex items from the old collection to the new collection await self.client.reindex( body={ "dest": {"index": f"{ITEMS_INDEX_PREFIX}{collection['id']}"}, @@ -1187,13 +1276,15 @@ async def update_collection( refresh=refresh, ) - await self.delete_collection(collection_id=collection_id, **kwargs) + # Delete the old collection + await self.delete_collection(collection_id) else: + # Update the existing collection await self.client.index( index=COLLECTIONS_INDEX, id=collection_id, - body=collection, + document=collection, refresh=refresh, ) @@ -1265,11 +1356,11 @@ async def json_patch_collection( await self.client.update( index=COLLECTIONS_INDEX, id=collection_id, - body={"script": script}, + script=script, refresh=True, ) - except exceptions.RequestError as exc: + except BadRequestError as exc: raise HTTPException( status_code=400, detail=exc.info["error"]["caused_by"] ) from exc @@ -1292,13 +1383,17 @@ async def delete_collection(self, collection_id: str, **kwargs: Any): """Delete a collection from the database. Parameters: - self: The instance of the object calling this function. collection_id (str): The ID of the collection to be deleted. - **kwargs: Additional keyword arguments like refresh. + kwargs (Any, optional): Additional keyword arguments, including `refresh`. + - refresh (str): Whether to refresh the index after the operation. Can be "true", "false", or "wait_for". + - refresh (bool): Whether to refresh the index after the operation. Defaults to the value in `self.async_settings.database_refresh`. Raises: NotFoundError: If the collection with the given `collection_id` is not found in the database. + Returns: + None + Notes: This function first verifies that the collection with the specified `collection_id` exists in the database, and then deletes the collection. If `refresh` is set to "true", "false", or "wait_for", the index is refreshed accordingly after @@ -1307,19 +1402,14 @@ async def delete_collection(self, collection_id: str, **kwargs: Any): # Ensure kwargs is a dictionary kwargs = kwargs or {} - await self.find_collection(collection_id=collection_id) - - # Resolve the `refresh` parameter refresh = kwargs.get("refresh", self.async_settings.database_refresh) refresh = validate_refresh(refresh) - # Log the deletion attempt - logger.info(f"Deleting collection {collection_id} with refresh={refresh}") - + # Verify that the collection exists + await self.find_collection(collection_id=collection_id) await self.client.delete( index=COLLECTIONS_INDEX, id=collection_id, refresh=refresh ) - # Delete the item index for the collection await delete_item_index(collection_id) async def bulk_async( @@ -1372,21 +1462,23 @@ async def bulk_async( logger.warning(f"No items to insert for collection {collection_id}") return 0, [] + # Perform the bulk insert raise_on_error = self.async_settings.raise_on_bulk_error actions = await self.async_index_inserter.prepare_bulk_actions( collection_id, processed_items ) - success, errors = await helpers.async_bulk( self.client, actions, refresh=refresh, raise_on_error=raise_on_error, ) + # Log the result logger.info( f"Bulk insert completed for collection {collection_id}: {success} successes, {len(errors)} errors" ) + return success, errors def bulk_sync( @@ -1396,7 +1488,7 @@ def bulk_sync( **kwargs: Any, ) -> Tuple[int, List[Dict[str, Any]]]: """ - Perform a bulk insert of items into the database asynchronously. + Perform a bulk insert of items into the database synchronously. Args: collection_id (str): The ID of the collection to which the items belong. @@ -1439,11 +1531,7 @@ def bulk_sync( logger.warning(f"No items to insert for collection {collection_id}") return 0, [] - # Handle empty processed_items - if not processed_items: - logger.warning(f"No items to insert for collection {collection_id}") - return 0, [] - + # Perform the bulk insert raise_on_error = self.sync_settings.raise_on_bulk_error success, errors = helpers.bulk( self.sync_client, @@ -1451,6 +1539,12 @@ def bulk_sync( refresh=refresh, raise_on_error=raise_on_error, ) + + # Log the result + logger.info( + f"Bulk insert completed for collection {collection_id}: {success} successes, {len(errors)} errors" + ) + return success, errors # DANGER diff --git a/stac_fastapi/tests/api/test_api.py b/stac_fastapi/tests/api/test_api.py index cdf383f96..e74ab5600 100644 --- a/stac_fastapi/tests/api/test_api.py +++ b/stac_fastapi/tests/api/test_api.py @@ -1537,3 +1537,89 @@ async def test_search_max_item_limit( assert resp.status_code == 200 resp_json = resp.json() assert int(limit) == len(resp_json["features"]) + + +@pytest.mark.asyncio +async def test_use_datetime_true(app_client, load_test_data, txn_client, monkeypatch): + monkeypatch.setenv("USE_DATETIME", "true") + + test_collection = load_test_data("test_collection.json") + test_collection["id"] = "test-collection-datetime-true" + await create_collection(txn_client, test_collection) + + item = load_test_data("test_item.json") + + item1 = item.copy() + item1["id"] = "test-item-datetime" + item1["collection"] = test_collection["id"] + item1["properties"]["datetime"] = "2020-01-01T12:00:00Z" + await create_item(txn_client, item1) + + item2 = item.copy() + item2["id"] = "test-item-start-end" + item2["collection"] = test_collection["id"] + item1["properties"]["datetime"] = None + item2["properties"]["start_datetime"] = "2020-01-01T10:00:00Z" + item2["properties"]["end_datetime"] = "2020-01-01T13:00:00Z" + await create_item(txn_client, item2) + + resp = await app_client.post( + "/search", + json={ + "datetime": "2020-01-01T12:00:00Z", + "collections": [test_collection["id"]], + }, + ) + + assert resp.status_code == 200 + resp_json = resp.json() + + found_ids = {feature["id"] for feature in resp_json["features"]} + assert "test-item-datetime" in found_ids + assert "test-item-start-end" in found_ids + + +@pytest.mark.asyncio +async def test_use_datetime_false(app_client, load_test_data, txn_client, monkeypatch): + monkeypatch.setenv("USE_DATETIME", "false") + + test_collection = load_test_data("test_collection.json") + test_collection["id"] = "test-collection-datetime-false" + await create_collection(txn_client, test_collection) + + item = load_test_data("test_item.json") + + # Item 1: Should NOT be found + item1 = item.copy() + item1["id"] = "test-item-datetime-only" + item1["collection"] = test_collection["id"] + item1["properties"]["datetime"] = "2020-01-01T12:00:00Z" + item1["properties"]["start_datetime"] = "2021-01-01T10:00:00Z" + item1["properties"]["end_datetime"] = "2021-01-01T14:00:00Z" + await create_item(txn_client, item1) + + # Item 2: Should be found + item2 = item.copy() + item2["id"] = "test-item-start-end-only" + item2["collection"] = test_collection["id"] + item2["properties"]["datetime"] = None + item2["properties"]["start_datetime"] = "2020-01-01T10:00:00Z" + item2["properties"]["end_datetime"] = "2020-01-01T14:00:00Z" + await create_item(txn_client, item2) + + resp = await app_client.post( + "/search", + json={ + "datetime": "2020-01-01T12:00:00Z", + "collections": [test_collection["id"]], + "limit": 10, + }, + ) + + assert resp.status_code == 200 + resp_json = resp.json() + + found_ids = {feature["id"] for feature in resp_json["features"]} + + assert "test-item-datetime-only" not in found_ids + assert "test-item-start-end-only" in found_ids