Skip to content

Commit a6d149b

Browse files
committed
move from attr to dataclass+fastapi.Query() for GET models
1 parent c92c88a commit a6d149b

File tree

10 files changed

+136
-63
lines changed

10 files changed

+136
-63
lines changed

stac_fastapi/api/stac_fastapi/api/models.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Api request/response models."""
22

33
import importlib.util
4+
from dataclasses import dataclass, make_dataclass
45
from typing import List, Optional, Type, Union
56

6-
import attr
7-
from fastapi import Path
7+
from fastapi import Path, Query
88
from pydantic import BaseModel, create_model
99
from stac_pydantic.shared import BBox
10+
from typing_extensions import Annotated
1011

1112
from stac_fastapi.types.extension import ApiExtension
1213
from stac_fastapi.types.rfc3339 import DateTimeType
@@ -37,11 +38,11 @@ def create_request_model(
3738

3839
mixins = mixins or []
3940

40-
models = [base_model] + extension_models + mixins
41+
models = extension_models + mixins + [base_model]
4142

4243
# Handle GET requests
4344
if all([issubclass(m, APIRequest) for m in models]):
44-
return attr.make_class(model_name, attrs={}, bases=tuple(models))
45+
return make_dataclass(model_name, [], bases=tuple(models))
4546

4647
# Handle POST requests
4748
elif all([issubclass(m, BaseModel) for m in models]):
@@ -80,34 +81,43 @@ def create_post_request_model(
8081
)
8182

8283

83-
@attr.s # type:ignore
84+
@dataclass
8485
class CollectionUri(APIRequest):
8586
"""Get or delete collection."""
8687

87-
collection_id: str = attr.ib(default=Path(..., description="Collection ID"))
88+
collection_id: Annotated[str, Path(description="Collection ID")]
8889

8990

90-
@attr.s
91-
class ItemUri(CollectionUri):
91+
@dataclass
92+
class ItemUri(APIRequest):
9293
"""Get or delete item."""
9394

94-
item_id: str = attr.ib(default=Path(..., description="Item ID"))
95+
collection_id: Annotated[str, Path(description="Collection ID")]
96+
item_id: Annotated[str, Path(description="Item ID")]
9597

9698

97-
@attr.s
99+
@dataclass
98100
class EmptyRequest(APIRequest):
99101
"""Empty request."""
100102

101103
...
102104

103105

104-
@attr.s
105-
class ItemCollectionUri(CollectionUri):
106+
@dataclass
107+
class ItemCollectionUri(APIRequest):
106108
"""Get item collection."""
107109

108-
limit: int = attr.ib(default=10)
109-
bbox: Optional[BBox] = attr.ib(default=None, converter=str2bbox)
110-
datetime: Optional[DateTimeType] = attr.ib(default=None, converter=str_to_interval)
110+
collection_id: Annotated[str, Path(description="Collection ID")]
111+
limit: Annotated[int, Query()] = 10
112+
bbox: Annotated[Optional[BBox], Query()] = None
113+
datetime: Annotated[Optional[DateTimeType], Query()] = None
114+
115+
def __post_init__(self):
116+
"""convert attributes."""
117+
if self.bbox:
118+
self.bbox = str2bbox(self.bbox) # type: ignore
119+
if self.datetime:
120+
self.datetime = str_to_interval(self.datetime) # type: ignore
111121

112122

113123
class POSTTokenPagination(BaseModel):
@@ -116,11 +126,11 @@ class POSTTokenPagination(BaseModel):
116126
token: Optional[str] = None
117127

118128

119-
@attr.s
129+
@dataclass
120130
class GETTokenPagination(APIRequest):
121131
"""Token pagination for GET requests."""
122132

123-
token: Optional[str] = attr.ib(default=None)
133+
token: Annotated[Optional[str], Query()] = None
124134

125135

126136
class POSTPagination(BaseModel):
@@ -129,11 +139,11 @@ class POSTPagination(BaseModel):
129139
page: Optional[str] = None
130140

131141

132-
@attr.s
142+
@dataclass
133143
class GETPagination(APIRequest):
134144
"""Page based pagination for GET requests."""
135145

136-
page: Optional[str] = attr.ib(default=None)
146+
page: Annotated[Optional[str], Query()] = None
137147

138148

139149
# Test for ORJSON and use it rather than stdlib JSON where supported

stac_fastapi/api/tests/test_models.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22

33
import pytest
4+
from fastapi import Depends, FastAPI
5+
from fastapi.testclient import TestClient
46
from pydantic import ValidationError
57

68
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
@@ -26,13 +28,33 @@ def test_create_get_request_model():
2628
datetime="2020-01-01T00:00:00Z",
2729
limit=10,
2830
filter="test==test",
29-
# FIXME: https://github.com/stac-utils/stac-fastapi/issues/638
30-
# hyphen aliases are not properly working
31-
# **{"filter-crs": "epsg:4326", "filter-lang": "cql2-text"},
31+
filter_crs="epsg:4326",
32+
filter_lang="cql2-text",
3233
)
3334

3435
assert model.collections == ["test1", "test2"]
35-
# assert model.filter_crs == "epsg:4326"
36+
assert model.filter_crs == "epsg:4326"
37+
38+
app = FastAPI()
39+
40+
@app.get("/test")
41+
def route(model=Depends(request_model)):
42+
return model
43+
44+
with TestClient(app) as client:
45+
resp = client.get(
46+
"/test",
47+
params={
48+
"collections": "test1,test2",
49+
"filter-crs": "epsg:4326",
50+
"filter-lang": "cql2-text",
51+
},
52+
)
53+
assert resp.status_code == 200
54+
response_dict = resp.json()
55+
assert response_dict["collections"] == ["test1", "test2"]
56+
assert response_dict["filter_crs"] == "epsg:4326"
57+
assert response_dict["filter_lang"] == "cql2-text"
3658

3759

3860
@pytest.mark.parametrize(
Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
11
"""Request model for the Aggregation extension."""
22

3+
from dataclasses import dataclass
34
from typing import List, Optional, Union
45

5-
import attr
6+
from fastapi import Query
7+
from pydantic import BaseModel, Field
8+
from typing_extensions import Annotated
69

7-
from stac_fastapi.extensions.core.filter.request import (
8-
FilterExtensionGetRequest,
9-
FilterExtensionPostRequest,
10-
)
11-
from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest
10+
from stac_fastapi.types.search import APIRequest
1211

1312

14-
@attr.s
15-
class AggregationExtensionGetRequest(BaseSearchGetRequest, FilterExtensionGetRequest):
13+
@dataclass
14+
class AggregationExtensionGetRequest(APIRequest):
1615
"""Aggregation Extension GET request model."""
1716

18-
aggregations: Optional[str] = attr.ib(default=None)
17+
aggregations: Annotated[Optional[str], Query()] = None
1918

2019

21-
class AggregationExtensionPostRequest(BaseSearchPostRequest, FilterExtensionPostRequest):
20+
class AggregationExtensionPostRequest(BaseModel):
2221
"""Aggregation Extension POST request model."""
2322

24-
aggregations: Optional[Union[str, List[str]]] = attr.ib(default=None)
23+
aggregations: Optional[Union[str, List[str]]] = Field(default=None)

stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Request models for the fields extension."""
22

