11"""api request/response models."""
22
3- import abc
43import importlib
5- from typing import Dict , Optional , Type , Union
4+ from typing import Optional , Type , Union
65
76import attr
87from fastapi import Body , Path
98from pydantic import BaseModel , create_model
109from pydantic .fields import UndefinedType
1110
12-
13- def _create_request_model (model : Type [BaseModel ]) -> Type [BaseModel ]:
11+ from stac_fastapi .types .extension import ApiExtension
12+ from stac_fastapi .types .search import (
13+ APIRequest ,
14+ BaseSearchGetRequest ,
15+ BaseSearchPostRequest ,
16+ )
17+
18+
19+ def create_request_model (
20+ model_name = "SearchGetRequest" ,
21+ base_model : Union [Type [BaseModel ], APIRequest ] = BaseSearchGetRequest ,
22+ extensions : Optional [ApiExtension ] = None ,
23+ mixins : Optional [Union [BaseModel , APIRequest ]] = None ,
24+ request_type : Optional [str ] = "GET" ,
25+ ) -> Union [Type [BaseModel ], APIRequest ]:
1426 """Create a pydantic model for validating request bodies."""
1527 fields = {}
16- for (k , v ) in model .__fields__ .items ():
17- # TODO: Filter out fields based on which extensions are present
18- field_info = v .field_info
19- body = Body (
20- None
21- if isinstance (field_info .default , UndefinedType )
22- else field_info .default ,
23- default_factory = field_info .default_factory ,
24- alias = field_info .alias ,
25- alias_priority = field_info .alias_priority ,
26- title = field_info .title ,
27- description = field_info .description ,
28- const = field_info .const ,
29- gt = field_info .gt ,
30- ge = field_info .ge ,
31- lt = field_info .lt ,
32- le = field_info .le ,
33- multiple_of = field_info .multiple_of ,
34- min_items = field_info .min_items ,
35- max_items = field_info .max_items ,
36- min_length = field_info .min_length ,
37- max_length = field_info .max_length ,
38- regex = field_info .regex ,
39- extra = field_info .extra ,
40- )
41- fields [k ] = (v .outer_type_ , body )
42- return create_model (model .__name__ , ** fields , __base__ = model )
43-
44-
45- @attr .s # type:ignore
46- class APIRequest (abc .ABC ):
47- """Generic API Request base class."""
48-
49- @abc .abstractmethod
50- def kwargs (self ) -> Dict :
51- """Transform api request params into format which matches the signature of the endpoint."""
52- ...
28+ extension_models = []
29+
30+ # Check extensions for additional parameters to search
31+ for extension in extensions or []:
32+ if extension_model := extension .get_request_model (request_type ):
33+ extension_models .append (extension_model )
34+
35+ mixins = mixins or []
36+
37+ models = [base_model ] + extension_models + mixins
38+
39+ # Handle GET requests
40+ if all ([issubclass (m , APIRequest ) for m in models ]):
41+ return attr .make_class (model_name , attrs = {}, bases = tuple (models ))
42+
43+ # Handle POST requests
44+ elif all ([issubclass (m , BaseModel ) for m in models ]):
45+ for model in models :
46+ for (k , v ) in model .__fields__ .items ():
47+ field_info = v .field_info
48+ body = Body (
49+ None
50+ if isinstance (field_info .default , UndefinedType )
51+ else field_info .default ,
52+ default_factory = field_info .default_factory ,
53+ alias = field_info .alias ,
54+ alias_priority = field_info .alias_priority ,
55+ title = field_info .title ,
56+ description = field_info .description ,
57+ const = field_info .const ,
58+ gt = field_info .gt ,
59+ ge = field_info .ge ,
60+ lt = field_info .lt ,
61+ le = field_info .le ,
62+ multiple_of = field_info .multiple_of ,
63+ min_items = field_info .min_items ,
64+ max_items = field_info .max_items ,
65+ min_length = field_info .min_length ,
66+ max_length = field_info .max_length ,
67+ regex = field_info .regex ,
68+ extra = field_info .extra ,
69+ )
70+ fields [k ] = (v .outer_type_ , body )
71+ return create_model (model_name , ** fields , __base__ = base_model )
72+
73+ raise TypeError ("Mixed Request Model types. Check extension request types." )
74+
75+
76+ def create_get_request_model (
77+ extensions , base_model : BaseSearchGetRequest = BaseSearchGetRequest
78+ ):
79+ """Wrap create_request_model to create the GET request model."""
80+ return create_request_model (
81+ "SearchGetRequest" ,
82+ base_model = BaseSearchGetRequest ,
83+ extensions = extensions ,
84+ request_type = "GET" ,
85+ )
86+
87+
88+ def create_post_request_model (
89+ extensions , base_model : BaseSearchPostRequest = BaseSearchGetRequest
90+ ):
91+ """Wrap create_request_model to create the POST request model."""
92+ return create_request_model (
93+ "SearchPostRequest" ,
94+ base_model = BaseSearchPostRequest ,
95+ extensions = extensions ,
96+ request_type = "POST" ,
97+ )
5398
5499
55100@attr .s # type:ignore
@@ -58,76 +103,52 @@ class CollectionUri(APIRequest):
58103
59104 collection_id : str = attr .ib (default = Path (..., description = "Collection ID" ))
60105
61- def kwargs (self ) -> Dict :
62- """kwargs."""
63- return {"id" : self .collection_id }
64-
65106
66107@attr .s
67108class ItemUri (CollectionUri ):
68109 """Delete item."""
69110
70111 item_id : str = attr .ib (default = Path (..., description = "Item ID" ))
71112
72- def kwargs (self ) -> Dict :
73- """kwargs."""
74- return {"collection_id" : self .collection_id , "item_id" : self .item_id }
75-
76113
77114@attr .s
78115class EmptyRequest (APIRequest ):
79116 """Empty request."""
80117
81- def kwargs (self ) -> Dict :
82- """kwargs."""
83- return {}
118+ ...
84119
85120
86121@attr .s
87122class ItemCollectionUri (CollectionUri ):
88123 """Get item collection."""
89124
90125 limit : int = attr .ib (default = 10 )
91- token : str = attr .ib (default = None )
92126
93- def kwargs (self ) -> Dict :
94- """kwargs."""
95- return {
96- "id" : self .collection_id ,
97- "limit" : self .limit ,
98- "token" : self .token ,
99- }
127+
128+ class POSTTokenPagination (BaseModel ):
129+ """Token pagination model for POST requests."""
130+
131+ token : Optional [str ] = None
100132
101133
102134@attr .s
103- class SearchGetRequest (APIRequest ):
104- """GET search request."""
105-
106- collections : Optional [str ] = attr .ib (default = None )
107- ids : Optional [str ] = attr .ib (default = None )
108- bbox : Optional [str ] = attr .ib (default = None )
109- datetime : Optional [Union [str ]] = attr .ib (default = None )
110- limit : Optional [int ] = attr .ib (default = 10 )
111- query : Optional [str ] = attr .ib (default = None )
135+ class GETTokenPagination (APIRequest ):
136+ """Token pagination for GET requests."""
137+
112138 token : Optional [str ] = attr .ib (default = None )
113- fields : Optional [str ] = attr .ib (default = None )
114- sortby : Optional [str ] = attr .ib (default = None )
115-
116- def kwargs (self ) -> Dict :
117- """kwargs."""
118- return {
119- "collections" : self .collections .split ("," )
120- if self .collections
121- else self .collections ,
122- "ids" : self .ids .split ("," ) if self .ids else self .ids ,
123- "bbox" : self .bbox .split ("," ) if self .bbox else self .bbox ,
124- "datetime" : self .datetime ,
125- "limit" : self .limit ,
126- "query" : self .query ,
127- "token" : self .token ,
128- "fields" : self .fields .split ("," ) if self .fields else self .fields ,
129- "sortby" : self .sortby .split ("," ) if self .sortby else self .sortby ,
130- }
139+
140+
141+ class POSTPagination (BaseModel ):
142+ """Page based pagination for POST requests."""
143+
144+ page : Optional [str ] = None
145+
146+
147+ @attr .s
148+ class GETPagination (APIRequest ):
149+ """Page based pagination for GET requests."""
150+
151+ page : Optional [str ] = attr .ib (default = None )
131152
132153
133154# Test for ORJSON and use it rather than stdlib JSON where supported
0 commit comments