diff --git a/CHANGES.md b/CHANGES.md index d19e1607f..649bd2edf 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,12 @@ ## [Unreleased] - TBD +## [3.0.0b2] - 2024-07-09 + +### Changed + +* move back to `@attrs` (instead of dataclass) for `APIRequest` (model for GET request) class type [#729](https://github.com/stac-utils/stac-fastapi/pull/729) + ## [3.0.0b1] - 2024-07-05 ### Added @@ -432,7 +438,8 @@ * First PyPi release! -[Unreleased]: +[Unreleased]: +[3.0.0b2]: [3.0.0b1]: [3.0.0a4]: [3.0.0a3]: diff --git a/docs/src/migrations/v3.0.0.md b/docs/src/migrations/v3.0.0.md index 0cb66653a..e9b2ee649 100644 --- a/docs/src/migrations/v3.0.0.md +++ b/docs/src/migrations/v3.0.0.md @@ -23,49 +23,6 @@ In addition to pydantic v2 update, `stac-pydantic` has been updated to better ma * `PostFieldsExtension.filter_fields` property has been removed. -## `attr` -> `dataclass` for APIRequest models - -Models for **GET** requests, defining the path and query parameters, now uses python `dataclass` instead of `attr`. - -```python -# before -@attr.s -class CollectionModel(APIRequest): - collections: Optional[str] = attr.ib(default=None, converter=str2list) - -# now -@dataclass -class CollectionModel(APIRequest): - collections: Annotated[Optional[str], Query()] = None - - def __post_init__(self): - """convert attributes.""" - if self.collections: - self.collections = str2list(self.collections) # type: ignore - -``` - -!!! warning - - if you want to extend a class with a `required` attribute (without default), you will have to write all the attributes to avoid having *non-default* attributes defined after *default* attributes (ref: https://github.com/stac-utils/stac-fastapi/pull/714/files#r1651557338) - - ```python - @dataclass - class A: - value: Annotated[str, Query()] - - # THIS WON'T WORK - @dataclass - class B(A): - another_value: Annotated[str, Query(...)] - - # DO THIS - @dataclass - class B(A): - another_value: Annotated[str, Query(...)] - value: Annotated[str, Query()] - ``` - ## Middlewares configuration The `StacApi.middlewares` attribute has been updated to accept a list of `starlette.middleware.Middleware`. This enables dynamic configuration of middlewares (see https://github.com/stac-utils/stac-fastapi/pull/442). @@ -113,9 +70,9 @@ stac = StacApi( ) # now -@dataclass +@attr.s class CollectionsRequest(APIRequest): - user: str = Query(...) + user: Annotated[str, Query(...)] = attr.ib() stac = StacApi( search_get_request_model=getSearchModel, @@ -127,6 +84,37 @@ stac = StacApi( ) ``` +## APIRequest - GET Request Model + +Most of the **GET** endpoints are configured with `stac_fastapi.types.search.APIRequest` base class. + +e.g the BaseSearchGetRequest, default for the `GET - /search` endpoint: + +```python +@attr.s +class BaseSearchGetRequest(APIRequest): + """Base arguments for GET Request.""" + + collections: Annotated[Optional[str], Query()] = attr.ib( + default=None, converter=str2list + ) + ids: Annotated[Optional[str], Query()] = attr.ib(default=None, converter=str2list) + bbox: Annotated[Optional[BBox], Query()] = attr.ib(default=None, converter=str2bbox) + intersects: Annotated[Optional[str], Query()] = attr.ib(default=None) + datetime: Annotated[Optional[DateTimeType], Query()] = attr.ib( + default=None, converter=str_to_interval + ) + limit: Annotated[Optional[int], Query()] = attr.ib(default=10) +``` + +We use [*python attrs*](https://www.attrs.org/en/stable/) to construct those classes. **Type Hint** for each attribute is important and should be defined using `Annotated[{type}, fastapi.Query()]` form. + +```python +@attr.s +class SomeRequest(APIRequest): + user_number: Annotated[Optional[int], Query(alias="user-number")] = attr.ib(default=None) +``` + ## Filter extension `default_includes` attribute has been removed from the `ApiSettings` object. If you need `defaults` includes you can overwrite the `FieldExtension` models (see https://github.com/stac-utils/stac-fastapi/pull/706). diff --git a/stac_fastapi/api/stac_fastapi/api/models.py b/stac_fastapi/api/stac_fastapi/api/models.py index 1c2146d44..737089c2d 100644 --- a/stac_fastapi/api/stac_fastapi/api/models.py +++ b/stac_fastapi/api/stac_fastapi/api/models.py @@ -1,8 +1,8 @@ """Api request/response models.""" -from dataclasses import dataclass, make_dataclass from typing import List, Optional, Type, Union +import attr from fastapi import Path, Query from pydantic import BaseModel, create_model from stac_pydantic.shared import BBox @@ -43,11 +43,11 @@ def create_request_model( mixins = mixins or [] - models = extension_models + mixins + [base_model] + models = [base_model] + extension_models + mixins # Handle GET requests if all([issubclass(m, APIRequest) for m in models]): - return make_dataclass(model_name, [], bases=tuple(models)) + return attr.make_class(model_name, attrs={}, bases=tuple(models)) # Handle POST requests elif all([issubclass(m, BaseModel) for m in models]): @@ -86,43 +86,38 @@ def create_post_request_model( ) -@dataclass +@attr.s class CollectionUri(APIRequest): """Get or delete collection.""" - collection_id: Annotated[str, Path(description="Collection ID")] + collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib() -@dataclass +@attr.s class ItemUri(APIRequest): """Get or delete item.""" - collection_id: Annotated[str, Path(description="Collection ID")] - item_id: Annotated[str, Path(description="Item ID")] + collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib() + item_id: Annotated[str, Path(description="Item ID")] = attr.ib() -@dataclass +@attr.s class EmptyRequest(APIRequest): """Empty request.""" ... -@dataclass +@attr.s class ItemCollectionUri(APIRequest): """Get item collection.""" - collection_id: Annotated[str, Path(description="Collection ID")] - limit: Annotated[int, Query()] = 10 - bbox: Annotated[Optional[BBox], Query()] = None - datetime: Annotated[Optional[DateTimeType], Query()] = None - - def __post_init__(self): - """convert attributes.""" - if self.bbox: - self.bbox = str2bbox(self.bbox) # type: ignore - if self.datetime: - self.datetime = str_to_interval(self.datetime) # type: ignore + collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib() + limit: Annotated[int, Query()] = attr.ib(default=10) + bbox: Annotated[Optional[BBox], Query()] = attr.ib(default=None, converter=str2bbox) + datetime: Annotated[Optional[DateTimeType], Query()] = attr.ib( + default=None, converter=str_to_interval + ) class GeoJSONResponse(JSONResponse): diff --git a/stac_fastapi/api/tests/test_app.py b/stac_fastapi/api/tests/test_app.py index 9fb2c52e0..0ddcb2429 100644 --- a/stac_fastapi/api/tests/test_app.py +++ b/stac_fastapi/api/tests/test_app.py @@ -1,12 +1,13 @@ -from dataclasses import dataclass from datetime import datetime from typing import List, Optional, Union +import attr import pytest from fastapi import Path, Query from fastapi.testclient import TestClient from pydantic import ValidationError from stac_pydantic import api +from typing_extensions import Annotated from stac_fastapi.api import app from stac_fastapi.api.models import ( @@ -328,25 +329,25 @@ def item_collection( def test_request_model(AsyncTestCoreClient): """Test if request models are passed correctly.""" - @dataclass + @attr.s class CollectionsRequest(APIRequest): - user: str = Query(...) + user: Annotated[str, Query(...)] = attr.ib() - @dataclass + @attr.s class CollectionRequest(APIRequest): - collection_id: str = Path(description="Collection ID") - user: str = Query(...) + collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib() + user: Annotated[str, Query(...)] = attr.ib() - @dataclass + @attr.s class ItemsRequest(APIRequest): - collection_id: str = Path(description="Collection ID") - user: str = Query(...) + collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib() + user: Annotated[str, Query(...)] = attr.ib() - @dataclass + @attr.s class ItemRequest(APIRequest): - collection_id: str = Path(description="Collection ID") - item_id: str = Path(description="Item ID") - user: str = Query(...) + collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib() + item_id: Annotated[str, Path(description="Item ID")] = attr.ib() + user: Annotated[str, Query(...)] = attr.ib() test_app = app.StacApi( settings=ApiSettings(), diff --git a/stac_fastapi/api/tests/test_models.py b/stac_fastapi/api/tests/test_models.py index 24ed59a18..b0c2ad90e 100644 --- a/stac_fastapi/api/tests/test_models.py +++ b/stac_fastapi/api/tests/test_models.py @@ -1,19 +1,20 @@ import json import pytest -from fastapi import Depends, FastAPI +from fastapi import Depends, FastAPI, HTTPException from fastapi.testclient import TestClient from pydantic import ValidationError from stac_fastapi.api.models import create_get_request_model, create_post_request_model -from stac_fastapi.extensions.core.filter.filter import FilterExtension -from stac_fastapi.extensions.core.sort.sort import SortExtension +from stac_fastapi.extensions.core import FieldsExtension, FilterExtension, SortExtension from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest def test_create_get_request_model(): - extensions = [FilterExtension()] - request_model = create_get_request_model(extensions, BaseSearchGetRequest) + request_model = create_get_request_model( + extensions=[FilterExtension(), FieldsExtension()], + base_model=BaseSearchGetRequest, + ) model = request_model( collections="test1,test2", @@ -35,6 +36,9 @@ def test_create_get_request_model(): assert model.collections == ["test1", "test2"] assert model.filter_crs == "epsg:4326" + with pytest.raises(HTTPException): + request_model(datetime="yo") + app = FastAPI() @app.get("/test") @@ -62,8 +66,10 @@ def route(model=Depends(request_model)): [(None, True), ({"test": "test"}, True), ("test==test", False), ([], False)], ) def test_create_post_request_model(filter, passes): - extensions = [FilterExtension()] - request_model = create_post_request_model(extensions, BaseSearchPostRequest) + request_model = create_post_request_model( + extensions=[FilterExtension(), FieldsExtension()], + base_model=BaseSearchPostRequest, + ) if not passes: with pytest.raises(ValidationError): @@ -100,8 +106,10 @@ def test_create_post_request_model(filter, passes): ], ) def test_create_post_request_model_nested_fields(sortby, passes): - extensions = [SortExtension()] - request_model = create_post_request_model(extensions, BaseSearchPostRequest) + request_model = create_post_request_model( + extensions=[SortExtension()], + base_model=BaseSearchPostRequest, + ) if not passes: with pytest.raises(ValidationError): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py index 325fc55ee..1f4b6a93b 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py @@ -1,8 +1,8 @@ """Request model for the Aggregation extension.""" -from dataclasses import dataclass from typing import List, Optional +import attr from fastapi import Query from pydantic import Field from typing_extensions import Annotated @@ -14,17 +14,13 @@ ) -@dataclass +@attr.s class AggregationExtensionGetRequest(BaseSearchGetRequest): """Aggregation Extension GET request model.""" - aggregations: Annotated[Optional[str], Query()] = None - - def __post_init__(self): - """convert attributes.""" - super().__post_init__() - if self.aggregations: - self.aggregations = str2list(self.aggregations) # type: ignore + aggregations: Annotated[Optional[str], Query()] = attr.ib( + default=None, converter=str2list + ) class AggregationExtensionPostRequest(BaseSearchPostRequest): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py index a77539c0b..e0c42a574 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py @@ -1,9 +1,9 @@ """Request models for the fields extension.""" import warnings -from dataclasses import dataclass from typing import Dict, Optional, Set +import attr from fastapi import Query from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -70,16 +70,11 @@ def filter_fields(self) -> Dict: } -@dataclass +@attr.s class FieldsExtensionGetRequest(APIRequest): """Additional fields for the GET request.""" - fields: Annotated[Optional[str], Query()] = None - - def __post_init__(self): - """convert attributes.""" - if self.fields: - self.fields = str2list(self.fields) # type: ignore + fields: Annotated[Optional[str], Query()] = attr.ib(default=None, converter=str2list) class FieldsExtensionPostRequest(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py index 970804b6d..917f5f086 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py @@ -1,8 +1,8 @@ """Filter extension request models.""" -from dataclasses import dataclass from typing import Any, Dict, Literal, Optional +import attr from fastapi import Query from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -12,13 +12,17 @@ FilterLang = Literal["cql-json", "cql2-json", "cql2-text"] -@dataclass +@attr.s class FilterExtensionGetRequest(APIRequest): """Filter extension GET request model.""" - filter: Annotated[Optional[str], Query()] = None - filter_crs: Annotated[Optional[str], Query(alias="filter-crs")] = None - filter_lang: Annotated[Optional[FilterLang], Query(alias="filter-lang")] = "cql2-text" + filter: Annotated[Optional[str], Query()] = attr.ib(default=None) + filter_crs: Annotated[Optional[str], Query(alias="filter-crs")] = attr.ib( + default=None + ) + filter_lang: Annotated[Optional[FilterLang], Query(alias="filter-lang")] = attr.ib( + default="cql2-text" + ) class FilterExtensionPostRequest(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/request.py index 94d98df65..66391c7f9 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/request.py @@ -1,8 +1,8 @@ """Pagination extension request models.""" -from dataclasses import dataclass from typing import Optional +import attr from fastapi import Query from pydantic import BaseModel from typing_extensions import Annotated @@ -10,11 +10,11 @@ from stac_fastapi.types.search import APIRequest -@dataclass +@attr.s class GETTokenPagination(APIRequest): """Token pagination for GET requests.""" - token: Annotated[Optional[str], Query()] = None + token: Annotated[Optional[str], Query()] = attr.ib(default=None) class POSTTokenPagination(BaseModel): @@ -23,11 +23,11 @@ class POSTTokenPagination(BaseModel): token: Optional[str] = None -@dataclass +@attr.s class GETPagination(APIRequest): """Page based pagination for GET requests.""" - page: Annotated[Optional[str], Query()] = None + page: Annotated[Optional[str], Query()] = attr.ib(default=None) class POSTPagination(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py index d431b0dea..5d403a677 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py @@ -1,8 +1,8 @@ """Request model for the Query extension.""" -from dataclasses import dataclass from typing import Any, Dict, Optional +import attr from fastapi import Query from pydantic import BaseModel from typing_extensions import Annotated @@ -10,11 +10,11 @@ from stac_fastapi.types.search import APIRequest -@dataclass +@attr.s class QueryExtensionGetRequest(APIRequest): """Query Extension GET request model.""" - query: Annotated[Optional[str], Query()] = None + query: Annotated[Optional[str], Query()] = attr.ib(default=None) class QueryExtensionPostRequest(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py index 7165d2e31..8eeccba0c 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py @@ -1,9 +1,8 @@ -# encoding: utf-8 """Request model for the Sort Extension.""" -from dataclasses import dataclass from typing import List, Optional +import attr from fastapi import Query from pydantic import BaseModel from stac_pydantic.api.extensions.sort import SortExtension as PostSortModel @@ -12,16 +11,11 @@ from stac_fastapi.types.search import APIRequest, str2list -@dataclass +@attr.s class SortExtensionGetRequest(APIRequest): """Sortby Parameter for GET requests.""" - sortby: Annotated[Optional[str], Query()] = None - - def __post_init__(self): - """convert attributes.""" - if self.sortby: - self.sortby = str2list(self.sortby) # type: ignore + sortby: Annotated[Optional[str], Query()] = attr.ib(default=None, converter=str2list) class SortExtensionPostRequest(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py index 27f2291d1..4e940a0ea 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py @@ -1,6 +1,5 @@ """Transaction extension.""" -from dataclasses import dataclass from typing import List, Optional, Type, Union import attr @@ -17,25 +16,25 @@ from stac_fastapi.types.extension import ApiExtension -@dataclass +@attr.s class PostItem(CollectionUri): """Create Item.""" - item: Annotated[Union[Item, ItemCollection], Body()] = None + item: Annotated[Union[Item, ItemCollection], Body()] = attr.ib(default=None) -@dataclass +@attr.s class PutItem(ItemUri): """Update Item.""" - item: Annotated[Item, Body()] = None + item: Annotated[Item, Body()] = attr.ib(default=None) -@dataclass +@attr.s class PutCollection(CollectionUri): """Update Collection.""" - collection: Annotated[Collection, Body()] = None + collection: Annotated[Collection, Body()] = attr.ib(default=None) @attr.s diff --git a/stac_fastapi/types/stac_fastapi/types/search.py b/stac_fastapi/types/stac_fastapi/types/search.py index 649a1a8ef..b8ae23c86 100644 --- a/stac_fastapi/types/stac_fastapi/types/search.py +++ b/stac_fastapi/types/stac_fastapi/types/search.py @@ -2,10 +2,9 @@ """ -import abc -from dataclasses import dataclass from typing import Dict, List, Optional, Union +import attr from fastapi import Query from pydantic import PositiveInt from pydantic.functional_validators import AfterValidator @@ -43,8 +42,8 @@ def str2bbox(x: str) -> Optional[BBox]: Limit = Annotated[PositiveInt, AfterValidator(crop)] -@dataclass -class APIRequest(abc.ABC): +@attr.s +class APIRequest: """Generic API Request base class.""" def kwargs(self) -> Dict: @@ -53,27 +52,20 @@ def kwargs(self) -> Dict: return self.__dict__ -@dataclass +@attr.s class BaseSearchGetRequest(APIRequest): """Base arguments for GET Request.""" - collections: Annotated[Optional[str], Query()] = None - ids: Annotated[Optional[str], Query()] = None - bbox: Annotated[Optional[BBox], Query()] = None - intersects: Annotated[Optional[str], Query()] = None - datetime: Annotated[Optional[DateTimeType], Query()] = None - limit: Annotated[Optional[int], Query()] = 10 - - def __post_init__(self): - """convert attributes.""" - if self.collections: - self.collections = str2list(self.collections) # type: ignore - if self.ids: - self.ids = str2list(self.ids) # type: ignore - if self.bbox: - self.bbox = str2bbox(self.bbox) # type: ignore - if self.datetime: - self.datetime = str_to_interval(self.datetime) # type: ignore + collections: Annotated[Optional[str], Query()] = attr.ib( + default=None, converter=str2list + ) + ids: Annotated[Optional[str], Query()] = attr.ib(default=None, converter=str2list) + bbox: Annotated[Optional[BBox], Query()] = attr.ib(default=None, converter=str2bbox) + intersects: Annotated[Optional[str], Query()] = attr.ib(default=None) + datetime: Annotated[Optional[DateTimeType], Query()] = attr.ib( + default=None, converter=str_to_interval + ) + limit: Annotated[Optional[int], Query()] = attr.ib(default=10) class BaseSearchPostRequest(Search):