diff --git a/.github/workflows/cicd.yaml b/.github/workflows/cicd.yaml index b363efa2..28724d80 100644 --- a/.github/workflows/cicd.yaml +++ b/.github/workflows/cicd.yaml @@ -27,7 +27,7 @@ jobs: cache-dependency-path: setup.py - name: Lint code - if: ${{ matrix.python-version == 3.8 }} + if: ${{ matrix.python-version == 3.11 }} run: | python -m pip install pre-commit pre-commit run --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 49f90ba2..b0c6c0c4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,57 +1,24 @@ repos: - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort language_version: python - - repo: https://github.com/psf/black - rev: 22.12.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.5 hooks: - - id: black - args: ["--safe"] - language_version: python - - - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 - hooks: - - id: flake8 - language_version: python - args: [ - # E501 let black handle all line length decisions - # W503 black conflicts with "line break before operator" rule - # E203 black conflicts with "whitespace before ':'" rule - "--ignore=E501,W503,E203,C901", - ] - - repo: https://github.com/chewse/pre-commit-mirrors-pydocstyle - # 2.1.1 - rev: v2.1.1 - hooks: - - id: pydocstyle - language_version: python - exclude: ".*(test|scripts).*" - args: - [ - # Check for docstring presence only - "--select=D1", - ] - # Don't require docstrings for tests - # '--match=(?!test).*\.py'] - # - - # repo: https://github.com/pre-commit/mirrors-mypy - # rev: v0.770 - # hooks: - # - id: mypy - # language_version: python3.8 - # args: [--no-strict-optional, --ignore-missing-imports] + - id: ruff + args: ["--fix"] + - id: ruff-format - - repo: https://github.com/PyCQA/pydocstyle - rev: 6.3.0 + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.9.0 hooks: - - id: pydocstyle + - id: mypy language_version: python - exclude: ".*(test|scripts).*" - #args: [ - # Don't require docstrings for tests - #'--match=(?!test|scripts).*\.py', - #] + exclude: tests/.* + additional_dependencies: + - types-requests + - types-attrs + - pydantic~=1.10 diff --git a/pyproject.toml b/pyproject.toml index 1c6dacaf..7862afa6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,28 @@ -[flake8] -ignore = "D203" -exclude = [".git", "__pycache__", "docs/source/conf.py", "build", "dist"] -max-complexity = 12 -max-line-length = 90 - [tool.isort] profile = "black" known_first_party = "stac_fastapi.pgstac" known_third_party = ["rasterio", "stac-pydantic", "sqlalchemy", "geoalchemy2", "fastapi", "stac_fastapi"] sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] + +[tool.mypy] +ignore_missing_imports = true +namespace_packages = true +explicit_package_bases = true +exclude = ["tests", ".venv"] + +[tool.ruff] +line-length = 90 + +[tool.ruff.lint] +select = [ + "C", + "E", + "F", + "W", + "B", +] +ignore = [ + "E203", # line too long, handled by black + "E501", # do not perform function calls in argument defaults + "B028", # No explicit `stacklevel` keyword argument found +] diff --git a/scripts/ingest_joplin.py b/scripts/ingest_joplin.py index 0440d094..58e30013 100644 --- a/scripts/ingest_joplin.py +++ b/scripts/ingest_joplin.py @@ -1,4 +1,5 @@ """Ingest sample data during docker-compose""" + import json import sys from pathlib import Path diff --git a/setup.py b/setup.py index af3669dc..3110150e 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,5 @@ install_requires=install_requires, tests_require=extra_reqs["dev"], extras_require=extra_reqs, - entry_points={ - "console_scripts": ["stac-fastapi-pgstac=stac_fastapi.pgstac.app:run"] - }, + entry_points={"console_scripts": ["stac-fastapi-pgstac=stac_fastapi.pgstac.app:run"]}, ) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 42586998..92c43089 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -46,8 +46,7 @@ if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"): extensions = [ - extensions_map[extension_name] - for extension_name in enabled_extensions.split(",") + extensions_map[extension_name] for extension_name in enabled_extensions.split(",") ] else: extensions = list(extensions_map.values()) @@ -57,7 +56,7 @@ api = StacApi( settings=settings, extensions=extensions, - client=CoreCrudClient(post_request_model=post_request_model), + client=CoreCrudClient(post_request_model=post_request_model), # type: ignore response_class=ORJSONResponse, search_get_request_model=create_get_request_model(extensions), search_post_request_model=post_request_model, @@ -90,8 +89,8 @@ def run(): reload=settings.reload, root_path=os.getenv("UVICORN_ROOT_PATH", ""), ) - except ImportError: - raise RuntimeError("Uvicorn must be installed in order to use command") + except ImportError as e: + raise RuntimeError("Uvicorn must be installed in order to use command") from e if __name__ == "__main__": diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index e8cf1219..6301c43e 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -1,4 +1,5 @@ """Item crud client.""" + import re from typing import Any, Dict, List, Optional, Union from urllib.parse import unquote_plus, urljoin @@ -136,7 +137,7 @@ async def _get_base_item( return item - async def _search_base( + async def _search_base( # noqa: C901 self, search_request: PgstacSearch, request: Request, @@ -172,10 +173,10 @@ async def _search_base( req=search_request_json, ) items = await conn.fetchval(q, *p) - except InvalidDatetimeFormatError: + except InvalidDatetimeFormatError as e: raise InvalidQueryParameter( f"Datetime parameter {search_request.datetime} is invalid." - ) + ) from e next: Optional[str] = items.pop("next", None) prev: Optional[str] = items.pop("prev", None) @@ -207,8 +208,8 @@ async def _add_item_links( and all([collection_id, item_id]) ): feature["links"] = await ItemLinks( - collection_id=collection_id, - item_id=item_id, + collection_id=collection_id, # type: ignore + item_id=item_id, # type: ignore request=request, ).get_links(extra_links=feature.get("links")) @@ -255,7 +256,7 @@ async def item_collection( bbox: Optional[BBox] = None, datetime: Optional[DateTimeType] = None, limit: Optional[int] = None, - token: str = None, + token: Optional[str] = None, **kwargs, ) -> ItemCollection: """Get all items from a specific collection. @@ -340,7 +341,7 @@ async def post_search( item_collection = await self._search_base(search_request, request=request) return ItemCollection(**item_collection) - async def get_search( + async def get_search( # noqa: C901 self, request: Request, collections: Optional[List[str]] = None, @@ -432,5 +433,6 @@ async def get_search( except ValidationError as e: raise HTTPException( status_code=400, detail=f"Invalid parameters provided {e}" - ) + ) from e + return await self.post_search(search_request, request=request) diff --git a/stac_fastapi/pgstac/db.py b/stac_fastapi/pgstac/db.py index c70611b2..b684a138 100644 --- a/stac_fastapi/pgstac/db.py +++ b/stac_fastapi/pgstac/db.py @@ -2,7 +2,16 @@ import json from contextlib import asynccontextmanager, contextmanager -from typing import AsyncIterator, Callable, Dict, Generator, Literal, Union +from typing import ( + AsyncIterator, + Callable, + Dict, + Generator, + List, + Literal, + Optional, + Union, +) import attr import orjson @@ -36,7 +45,9 @@ async def con_init(conn): ConnectionGetter = Callable[[Request, Literal["r", "w"]], AsyncIterator[Connection]] -async def connect_to_db(app: FastAPI, get_conn: ConnectionGetter = None) -> None: +async def connect_to_db( + app: FastAPI, get_conn: Optional[ConnectionGetter] = None +) -> None: """Create connection pools & connection retriever on application.""" settings = app.state.settings if app.state.settings.testing: @@ -44,6 +55,7 @@ async def connect_to_db(app: FastAPI, get_conn: ConnectionGetter = None) -> None else: readpool = settings.reader_connection_string writepool = settings.writer_connection_string + db = DB() app.state.readpool = await db.create_pool(readpool, settings) app.state.writepool = await db.create_pool(writepool, settings) @@ -62,15 +74,13 @@ async def get_connection( readwrite: Literal["r", "w"] = "r", ) -> AsyncIterator[Connection]: """Retrieve connection from database conection pool.""" - pool = ( - request.app.state.writepool if readwrite == "w" else request.app.state.readpool - ) + pool = request.app.state.writepool if readwrite == "w" else request.app.state.readpool with translate_pgstac_errors(): async with pool.acquire() as conn: yield conn -async def dbfunc(conn: Connection, func: str, arg: Union[str, Dict]): +async def dbfunc(conn: Connection, func: str, arg: Union[str, Dict, List]): """Wrap PLPGSQL Functions. Keyword arguments: diff --git a/stac_fastapi/pgstac/extensions/filter.py b/stac_fastapi/pgstac/extensions/filter.py index 0d249a1f..148a952a 100644 --- a/stac_fastapi/pgstac/extensions/filter.py +++ b/stac_fastapi/pgstac/extensions/filter.py @@ -1,4 +1,5 @@ """Get Queryables.""" + from typing import Any, Optional from buildpg import render diff --git a/stac_fastapi/pgstac/models/links.py b/stac_fastapi/pgstac/models/links.py index 8ba435d4..0e6d9071 100644 --- a/stac_fastapi/pgstac/models/links.py +++ b/stac_fastapi/pgstac/models/links.py @@ -59,13 +59,19 @@ def resolve(self, url): def link_self(self) -> Dict: """Return the self link.""" - return dict(rel=Relations.self.value, type=MimeTypes.json.value, href=self.url) + return { + "rel": Relations.self.value, + "type": MimeTypes.json.value, + "href": self.url, + } def link_root(self) -> Dict: """Return the catalog root.""" - return dict( - rel=Relations.root.value, type=MimeTypes.json.value, href=self.base_url - ) + return { + "rel": Relations.root.value, + "type": MimeTypes.json.value, + "href": self.base_url, + } def create_links(self) -> List[Dict[str, Any]]: """Return all inferred links.""" @@ -124,13 +130,14 @@ def link_next(self) -> Optional[Dict[str, Any]]: method = self.request.method if method == "GET": href = merge_params(self.url, {"token": f"next:{self.next}"}) - link = dict( - rel=Relations.next.value, - type=MimeTypes.geojson.value, - method=method, - href=href, - ) + link = { + "rel": Relations.next.value, + "type": MimeTypes.geojson.value, + "method": method, + "href": href, + } return link + if method == "POST": return { "rel": Relations.next, @@ -148,12 +155,13 @@ def link_prev(self) -> Optional[Dict[str, Any]]: method = self.request.method if method == "GET": href = merge_params(self.url, {"token": f"prev:{self.prev}"}) - return dict( - rel=Relations.previous.value, - type=MimeTypes.geojson.value, - method=method, - href=href, - ) + return { + "rel": Relations.previous.value, + "type": MimeTypes.geojson.value, + "method": method, + "href": href, + } + if method == "POST": return { "rel": Relations.previous, @@ -173,11 +181,11 @@ class CollectionLinksBase(BaseLinks): def collection_link(self, rel: str = Relations.collection.value) -> Dict: """Create a link to a collection.""" - return dict( - rel=rel, - type=MimeTypes.json.value, - href=self.resolve(f"collections/{self.collection_id}"), - ) + return { + "rel": rel, + "type": MimeTypes.json.value, + "href": self.resolve(f"collections/{self.collection_id}"), + } @attr.s @@ -190,19 +198,19 @@ def link_self(self) -> Dict: def link_parent(self) -> Dict: """Create the `parent` link.""" - return dict( - rel=Relations.parent.value, - type=MimeTypes.json.value, - href=self.base_url, - ) + return { + "rel": Relations.parent.value, + "type": MimeTypes.json.value, + "href": self.base_url, + } def link_items(self) -> Dict: """Create the `item` link.""" - return dict( - rel="items", - type=MimeTypes.geojson.value, - href=self.resolve(f"collections/{self.collection_id}/items"), - ) + return { + "rel": "items", + "type": MimeTypes.geojson.value, + "href": self.resolve(f"collections/{self.collection_id}/items"), + } @attr.s @@ -211,11 +219,11 @@ class ItemCollectionLinks(CollectionLinksBase): def link_self(self) -> Dict: """Return the self link.""" - return dict( - rel=Relations.self.value, - type=MimeTypes.geojson.value, - href=self.resolve(f"collections/{self.collection_id}/items"), - ) + return { + "rel": Relations.self.value, + "type": MimeTypes.geojson.value, + "href": self.resolve(f"collections/{self.collection_id}/items"), + } def link_parent(self) -> Dict: """Create the `parent` link.""" @@ -234,11 +242,13 @@ class ItemLinks(CollectionLinksBase): def link_self(self) -> Dict: """Create the self link.""" - return dict( - rel=Relations.self.value, - type=MimeTypes.geojson.value, - href=self.resolve(f"collections/{self.collection_id}/items/{self.item_id}"), - ) + return { + "rel": Relations.self.value, + "type": MimeTypes.geojson.value, + "href": self.resolve( + f"collections/{self.collection_id}/items/{self.item_id}" + ), + } def link_parent(self) -> Dict: """Create the `parent` link.""" diff --git a/stac_fastapi/pgstac/transactions.py b/stac_fastapi/pgstac/transactions.py index 5e2edeb1..3687068a 100644 --- a/stac_fastapi/pgstac/transactions.py +++ b/stac_fastapi/pgstac/transactions.py @@ -28,7 +28,7 @@ class TransactionsClient(AsyncBaseTransactionsClient): """Transactions extension specific CRUD operations.""" - def _validate_id(self, id: str, settings: Settings) -> bool: + def _validate_id(self, id: str, settings: Settings): invalid_chars = settings.invalid_id_chars id_regex = "[" + "".join(re.escape(char) for char in invalid_chars) + "]" @@ -76,7 +76,7 @@ async def create_item( """Create item.""" if item["type"] == "FeatureCollection": valid_items = [] - for item in item["features"]: + for item in item["features"]: # noqa: B020 self._validate_item(request, item, collection_id) item["collection"] = collection_id valid_items.append(item) diff --git a/stac_fastapi/pgstac/types/base_item_cache.py b/stac_fastapi/pgstac/types/base_item_cache.py index 9b92e759..519f687a 100644 --- a/stac_fastapi/pgstac/types/base_item_cache.py +++ b/stac_fastapi/pgstac/types/base_item_cache.py @@ -1,4 +1,5 @@ """base_item_cache classes for pgstac fastapi.""" + import abc from typing import Any, Callable, Coroutine, Dict @@ -43,7 +44,7 @@ def __init__( request: Request, ): """Initialize the base item cache.""" - self._base_items = {} + self._base_items: Dict = {} super().__init__(fetch_base_item, request) async def get(self, collection_id: str): diff --git a/stac_fastapi/pgstac/utils.py b/stac_fastapi/pgstac/utils.py index f696ca51..21cfec1b 100644 --- a/stac_fastapi/pgstac/utils.py +++ b/stac_fastapi/pgstac/utils.py @@ -1,4 +1,5 @@ """stac-fastapi utility methods.""" + from datetime import datetime from typing import Any, Dict, Optional, Set, Union @@ -6,7 +7,7 @@ from stac_fastapi.types.stac import Item -def filter_fields( +def filter_fields( # noqa: C901 item: Union[Item, Dict[str, Any]], include: Optional[Set[str]] = None, exclude: Optional[Set[str]] = None, @@ -38,7 +39,7 @@ def include_fields( # key path indicates a sub-key to be included. Walk the dict # from the root key and get the full nested value to include. value = include_fields( - source[key_root], fields=set([".".join(key_path_parts[1:])]) + source[key_root], fields={".".join(key_path_parts[1:])} ) if isinstance(clean_item.get(key_root), dict): @@ -70,9 +71,7 @@ def exclude_fields(source: Dict[str, Any], fields: Optional[Set[str]]) -> None: if key_root in source: if isinstance(source[key_root], dict) and len(key_path_part) > 1: # Walk the nested path of this key to remove the leaf-key - exclude_fields( - source[key_root], fields=set([".".join(key_path_part[1:])]) - ) + exclude_fields(source[key_root], fields={".".join(key_path_part[1:])}) # If, after removing the leaf-key, the root is now an empty # dict, remove it entirely if not source[key_root]: @@ -93,7 +92,7 @@ def exclude_fields(source: Dict[str, Any], fields: Optional[Set[str]]) -> None: # If, after including all the specified fields, there are no included properties, # return just id and collection. if not clean_item: - return Item({"id": item.get(id), "collection": item.get("collection")}) + return Item({"id": item["id"], "collection": item["collection"]}) exclude_fields(clean_item, exclude) diff --git a/stac_fastapi/pgstac/version.py b/stac_fastapi/pgstac/version.py index 1b96f9f8..db1a4897 100644 --- a/stac_fastapi/pgstac/version.py +++ b/stac_fastapi/pgstac/version.py @@ -1,2 +1,3 @@ """library version.""" + __version__ = "2.4.11" diff --git a/tests/api/test_api.py b/tests/api/test_api.py index cdac75ff..02460ce8 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -87,16 +87,12 @@ async def test_get_features_self_link(app_client, load_test_collection): resp = await app_client.get(f"collections/{load_test_collection.id}/items") assert resp.status_code == 200 resp_json = resp.json() - self_link = next( - (link for link in resp_json["links"] if link["rel"] == "self"), None - ) + self_link = next((link for link in resp_json["links"] if link["rel"] == "self"), None) assert self_link is not None assert self_link["href"].endswith("/items") -async def test_get_feature_content_type( - app_client, load_test_collection, load_test_item -): +async def test_get_feature_content_type(app_client, load_test_collection, load_test_item): resp = await app_client.get( f"collections/{load_test_collection.id}/items/{load_test_item.id}" ) @@ -105,9 +101,7 @@ async def test_get_feature_content_type( async def test_api_headers(app_client): resp = await app_client.get("/api") - assert ( - resp.headers["content-type"] == "application/vnd.oai.openapi+json;version=3.0" - ) + assert resp.headers["content-type"] == "application/vnd.oai.openapi+json;version=3.0" assert resp.status_code == 200 @@ -117,9 +111,9 @@ async def test_core_router(api_client, app): method, path = core_route.split(" ") core_routes.add("{} {}".format(method, app.state.router_prefix + path)) - api_routes = set( - [f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes] - ) + api_routes = { + f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes + } assert not core_routes - api_routes @@ -136,9 +130,9 @@ async def test_transactions_router(api_client, app): method, path = transaction_route.split(" ") transaction_routes.add("{} {}".format(method, app.state.router_prefix + path)) - api_routes = set( - [f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes] - ) + api_routes = { + f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes + } assert not transaction_routes - api_routes @@ -230,9 +224,7 @@ async def test_app_query_extension_gt(load_test_data, app_client, load_test_coll assert len(resp_json["features"]) == 0 -async def test_app_query_extension_gte( - load_test_data, app_client, load_test_collection -): +async def test_app_query_extension_gte(load_test_data, app_client, load_test_collection): coll = load_test_collection item = load_test_data("test_item.json") resp = await app_client.post(f"/collections/{coll.id}/items", json=item) @@ -333,20 +325,18 @@ async def test_app_search_response(load_test_data, app_client, load_test_collect assert resp_json.get("stac_extensions") is None -async def test_search_point_intersects( - load_test_data, app_client, load_test_collection -): +async def test_search_point_intersects(load_test_data, app_client, load_test_collection): coll = load_test_collection item = load_test_data("test_item.json") resp = await app_client.post(f"/collections/{coll.id}/items", json=item) assert resp.status_code == 200 - new_coordinates = list() + new_coordinates = [] for coordinate in item["geometry"]["coordinates"][0]: new_coordinates.append([coordinate[0] * -1, coordinate[1] * -1]) item["id"] = "test-item-other-hemispheres" item["geometry"]["coordinates"] = [new_coordinates] - item["bbox"] = list(value * -1 for value in item["bbox"]) + item["bbox"] = [value * -1 for value in item["bbox"]] resp = await app_client.post(f"/collections/{coll.id}/items", json=item) assert resp.status_code == 200 @@ -391,9 +381,7 @@ async def test_search_line_string_intersects( @pytest.mark.asyncio -async def test_landing_forwarded_header( - load_test_data, app_client, load_test_collection -): +async def test_landing_forwarded_header(load_test_data, app_client, load_test_collection): coll = load_test_collection item = load_test_data("test_item.json") await app_client.post(f"/collections/{coll.id}/items", json=item) @@ -412,9 +400,7 @@ async def test_landing_forwarded_header( @pytest.mark.asyncio -async def test_search_forwarded_header( - load_test_data, app_client, load_test_collection -): +async def test_search_forwarded_header(load_test_data, app_client, load_test_collection): coll = load_test_collection item = load_test_data("test_item.json") await app_client.post(f"/collections/{coll.id}/items", json=item) @@ -612,7 +598,7 @@ async def test_sorting_and_paging(app_client, load_test_collection, direction: s assert response.status_code == 200 async def search(query: Dict[str, Any]) -> List[Item]: - items: List[Item] = list() + items: List[Item] = [] while True: response = await app_client.post("/search", json=query) json = response.json() @@ -649,7 +635,7 @@ def wrap() -> ( ] ): def decorator( - fn: Callable[..., Coroutine[Any, Any, T]] + fn: Callable[..., Coroutine[Any, Any, T]], ) -> Callable[..., Coroutine[Any, Any, T]]: async def _wrapper(*args: Any, **kwargs: Any) -> T: request: Optional[Request] = kwargs.get("request") @@ -668,9 +654,7 @@ class Client(CoreCrudClient): async def get_collection( self, collection_id: str, request: Request, **kwargs ) -> stac_types.Item: - return await super().get_collection( - collection_id, request=request, **kwargs - ) + return await super().get_collection(collection_id, request=request, **kwargs) settings = Settings( postgres_user=database.user, diff --git a/tests/resources/test_collection.py b/tests/resources/test_collection.py index 980428ef..6b8795b0 100644 --- a/tests/resources/test_collection.py +++ b/tests/resources/test_collection.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Optional import pystac import pytest @@ -20,9 +20,7 @@ async def test_create_collection(app_client, load_test_data: Callable): get_coll = Collection.parse_obj(resp.json()) assert post_coll.dict(exclude={"links"}) == get_coll.dict(exclude={"links"}) - post_self_link = next( - (link for link in post_coll.links if link.rel == "self"), None - ) + post_self_link = next((link for link in post_coll.links if link.rel == "self"), None) get_self_link = next((link for link in get_coll.links if link.rel == "self"), None) assert post_self_link is not None and get_self_link is not None assert post_self_link.href == get_self_link.href @@ -153,7 +151,8 @@ async def test_returns_valid_links_in_collections(app_client, load_test_data): # Find collection in list by ID single_coll = next(coll for coll in collections if coll["id"] == in_json["id"]) is_coll_from_list_valid = False - single_coll_mocked_link = dict() + + single_coll_mocked_link: Optional[pystac.Collection] = None if single_coll is not None: single_coll_mocked_link = pystac.Collection.from_dict( single_coll, root=mock_root, preserve_dict=False diff --git a/tests/resources/test_item.py b/tests/resources/test_item.py index 5fd4b044..0f268839 100644 --- a/tests/resources/test_item.py +++ b/tests/resources/test_item.py @@ -82,9 +82,7 @@ async def test_create_item(app_client, load_test_data: Callable, load_test_colle get_item = Item.parse_obj(resp.json()) assert in_item.dict(exclude={"links"}) == get_item.dict(exclude={"links"}) - post_self_link = next( - (link for link in post_item.links if link.rel == "self"), None - ) + post_self_link = next((link for link in post_item.links if link.rel == "self"), None) get_self_link = next((link for link in get_item.links if link.rel == "self"), None) assert post_self_link is not None and get_self_link is not None assert post_self_link.href == get_self_link.href @@ -371,7 +369,7 @@ async def test_item_search_by_id_post(app_client, load_test_data, load_test_coll assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == len(ids) - assert set([feat["id"] for feat in resp_json["features"]]) == set(ids) + assert {feat["id"] for feat in resp_json["features"]} == set(ids) async def test_item_search_by_id_no_results_post( @@ -533,7 +531,7 @@ async def test_item_search_by_id_get(app_client, load_test_data, load_test_colle assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == len(ids) - assert set([feat["id"] for feat in resp_json["features"]]) == set(ids) + assert {feat["id"] for feat in resp_json["features"]} == set(ids) async def test_item_search_bbox_get(app_client, load_test_data, load_test_collection): @@ -915,7 +913,7 @@ async def test_pagination_item_collection( ids = [] # Ingest 5 items - for idx in range(5): + for _ in range(5): uid = str(uuid.uuid4()) test_item["id"] = uid resp = await app_client.post( @@ -934,14 +932,13 @@ async def test_pagination_item_collection( idx += 1 page_data = page.json() item_ids.append(page_data["features"][0]["id"]) - nextlink = [ - link["href"] for link in page_data["links"] if link["rel"] == "next" - ] + nextlink = [link["href"] for link in page_data["links"] if link["rel"] == "next"] if len(nextlink) < 1: break + page = await app_client.get(nextlink.pop()) - if idx >= 10: - assert False + + assert idx < 10 # Our limit is 1 so we expect len(ids) number of requests before we run out of pages assert idx == len(ids) @@ -956,7 +953,7 @@ async def test_pagination_post(app_client, load_test_data, load_test_collection) ids = [] # Ingest 5 items - for idx in range(5): + for _ in range(5): uid = str(uuid.uuid4()) test_item["id"] = uid resp = await app_client.post( @@ -981,12 +978,12 @@ async def test_pagination_post(app_client, load_test_data, load_test_collection) next_link = list(filter(lambda link: link["rel"] == "next", page_data["links"])) if not next_link: break + # Merge request bodies request_body.update(next_link[0]["body"]) page = await app_client.post("/search", json=request_body) - if idx > 10: - assert False + assert idx < 10 # Our limit is 1 so we expect len(ids) number of requests before we run out of pages assert idx == len(ids) @@ -1003,7 +1000,7 @@ async def test_pagination_token_idempotent( ids = [] # Ingest 5 items - for idx in range(5): + for _ in range(5): uid = str(uuid.uuid4()) test_item["id"] = uid resp = await app_client.post( @@ -1133,7 +1130,7 @@ async def test_field_extension_include_multiple_subkeys( resp_json = resp.json() resp_prop_keys = resp_json["features"][0]["properties"].keys() - assert set(resp_prop_keys) == set(["width", "height"]) + assert set(resp_prop_keys) == {"width", "height"} async def test_field_extension_include_multiple_deeply_nested_subkeys( @@ -1147,8 +1144,8 @@ async def test_field_extension_include_multiple_deeply_nested_subkeys( resp_json = resp.json() resp_assets = resp_json["features"][0]["assets"] - assert set(resp_assets.keys()) == set(["ANG"]) - assert set(resp_assets["ANG"].keys()) == set(["type", "href"]) + assert set(resp_assets.keys()) == {"ANG"} + assert set(resp_assets["ANG"].keys()) == {"type", "href"} async def test_field_extension_exclude_multiple_deeply_nested_subkeys(