33
import warnings
4+
from dataclasses import dataclass
45
from typing import Dict, Optional, Set
56

6-
import attr
7+
from fastapi import Query
78
from pydantic import BaseModel, Field
9+
from typing_extensions import Annotated
810

911
from stac_fastapi.types.search import APIRequest, str2list
1012

@@ -68,11 +70,16 @@ def filter_fields(self) -> Dict:
6870
}
6971

7072

71-
@attr.s
73+
@dataclass
7274
class FieldsExtensionGetRequest(APIRequest):
7375
"""Additional fields for the GET request."""
7476

75-
fields: Optional[str] = attr.ib(default=None, converter=str2list)
77+
fields: Annotated[Optional[str], Query()] = None
78+
79+
def __post_init__(self):
80+
"""convert attributes."""
81+
if self.fields:
82+
self.fields = str2list(self.fields) # type: ignore
7683

7784

7885
class FieldsExtensionPostRequest(BaseModel):

stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
"""Filter extension request models."""
22

3+
from dataclasses import dataclass
34
from typing import Any, Dict, Literal, Optional
45

5-
import attr
6+
from fastapi import Query
67
from pydantic import BaseModel, Field
8+
from typing_extensions import Annotated
79

810
from stac_fastapi.types.search import APIRequest
911

1012
FilterLang = Literal["cql-json", "cql2-json", "cql2-text"]
1113

1214

