1414from pygeofilter .parsers .cql2_text import parse as parse_cql2_text
1515from pypgstac .hydration import hydrate
1616from stac_fastapi .api .models import JSONResponse
17- from stac_fastapi .types .core import AsyncBaseCoreClient
17+ from stac_fastapi .types .core import AsyncBaseCoreClient , Relations
1818from stac_fastapi .types .errors import InvalidQueryParameter , NotFoundError
1919from stac_fastapi .types .requests import get_base_url
2020from stac_fastapi .types .rfc3339 import DateTimeType
2121from stac_fastapi .types .stac import Collection , Collections , Item , ItemCollection
22- from stac_pydantic .links import Relations
2322from stac_pydantic .shared import BBox , MimeTypes
2423
2524from stac_fastapi .pgstac .config import Settings
3938class CoreCrudClient (AsyncBaseCoreClient ):
4039 """Client for core endpoints defined by stac."""
4140
42- async def all_collections (self , request : Request , ** kwargs ) -> Collections :
43- """Read all collections from the database."""
41+ async def all_collections ( # noqa: C901
42+ self ,
43+ request : Request ,
44+ # Extensions
45+ bbox : Optional [BBox ] = None ,
46+ datetime : Optional [DateTimeType ] = None ,
47+ limit : Optional [int ] = None ,
48+ query : Optional [str ] = None ,
49+ token : Optional [str ] = None ,
50+ fields : Optional [List [str ]] = None ,
51+ sortby : Optional [str ] = None ,
52+ filter : Optional [str ] = None ,
53+ filter_lang : Optional [str ] = None ,
54+ ** kwargs ,
55+ ) -> Collections :
56+ """Cross catalog search (GET).
57+
58+ Called with `GET /collections`.
59+
60+ Returns:
61+ Collections which match the search criteria, returns all
62+ collections by default.
63+ """
64+
65+ # Parse request parameters
66+ base_args = {
67+ "bbox" : bbox ,
68+ "limit" : limit ,
69+ "token" : token ,
70+ "query" : orjson .loads (unquote_plus (query )) if query else query ,
71+ }
72+
73+ clean = clean_search_args (
74+ base_args = base_args ,
75+ datetime = datetime ,
76+ fields = fields ,
77+ sortby = sortby ,
78+ filter = filter ,
79+ filter_lang = filter_lang ,
80+ )
81+
82+ # Do the request
83+ try :
84+ search_request = self .post_request_model (** clean )
85+ except ValidationError as e :
86+ raise HTTPException (
87+ status_code = 400 , detail = f"Invalid parameters provided { e } "
88+ ) from e
89+
90+ return await self ._collection_search_base (search_request , request = request )
91+
92+ async def _collection_search_base ( # noqa: C901
93+ self ,
94+ search_request : PgstacSearch ,
95+ request : Request ,
96+ ) -> Collections :
97+ """Cross catalog search (GET).
98+
99+ Called with `GET /search`.
100+
101+ Args:
102+ search_request: search request parameters.
103+
104+ Returns:
105+ All collections which match the search criteria.
106+ """
44107 base_url = get_base_url (request )
108+ search_request_json = search_request .model_dump_json (
109+ exclude_none = True , by_alias = True
110+ )
111+
112+ try :
113+ async with request .app .state .get_connection (request , "r" ) as conn :
114+ q , p = render (
115+ """
116+ SELECT * FROM collection_search(:req::text::jsonb);
117+ """ ,
118+ req = search_request_json ,
119+ )
120+ collections_result : Collections = await conn .fetchval (q , * p )
121+ except InvalidDatetimeFormatError as e :
122+ raise InvalidQueryParameter (
123+ f"Datetime parameter { search_request .datetime } is invalid."
124+ ) from e
125+
126+ next : Optional [str ] = None
127+ prev : Optional [str ] = None
128+
129+ if links := collections_result .get ("links" ):
130+ next = collections_result ["links" ].pop ("next" )
131+ prev = collections_result ["links" ].pop ("prev" )
45132
46- async with request .app .state .get_connection (request , "r" ) as conn :
47- collections = await conn .fetchval (
48- """
49- SELECT * FROM all_collections();
50- """
51- )
52133 linked_collections : List [Collection ] = []
134+ collections = collections_result ["collections" ]
53135 if collections is not None and len (collections ) > 0 :
54136 for c in collections :
55137 coll = Collection (** c )
@@ -71,25 +153,16 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
71153
72154 linked_collections .append (coll )
73155
74- links = [
75- {
76- "rel" : Relations .root .value ,
77- "type" : MimeTypes .json ,
78- "href" : base_url ,
79- },
80- {
81- "rel" : Relations .parent .value ,
82- "type" : MimeTypes .json ,
83- "href" : base_url ,
84- },
85- {
86- "rel" : Relations .self .value ,
87- "type" : MimeTypes .json ,
88- "href" : urljoin (base_url , "collections" ),
89- },
90- ]
91- collection_list = Collections (collections = linked_collections or [], links = links )
92- return collection_list
156+ links = await PagingLinks (
157+ request = request ,
158+ next = next ,
159+ prev = prev ,
160+ ).get_links ()
161+
162+ return Collections (
163+ collections = linked_collections or [],
164+ links = links ,
165+ )
93166
94167 async def get_collection (
95168 self , collection_id : str , request : Request , ** kwargs
@@ -383,7 +456,7 @@ async def post_search(
383456
384457 return ItemCollection (** item_collection )
385458
386- async def get_search ( # noqa: C901
459+ async def get_search (
387460 self ,
388461 request : Request ,
389462 collections : Optional [List [str ]] = None ,
@@ -418,49 +491,15 @@ async def get_search( # noqa: C901
418491 "query" : orjson .loads (unquote_plus (query )) if query else query ,
419492 }
420493
421- if filter :
422- if filter_lang == "cql2-text" :
423- ast = parse_cql2_text (filter )
424- base_args ["filter" ] = orjson .loads (to_cql2 (ast ))
425- base_args ["filter-lang" ] = "cql2-json"
426-
427- if datetime :
428- base_args ["datetime" ] = format_datetime_range (datetime )
429-
430- if intersects :
431- base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
432-
433- if sortby :
434- # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
435- sort_param = []
436- for sort in sortby :
437- sortparts = re .match (r"^([+-]?)(.*)$" , sort )
438- if sortparts :
439- sort_param .append (
440- {
441- "field" : sortparts .group (2 ).strip (),
442- "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
443- }
444- )
445- base_args ["sortby" ] = sort_param
446-
447- if fields :
448- includes = set ()
449- excludes = set ()
450- for field in fields :
451- if field [0 ] == "-" :
452- excludes .add (field [1 :])
453- elif field [0 ] == "+" :
454- includes .add (field [1 :])
455- else :
456- includes .add (field )
457- base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
458-
459- # Remove None values from dict
460- clean = {}
461- for k , v in base_args .items ():
462- if v is not None and v != []:
463- clean [k ] = v
494+ clean = clean_search_args (
495+ base_args = base_args ,
496+ intersects = intersects ,
497+ datetime = datetime ,
498+ fields = fields ,
499+ sortby = sortby ,
500+ filter = filter ,
501+ filter_lang = filter_lang ,
502+ )
464503
465504 # Do the request
466505 try :
@@ -471,3 +510,60 @@ async def get_search( # noqa: C901
471510 ) from e
472511
473512 return await self .post_search (search_request , request = request )
513+
514+
515+ def clean_search_args ( # noqa: C901
516+ base_args : Dict [str , Any ],
517+ intersects : Optional [str ] = None ,
518+ datetime : Optional [DateTimeType ] = None ,
519+ fields : Optional [List [str ]] = None ,
520+ sortby : Optional [str ] = None ,
521+ filter : Optional [str ] = None ,
522+ filter_lang : Optional [str ] = None ,
523+ ) -> Dict [str , Any ]:
524+ """Clean up search arguments to match format expected by pgstac"""
525+ if filter :
526+ if filter_lang == "cql2-text" :
527+ ast = parse_cql2_text (filter )
528+ base_args ["filter" ] = orjson .loads (to_cql2 (ast ))
529+ base_args ["filter-lang" ] = "cql2-json"
530+
531+ if datetime :
532+ base_args ["datetime" ] = format_datetime_range (datetime )
533+
534+ if intersects :
535+ base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
536+
537+ if sortby :
538+ # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
539+ sort_param = []
540+ for sort in sortby :
541+ sortparts = re .match (r"^([+-]?)(.*)$" , sort )
542+ if sortparts :
543+ sort_param .append (
544+ {
545+ "field" : sortparts .group (2 ).strip (),
546+ "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
547+ }
548+ )
549+ base_args ["sortby" ] = sort_param
550+
551+ if fields :
552+ includes = set ()
553+ excludes = set ()
554+ for field in fields :
555+ if field [0 ] == "-" :
556+ excludes .add (field [1 :])
557+ elif field [0 ] == "+" :
558+ includes .add (field [1 :])
559+ else :
560+ includes .add (field )
561+ base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
562+
563+ # Remove None values from dict
564+ clean = {}
565+ for k , v in base_args .items ():
566+ if v is not None and v != []:
567+ clean [k ] = v
568+
569+ return clean
0 commit comments