3939class CoreCrudClient (AsyncBaseCoreClient ):
4040 """Client for core endpoints defined by stac."""
4141
42- async def all_collections (self , request : Request , ** kwargs ) -> Collections :
43- """Read all collections from the database."""
42+ async def all_collections ( # noqa: C901
43+ self ,
44+ request : Request ,
45+ bbox : Optional [BBox ] = None ,
46+ datetime : Optional [DateTimeType ] = None ,
47+ limit : Optional [int ] = None ,
48+ # Extensions
49+ query : Optional [str ] = None ,
50+ token : Optional [str ] = None ,
51+ fields : Optional [List [str ]] = None ,
52+ sortby : Optional [str ] = None ,
53+ filter : Optional [str ] = None ,
54+ filter_lang : Optional [str ] = None ,
55+ ** kwargs ,
56+ ) -> Collections :
57+ """Cross catalog search (GET).
58+
59+ Called with `GET /collections`.
60+
61+ Returns:
62+ Collections which match the search criteria, returns all
63+ collections by default.
64+ """
65+ query_params = str (request .query_params )
66+
67+ # Kludgy fix because using factory does not allow alias for filter-lang
68+ if filter_lang is None :
69+ match = re .search (r"filter-lang=([a-z0-9-]+)" , query_params , re .IGNORECASE )
70+ if match :
71+ filter_lang = match .group (1 )
72+
73+ # Parse request parameters
74+ base_args = {
75+ "bbox" : bbox ,
76+ "limit" : limit ,
77+ "token" : token ,
78+ "query" : orjson .loads (unquote_plus (query )) if query else query ,
79+ }
80+
81+ clean = clean_search_args (
82+ base_args = base_args ,
83+ datetime = datetime ,
84+ fields = fields ,
85+ sortby = sortby ,
86+ filter = filter ,
87+ filter_lang = filter_lang ,
88+ )
89+
90+ # Do the request
91+ try :
92+ search_request = self .post_request_model (** clean )
93+ except ValidationError as e :
94+ raise HTTPException (
95+ status_code = 400 , detail = f"Invalid parameters provided { e } "
96+ ) from e
97+
98+ return await self ._collection_search_base (search_request , request = request )
99+
100+ async def _collection_search_base ( # noqa: C901
101+ self ,
102+ search_request : PgstacSearch ,
103+ request : Request ,
104+ ) -> Collections :
105+ """Cross catalog search (POST).
106+
107+ Called with `POST /search`.
108+
109+ Args:
110+ search_request: search request parameters.
111+
112+ Returns:
113+ All collections which match the search criteria.
114+ """
115+
44116 base_url = get_base_url (request )
45117
46- async with request .app .state .get_connection (request , "r" ) as conn :
47- collections = await conn .fetchval (
48- """
49- SELECT * FROM all_collections();
50- """
51- )
118+ settings : Settings = request .app .state .settings
119+
120+ if search_request .datetime :
121+ search_request .datetime = format_datetime_range (search_request .datetime )
122+
123+ search_request .conf = search_request .conf or {}
124+ search_request .conf ["nohydrate" ] = settings .use_api_hydrate
125+
126+ search_request_json = search_request .model_dump_json (
127+ exclude_none = True , by_alias = True
128+ )
129+
130+ try :
131+ async with request .app .state .get_connection (request , "r" ) as conn :
132+ q , p = render (
133+ """
134+ SELECT * FROM collection_search(:req::text::jsonb);
135+ """ ,
136+ req = search_request_json ,
137+ )
138+ collections_result : Collections = await conn .fetchval (q , * p )
139+ except InvalidDatetimeFormatError as e :
140+ raise InvalidQueryParameter (
141+ f"Datetime parameter { search_request .datetime } is invalid."
142+ ) from e
143+
144+ # next: Optional[str] = collections_result["links"].pop("next")
145+ # prev: Optional[str] = collections_result["links"].pop("prev")
146+
52147 linked_collections : List [Collection ] = []
148+ collections = collections_result ["collections" ]
53149 if collections is not None and len (collections ) > 0 :
54150 for c in collections :
55151 coll = Collection (** c )
@@ -71,6 +167,12 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
71167
72168 linked_collections .append (coll )
73169
170+ # paging_links = await PagingLinks(
171+ # request=request,
172+ # next=next,
173+ # prev=prev,
174+ # ).get_links()
175+
74176 links = [
75177 {
76178 "rel" : Relations .root .value ,
@@ -88,8 +190,10 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
88190 "href" : urljoin (base_url , "collections" ),
89191 },
90192 ]
91- collection_list = Collections (collections = linked_collections or [], links = links )
92- return collection_list
193+ return Collections (
194+ collections = linked_collections or [],
195+ links = links , # + paging_links
196+ )
93197
94198 async def get_collection (
95199 self , collection_id : str , request : Request , ** kwargs
@@ -383,7 +487,7 @@ async def post_search(
383487
384488 return ItemCollection (** item_collection )
385489
386- async def get_search ( # noqa: C901
490+ async def get_search (
387491 self ,
388492 request : Request ,
389493 collections : Optional [List [str ]] = None ,
@@ -418,49 +522,15 @@ async def get_search( # noqa: C901
418522 "query" : orjson .loads (unquote_plus (query )) if query else query ,
419523 }
420524
421- if filter :
422- if filter_lang == "cql2-text" :
423- ast = parse_cql2_text (filter )
424- base_args ["filter" ] = orjson .loads (to_cql2 (ast ))
425- base_args ["filter-lang" ] = "cql2-json"
426-
427- if datetime :
428- base_args ["datetime" ] = format_datetime_range (datetime )
429-
430- if intersects :
431- base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
432-
433- if sortby :
434- # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
435- sort_param = []
436- for sort in sortby :
437- sortparts = re .match (r"^([+-]?)(.*)$" , sort )
438- if sortparts :
439- sort_param .append (
440- {
441- "field" : sortparts .group (2 ).strip (),
442- "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
443- }
444- )
445- base_args ["sortby" ] = sort_param
446-
447- if fields :
448- includes = set ()
449- excludes = set ()
450- for field in fields :
451- if field [0 ] == "-" :
452- excludes .add (field [1 :])
453- elif field [0 ] == "+" :
454- includes .add (field [1 :])
455- else :
456- includes .add (field )
457- base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
458-
459- # Remove None values from dict
460- clean = {}
461- for k , v in base_args .items ():
462- if v is not None and v != []:
463- clean [k ] = v
525+ clean = clean_search_args (
526+ base_args = base_args ,
527+ intersects = intersects ,
528+ datetime = datetime ,
529+ fields = fields ,
530+ sortby = sortby ,
531+ filter = filter ,
532+ filter_lang = filter_lang ,
533+ )
464534
465535 # Do the request
466536 try :
@@ -471,3 +541,60 @@ async def get_search( # noqa: C901
471541 ) from e
472542
473543 return await self .post_search (search_request , request = request )
544+
545+
546+ def clean_search_args ( # noqa: C901
547+ base_args : dict [str , Any ],
548+ intersects : Optional [str ] = None ,
549+ datetime : Optional [DateTimeType ] = None ,
550+ fields : Optional [List [str ]] = None ,
551+ sortby : Optional [str ] = None ,
552+ filter : Optional [str ] = None ,
553+ filter_lang : Optional [str ] = None ,
554+ ) -> dict [str , Any ]:
555+ """Clean up search arguments to match format expected by pgstac"""
556+ if filter :
557+ if filter_lang == "cql2-text" :
558+ ast = parse_cql2_text (filter )
559+ base_args ["filter" ] = orjson .loads (to_cql2 (ast ))
560+ base_args ["filter-lang" ] = "cql2-json"
561+
562+ if datetime :
563+ base_args ["datetime" ] = format_datetime_range (datetime )
564+
565+ if intersects :
566+ base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
567+
568+ if sortby :
569+ # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
570+ sort_param = []
571+ for sort in sortby :
572+ sortparts = re .match (r"^([+-]?)(.*)$" , sort )
573+ if sortparts :
574+ sort_param .append (
575+ {
576+ "field" : sortparts .group (2 ).strip (),
577+ "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
578+ }
579+ )
580+ base_args ["sortby" ] = sort_param
581+
582+ if fields :
583+ includes = set ()
584+ excludes = set ()
585+ for field in fields :
586+ if field [0 ] == "-" :
587+ excludes .add (field [1 :])
588+ elif field [0 ] == "+" :
589+ includes .add (field [1 :])
590+ else :
591+ includes .add (field )
592+ base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
593+
594+ # Remove None values from dict
595+ clean = {}
596+ for k , v in base_args .items ():
597+ if v is not None and v != []:
598+ clean [k ] = v
599+
600+ return clean
0 commit comments