13-
@attr.s
15+
@dataclass
1416
class FilterExtensionGetRequest(APIRequest):
1517
"""Filter extension GET request model."""
1618

17-
filter: Optional[str] = attr.ib(default=None)
18-
filter_crs: Optional[str] = Field(alias="filter-crs", default=None)
19-
filter_lang: Optional[FilterLang] = Field(alias="filter-lang", default="cql2-text")
19+
filter: Annotated[Optional[str], Query()] = None
20+
filter_crs: Annotated[Optional[str], Query(alias="filter-crs")] = None
21+
filter_lang: Annotated[Optional[FilterLang], Query(alias="filter-lang")] = "cql2-text"
2022

2123

2224
class FilterExtensionPostRequest(BaseModel):

stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
"""Request model for the Query extension."""
22

3+
from dataclasses import dataclass
34
from typing import Any, Dict, Optional
45

5-
import attr
6+
from fastapi import Query
67
from pydantic import BaseModel
8+
from typing_extensions import Annotated
79

810
from stac_fastapi.types.search import APIRequest
911

1012

11-
@attr.s
13+
@dataclass
1214
class QueryExtensionGetRequest(APIRequest):
1315
"""Query Extension GET request model."""
1416

15-
query: Optional[str] = attr.ib(default=None)
17+
query: Annotated[Optional[str], Query()] = None
1618

1719

1820
class QueryExtensionPostRequest(BaseModel):

stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
# encoding: utf-8
22
"""Request model for the Sort Extension."""
33

4+
from dataclasses import dataclass
45
from typing import List, Optional
56

6-
import attr
7+
from fastapi import Query
78
from pydantic import BaseModel
89
from stac_pydantic.api.extensions.sort import SortExtension as PostSortModel
10+
from typing_extensions import Annotated
911

1012
from stac_fastapi.types.search import APIRequest, str2list
1113

1214

13-
@attr.s
15+
@dataclass
1416
class SortExtensionGetRequest(APIRequest):
1517
"""Sortby Parameter for GET requests."""
1618

17-
sortby: Optional[str] = attr.ib(default=None, converter=str2list)
19+
sortby: Annotated[Optional[str], Query()] = None
20+
21+
def __post_init__(self):
22+
"""convert attributes."""
23+
if self.sortby:
24+
self.sortby = str2list(self.sortby) # type: ignore
1825

1926

2027
class SortExtensionPostRequest(BaseModel):

stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""Transaction extension."""
22

3+
from dataclasses import dataclass
34
from typing import List, Optional, Type, Union
45

56
import attr
67
from fastapi import APIRouter, Body, FastAPI
78
from stac_pydantic import Collection, Item, ItemCollection
89
from stac_pydantic.shared import MimeTypes
910
from starlette.responses import JSONResponse, Response
11+
from typing_extensions import Annotated
1012

1113
from stac_fastapi.api.models import CollectionUri, ItemUri
1214
from stac_fastapi.api.routes import create_async_endpoint
@@ -15,25 +17,25 @@
1517
from stac_fastapi.types.extension import ApiExtension
1618

1719

18-
@attr.s
20+
@dataclass
1921
class PostItem(CollectionUri):
2022
"""Create Item."""
2123

22-
item: Union[Item, ItemCollection] = attr.ib(default=Body(None))
24+
item: Annotated[Union[Item, ItemCollection], Body()] = None
2325

2426

25-
@attr.s
27+
@dataclass
2628
class PutItem(ItemUri):
2729
"""Update Item."""
2830

29-
item: Item = attr.ib(default=Body(None))
31+
item: Annotated[Item, Body()] = None
3032

3133

32-
@attr.s
34+
@dataclass
3335
class PutCollection(CollectionUri):
3436
"""Update Collection."""
3537

36-
collection: Collection = attr.ib(default=Body(None))
38+
collection: Annotated[Collection, Body()] = None
3739

3840

3941
@attr.s

stac_fastapi/extensions/tests/test_filter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,13 @@ def test_search_filter_get(client: TestClient):
107107
)
108108
assert not response_dict["filter_crs"]
109109
assert response_dict["filter_lang"] == "cql2-json"
110+
111+
response = client.get(
112+
"/search",
113+
params={
114+
"collections": "collection1,collection2",
115+
},
116+
)
117+
assert response.is_success, response.json()
118+
response_dict = response.json()
119+
assert response_dict["collections"] == ["collection1", "collection2"]

0 commit comments

Comments
 (0)