88import orjson
99from asyncpg .exceptions import InvalidDatetimeFormatError
1010from buildpg import render
11- from fastapi import HTTPException
11+ from fastapi import HTTPException , Request
1212from pydantic import ValidationError
1313from pygeofilter .backends .cql2_json import to_cql2
1414from pygeofilter .parsers .cql2_text import parse as parse_cql2_text
1515from pypgstac .hydration import hydrate
1616from stac_pydantic .links import Relations
1717from stac_pydantic .shared import MimeTypes
18- from starlette .requests import Request
1918
2019from stac_fastapi .pgstac .config import Settings
2120from stac_fastapi .pgstac .models .links import (
3837class CoreCrudClient (AsyncBaseCoreClient ):
3938 """Client for core endpoints defined by stac."""
4039
41- async def all_collections (self , ** kwargs ) -> Collections :
40+ async def all_collections (self , request : Request , ** kwargs ) -> Collections :
4241 """Read all collections from the database."""
43- request : Request = kwargs ["request" ]
4442 base_url = get_base_url (request )
45- pool = request .app .state .readpool
4643
47- async with pool . acquire ( ) as conn :
44+ async with request . app . state . get_connection ( request , "r" ) as conn :
4845 collections = await conn .fetchval (
4946 """
5047 SELECT * FROM all_collections();
@@ -80,7 +77,9 @@ async def all_collections(self, **kwargs) -> Collections:
8077 collection_list = Collections (collections = linked_collections or [], links = links )
8178 return collection_list
8279
83- async def get_collection (self , collection_id : str , ** kwargs ) -> Collection :
80+ async def get_collection (
81+ self , collection_id : str , request : Request , ** kwargs
82+ ) -> Collection :
8483 """Get collection by id.
8584
8685 Called with `GET /collections/{collection_id}`.
@@ -93,9 +92,7 @@ async def get_collection(self, collection_id: str, **kwargs) -> Collection:
9392 """
9493 collection : Optional [Dict [str , Any ]]
9594
96- request : Request = kwargs ["request" ]
97- pool = request .app .state .readpool
98- async with pool .acquire () as conn :
95+ async with request .app .state .get_connection (request , "r" ) as conn :
9996 q , p = render (
10097 """
10198 SELECT * FROM get_collection(:id::text);
@@ -125,8 +122,7 @@ async def _get_base_item(
125122 """
126123 item : Optional [Dict [str , Any ]]
127124
128- pool = request .app .state .readpool
129- async with pool .acquire () as conn :
125+ async with request .app .state .get_connection (request , "r" ) as conn :
130126 q , p = render (
131127 """
132128 SELECT * FROM collection_base_item(:collection_id::text);
@@ -143,7 +139,7 @@ async def _get_base_item(
143139 async def _search_base (
144140 self ,
145141 search_request : PgstacSearch ,
146- ** kwargs : Any ,
142+ request : Request ,
147143 ) -> ItemCollection :
148144 """Cross catalog search (POST).
149145
@@ -157,21 +153,19 @@ async def _search_base(
157153 """
158154 items : Dict [str , Any ]
159155
160- request : Request = kwargs ["request" ]
161156 settings : Settings = request .app .state .settings
162- pool = request .app .state .readpool
163157
164158 search_request .conf = search_request .conf or {}
165159 search_request .conf ["nohydrate" ] = settings .use_api_hydrate
166- req = search_request .json (exclude_none = True , by_alias = True )
160+ search_request_json = search_request .json (exclude_none = True , by_alias = True )
167161
168162 try :
169- async with pool . acquire ( ) as conn :
163+ async with request . app . state . get_connection ( request , "r" ) as conn :
170164 q , p = render (
171165 """
172166 SELECT * FROM search(:req::text::jsonb);
173167 """ ,
174- req = req ,
168+ req = search_request_json ,
175169 )
176170 items = await conn .fetchval (q , * p )
177171 except InvalidDatetimeFormatError :
@@ -253,6 +247,7 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]:
253247 async def item_collection (
254248 self ,
255249 collection_id : str ,
250+ request : Request ,
256251 bbox : Optional [List [NumType ]] = None ,
257252 datetime : Optional [Union [str , datetime ]] = None ,
258253 limit : Optional [int ] = None ,
@@ -272,7 +267,7 @@ async def item_collection(
272267 An ItemCollection.
273268 """
274269 # If collection does not exist, NotFoundError wil be raised
275- await self .get_collection (collection_id , ** kwargs )
270+ await self .get_collection (collection_id , request )
276271
277272 base_args = {
278273 "collections" : [collection_id ],
@@ -287,17 +282,19 @@ async def item_collection(
287282 if v is not None and v != []:
288283 clean [k ] = v
289284
290- req = self .post_request_model (
285+ search_request = self .post_request_model (
291286 ** clean ,
292287 )
293- item_collection = await self ._search_base (req , ** kwargs )
288+ item_collection = await self ._search_base (search_request , request )
294289 links = await ItemCollectionLinks (
295- collection_id = collection_id , request = kwargs [ " request" ]
290+ collection_id = collection_id , request = request
296291 ).get_links (extra_links = item_collection ["links" ])
297292 item_collection ["links" ] = links
298293 return item_collection
299294
300- async def get_item (self , item_id : str , collection_id : str , ** kwargs ) -> Item :
295+ async def get_item (
296+ self , item_id : str , collection_id : str , request : Request , ** kwargs
297+ ) -> Item :
301298 """Get item by id.
302299
303300 Called with `GET /collections/{collection_id}/items/{item_id}`.
@@ -310,12 +307,12 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
310307 Item.
311308 """
312309 # If collection does not exist, NotFoundError wil be raised
313- await self .get_collection (collection_id , ** kwargs )
310+ await self .get_collection (collection_id , request )
314311
315- req = self .post_request_model (
312+ search_request = self .post_request_model (
316313 ids = [item_id ], collections = [collection_id ], limit = 1
317314 )
318- item_collection = await self ._search_base (req , ** kwargs )
315+ item_collection = await self ._search_base (search_request , request )
319316 if not item_collection ["features" ]:
320317 raise NotFoundError (
321318 f"Item { item_id } in Collection { collection_id } does not exist."
@@ -324,7 +321,7 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
324321 return Item (** item_collection ["features" ][0 ])
325322
326323 async def post_search (
327- self , search_request : PgstacSearch , ** kwargs
324+ self , search_request : PgstacSearch , request : Request , ** kwargs
328325 ) -> ItemCollection :
329326 """Cross catalog search (POST).
330327
@@ -336,11 +333,12 @@ async def post_search(
336333 Returns:
337334 ItemCollection containing items which match the search criteria.
338335 """
339- item_collection = await self ._search_base (search_request , ** kwargs )
336+ item_collection = await self ._search_base (search_request , request )
340337 return ItemCollection (** item_collection )
341338
342339 async def get_search (
343340 self ,
341+ request : Request ,
344342 collections : Optional [List [str ]] = None ,
345343 ids : Optional [List [str ]] = None ,
346344 bbox : Optional [List [NumType ]] = None ,
@@ -362,7 +360,6 @@ async def get_search(
362360 Returns:
363361 ItemCollection containing items which match the search criteria.
364362 """
365- request = kwargs ["request" ]
366363 query_params = str (request .query_params )
367364
368365 # Kludgy fix because using factory does not allow alias for filter-lang
@@ -432,4 +429,4 @@ async def get_search(
432429 raise HTTPException (
433430 status_code = 400 , detail = f"Invalid parameters provided { e } "
434431 )
435- return await self .post_search (search_request , request = kwargs [ " request" ] )
432+ return await self .post_search (search_request , request = request )
0 commit comments