diff --git a/CHANGES.md b/CHANGES.md index 02e616a60..e409d7d02 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,6 +6,9 @@ ### Changed +* Refactor to remove hardcoded search request models. Request models are now dynamically created based on the enabled extensions. + ([#213](https://github.com/stac-utils/stac-fastapi/pull/213)) + ### Removed ### Fixed diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index b7efac822..0af7d9e7e 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -7,7 +7,7 @@ from fastapi.openapi.utils import get_openapi from pydantic import BaseModel from stac_pydantic import Collection, Item, ItemCollection -from stac_pydantic.api import ConformanceClasses, LandingPage, Search +from stac_pydantic.api import ConformanceClasses, LandingPage from stac_pydantic.api.collections import Collections from stac_pydantic.version import STAC_VERSION from starlette.responses import JSONResponse, Response @@ -20,18 +20,17 @@ GeoJSONResponse, ItemCollectionUri, ItemUri, - SearchGetRequest, - _create_request_model, + create_request_model, ) from stac_fastapi.api.openapi import update_openapi from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint # TODO: make this module not depend on `stac_fastapi.extensions` -from stac_fastapi.extensions.core import FieldsExtension +from stac_fastapi.extensions.core import FieldsExtension, TokenPaginationExtension from stac_fastapi.types.config import ApiSettings, Settings from stac_fastapi.types.core import AsyncBaseCoreClient, BaseCoreClient from stac_fastapi.types.extension import ApiExtension -from stac_fastapi.types.search import STACSearch +from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest @attr.s @@ -76,9 +75,13 @@ class StacApi: api_version: str = attr.ib(default="0.1") stac_version: str = attr.ib(default=STAC_VERSION) description: str = attr.ib(default="stac-fastapi") - search_request_model: Type[Search] = attr.ib(default=STACSearch) - search_get_request: Type[SearchGetRequest] = attr.ib(default=SearchGetRequest) - item_collection_uri: Type[ItemCollectionUri] = attr.ib(default=ItemCollectionUri) + search_get_request_model: Type[BaseSearchGetRequest] = attr.ib( + default=BaseSearchGetRequest + ) + search_post_request_model: Type[BaseSearchPostRequest] = attr.ib( + default=BaseSearchPostRequest + ) + pagination_extension = attr.ib(default=TokenPaginationExtension) response_class: Type[Response] = attr.ib(default=JSONResponse) middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware])) @@ -176,7 +179,6 @@ def register_post_search(self): Returns: None """ - search_request_model = _create_request_model(self.search_request_model) fields_ext = self.get_extension(FieldsExtension) self.router.add_api_route( name="Search", @@ -189,7 +191,7 @@ def register_post_search(self): response_model_exclude_none=True, methods=["POST"], endpoint=self._create_endpoint( - self.client.post_search, search_request_model, GeoJSONResponse + self.client.post_search, self.search_post_request_model, GeoJSONResponse ), ) @@ -211,7 +213,7 @@ def register_get_search(self): response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.get_search, self.search_get_request, GeoJSONResponse + self.client.get_search, self.search_get_request_model, GeoJSONResponse ), ) @@ -261,6 +263,12 @@ def register_get_item_collection(self): Returns: None """ + get_pagination_model = self.get_extension(self.pagination_extension).GET + request_model = create_request_model( + "ItemCollectionURI", + base_model=ItemCollectionUri, + mixins=[get_pagination_model], + ) self.router.add_api_route( name="Get ItemCollection", path="/collections/{collection_id}/items", @@ -272,9 +280,7 @@ def register_get_item_collection(self): response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.item_collection, - self.item_collection_uri, - self.response_class, + self.client.item_collection, request_model, self.response_class ), ) diff --git a/stac_fastapi/api/stac_fastapi/api/models.py b/stac_fastapi/api/stac_fastapi/api/models.py index e472aaec9..704cf6ed0 100644 --- a/stac_fastapi/api/stac_fastapi/api/models.py +++ b/stac_fastapi/api/stac_fastapi/api/models.py @@ -1,55 +1,100 @@ """api request/response models.""" -import abc import importlib -from typing import Dict, Optional, Type, Union +from typing import Optional, Type, Union import attr from fastapi import Body, Path from pydantic import BaseModel, create_model from pydantic.fields import UndefinedType - -def _create_request_model(model: Type[BaseModel]) -> Type[BaseModel]: +from stac_fastapi.types.extension import ApiExtension +from stac_fastapi.types.search import ( + APIRequest, + BaseSearchGetRequest, + BaseSearchPostRequest, +) + + +def create_request_model( + model_name="SearchGetRequest", + base_model: Union[Type[BaseModel], APIRequest] = BaseSearchGetRequest, + extensions: Optional[ApiExtension] = None, + mixins: Optional[Union[BaseModel, APIRequest]] = None, + request_type: Optional[str] = "GET", +) -> Union[Type[BaseModel], APIRequest]: """Create a pydantic model for validating request bodies.""" fields = {} - for (k, v) in model.__fields__.items(): - # TODO: Filter out fields based on which extensions are present - field_info = v.field_info - body = Body( - None - if isinstance(field_info.default, UndefinedType) - else field_info.default, - default_factory=field_info.default_factory, - alias=field_info.alias, - alias_priority=field_info.alias_priority, - title=field_info.title, - description=field_info.description, - const=field_info.const, - gt=field_info.gt, - ge=field_info.ge, - lt=field_info.lt, - le=field_info.le, - multiple_of=field_info.multiple_of, - min_items=field_info.min_items, - max_items=field_info.max_items, - min_length=field_info.min_length, - max_length=field_info.max_length, - regex=field_info.regex, - extra=field_info.extra, - ) - fields[k] = (v.outer_type_, body) - return create_model(model.__name__, **fields, __base__=model) - - -@attr.s # type:ignore -class APIRequest(abc.ABC): - """Generic API Request base class.""" - - @abc.abstractmethod - def kwargs(self) -> Dict: - """Transform api request params into format which matches the signature of the endpoint.""" - ... + extension_models = [] + + # Check extensions for additional parameters to search + for extension in extensions or []: + if extension_model := extension.get_request_model(request_type): + extension_models.append(extension_model) + + mixins = mixins or [] + + models = [base_model] + extension_models + mixins + + # Handle GET requests + if all([issubclass(m, APIRequest) for m in models]): + return attr.make_class(model_name, attrs={}, bases=tuple(models)) + + # Handle POST requests + elif all([issubclass(m, BaseModel) for m in models]): + for model in models: + for (k, v) in model.__fields__.items(): + field_info = v.field_info + body = Body( + None + if isinstance(field_info.default, UndefinedType) + else field_info.default, + default_factory=field_info.default_factory, + alias=field_info.alias, + alias_priority=field_info.alias_priority, + title=field_info.title, + description=field_info.description, + const=field_info.const, + gt=field_info.gt, + ge=field_info.ge, + lt=field_info.lt, + le=field_info.le, + multiple_of=field_info.multiple_of, + min_items=field_info.min_items, + max_items=field_info.max_items, + min_length=field_info.min_length, + max_length=field_info.max_length, + regex=field_info.regex, + extra=field_info.extra, + ) + fields[k] = (v.outer_type_, body) + return create_model(model_name, **fields, __base__=base_model) + + raise TypeError("Mixed Request Model types. Check extension request types.") + + +def create_get_request_model( + extensions, base_model: BaseSearchGetRequest = BaseSearchGetRequest +): + """Wrap create_request_model to create the GET request model.""" + return create_request_model( + "SearchGetRequest", + base_model=BaseSearchGetRequest, + extensions=extensions, + request_type="GET", + ) + + +def create_post_request_model( + extensions, base_model: BaseSearchPostRequest = BaseSearchGetRequest +): + """Wrap create_request_model to create the POST request model.""" + return create_request_model( + "SearchPostRequest", + base_model=BaseSearchPostRequest, + extensions=extensions, + request_type="POST", + ) @attr.s # type:ignore @@ -58,10 +103,6 @@ class CollectionUri(APIRequest): collection_id: str = attr.ib(default=Path(..., description="Collection ID")) - def kwargs(self) -> Dict: - """kwargs.""" - return {"id": self.collection_id} - @attr.s class ItemUri(CollectionUri): @@ -69,18 +110,12 @@ class ItemUri(CollectionUri): item_id: str = attr.ib(default=Path(..., description="Item ID")) - def kwargs(self) -> Dict: - """kwargs.""" - return {"collection_id": self.collection_id, "item_id": self.item_id} - @attr.s class EmptyRequest(APIRequest): """Empty request.""" - def kwargs(self) -> Dict: - """kwargs.""" - return {} + ... @attr.s @@ -88,46 +123,32 @@ class ItemCollectionUri(CollectionUri): """Get item collection.""" limit: int = attr.ib(default=10) - token: str = attr.ib(default=None) - def kwargs(self) -> Dict: - """kwargs.""" - return { - "id": self.collection_id, - "limit": self.limit, - "token": self.token, - } + +class POSTTokenPagination(BaseModel): + """Token pagination model for POST requests.""" + + token: Optional[str] = None @attr.s -class SearchGetRequest(APIRequest): - """GET search request.""" - - collections: Optional[str] = attr.ib(default=None) - ids: Optional[str] = attr.ib(default=None) - bbox: Optional[str] = attr.ib(default=None) - datetime: Optional[Union[str]] = attr.ib(default=None) - limit: Optional[int] = attr.ib(default=10) - query: Optional[str] = attr.ib(default=None) +class GETTokenPagination(APIRequest): + """Token pagination for GET requests.""" + token: Optional[str] = attr.ib(default=None) - fields: Optional[str] = attr.ib(default=None) - sortby: Optional[str] = attr.ib(default=None) - - def kwargs(self) -> Dict: - """kwargs.""" - return { - "collections": self.collections.split(",") - if self.collections - else self.collections, - "ids": self.ids.split(",") if self.ids else self.ids, - "bbox": self.bbox.split(",") if self.bbox else self.bbox, - "datetime": self.datetime, - "limit": self.limit, - "query": self.query, - "token": self.token, - "fields": self.fields.split(",") if self.fields else self.fields, - "sortby": self.sortby.split(",") if self.sortby else self.sortby, - } + + +class POSTPagination(BaseModel): + """Page based pagination for POST requests.""" + + page: Optional[str] = None + + +@attr.s +class GETPagination(APIRequest): + """Page based pagination for GET requests.""" + + page: Optional[str] = attr.ib(default=None) # Test for ORJSON and use it rather than stdlib JSON where supported diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py index beb3e41d1..d720a6377 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py @@ -4,6 +4,7 @@ from .context import ContextExtension from .fields import FieldsExtension from .filter import FilterExtension +from .pagination import PaginationExtension, TokenPaginationExtension from .query import QueryExtension from .sort import SortExtension from .transaction import TransactionExtension @@ -12,8 +13,10 @@ "ContextExtension", "FieldsExtension", "FilterExtension", + "PaginationExtension", "QueryExtension", "SortExtension", "TilesExtension", + "TokenPaginationExtension", "TransactionExtension", ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py new file mode 100644 index 000000000..b9a246b63 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py @@ -0,0 +1,6 @@ +"""Fields extension module.""" + + +from .fields import FieldsExtension + +__all__ = ["FieldsExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py similarity index 91% rename from stac_fastapi/extensions/stac_fastapi/extensions/core/fields.py rename to stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py index ba8d84542..13fe62d9c 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py @@ -6,6 +6,8 @@ from stac_fastapi.types.extension import ApiExtension +from .request import FieldsExtensionGetRequest, FieldsExtensionPostRequest + @attr.s class FieldsExtension(ApiExtension): @@ -24,10 +26,12 @@ class FieldsExtension(ApiExtension): """ + GET = FieldsExtensionGetRequest + POST = FieldsExtensionPostRequest + conformance_classes: List[str] = attr.ib( factory=lambda: ["https://api.stacspec.org/v1.0.0-beta.3/item-search/#fields"] ) - schema_href: Optional[str] = attr.ib(default=None) default_includes: Set[str] = attr.ib( factory=lambda: { "id", @@ -41,6 +45,7 @@ class FieldsExtension(ApiExtension): "collection", } ) + schema_href: Optional[str] = attr.ib(default=None) def register(self, app: FastAPI) -> None: """Register the extension with a FastAPI application. diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py new file mode 100644 index 000000000..52ea3af2c --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py @@ -0,0 +1,71 @@ +"""Request models for the fields extension.""" + +from typing import Dict, Optional, Set + +import attr +from pydantic import BaseModel, Field + +from stac_fastapi.types.config import Settings +from stac_fastapi.types.search import APIRequest, str2list + + +class PostFieldsExtension(BaseModel): + """FieldsExtension. + + Attributes: + include: set of fields to include. + exclude: set of fields to exclude. + """ + + include: Optional[Set[str]] = set() + exclude: Optional[Set[str]] = set() + + @staticmethod + def _get_field_dict(fields: Optional[Set[str]]) -> Dict: + """Pydantic include/excludes notation. + + Internal method to create a dictionary for advanced include or exclude of pydantic fields on model export + Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude + """ + field_dict = {} + for field in fields or []: + if "." in field: + parent, key = field.split(".") + if parent not in field_dict: + field_dict[parent] = {key} + else: + field_dict[parent].add(key) + else: + field_dict[field] = ... # type:ignore + return field_dict + + @property + def filter_fields(self) -> Dict: + """Create pydantic include/exclude expression. + + Create dictionary of fields to include/exclude on model export based on the included and excluded fields passed + to the API + Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude + """ + # Always include default_includes, even if they + # exist in the exclude list. + include = (self.include or set()) - (self.exclude or set()) + include |= Settings.get().default_includes or set() + + return { + "include": self._get_field_dict(include), + "exclude": self._get_field_dict(self.exclude), + } + + +@attr.s +class FieldsExtensionGetRequest(APIRequest): + """Additional fields for the GET request.""" + + fields: Optional[str] = attr.ib(default=None, converter=str2list) + + +class FieldsExtensionPostRequest(BaseModel): + """Additional fields and schema for the POST request.""" + + fields: Optional[PostFieldsExtension] = Field(PostFieldsExtension()) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/__init__.py new file mode 100644 index 000000000..78256bfd2 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/__init__.py @@ -0,0 +1,6 @@ +"""Filter extension module.""" + + +from .filter import FilterExtension + +__all__ = ["FilterExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py similarity index 96% rename from stac_fastapi/extensions/stac_fastapi/extensions/core/filter.py rename to stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py index 1a2323d60..8a51a6657 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py @@ -12,6 +12,8 @@ from stac_fastapi.types.core import AsyncBaseFiltersClient, BaseFiltersClient from stac_fastapi.types.extension import ApiExtension +from .request import FilterExtensionGetRequest, FilterExtensionPostRequest + class FilterConformanceClasses(str, Enum): """Conformance classes for the Filter extension. @@ -54,6 +56,9 @@ class FilterExtension(ApiExtension): """ + GET = FilterExtensionGetRequest + POST = FilterExtensionPostRequest + client: Union[AsyncBaseFiltersClient, BaseFiltersClient] = attr.ib( factory=BaseFiltersClient ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py new file mode 100644 index 000000000..afd5b947c --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py @@ -0,0 +1,21 @@ +"""Filter extension request models.""" + +from typing import Any, Dict, Optional + +import attr +from pydantic import BaseModel + +from stac_fastapi.types.search import APIRequest + + +@attr.s +class FilterExtensionGetRequest(APIRequest): + """Filter extension GET request model.""" + + filter: Optional[str] = attr.ib(default=None) + + +class FilterExtensionPostRequest(BaseModel): + """Filter extension POST request model.""" + + filter: Optional[Dict[str, Any]] = None diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/__init__.py new file mode 100644 index 000000000..255701226 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/__init__.py @@ -0,0 +1,6 @@ +"""pagination classes as extensions.""" + +from .pagination import PaginationExtension +from .token_pagination import TokenPaginationExtension + +__all__ = ["PaginationExtension", "TokenPaginationExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/pagination.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/pagination.py new file mode 100644 index 000000000..5e834ed38 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/pagination.py @@ -0,0 +1,37 @@ +"""Pagination API extension.""" + +from typing import List, Optional + +import attr +from fastapi import FastAPI + +from stac_fastapi.api.models import GETPagination, POSTPagination +from stac_fastapi.types.extension import ApiExtension + + +@attr.s +class PaginationExtension(ApiExtension): + """Token Pagination. + + Though not strictly an extension, the chosen pagination will modify the + form of the request object. By making pagination an extension class, we can + use create_request_model to dynamically add the correct pagination parameter + to the request model for OpenAPI generation. + """ + + GET = GETPagination + POST = POSTPagination + + conformance_classes: List[str] = attr.ib(factory=list) + schema_href: Optional[str] = attr.ib(default=None) + + def register(self, app: FastAPI) -> None: + """Register the extension with a FastAPI application. + + Args: + app: target FastAPI application. + + Returns: + None + """ + pass diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/token_pagination.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/token_pagination.py new file mode 100644 index 000000000..1e1399971 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/token_pagination.py @@ -0,0 +1,37 @@ +"""Token pagination API extension.""" + +from typing import List, Optional + +import attr +from fastapi import FastAPI + +from stac_fastapi.api.models import GETTokenPagination, POSTTokenPagination +from stac_fastapi.types.extension import ApiExtension + + +@attr.s +class TokenPaginationExtension(ApiExtension): + """Token Pagination. + + Though not strictly an extension, the chosen pagination will modify the + form of the request object. By making pagination an extension class, we can + use create_request_model to dynamically add the correct pagination parameter + to the request model for OpenAPI generation. + """ + + GET = GETTokenPagination + POST = POSTTokenPagination + + conformance_classes: List[str] = attr.ib(factory=list) + schema_href: Optional[str] = attr.ib(default=None) + + def register(self, app: FastAPI) -> None: + """Register the extension with a FastAPI application. + + Args: + app: target FastAPI application. + + Returns: + None + """ + pass diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/__init__.py new file mode 100644 index 000000000..5bbe70595 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/__init__.py @@ -0,0 +1,5 @@ +"""Query extension module.""" + +from .query import QueryExtension + +__all__ = ["QueryExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py similarity index 86% rename from stac_fastapi/extensions/stac_fastapi/extensions/core/query.py rename to stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py index 1c01b1f3f..6bede1164 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/query.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py @@ -6,6 +6,8 @@ from stac_fastapi.types.extension import ApiExtension +from .request import QueryExtensionGetRequest, QueryExtensionPostRequest + @attr.s class QueryExtension(ApiExtension): @@ -17,6 +19,9 @@ class QueryExtension(ApiExtension): https://github.com/radiantearth/stac-api-spec/blob/master/item-search/README.md#query """ + GET = QueryExtensionGetRequest + POST = QueryExtensionPostRequest + conformance_classes: List[str] = attr.ib( factory=lambda: ["https://api.stacspec.org/v1.0.0-beta.3/item-search/#query"] ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py new file mode 100644 index 000000000..8b282884a --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py @@ -0,0 +1,21 @@ +"""Request model for the Query extension.""" + +from typing import Any, Dict, Optional + +import attr +from pydantic import BaseModel + +from stac_fastapi.types.search import APIRequest + + +@attr.s +class QueryExtensionGetRequest(APIRequest): + """Query Extension GET request model.""" + + query: Optional[str] = attr.ib(default=None) + + +class QueryExtensionPostRequest(BaseModel): + """Query Extension POST request model.""" + + query: Optional[Dict[str, Dict[str, Any]]] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/__init__.py new file mode 100644 index 000000000..b6996b018 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/__init__.py @@ -0,0 +1,5 @@ +"""Sort extension module.""" + +from .sort import SortExtension + +__all__ = ["SortExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py new file mode 100644 index 000000000..c19f40dba --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py @@ -0,0 +1,23 @@ +# encoding: utf-8 +"""Request model for the Sort Extension.""" + +from typing import List, Optional + +import attr +from pydantic import BaseModel +from stac_pydantic.api.extensions.sort import SortExtension as PostSortModel + +from stac_fastapi.types.search import APIRequest, str2list + + +@attr.s +class SortExtensionGetRequest(APIRequest): + """Sortby Parameter for GET requests.""" + + sortby: Optional[str] = attr.ib(default=None, converter=str2list) + + +class SortExtensionPostRequest(BaseModel): + """Sortby parameter for POST requests.""" + + sortby: Optional[List[PostSortModel]] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py similarity index 86% rename from stac_fastapi/extensions/stac_fastapi/extensions/core/sort.py rename to stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py index 73d4bb4bf..fcc9d20dd 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py @@ -6,6 +6,8 @@ from stac_fastapi.types.extension import ApiExtension +from .request import SortExtensionGetRequest, SortExtensionPostRequest + @attr.s class SortExtension(ApiExtension): @@ -17,6 +19,9 @@ class SortExtension(ApiExtension): https://github.com/radiantearth/stac-api-spec/blob/master/item-search/README.md#sort """ + GET = SortExtensionGetRequest + POST = SortExtensionPostRequest + conformance_classes: List[str] = attr.ib( factory=lambda: ["https://api.stacspec.org/v1.0.0-beta.3/item-search/#sort"] ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py index efde44a50..a776d12d9 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, FastAPI from pydantic import BaseModel -from stac_fastapi.api.models import _create_request_model +from stac_fastapi.api.models import create_request_model from stac_fastapi.api.routes import create_sync_endpoint from stac_fastapi.types.extension import ApiExtension @@ -72,7 +72,7 @@ def register(self, app: FastAPI) -> None: Returns: None """ - items_request_model = _create_request_model(Items) + items_request_model = create_request_model("Items", base_model=Items) router = APIRouter() router.add_api_route( diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/app.py index 96ab724a0..80eeccedf 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/app.py @@ -2,10 +2,12 @@ from fastapi.responses import ORJSONResponse from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import create_get_request_model, create_post_request_model +from stac_fastapi.extensions import QueryExtension from stac_fastapi.extensions.core import ( FieldsExtension, - QueryExtension, SortExtension, + TokenPaginationExtension, TransactionExtension, ) from stac_fastapi.pgstac.config import Settings @@ -15,22 +17,27 @@ from stac_fastapi.pgstac.types.search import PgstacSearch settings = Settings() +extensions = [ + TransactionExtension( + client=TransactionsClient(), + settings=settings, + response_class=ORJSONResponse, + ), + QueryExtension(), + SortExtension(), + FieldsExtension(), + TokenPaginationExtension(), +] + +post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) api = StacApi( settings=settings, - extensions=[ - TransactionExtension( - client=TransactionsClient(), - settings=settings, - response_class=ORJSONResponse, - ), - QueryExtension(), - SortExtension(), - FieldsExtension(), - ], - client=CoreCrudClient(), - search_request_model=PgstacSearch, + extensions=extensions, + client=CoreCrudClient(post_request_model=post_request_model), response_class=ORJSONResponse, + search_get_request_model=create_get_request_model(extensions), + search_post_request_model=post_request_model, ) app = api.app diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py index a01dc7f33..c0ba8e183 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py @@ -1,7 +1,7 @@ """Item crud client.""" import re from datetime import datetime -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Union from urllib.parse import urljoin import attr @@ -27,8 +27,6 @@ class CoreCrudClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" - search_request_model: Type[PgstacSearch] = attr.ib(init=False, default=PgstacSearch) - async def all_collections(self, **kwargs) -> Collections: """Read all collections from the database.""" request: Request = kwargs["request"] @@ -71,7 +69,7 @@ async def all_collections(self, **kwargs) -> Collections: collection_list = Collections(collections=linked_collections or [], links=links) return collection_list - async def get_collection(self, id: str, **kwargs) -> Collection: + async def get_collection(self, collection_id: str, **kwargs) -> Collection: """Get collection by id. Called with `GET /collections/{collection_id}`. @@ -91,14 +89,14 @@ async def get_collection(self, id: str, **kwargs) -> Collection: """ SELECT * FROM get_collection(:id::text); """, - id=id, + id=collection_id, ) collection = await conn.fetchval(q, *p) if collection is None: raise NotFoundError(f"Collection {id} does not exist.") collection["links"] = await CollectionLinks( - collection_id=id, request=request + collection_id=collection_id, request=request ).get_links(extra_links=collection.get("links")) return Collection(**collection) @@ -175,7 +173,11 @@ async def _search_base( return collection async def item_collection( - self, id: str, limit: Optional[int] = None, token: str = None, **kwargs + self, + collection_id: str, + limit: Optional[int] = None, + token: str = None, + **kwargs, ) -> ItemCollection: """Get all items from a specific collection. @@ -190,12 +192,14 @@ async def item_collection( An ItemCollection. """ # If collection does not exist, NotFoundError wil be raised - await self.get_collection(id, **kwargs) + await self.get_collection(collection_id, **kwargs) - req = self.search_request_model(collections=[id], limit=limit, token=token) + req = self.post_request_model( + collections=[collection_id], limit=limit, token=token + ) item_collection = await self._search_base(req, **kwargs) links = await CollectionLinks( - collection_id=id, request=kwargs["request"] + collection_id=collection_id, request=kwargs["request"] ).get_links(extra_links=item_collection["links"]) item_collection["links"] = links return item_collection @@ -214,7 +218,7 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: # If collection does not exist, NotFoundError wil be raised await self.get_collection(collection_id, **kwargs) - req = self.search_request_model( + req = self.post_request_model( ids=[item_id], collections=[collection_id], limit=1 ) item_collection = await self._search_base(req, **kwargs) @@ -301,7 +305,7 @@ async def get_search( # Do the request try: - search_request = self.search_request_model(**base_args) + search_request = self.post_request_model(**base_args) except ValidationError: raise HTTPException(status_code=400, detail="Invalid parameters provided") return await self.post_search(search_request, request=kwargs["request"]) diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/__init__.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/__init__.py new file mode 100644 index 000000000..410bc63f1 --- /dev/null +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/__init__.py @@ -0,0 +1,5 @@ +"""pgstac extension customisations.""" + +from .query import QueryExtension + +__all__ = ["QueryExtension"] diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/query.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/query.py new file mode 100644 index 000000000..91df8539d --- /dev/null +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/extensions/query.py @@ -0,0 +1,48 @@ +"""Pgstac query customisation.""" + +import operator +from enum import auto +from types import DynamicClassAttribute +from typing import Any, Callable, Dict, Optional + +from pydantic import BaseModel +from stac_pydantic.utils import AutoValueEnum + +from stac_fastapi.extensions.core.query import QueryExtension as QueryExtensionBase + + +class Operator(str, AutoValueEnum): + """Defines the set of operators supported by the API.""" + + eq = auto() + ne = auto() + lt = auto() + lte = auto() + gt = auto() + gte = auto() + # TODO: These are defined in the spec but aren't currently implemented by the api + # startsWith = auto() + # endsWith = auto() + # contains = auto() + # in = auto() + + @DynamicClassAttribute + def operator(self) -> Callable[[Any, Any], bool]: + """Return python operator.""" + return getattr(operator, self._value_) + + +class QueryExtensionPostRequest(BaseModel): + """Query Extension POST request model.""" + + query: Optional[Dict[str, Dict[Operator, Any]]] + + +class QueryExtension(QueryExtensionBase): + """Query Extension. + + Override the POST request model to add validation against + supported fields + """ + + POST = QueryExtensionPostRequest diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py index ca835a7bd..4a06a928f 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py @@ -56,9 +56,9 @@ async def delete_item(self, item_id: str, collection_id: str, **kwargs) -> Dict: await dbfunc(pool, "delete_item", item_id) return {"deleted item": item_id} - async def delete_collection(self, id: str, **kwargs) -> Dict: + async def delete_collection(self, collection_id: str, **kwargs) -> Dict: """Delete collection.""" request = kwargs["request"] pool = request.app.state.writepool - await dbfunc(pool, "delete_collection", id) - return {"deleted collection": id} + await dbfunc(pool, "delete_collection", collection_id) + return {"deleted collection": collection_id} diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/types/search.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/types/search.py index ed966ce32..a13126f1e 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/types/search.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/types/search.py @@ -1,111 +1,19 @@ """stac_fastapi.types.search module.""" -import operator -from enum import auto -from types import DynamicClassAttribute -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Optional -from pydantic import Field, conint, root_validator, validator -from stac_pydantic.api import Search -from stac_pydantic.api.extensions.fields import FieldsExtension as FieldsBase -from stac_pydantic.utils import AutoValueEnum +from pydantic import validator -from stac_fastapi.types.config import Settings +from stac_fastapi.types.search import BaseSearchPostRequest -# Be careful: https://github.com/samuelcolvin/pydantic/issues/1423#issuecomment-642797287 -NumType = Union[float, int] +class PgstacSearch(BaseSearchPostRequest): + """Search model. -class Operator(str, AutoValueEnum): - """Defines the set of operators supported by the API.""" - - eq = auto() - ne = auto() - lt = auto() - lte = auto() - gt = auto() - gte = auto() - # TODO: These are defined in the spec but aren't currently implemented by the api - # startsWith = auto() - # endsWith = auto() - # contains = auto() - # in = auto() - - @DynamicClassAttribute - def operator(self) -> Callable[[Any, Any], bool]: - """Return python operator.""" - return getattr(operator, self._value_) - - -class FieldsExtension(FieldsBase): - """FieldsExtension. - - Attributes: - include: set of fields to include. - exclude: set of fields to exclude. + Overrides the validation for datetime from the base request model. """ - include: Optional[Set[str]] = set() - exclude: Optional[Set[str]] = set() - - @staticmethod - def _get_field_dict(fields: Optional[Set[str]]) -> Dict: - """Pydantic include/excludes notation. - - Internal method to create a dictionary for advanced include or exclude of pydantic fields on model export - Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude - """ - field_dict = {} - for field in fields or []: - if "." in field: - parent, key = field.split(".") - if parent not in field_dict: - field_dict[parent] = {key} - else: - field_dict[parent].add(key) - else: - field_dict[field] = ... # type:ignore - return field_dict - - @property - def filter_fields(self) -> Dict: - """Create pydantic include/exclude expression. - - Create dictionary of fields to include/exclude on model export based on the included and excluded fields passed - to the API - Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude - """ - # Always include default_includes, even if they - # exist in the exclude list. - include = (self.include or set()) - (self.exclude or set()) - include |= Settings.get().default_includes or set() - - return { - "include": self._get_field_dict(include), - "exclude": self._get_field_dict(self.exclude), - } - - -class PgstacSearch(Search): - """Search model.""" - - # Make collections optional, default to searching all collections if none are provided - collections: Optional[List[str]] = None - ids: Optional[List[str]] = None - # Override default field extension to include default fields and pydantic includes/excludes factory - fields: FieldsExtension = Field(FieldsExtension()) - # Override query extension with supported operators - query: Optional[Dict[str, Dict[Operator, Any]]] - filter: Optional[Dict] - token: Optional[str] = None datetime: Optional[str] = None - sortby: Any - limit: Optional[conint(gt=0, le=10000)] = 10 - - @root_validator(pre=True) - def validate_query_fields(cls, values: Dict) -> Dict: - """Pgstac does not require the base validator for query fields.""" - return values @validator("datetime") def validate_datetime(cls, v): diff --git a/stac_fastapi/pgstac/tests/conftest.py b/stac_fastapi/pgstac/tests/conftest.py index a82af074a..12f8274f2 100644 --- a/stac_fastapi/pgstac/tests/conftest.py +++ b/stac_fastapi/pgstac/tests/conftest.py @@ -12,15 +12,18 @@ from stac_pydantic import Collection, Item from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import create_get_request_model, create_post_request_model from stac_fastapi.extensions.core import ( FieldsExtension, - QueryExtension, + FilterExtension, SortExtension, + TokenPaginationExtension, TransactionExtension, ) from stac_fastapi.pgstac.config import Settings from stac_fastapi.pgstac.core import CoreCrudClient from stac_fastapi.pgstac.db import close_db_connection, connect_to_db +from stac_fastapi.pgstac.extensions import QueryExtension from stac_fastapi.pgstac.transactions import TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch @@ -82,16 +85,23 @@ async def pgstac(pg): @pytest.fixture(scope="session") def api_client(pg): print("creating client with settings") + + extensions = [ + TransactionExtension(client=TransactionsClient(), settings=settings), + QueryExtension(), + FilterExtension(), + SortExtension(), + FieldsExtension(), + TokenPaginationExtension(), + ] + post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) + api = StacApi( settings=settings, - extensions=[ - TransactionExtension(client=TransactionsClient(), settings=settings), - QueryExtension(), - SortExtension(), - FieldsExtension(), - ], - client=CoreCrudClient(), - search_request_model=PgstacSearch, + extensions=extensions, + client=CoreCrudClient(post_request_model=post_request_model), + search_get_request_model=create_get_request_model(extensions), + search_post_request_model=post_request_model, response_class=ORJSONResponse, ) diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py index c70231171..09fdf2e73 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py @@ -1,36 +1,43 @@ """FastAPI application.""" from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import create_get_request_model, create_post_request_model from stac_fastapi.extensions.core import ( FieldsExtension, - QueryExtension, SortExtension, + TokenPaginationExtension, TransactionExtension, ) from stac_fastapi.extensions.third_party import BulkTransactionExtension from stac_fastapi.sqlalchemy.config import SqlalchemySettings from stac_fastapi.sqlalchemy.core import CoreCrudClient +from stac_fastapi.sqlalchemy.extensions import QueryExtension from stac_fastapi.sqlalchemy.session import Session from stac_fastapi.sqlalchemy.transactions import ( BulkTransactionsClient, TransactionsClient, ) -from stac_fastapi.sqlalchemy.types.search import SQLAlchemySTACSearch settings = SqlalchemySettings() session = Session.create_from_settings(settings) +extensions = [ + TransactionExtension(client=TransactionsClient(session=session), settings=settings), + BulkTransactionExtension(client=BulkTransactionsClient(session=session)), + FieldsExtension(), + QueryExtension(), + SortExtension(), + TokenPaginationExtension(), +] + +post_request_model = create_post_request_model(extensions) + api = StacApi( settings=settings, - extensions=[ - TransactionExtension( - client=TransactionsClient(session=session), settings=settings - ), - BulkTransactionExtension(client=BulkTransactionsClient(session=session)), - FieldsExtension(), - QueryExtension(), - SortExtension(), - ], - client=CoreCrudClient(session=session), - search_request_model=SQLAlchemySTACSearch, + extensions=extensions, + client=CoreCrudClient( + session=session, extensions=extensions, post_request_model=post_request_model + ), + search_get_request_model=create_get_request_model(extensions), + search_post_request_model=post_request_model, ) app = api.app diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py index 3102b4f6c..b43241d51 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py @@ -21,13 +21,14 @@ from stac_pydantic.shared import MimeTypes from stac_fastapi.sqlalchemy import serializers +from stac_fastapi.sqlalchemy.extensions.query import Operator from stac_fastapi.sqlalchemy.models import database from stac_fastapi.sqlalchemy.session import Session from stac_fastapi.sqlalchemy.tokens import PaginationTokenClient -from stac_fastapi.sqlalchemy.types.search import Operator, SQLAlchemySTACSearch from stac_fastapi.types.config import Settings from stac_fastapi.types.core import BaseCoreClient from stac_fastapi.types.errors import NotFoundError +from stac_fastapi.types.search import BaseSearchPostRequest from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection logger = logging.getLogger(__name__) @@ -90,15 +91,15 @@ def all_collections(self, **kwargs) -> Collections: ) return collection_list - def get_collection(self, id: str, **kwargs) -> Collection: + def get_collection(self, collection_id: str, **kwargs) -> Collection: """Get collection by id.""" base_url = str(kwargs["request"].base_url) with self.session.reader.context_session() as session: - collection = self._lookup_id(id, self.collection_table, session) + collection = self._lookup_id(collection_id, self.collection_table, session) return self.collection_serializer.db_to_stac(collection, base_url) def item_collection( - self, id: str, limit: int = 10, token: str = None, **kwargs + self, collection_id: str, limit: int = 10, token: str = None, **kwargs ) -> ItemCollection: """Read an item collection from the database.""" base_url = str(kwargs["request"].base_url) @@ -106,7 +107,7 @@ def item_collection( collection_children = ( session.query(self.item_table) .join(self.collection_table) - .filter(self.collection_table.id == id) + .filter(self.collection_table.id == collection_id) .order_by(self.item_table.datetime.desc(), self.item_table.id) ) count = None @@ -135,7 +136,7 @@ def item_collection( { "rel": Relations.next.value, "type": "application/geo+json", - "href": f"{kwargs['request'].base_url}collections/{id}/items?token={page.next}&limit={limit}", + "href": f"{kwargs['request'].base_url}collections/{collection_id}/items?token={page.next}&limit={limit}", "method": "GET", } ) @@ -144,7 +145,7 @@ def item_collection( { "rel": Relations.previous.value, "type": "application/geo+json", - "href": f"{kwargs['request'].base_url}collections/{id}/items?token={page.previous}&limit={limit}", + "href": f"{kwargs['request'].base_url}collections/{collection_id}/items?token={page.previous}&limit={limit}", "method": "GET", } ) @@ -179,7 +180,7 @@ def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: db_query = db_query.filter(self.item_table.id == item_id) item = db_query.first() if not item: - raise NotFoundError(f"{self.item_table.__name__} {id} not found") + raise NotFoundError(f"{self.item_table.__name__} {item_id} not found") return self.item_serializer.db_to_stac(item, base_url=base_url) def get_search( @@ -233,7 +234,7 @@ def get_search( # Do the request try: - search_request = SQLAlchemySTACSearch(**base_args) + search_request = self.post_request_model(**base_args) except ValidationError: raise HTTPException(status_code=400, detail="Invalid parameters provided") resp = self.post_search(search_request, request=kwargs["request"]) @@ -256,7 +257,7 @@ def get_search( return resp def post_search( - self, search_request: SQLAlchemySTACSearch, **kwargs + self, search_request: BaseSearchPostRequest, **kwargs ) -> ItemCollection: """POST search catalog.""" base_url = str(kwargs["request"].base_url) @@ -428,12 +429,12 @@ def post_search( for k in search_request.query.keys() ] ) - if not search_request.field.include: - search_request.field.include = query_include + if not search_request.fields.include: + search_request.fields.include = query_include else: - search_request.field.include.union(query_include) + search_request.fields.include.union(query_include) - filter_kwargs = search_request.field.filter_fields + filter_kwargs = search_request.fields.filter_fields # Need to pass through `.json()` for proper serialization # of datetime response_features = [ diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/__init__.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/__init__.py new file mode 100644 index 000000000..d97a001cd --- /dev/null +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/__init__.py @@ -0,0 +1,5 @@ +"""sqlalchemy extensions modifications.""" + +from .query import Operator, QueryableTypes, QueryExtension + +__all__ = ["Operator", "QueryableTypes", "QueryExtension"] diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/types/search.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/query.py similarity index 53% rename from stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/types/search.py rename to stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/query.py index 7acc4f3fa..36f7a7710 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/types/search.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/extensions/query.py @@ -1,4 +1,4 @@ -"""stac_fastapi.types.search module. +"""STAC SQLAlchemy specific query search model. # TODO: replace with stac-pydantic """ @@ -8,16 +8,14 @@ from dataclasses import dataclass from enum import auto from types import DynamicClassAttribute -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, Optional, Union import sqlalchemy as sa -from pydantic import Field, ValidationError, conint, root_validator +from pydantic import BaseModel, ValidationError, root_validator from pydantic.error_wrappers import ErrorWrapper -from stac_pydantic.api import Search -from stac_pydantic.api.extensions.fields import FieldsExtension as FieldsBase from stac_pydantic.utils import AutoValueEnum -from stac_fastapi.types.config import Settings +from stac_fastapi.extensions.core.query import QueryExtension as QueryExtensionBase logger = logging.getLogger("uvicorn") logger.setLevel(logging.INFO) @@ -34,6 +32,7 @@ class Operator(str, AutoValueEnum): lte = auto() gt = auto() gte = auto() + # TODO: These are defined in the spec but aren't currently implemented by the api # startsWith = auto() # endsWith = auto() @@ -86,66 +85,14 @@ class QueryableTypes: dtype = sa.String -class FieldsExtension(FieldsBase): - """FieldsExtension. +class QueryExtensionPostRequest(BaseModel): + """Queryable validation. - Attributes: - include: set of fields to include. - exclude: set of fields to exclude. + Add queryables validation to the POST request + to raise errors for unsupported querys. """ - include: Optional[Set[str]] = set() - exclude: Optional[Set[str]] = set() - - @staticmethod - def _get_field_dict(fields: Optional[Set[str]]) -> Dict: - """Pydantic include/excludes notation. - - Internal method to create a dictionary for advanced include or exclude of pydantic fields on model export - Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude - """ - field_dict = {} - for field in fields or []: - if "." in field: - parent, key = field.split(".") - if parent not in field_dict: - field_dict[parent] = {key} - else: - field_dict[parent].add(key) - else: - field_dict[field] = ... # type:ignore - return field_dict - - @property - def filter_fields(self) -> Dict: - """Create pydantic include/exclude expression. - - Create dictionary of fields to include/exclude on model export based on the included and excluded fields passed - to the API - Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude - """ - # Always include default_includes, even if they - # exist in the exclude list. - include = (self.include or set()) - (self.exclude or set()) - include |= Settings.get().default_includes or set() - - return { - "include": self._get_field_dict(include), - "exclude": self._get_field_dict(self.exclude), - } - - -class SQLAlchemySTACSearch(Search): - """Search model.""" - - # Make collections optional, default to searching all collections if none are provided - collections: Optional[List[str]] = None - # Override default field extension to include default fields and pydantic includes/excludes factory - field: FieldsExtension = Field(FieldsExtension(), alias="fields") - # Override query extension with supported operators query: Optional[Dict[Queryables, Dict[Operator, Any]]] - token: Optional[str] = None - limit: Optional[conint(gt=0, le=10000)] = 10 @root_validator(pre=True) def validate_query_fields(cls, values: Dict) -> Dict: @@ -162,6 +109,16 @@ def validate_query_fields(cls, values: Dict) -> Dict: "STACSearch", ) ], - SQLAlchemySTACSearch, + QueryExtensionPostRequest, ) return values + + +class QueryExtension(QueryExtensionBase): + """Query Extenson. + + Override the POST request model to add validation against + supported fields + """ + + POST = QueryExtensionPostRequest diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/models/database.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/models/database.py index 128a145c7..e521d453f 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/models/database.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/models/database.py @@ -8,7 +8,7 @@ from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.declarative import declarative_base -from stac_fastapi.sqlalchemy.types.search import Queryables, QueryableTypes +from stac_fastapi.sqlalchemy.extensions.query import Queryables, QueryableTypes BaseModel = declarative_base() diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py index 555458869..47b6cd6dc 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py @@ -98,16 +98,16 @@ def delete_item( query.delete() return self.item_serializer.db_to_stac(data, base_url=base_url) - def delete_collection(self, id: str, **kwargs) -> stac_types.Collection: + def delete_collection(self, collection_id: str, **kwargs) -> stac_types.Collection: """Delete collection.""" base_url = str(kwargs["request"].base_url) with self.session.writer.context_session() as session: query = session.query(self.collection_table).filter( - self.collection_table.id == id + self.collection_table.id == collection_id ) data = query.first() if not data: - raise NotFoundError(f"Collection {id} not found") + raise NotFoundError(f"Collection {collection_id} not found") query.delete() return self.collection_serializer.db_to_stac(data, base_url=base_url) diff --git a/stac_fastapi/sqlalchemy/tests/conftest.py b/stac_fastapi/sqlalchemy/tests/conftest.py index 795868e2e..7abd9150f 100644 --- a/stac_fastapi/sqlalchemy/tests/conftest.py +++ b/stac_fastapi/sqlalchemy/tests/conftest.py @@ -6,23 +6,25 @@ from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import create_request_model from stac_fastapi.extensions.core import ( ContextExtension, FieldsExtension, - QueryExtension, SortExtension, + TokenPaginationExtension, TransactionExtension, ) from stac_fastapi.sqlalchemy.config import SqlalchemySettings from stac_fastapi.sqlalchemy.core import CoreCrudClient +from stac_fastapi.sqlalchemy.extensions import QueryExtension from stac_fastapi.sqlalchemy.models import database from stac_fastapi.sqlalchemy.session import Session from stac_fastapi.sqlalchemy.transactions import ( BulkTransactionsClient, TransactionsClient, ) -from stac_fastapi.sqlalchemy.types.search import SQLAlchemySTACSearch from stac_fastapi.types.config import Settings +from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest DATA_DIR = os.path.join(os.path.dirname(__file__), "data") @@ -105,19 +107,41 @@ def postgres_bulk_transactions(db_session): @pytest.fixture def api_client(db_session): settings = SqlalchemySettings() + extensions = [ + TransactionExtension( + client=TransactionsClient(session=db_session), settings=settings + ), + ContextExtension(), + SortExtension(), + FieldsExtension(), + QueryExtension(), + TokenPaginationExtension(), + ] + + get_request_model = create_request_model( + "SearchGetRequest", + base_model=BaseSearchGetRequest, + extensions=extensions, + request_type="GET", + ) + + post_request_model = create_request_model( + "SearchPostRequest", + base_model=BaseSearchPostRequest, + extensions=extensions, + request_type="POST", + ) + return StacApi( settings=settings, - client=CoreCrudClient(session=db_session), - extensions=[ - TransactionExtension( - client=TransactionsClient(session=db_session), settings=settings - ), - ContextExtension(), - SortExtension(), - FieldsExtension(), - QueryExtension(), - ], - search_request_model=SQLAlchemySTACSearch, + client=CoreCrudClient( + session=db_session, + extensions=extensions, + post_request_model=post_request_model, + ), + extensions=extensions, + search_get_request_model=get_request_model, + search_post_request_model=post_request_model, ) diff --git a/stac_fastapi/types/stac_fastapi/types/core.py b/stac_fastapi/types/stac_fastapi/types/core.py index 615587eb5..2cf08277b 100644 --- a/stac_fastapi/types/stac_fastapi/types/core.py +++ b/stac_fastapi/types/stac_fastapi/types/core.py @@ -14,6 +14,7 @@ from stac_fastapi.types import stac as stac_types from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES from stac_fastapi.types.extension import ApiExtension +from stac_fastapi.types.search import BaseSearchPostRequest from stac_fastapi.types.stac import Conformance NumType = Union[float, int] @@ -306,6 +307,7 @@ class BaseCoreClient(LandingPageMixin, abc.ABC): factory=lambda: BASE_CONFORMANCE_CLASSES ) extensions: List[ApiExtension] = attr.ib(default=attr.Factory(list)) + post_request_model = attr.ib(default=BaseSearchPostRequest) def conformance_classes(self) -> List[str]: """Generate conformance classes by adding extension conformance to base conformance classes.""" @@ -495,6 +497,7 @@ class AsyncBaseCoreClient(LandingPageMixin, abc.ABC): factory=lambda: BASE_CONFORMANCE_CLASSES ) extensions: List[ApiExtension] = attr.ib(default=attr.Factory(list)) + post_request_model = attr.ib(default=BaseSearchPostRequest) def conformance_classes(self) -> List[str]: """Generate conformance classes by adding extension conformance to base conformance classes.""" diff --git a/stac_fastapi/types/stac_fastapi/types/extension.py b/stac_fastapi/types/stac_fastapi/types/extension.py index a1beb62bf..1e4774b4c 100644 --- a/stac_fastapi/types/stac_fastapi/types/extension.py +++ b/stac_fastapi/types/stac_fastapi/types/extension.py @@ -4,12 +4,23 @@ import attr from fastapi import FastAPI +from pydantic import BaseModel @attr.s class ApiExtension(abc.ABC): """Abstract base class for defining API extensions.""" + GET = None + POST = None + + def get_request_model(self, verb: Optional[str] = "GET") -> Optional[BaseModel]: + """Return the request model for the extension.method. + + The model can differ based on HTTP verb + """ + return getattr(self, verb) + conformance_classes: List[str] = attr.ib(factory=list) schema_href: Optional[str] = attr.ib(default=None) diff --git a/stac_fastapi/types/stac_fastapi/types/search.py b/stac_fastapi/types/stac_fastapi/types/search.py index 631cb5c65..3ef9a80c1 100644 --- a/stac_fastapi/types/stac_fastapi/types/search.py +++ b/stac_fastapi/types/stac_fastapi/types/search.py @@ -3,18 +3,28 @@ # TODO: replace with stac-pydantic """ +import abc import operator +from datetime import datetime from enum import auto from types import DynamicClassAttribute -from typing import Any, Callable, Dict, List, Optional, Set, Union - -from pydantic import Field, root_validator -from stac_pydantic.api import Search -from stac_pydantic.api.extensions.fields import FieldsExtension as FieldsBase +from typing import Any, Callable, Dict, List, Optional, Union + +import attr +from geojson_pydantic.geometries import ( + LineString, + MultiLineString, + MultiPoint, + MultiPolygon, + Point, + Polygon, + _GeometryBase, +) +from pydantic import BaseModel, conint, validator +from pydantic.datetime_parse import parse_datetime +from stac_pydantic.shared import BBox from stac_pydantic.utils import AutoValueEnum -from stac_fastapi.types.config import Settings - # Be careful: https://github.com/samuelcolvin/pydantic/issues/1423#issuecomment-642797287 NumType = Union[float, int] @@ -40,67 +50,160 @@ def operator(self) -> Callable[[Any, Any], bool]: return getattr(operator, self._value_) -class FieldsExtension(FieldsBase): - """FieldsExtension. +def str2list(x: str) -> Optional[List]: + """Convert string to list base on , delimiter.""" + if x: + return x.split(",") + + +@attr.s # type:ignore +class APIRequest(abc.ABC): + """Generic API Request base class.""" + + def kwargs(self) -> Dict: + """Transform api request params into format which matches the signature of the endpoint.""" + return self.__dict__ + + +@attr.s +class BaseSearchGetRequest(APIRequest): + """Base arguments for GET Request.""" - Attributes: - include: set of fields to include. - exclude: set of fields to exclude. + collections: Optional[str] = attr.ib(default=None, converter=str2list) + ids: Optional[str] = attr.ib(default=None, converter=str2list) + bbox: Optional[str] = attr.ib(default=None, converter=str2list) + intersects: Optional[str] = attr.ib(default=None, converter=str2list) + datetime: Optional[Union[str]] = attr.ib(default=None) + limit: Optional[int] = attr.ib(default=10) + + +class BaseSearchPostRequest(BaseModel): + """Search model. + + Replace base model in STAC-pydantic as it includes additional fields, + not in the core model. + https://github.com/radiantearth/stac-api-spec/tree/master/item-search#query-parameter-table + + PR to fix this: + https://github.com/stac-utils/stac-pydantic/pull/100 """ - include: Optional[Set[str]] = set() - exclude: Optional[Set[str]] = set() + collections: Optional[List[str]] + ids: Optional[List[str]] + bbox: Optional[BBox] + intersects: Optional[ + Union[Point, MultiPoint, LineString, MultiLineString, Polygon, MultiPolygon] + ] + datetime: Optional[str] + limit: Optional[conint(gt=0, le=10000)] = 10 - @staticmethod - def _get_field_dict(fields: Optional[Set[str]]) -> Dict: - """Pydantic include/excludes notation. + @property + def start_date(self) -> Optional[datetime]: + """Extract the start date from the datetime string.""" + if not self.datetime: + return + + values = self.datetime.split("/") + if len(values) == 1: + return None + if values[0] == "..": + return None + return parse_datetime(values[0]) - Internal method to create a dictionary for advanced include or exclude of pydantic fields on model export - Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude - """ - field_dict = {} - for field in fields or []: - if "." in field: - parent, key = field.split(".") - if parent not in field_dict: - field_dict[parent] = {key} - else: - field_dict[parent].add(key) + @property + def end_date(self) -> Optional[datetime]: + """Extract the end date from the datetime string.""" + if not self.datetime: + return + + values = self.datetime.split("/") + if len(values) == 1: + return parse_datetime(values[0]) + if values[1] == "..": + return None + return parse_datetime(values[1]) + + @validator("intersects") + def validate_spatial(cls, v, values): + """Check bbox and intersects are not both supplied.""" + if v and values["bbox"]: + raise ValueError("intersects and bbox parameters are mutually exclusive") + return v + + @validator("bbox") + def validate_bbox(cls, v: BBox): + """Check order of supplied bbox coordinates.""" + if v: + # Validate order + if len(v) == 4: + xmin, ymin, xmax, ymax = v else: - field_dict[field] = ... # type:ignore - return field_dict + xmin, ymin, min_elev, xmax, ymax, max_elev = v + if max_elev < min_elev: + raise ValueError( + "Maximum elevation must greater than minimum elevation" + ) + + if xmax < xmin: + raise ValueError( + "Maximum longitude must be greater than minimum longitude" + ) + + if ymax < ymin: + raise ValueError( + "Maximum longitude must be greater than minimum longitude" + ) + + # Validate against WGS84 + if xmin < -180 or ymin < -90 or xmax > 180 or ymax > 90: + raise ValueError("Bounding box must be within (-180, -90, 180, 90)") + + return v + + @validator("datetime") + def validate_datetime(cls, v): + """Validate datetime.""" + if "/" in v: + values = v.split("/") + else: + # Single date is interpreted as end date + values = ["..", v] + + dates = [] + for value in values: + if value == "..": + dates.append(value) + continue + + parse_datetime(value) + dates.append(value) + + if ".." not in dates: + if parse_datetime(dates[0]) > parse_datetime(dates[1]): + raise ValueError( + "Invalid datetime range, must match format (begin_date, end_date)" + ) + + return v @property - def filter_fields(self) -> Dict: - """Create pydantic include/exclude expression. + def spatial_filter(self) -> Optional[_GeometryBase]: + """Return a geojson-pydantic object representing the spatial filter for the search request. - Create dictionary of fields to include/exclude on model export based on the included and excluded fields passed - to the API - Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude + Check for both because the ``bbox`` and ``intersects`` parameters are mutually exclusive. """ - # Always include default_includes, even if they - # exist in the exclude list. - include = (self.include or set()) - (self.exclude or set()) - include |= Settings.get().default_includes or set() - - return { - "include": self._get_field_dict(include), - "exclude": self._get_field_dict(self.exclude), - } - - -class STACSearch(Search): - """Search model.""" - - # Make collections optional, default to searching all collections if none are provided - collections: Optional[List[str]] = None - # Override default field extension to include default fields and pydantic includes/excludes factory - field: FieldsExtension = Field(FieldsExtension(), alias="fields") - # Override query extension with supported operators - query: Optional[Dict[str, Dict[Operator, Any]]] - token: Optional[str] = None - - @root_validator(pre=True) - def validate_query_fields(cls, values: Dict) -> Dict: - """Validate query fields (placeholder).""" - return values + if self.bbox: + return Polygon( + coordinates=[ + [ + [self.bbox[0], self.bbox[3]], + [self.bbox[2], self.bbox[3]], + [self.bbox[2], self.bbox[1]], + [self.bbox[0], self.bbox[1]], + [self.bbox[0], self.bbox[3]], + ] + ] + ) + if self.intersects: + return self.intersects + return