From 7c028600dad85e667c03ed1fdb164dbf0566e86b Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Tue, 25 Jan 2022 10:47:14 -0800 Subject: [PATCH 01/25] feature/1 added runtime configuration of CORS --- .gitignore | 5 +- stac_fastapi/api/stac_fastapi/api/app.py | 9 ++- .../api/stac_fastapi/api/middleware.py | 55 +++++++++++++- stac_fastapi/pgstac/tests/api/cors_support.py | 23 ++++++ stac_fastapi/pgstac/tests/api/test_api.py | 73 +++++++++++++++++++ stac_fastapi/pgstac/tests/conftest.py | 25 +++---- .../sqlalchemy/tests/api/cors_support.py | 23 ++++++ stac_fastapi/sqlalchemy/tests/api/test_api.py | 71 ++++++++++++++++++ stac_fastapi/sqlalchemy/tests/conftest.py | 15 +++- 9 files changed, 277 insertions(+), 22 deletions(-) create mode 100644 stac_fastapi/pgstac/tests/api/cors_support.py create mode 100644 stac_fastapi/sqlalchemy/tests/api/cors_support.py diff --git a/.gitignore b/.gitignore index 407d0fe9a..1cc0fbf51 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,7 @@ docs/api/* .envrc # Virtualenv -venv \ No newline at end of file +venv + +# IDE +.vscode \ No newline at end of file diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 53012cc51..fbb5991be 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -13,6 +13,7 @@ from starlette.responses import JSONResponse, Response from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers +from stac_fastapi.api.middleware import MiddlewareConfig, append_runtime_middlewares from stac_fastapi.api.models import ( APIRequest, CollectionUri, @@ -87,7 +88,9 @@ class StacApi: ) pagination_extension = attr.ib(default=TokenPaginationExtension) response_class: Type[Response] = attr.ib(default=JSONResponse) - middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware])) + middlewares: List[MiddlewareConfig] = attr.ib( + default=attr.Factory(lambda: [MiddlewareConfig(middleware=BrotliMiddleware)]) + ) def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]: """Get an extension. @@ -376,5 +379,5 @@ def __attrs_post_init__(self): self.app.openapi = self.customize_openapi # add middlewares - for middleware in self.middlewares: - self.app.add_middleware(middleware) + for entry in append_runtime_middlewares(self.middlewares): + self.app.add_middleware(entry.middleware, **entry.config) diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index acb00915b..e7ed896b0 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -1,12 +1,18 @@ """api middleware.""" -from typing import Callable +from json import loads +from logging import getLogger +from os import environ, path +from typing import Any, Callable, Dict, Final, List, Optional from fastapi import APIRouter, FastAPI +from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.routing import Match +logger: Final = getLogger(__file__) + def router_middleware(app: FastAPI, router: APIRouter): """Add middleware to a specific router, assumes no router prefix.""" @@ -29,3 +35,50 @@ async def _middleware(request: Request, call_next): return func return deco + + +class MiddlewareConfig: + """Represents a middleware class plus any configuration detail.""" + + def __init__(self, middleware: Any, config: Optional[Dict[str, Any]] = None): + """Defaults config to empty dictionary if not provided.""" + self.middleware = middleware + self.config = {} if config is None else config + + +def append_runtime_middlewares( + middlewares: List[MiddlewareConfig], +) -> List[MiddlewareConfig]: + """Add any middlewares specified via environment variable and configure if appropriate.""" + extended_middlewares = middlewares.copy() + has_cors_middleware = ( + len( + [ + entry + for entry in middlewares + if isinstance(entry.middleware, CORSMiddleware) + ] + ) + > 0 + ) + if not has_cors_middleware: + cors_config_location_key: Final = "CORS_CONFIG_LOCATION" + if cors_config_location_key in environ: + cors_config_path = environ[cors_config_location_key] + logger.info(f"looking for CORS config file at {cors_config_path}") + if path.exists(cors_config_path): + try: + with open(cors_config_path, "r") as cors_config_file: + cors_config = loads("".join(cors_config_file.readlines())) + extended_middlewares.append( + MiddlewareConfig(CORSMiddleware, cors_config) + ) + logger.debug(f"loaded CORS config {cors_config}") + except ValueError as e: + logger.error(f"error parsing JSON at {cors_config_path}: {e}") + except OSError as e: + logger.error(f"error reading {cors_config_path}: {e}") + else: + logger.warning(f"CORS config not found at {cors_config_path}") + + return extended_middlewares diff --git a/stac_fastapi/pgstac/tests/api/cors_support.py b/stac_fastapi/pgstac/tests/api/cors_support.py new file mode 100644 index 000000000..9f891a8d6 --- /dev/null +++ b/stac_fastapi/pgstac/tests/api/cors_support.py @@ -0,0 +1,23 @@ +from json import dumps +from os import environ, fdopen, path, sep +from tempfile import mkstemp +from typing import Final + +cors_config_location_key: Final = "CORS_CONFIG_LOCATION" +cors_permit_origin: Final = "http://cors.pass" +cors_deny_origin: Final = "http://cors.fail" + + +def cors_enable(): + tmp_file, tmp_filename = mkstemp() + with fdopen(tmp_file, "w") as f: + f.write(dumps({"allow_origins": [cors_permit_origin]})) + environ[cors_config_location_key] = tmp_filename + + +def cors_disable() -> None: + environ.pop(cors_config_location_key, None) + + +def cors_missing(): + environ[cors_config_location_key] = path.join(path.abspath(sep), "missing.file") diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index 9ab4c2c07..a5e7ed073 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -1,6 +1,16 @@ from datetime import datetime, timedelta +from http import HTTPStatus +from os import environ import pytest +from tests.api.cors_support import ( + cors_config_location_key, + cors_deny_origin, + cors_disable, + cors_enable, + cors_missing, + cors_permit_origin, +) STAC_CORE_ROUTES = [ "GET /", @@ -23,6 +33,10 @@ ] +def teardown_function(): + environ.pop(cors_config_location_key, None) + + @pytest.mark.asyncio async def test_post_search_content_type(app_client): params = {"limit": 1} @@ -302,3 +316,62 @@ async def test_search_line_string_intersects( assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app_client", [{"setup_func": cors_disable}], indirect=True) +async def test_without_cors(app_client): + resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) +async def test_with_match_cors(app_client): + resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_permit_origin + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) +async def test_with_mismatch_cors(app_client): + resp = await app_client.get("/", headers={"Origin": cors_deny_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app_client", [{"setup_func": cors_missing}], indirect=True) +async def test_with_missing_config(app_client): + resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) diff --git a/stac_fastapi/pgstac/tests/conftest.py b/stac_fastapi/pgstac/tests/conftest.py index d2fd9a00d..929860c75 100644 --- a/stac_fastapi/pgstac/tests/conftest.py +++ b/stac_fastapi/pgstac/tests/conftest.py @@ -1,7 +1,6 @@ import asyncio import json import os -import time from typing import Callable, Dict import asyncpg @@ -92,8 +91,7 @@ async def pgstac(pg): await conn.close() -@pytest.fixture(scope="session") -def api_client(pg): +def _api_client_provider(): print("creating client with settings") extensions = [ @@ -118,24 +116,25 @@ def api_client(pg): return api -@pytest.mark.asyncio @pytest.fixture(scope="session") -async def app(api_client): - time.time() - app = api_client.app - await connect_to_db(app) - - yield app - - await close_db_connection(app) +def api_client(pg): + return _api_client_provider() @pytest.mark.asyncio @pytest.fixture(scope="session") -async def app_client(app): +async def app_client(pg, request): + setup_func = request.param.get("setup_func") if hasattr(request, "param") else None + if setup_func is not None: + setup_func() + app = _api_client_provider().app async with AsyncClient(app=app, base_url="http://test") as c: + await connect_to_db(app) + yield c + await close_db_connection(app) + @pytest.fixture def load_test_data() -> Callable[[str], Dict]: diff --git a/stac_fastapi/sqlalchemy/tests/api/cors_support.py b/stac_fastapi/sqlalchemy/tests/api/cors_support.py new file mode 100644 index 000000000..9f891a8d6 --- /dev/null +++ b/stac_fastapi/sqlalchemy/tests/api/cors_support.py @@ -0,0 +1,23 @@ +from json import dumps +from os import environ, fdopen, path, sep +from tempfile import mkstemp +from typing import Final + +cors_config_location_key: Final = "CORS_CONFIG_LOCATION" +cors_permit_origin: Final = "http://cors.pass" +cors_deny_origin: Final = "http://cors.fail" + + +def cors_enable(): + tmp_file, tmp_filename = mkstemp() + with fdopen(tmp_file, "w") as f: + f.write(dumps({"allow_origins": [cors_permit_origin]})) + environ[cors_config_location_key] = tmp_filename + + +def cors_disable() -> None: + environ.pop(cors_config_location_key, None) + + +def cors_missing(): + environ[cors_config_location_key] = path.join(path.abspath(sep), "missing.file") diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index b5c531cfb..9433abbf2 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -1,4 +1,16 @@ from datetime import datetime, timedelta +from http import HTTPStatus +from os import environ + +import pytest +from tests.api.cors_support import ( + cors_config_location_key, + cors_deny_origin, + cors_disable, + cors_enable, + cors_missing, + cors_permit_origin, +) from ..conftest import MockStarletteRequest @@ -23,6 +35,10 @@ ] +def teardown_function(): + environ.pop(cors_config_location_key, None) + + def test_post_search_content_type(app_client): params = {"limit": 1} resp = app_client.post("search", json=params) @@ -285,3 +301,58 @@ def test_search_line_string_intersects( assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 1 + + +@pytest.mark.parametrize("app_client", [{"setup_func": cors_disable}], indirect=True) +def test_without_cors(app_client): + resp = app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) + + +@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) +def test_with_match_cors(app_client): + resp = app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_permit_origin + + +@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) +def test_with_mismatch_cors(app_client): + resp = app_client.get("/", headers={"Origin": cors_deny_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) + + +@pytest.mark.parametrize("app_client", [{"setup_func": cors_missing}], indirect=True) +def test_with_missing_config(app_client): + resp = app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) diff --git a/stac_fastapi/sqlalchemy/tests/conftest.py b/stac_fastapi/sqlalchemy/tests/conftest.py index 7abd9150f..4012eeb5a 100644 --- a/stac_fastapi/sqlalchemy/tests/conftest.py +++ b/stac_fastapi/sqlalchemy/tests/conftest.py @@ -104,8 +104,7 @@ def postgres_bulk_transactions(db_session): return BulkTransactionsClient(session=db_session) -@pytest.fixture -def api_client(db_session): +def _api_client_provider(db_session): settings = SqlalchemySettings() extensions = [ TransactionExtension( @@ -146,9 +145,17 @@ def api_client(db_session): @pytest.fixture -def app_client(api_client, load_test_data, postgres_transactions): +def api_client(db_session): + return _api_client_provider(db_session) + + +@pytest.fixture +def app_client(db_session, load_test_data, postgres_transactions, request): + setup_func = request.param.get("setup_func") if hasattr(request, "param") else None + if setup_func is not None: + setup_func() coll = load_test_data("test_collection.json") postgres_transactions.create_collection(coll, request=MockStarletteRequest) - with TestClient(api_client.app) as test_app: + with TestClient(_api_client_provider(db_session).app) as test_app: yield test_app From 602e7b57d3dde1e8f87930e39247bafd2da0130a Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Wed, 26 Jan 2022 13:44:00 -0800 Subject: [PATCH 02/25] feature/1 updated documentation --- CHANGES.md | 2 ++ docs/tips-and-tricks.md | 18 ++++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 121257512..1f3a4fa3d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,6 +4,8 @@ ### Added +* Added ability to configure CORS middleware via JSON configuration file and environment variable, rather than having to modify code. + ### Changed ### Removed diff --git a/docs/tips-and-tricks.md b/docs/tips-and-tricks.md index 3d4c9ac0f..997edfd0e 100644 --- a/docs/tips-and-tricks.md +++ b/docs/tips-and-tricks.md @@ -5,16 +5,26 @@ This page contains a few 'tips and tricks' for getting stac-fastapi working in v CORS (Cross-Origin Resource Sharing) support may be required to use stac-fastapi in certain situations. For example, if you are running [stac-browser](https://github.com/radiantearth/stac-browser) to browse the STAC catalog created by stac-fastapi, then you will need to enable CORS support. -To do this, edit `stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py` (or the equivalent in the `pgstac` folder) and add the following import: +To do this, create a JSON configuration file whose schema matches the options described in the [FastAPI docs](https://fastapi.tiangolo.com/tutorial/cors/), e.g. ``` -from fastapi.middleware.cors import CORSMiddleware +{ + "allow_origins": ["*"], + "allow_methods": ["*"] +} ``` -and then edit the `api = StacApi(...` call to add the following parameter: +Deploy this file to a location accessible by stac-fastapi, e.g. in Dockerfile ``` -middlewares=[lambda app: CORSMiddleware(app, allow_origins=["*"])] +RUN mkdir /config +COPY cors.json /config/cors.json +``` + +Set an environment variable `CORS_CONFIG_LOCATION` pointing to this file, e.g. in Dockerfile + +``` +ENV CORS_CONFIG_LOCATION=/config/cors.json ``` If needed, you can edit the `allow_origins` parameter to only allow CORS requests from specific origins. From d2bc362c6a641f191744881e4f08ddd20014570f Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Wed, 26 Jan 2022 13:53:54 -0800 Subject: [PATCH 03/25] feature/1 doc updates --- stac_fastapi/api/stac_fastapi/api/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index fbb5991be..69834dce5 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -89,7 +89,7 @@ class StacApi: pagination_extension = attr.ib(default=TokenPaginationExtension) response_class: Type[Response] = attr.ib(default=JSONResponse) middlewares: List[MiddlewareConfig] = attr.ib( - default=attr.Factory(lambda: [MiddlewareConfig(middleware=BrotliMiddleware)]) + default=attr.Factory(lambda: [MiddlewareConfig(BrotliMiddleware)]) ) def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]: From 1e4edd1aa7ef3398823e4872356cd776557d5599 Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Thu, 27 Jan 2022 09:08:59 -0800 Subject: [PATCH 04/25] feature/1 PR feedback and additional test --- .../api/stac_fastapi/api/middleware.py | 37 ++++++++++++------- stac_fastapi/pgstac/tests/api/test_api.py | 33 +++++++++++++++++ stac_fastapi/pgstac/tests/conftest.py | 20 ++++++---- stac_fastapi/sqlalchemy/tests/api/test_api.py | 32 ++++++++++++++++ stac_fastapi/sqlalchemy/tests/conftest.py | 22 ++++++++--- 5 files changed, 116 insertions(+), 28 deletions(-) diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index e7ed896b0..f1b292690 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -47,38 +47,47 @@ def __init__(self, middleware: Any, config: Optional[Dict[str, Any]] = None): def append_runtime_middlewares( - middlewares: List[MiddlewareConfig], + existing_middlewares: List[MiddlewareConfig], ) -> List[MiddlewareConfig]: """Add any middlewares specified via environment variable and configure if appropriate.""" - extended_middlewares = middlewares.copy() + return existing_middlewares + [ + addition + for addition in [_append_cors_middleware(existing_middlewares)] + if addition is not None + ] + + +def _append_cors_middleware( + existing_middlewares: List[MiddlewareConfig], +) -> Optional[MiddlewareConfig]: has_cors_middleware = ( len( [ entry - for entry in middlewares - if isinstance(entry.middleware, CORSMiddleware) + for entry in existing_middlewares + if entry.middleware == CORSMiddleware ] ) > 0 ) - if not has_cors_middleware: - cors_config_location_key: Final = "CORS_CONFIG_LOCATION" - if cors_config_location_key in environ: - cors_config_path = environ[cors_config_location_key] + cors_config_location_key: Final = "CORS_CONFIG_LOCATION" + if cors_config_location_key in environ: + cors_config_path = environ[cors_config_location_key] + if has_cors_middleware: + logger.warning( + f"CORSMiddleware already configured; ignoring config at {cors_config_path}" + ) + else: logger.info(f"looking for CORS config file at {cors_config_path}") if path.exists(cors_config_path): try: with open(cors_config_path, "r") as cors_config_file: - cors_config = loads("".join(cors_config_file.readlines())) - extended_middlewares.append( - MiddlewareConfig(CORSMiddleware, cors_config) - ) + cors_config = loads(cors_config_file.read()) logger.debug(f"loaded CORS config {cors_config}") + return MiddlewareConfig(CORSMiddleware, cors_config) except ValueError as e: logger.error(f"error parsing JSON at {cors_config_path}: {e}") except OSError as e: logger.error(f"error reading {cors_config_path}: {e}") else: logger.warning(f"CORS config not found at {cors_config_path}") - - return extended_middlewares diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index a5e7ed073..2485cb07b 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -3,6 +3,7 @@ from os import environ import pytest +from fastapi.middleware.cors import CORSMiddleware from tests.api.cors_support import ( cors_config_location_key, cors_deny_origin, @@ -12,6 +13,8 @@ cors_permit_origin, ) +from stac_fastapi.api.middleware import MiddlewareConfig + STAC_CORE_ROUTES = [ "GET /", "GET /collections", @@ -375,3 +378,33 @@ async def test_with_missing_config(app_client): ) == 0 ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "app_client", + [ + { + "setup_func": cors_enable, + "middleware_configs": [ + MiddlewareConfig( + CORSMiddleware, {"allow_origins": ["http://different.origin"]} + ) + ], + } + ], + indirect=True, +) +async def test_with_existing_cors(app_client): + resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) diff --git a/stac_fastapi/pgstac/tests/conftest.py b/stac_fastapi/pgstac/tests/conftest.py index 929860c75..aa2729890 100644 --- a/stac_fastapi/pgstac/tests/conftest.py +++ b/stac_fastapi/pgstac/tests/conftest.py @@ -1,7 +1,7 @@ import asyncio import json import os -from typing import Callable, Dict +from typing import Callable, Dict, Optional import asyncpg import pytest @@ -11,6 +11,7 @@ from stac_pydantic import Collection, Item from stac_fastapi.api.app import StacApi +from stac_fastapi.api.middleware import MiddlewareConfig from stac_fastapi.api.models import create_get_request_model, create_post_request_model from stac_fastapi.extensions.core import ( FieldsExtension, @@ -91,7 +92,7 @@ async def pgstac(pg): await conn.close() -def _api_client_provider(): +def _api_client_provider(middleware_configs: Optional[MiddlewareConfig] = []): print("creating client with settings") extensions = [ @@ -111,6 +112,7 @@ def _api_client_provider(): search_get_request_model=create_get_request_model(extensions), search_post_request_model=post_request_model, response_class=ORJSONResponse, + middlewares=middleware_configs, ) return api @@ -124,15 +126,17 @@ def api_client(pg): @pytest.mark.asyncio @pytest.fixture(scope="session") async def app_client(pg, request): - setup_func = request.param.get("setup_func") if hasattr(request, "param") else None - if setup_func is not None: - setup_func() - app = _api_client_provider().app + # support custom behaviours driven by fixture caller + middleware_configs = [] + if hasattr(request, "param"): + setup_func = request.param.get("setup_func") + if setup_func is not None: + setup_func() + middleware_configs = request.param.get("middleware_configs", []) + app = _api_client_provider(middleware_configs=middleware_configs).app async with AsyncClient(app=app, base_url="http://test") as c: await connect_to_db(app) - yield c - await close_db_connection(app) diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index 9433abbf2..8eda6157a 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -3,6 +3,7 @@ from os import environ import pytest +from fastapi.middleware.cors import CORSMiddleware from tests.api.cors_support import ( cors_config_location_key, cors_deny_origin, @@ -12,6 +13,8 @@ cors_permit_origin, ) +from stac_fastapi.api.middleware import MiddlewareConfig + from ..conftest import MockStarletteRequest STAC_CORE_ROUTES = [ @@ -356,3 +359,32 @@ def test_with_missing_config(app_client): ) == 0 ) + + +@pytest.mark.parametrize( + "app_client", + [ + { + "setup_func": cors_enable, + "middleware_configs": [ + MiddlewareConfig( + CORSMiddleware, {"allow_origins": ["http://different.origin"]} + ) + ], + } + ], + indirect=True, +) +async def test_with_existing_cors(app_client): + resp = app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) diff --git a/stac_fastapi/sqlalchemy/tests/conftest.py b/stac_fastapi/sqlalchemy/tests/conftest.py index 4012eeb5a..2f1889327 100644 --- a/stac_fastapi/sqlalchemy/tests/conftest.py +++ b/stac_fastapi/sqlalchemy/tests/conftest.py @@ -1,11 +1,12 @@ import json import os -from typing import Callable, Dict +from typing import Callable, Dict, Optional import pytest from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi +from stac_fastapi.api.middleware import MiddlewareConfig from stac_fastapi.api.models import create_request_model from stac_fastapi.extensions.core import ( ContextExtension, @@ -104,7 +105,9 @@ def postgres_bulk_transactions(db_session): return BulkTransactionsClient(session=db_session) -def _api_client_provider(db_session): +def _api_client_provider( + db_session, middleware_configs: Optional[MiddlewareConfig] = [] +): settings = SqlalchemySettings() extensions = [ TransactionExtension( @@ -141,6 +144,7 @@ def _api_client_provider(db_session): extensions=extensions, search_get_request_model=get_request_model, search_post_request_model=post_request_model, + middlewares=middleware_configs, ) @@ -151,11 +155,17 @@ def api_client(db_session): @pytest.fixture def app_client(db_session, load_test_data, postgres_transactions, request): - setup_func = request.param.get("setup_func") if hasattr(request, "param") else None - if setup_func is not None: - setup_func() + # support custom behaviours driven by fixture caller + middleware_configs = [] + if hasattr(request, "param"): + setup_func = request.param.get("setup_func") + if setup_func is not None: + setup_func() + middleware_configs = request.param.get("middleware_configs", []) coll = load_test_data("test_collection.json") postgres_transactions.create_collection(coll, request=MockStarletteRequest) - with TestClient(_api_client_provider(db_session).app) as test_app: + with TestClient( + _api_client_provider(db_session, middleware_configs=middleware_configs).app + ) as test_app: yield test_app From 15c08665b41c47d129cfd910b74154930081d3fe Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Fri, 28 Jan 2022 16:12:13 -0800 Subject: [PATCH 05/25] feature/1 removed unwanted async on test --- stac_fastapi/sqlalchemy/tests/api/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index 8eda6157a..63f5132c6 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -375,7 +375,7 @@ def test_with_missing_config(app_client): ], indirect=True, ) -async def test_with_existing_cors(app_client): +def test_with_existing_cors(app_client): resp = app_client.get("/", headers={"Origin": cors_permit_origin}) assert resp.status_code == HTTPStatus.OK assert ( From 1a3c55b078eac3de381f66f3c7598728bc5ebbcf Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Tue, 1 Feb 2022 11:27:39 -0800 Subject: [PATCH 06/25] feature/1 updated with PR feedback from stac-fastapi --- stac_fastapi/api/stac_fastapi/api/app.py | 10 +- stac_fastapi/api/stac_fastapi/api/config.py | 51 +++++++ .../api/stac_fastapi/api/middleware.py | 130 ++++++++++-------- stac_fastapi/pgstac/tests/api/cors_support.py | 52 +++++-- stac_fastapi/pgstac/tests/api/test_api.py | 115 +++++----------- stac_fastapi/pgstac/tests/conftest.py | 10 +- .../sqlalchemy/tests/api/cors_support.py | 52 +++++-- stac_fastapi/sqlalchemy/tests/api/test_api.py | 114 +++++---------- stac_fastapi/sqlalchemy/tests/conftest.py | 14 +- 9 files changed, 284 insertions(+), 264 deletions(-) diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 69834dce5..b525552e2 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -13,7 +13,7 @@ from starlette.responses import JSONResponse, Response from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers -from stac_fastapi.api.middleware import MiddlewareConfig, append_runtime_middlewares +from stac_fastapi.api.middleware import CORSMiddleware from stac_fastapi.api.models import ( APIRequest, CollectionUri, @@ -88,8 +88,8 @@ class StacApi: ) pagination_extension = attr.ib(default=TokenPaginationExtension) response_class: Type[Response] = attr.ib(default=JSONResponse) - middlewares: List[MiddlewareConfig] = attr.ib( - default=attr.Factory(lambda: [MiddlewareConfig(BrotliMiddleware)]) + middlewares: List = attr.ib( + default=attr.Factory(lambda: [BrotliMiddleware, CORSMiddleware]) ) def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]: @@ -379,5 +379,5 @@ def __attrs_post_init__(self): self.app.openapi = self.customize_openapi # add middlewares - for entry in append_runtime_middlewares(self.middlewares): - self.app.add_middleware(entry.middleware, **entry.config) + for middleware in self.middlewares: + self.app.add_middleware(middleware) diff --git a/stac_fastapi/api/stac_fastapi/api/config.py b/stac_fastapi/api/stac_fastapi/api/config.py index 3a423e45d..9f1b8c90a 100644 --- a/stac_fastapi/api/stac_fastapi/api/config.py +++ b/stac_fastapi/api/stac_fastapi/api/config.py @@ -1,5 +1,11 @@ """Application settings.""" import enum +import re +from logging import getLogger +from os import environ +from typing import Final, Sequence + +logger: Final = getLogger(__file__) # TODO: Move to stac-pydantic @@ -22,3 +28,48 @@ class AddOns(enum.Enum): """Enumeration of available third party add ons.""" bulk_transaction = "bulk-transaction" + + +def env_to_sequence( + env_var: str, default: Sequence[str], sequence_separator: str = "|" +) -> Sequence[str]: + """Retrieve a sequence of values from an env var string, or default if missing.""" + if env_var in environ: + if re.search(re.escape(sequence_separator), environ[env_var]): + return tuple( + [part for part in environ[env_var].split(sequence_separator) if part] + ) + else: + return (environ[env_var],) + else: + return default + + +def env_to_str(env_var: str, default: str) -> str: + """Retrieve a string from an env var, or default if missing.""" + if env_var in environ: + return environ[env_var] + else: + return default + + +def env_to_bool(env_var: str, default: bool) -> bool: + """Retrieve a bool from an env var, or default if missing.""" + if env_var in environ: + if re.match("^(true|1)$", environ[env_var], re.IGNORECASE): + return True + if re.match("^(false|0)$", environ[env_var], re.IGNORECASE): + return False + logger.warning(f"{env_var} set but not a valid bool") + return default + + +def env_to_int(env_var: str, default: int) -> int: + """Retrieve an int from an env var, or default if missing.""" + if env_var in environ: + value = environ[env_var].strip() + if value.isdigit(): + return int(value) + else: + logger.warning(f"{env_var} set but not a valid int") + return default diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index f1b292690..d1974ff7c 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -1,15 +1,16 @@ """api middleware.""" -from json import loads from logging import getLogger -from os import environ, path -from typing import Any, Callable, Dict, Final, List, Optional +from typing import Callable, Final, Optional, Sequence from fastapi import APIRouter, FastAPI -from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware import cors from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.routing import Match +from starlette.types import ASGIApp + +from stac_fastapi.api.config import env_to_bool, env_to_int, env_to_sequence, env_to_str logger: Final = getLogger(__file__) @@ -37,57 +38,76 @@ async def _middleware(request: Request, call_next): return deco -class MiddlewareConfig: - """Represents a middleware class plus any configuration detail.""" - - def __init__(self, middleware: Any, config: Optional[Dict[str, Any]] = None): - """Defaults config to empty dictionary if not provided.""" - self.middleware = middleware - self.config = {} if config is None else config - +class CORSMiddleware(cors.CORSMiddleware): + """Starlette CORS Middleware with default.""" -def append_runtime_middlewares( - existing_middlewares: List[MiddlewareConfig], -) -> List[MiddlewareConfig]: - """Add any middlewares specified via environment variable and configure if appropriate.""" - return existing_middlewares + [ - addition - for addition in [_append_cors_middleware(existing_middlewares)] - if addition is not None - ] - - -def _append_cors_middleware( - existing_middlewares: List[MiddlewareConfig], -) -> Optional[MiddlewareConfig]: - has_cors_middleware = ( - len( - [ - entry - for entry in existing_middlewares - if entry.middleware == CORSMiddleware - ] + def __init__( + self, + app: ASGIApp, + allow_origins: Optional[Sequence[str]] = None, + allow_methods: Optional[Sequence[str]] = None, + allow_headers: Optional[Sequence[str]] = None, + allow_credentials: Optional[bool] = None, + allow_origin_regex: Optional[str] = None, + expose_headers: Optional[Sequence[str]] = None, + max_age: Optional[int] = None, + ) -> None: + """Create CORSMiddleware Object.""" + allow_origins = ( + env_to_sequence("CORS_ALLOW_ORIGINS", ("*",)) + if allow_origins is None + else allow_origins + ) + allow_methods = ( + env_to_sequence("CORS_ALLOW_METHODS", ("*",)) + if allow_methods is None + else allow_methods ) - > 0 - ) - cors_config_location_key: Final = "CORS_CONFIG_LOCATION" - if cors_config_location_key in environ: - cors_config_path = environ[cors_config_location_key] - if has_cors_middleware: - logger.warning( - f"CORSMiddleware already configured; ignoring config at {cors_config_path}" + allow_headers = ( + env_to_sequence("CORS_ALLOW_HEADERS", ("*",)) + if allow_headers is None + else allow_headers + ) + allow_credentials = ( + env_to_bool("CORS_ALLOW_CREDENTIALS", False) + if allow_credentials is None + else allow_credentials + ) + allow_origin_regex = ( + env_to_str("CORS_ALLOW_ORIGIN_REGEX", None) + if allow_origin_regex is None + else allow_origin_regex + ) + if allow_origin_regex is not None: + logger.info( + "CORS_ALLOW_ORIGIN_REGEX present and will override CORS_ALLOW_ORIGINS" ) - else: - logger.info(f"looking for CORS config file at {cors_config_path}") - if path.exists(cors_config_path): - try: - with open(cors_config_path, "r") as cors_config_file: - cors_config = loads(cors_config_file.read()) - logger.debug(f"loaded CORS config {cors_config}") - return MiddlewareConfig(CORSMiddleware, cors_config) - except ValueError as e: - logger.error(f"error parsing JSON at {cors_config_path}: {e}") - except OSError as e: - logger.error(f"error reading {cors_config_path}: {e}") - else: - logger.warning(f"CORS config not found at {cors_config_path}") + allow_origins = "" + expose_headers = ( + env_to_sequence("CORS_EXPOSE_HEADERS", ("*",)) + if expose_headers is None + else expose_headers + ) + max_age = env_to_int("CORS_MAX_AGE", 600) if max_age is None else max_age + logger.debug( + f""" + CORS configuration + allow_origins: {allow_origins} + allow_methods: {allow_methods} + allow_headers: {allow_headers} + allow_credentials: {allow_credentials} + allow_origin_regex: {allow_origin_regex} + expose_headers: {expose_headers} + max_age: {max_age} + """ + ) + super().__init__( + app, + allow_origins=allow_origins, + allow_methods=allow_methods, + allow_headers=allow_headers, + allow_credentials=allow_credentials, + allow_origin_regex=allow_origin_regex, + expose_headers=expose_headers, + max_age=max_age, + ) diff --git a/stac_fastapi/pgstac/tests/api/cors_support.py b/stac_fastapi/pgstac/tests/api/cors_support.py index 9f891a8d6..b0471c957 100644 --- a/stac_fastapi/pgstac/tests/api/cors_support.py +++ b/stac_fastapi/pgstac/tests/api/cors_support.py @@ -1,23 +1,45 @@ -from json import dumps -from os import environ, fdopen, path, sep -from tempfile import mkstemp +from os import environ from typing import Final -cors_config_location_key: Final = "CORS_CONFIG_LOCATION" -cors_permit_origin: Final = "http://cors.pass" -cors_deny_origin: Final = "http://cors.fail" +cors_origin_1: Final = "http://permit.one" +cors_origin_2: Final = "http://permit.two" +cors_origin_3: Final = "http://permit.three" +cors_origin_deny: Final = "http://deny.me" -def cors_enable(): - tmp_file, tmp_filename = mkstemp() - with fdopen(tmp_file, "w") as f: - f.write(dumps({"allow_origins": [cors_permit_origin]})) - environ[cors_config_location_key] = tmp_filename +def cors_permit_1(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_1 -def cors_disable() -> None: - environ.pop(cors_config_location_key, None) +def cors_permit_2(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_2 -def cors_missing(): - environ[cors_config_location_key] = path.join(path.abspath(sep), "missing.file") +def cors_permit_3(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_3 + + +def cors_permit_12(): + environ["CORS_ALLOW_ORIGINS"] = f"{cors_origin_1}|{cors_origin_2}" + + +def cors_permit_123_regex(): + environ["CORS_ALLOW_ORIGIN_REGEX"] = "http\\://permit\\..+" + + +def cors_deny(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_deny + + +def cors_disable_get(): + environ["CORS_ALLOW_METHODS"] = "HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH" + + +def cors_clear_config(): + environ.pop("CORS_ALLOW_ORIGINS", None) + environ.pop("CORS_ALLOW_METHODS", None) + environ.pop("CORS_ALLOW_HEADERS", None) + environ.pop("CORS_ALLOW_CREDENTIALS", None) + environ.pop("CORS_ALLOW_ORIGIN_REGEX", None) + environ.pop("CORS_EXPOSE_HEADERS", None) + environ.pop("CORS_MAX_AGE", None) diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index 2485cb07b..ebc432907 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -1,20 +1,17 @@ from datetime import datetime, timedelta from http import HTTPStatus -from os import environ import pytest -from fastapi.middleware.cors import CORSMiddleware from tests.api.cors_support import ( - cors_config_location_key, - cors_deny_origin, - cors_disable, - cors_enable, - cors_missing, - cors_permit_origin, + cors_clear_config, + cors_deny, + cors_origin_1, + cors_origin_deny, + cors_permit_1, + cors_permit_12, + cors_permit_123_regex, ) -from stac_fastapi.api.middleware import MiddlewareConfig - STAC_CORE_ROUTES = [ "GET /", "GET /collections", @@ -37,7 +34,7 @@ def teardown_function(): - environ.pop(cors_config_location_key, None) + cors_clear_config() @pytest.mark.asyncio @@ -322,89 +319,51 @@ async def test_search_line_string_intersects( @pytest.mark.asyncio -@pytest.mark.parametrize("app_client", [{"setup_func": cors_disable}], indirect=True) -async def test_without_cors(app_client): - resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) +async def test_with_default_cors_origin(app_client): + resp = await app_client.get("/", headers={"Origin": cors_origin_1}) assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) + assert resp.headers["access-control-allow-origin"] == "*" @pytest.mark.asyncio -@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) -async def test_with_match_cors(app_client): - resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) +@pytest.mark.parametrize("app_client", [{"setup_func": cors_permit_1}], indirect=True) +async def test_with_match_cors_single(app_client): + resp = await app_client.get("/", headers={"Origin": cors_origin_1}) assert resp.status_code == HTTPStatus.OK - assert resp.headers["access-control-allow-origin"] == cors_permit_origin + assert resp.headers["access-control-allow-origin"] == cors_origin_1 @pytest.mark.asyncio -@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) -async def test_with_mismatch_cors(app_client): - resp = await app_client.get("/", headers={"Origin": cors_deny_origin}) +@pytest.mark.parametrize("app_client", [{"setup_func": cors_permit_12}], indirect=True) +async def test_with_match_cors_double(app_client): + resp = await app_client.get("/", headers={"Origin": cors_origin_1}) assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) + assert resp.headers["access-control-allow-origin"] == cors_origin_1 @pytest.mark.asyncio -@pytest.mark.parametrize("app_client", [{"setup_func": cors_missing}], indirect=True) -async def test_with_missing_config(app_client): - resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) +@pytest.mark.parametrize( + "app_client", [{"setup_func": cors_permit_123_regex}], indirect=True +) +async def test_with_match_cors_all_regex_match(app_client): + resp = await app_client.get("/", headers={"Origin": cors_origin_1}) assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) + assert resp.headers["access-control-allow-origin"] == cors_origin_1 @pytest.mark.asyncio @pytest.mark.parametrize( - "app_client", - [ - { - "setup_func": cors_enable, - "middleware_configs": [ - MiddlewareConfig( - CORSMiddleware, {"allow_origins": ["http://different.origin"]} - ) - ], - } - ], - indirect=True, + "app_client", [{"setup_func": cors_permit_123_regex}], indirect=True ) -async def test_with_existing_cors(app_client): - resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) +async def test_with_match_cors_all_regex_mismatch(app_client): + resp = await app_client.get("/", headers={"Origin": cors_origin_deny}) assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) + assert "access-control-allow-origin" not in resp.headers + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app_client", [{"setup_func": cors_deny}], indirect=True) +async def test_with_mismatch_cors_origin(app_client): + resp = await app_client.get("/", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert "access-control-allow-origin" not in resp.headers diff --git a/stac_fastapi/pgstac/tests/conftest.py b/stac_fastapi/pgstac/tests/conftest.py index aa2729890..793c809a2 100644 --- a/stac_fastapi/pgstac/tests/conftest.py +++ b/stac_fastapi/pgstac/tests/conftest.py @@ -1,7 +1,7 @@ import asyncio import json import os -from typing import Callable, Dict, Optional +from typing import Callable, Dict import asyncpg import pytest @@ -11,7 +11,6 @@ from stac_pydantic import Collection, Item from stac_fastapi.api.app import StacApi -from stac_fastapi.api.middleware import MiddlewareConfig from stac_fastapi.api.models import create_get_request_model, create_post_request_model from stac_fastapi.extensions.core import ( FieldsExtension, @@ -92,7 +91,7 @@ async def pgstac(pg): await conn.close() -def _api_client_provider(middleware_configs: Optional[MiddlewareConfig] = []): +def _api_client_provider(): print("creating client with settings") extensions = [ @@ -112,7 +111,6 @@ def _api_client_provider(middleware_configs: Optional[MiddlewareConfig] = []): search_get_request_model=create_get_request_model(extensions), search_post_request_model=post_request_model, response_class=ORJSONResponse, - middlewares=middleware_configs, ) return api @@ -127,13 +125,11 @@ def api_client(pg): @pytest.fixture(scope="session") async def app_client(pg, request): # support custom behaviours driven by fixture caller - middleware_configs = [] if hasattr(request, "param"): setup_func = request.param.get("setup_func") if setup_func is not None: setup_func() - middleware_configs = request.param.get("middleware_configs", []) - app = _api_client_provider(middleware_configs=middleware_configs).app + app = _api_client_provider().app async with AsyncClient(app=app, base_url="http://test") as c: await connect_to_db(app) yield c diff --git a/stac_fastapi/sqlalchemy/tests/api/cors_support.py b/stac_fastapi/sqlalchemy/tests/api/cors_support.py index 9f891a8d6..b0471c957 100644 --- a/stac_fastapi/sqlalchemy/tests/api/cors_support.py +++ b/stac_fastapi/sqlalchemy/tests/api/cors_support.py @@ -1,23 +1,45 @@ -from json import dumps -from os import environ, fdopen, path, sep -from tempfile import mkstemp +from os import environ from typing import Final -cors_config_location_key: Final = "CORS_CONFIG_LOCATION" -cors_permit_origin: Final = "http://cors.pass" -cors_deny_origin: Final = "http://cors.fail" +cors_origin_1: Final = "http://permit.one" +cors_origin_2: Final = "http://permit.two" +cors_origin_3: Final = "http://permit.three" +cors_origin_deny: Final = "http://deny.me" -def cors_enable(): - tmp_file, tmp_filename = mkstemp() - with fdopen(tmp_file, "w") as f: - f.write(dumps({"allow_origins": [cors_permit_origin]})) - environ[cors_config_location_key] = tmp_filename +def cors_permit_1(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_1 -def cors_disable() -> None: - environ.pop(cors_config_location_key, None) +def cors_permit_2(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_2 -def cors_missing(): - environ[cors_config_location_key] = path.join(path.abspath(sep), "missing.file") +def cors_permit_3(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_3 + + +def cors_permit_12(): + environ["CORS_ALLOW_ORIGINS"] = f"{cors_origin_1}|{cors_origin_2}" + + +def cors_permit_123_regex(): + environ["CORS_ALLOW_ORIGIN_REGEX"] = "http\\://permit\\..+" + + +def cors_deny(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_deny + + +def cors_disable_get(): + environ["CORS_ALLOW_METHODS"] = "HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH" + + +def cors_clear_config(): + environ.pop("CORS_ALLOW_ORIGINS", None) + environ.pop("CORS_ALLOW_METHODS", None) + environ.pop("CORS_ALLOW_HEADERS", None) + environ.pop("CORS_ALLOW_CREDENTIALS", None) + environ.pop("CORS_ALLOW_ORIGIN_REGEX", None) + environ.pop("CORS_EXPOSE_HEADERS", None) + environ.pop("CORS_MAX_AGE", None) diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index 63f5132c6..c03fa3473 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -1,20 +1,17 @@ from datetime import datetime, timedelta from http import HTTPStatus -from os import environ import pytest -from fastapi.middleware.cors import CORSMiddleware from tests.api.cors_support import ( - cors_config_location_key, - cors_deny_origin, - cors_disable, - cors_enable, - cors_missing, - cors_permit_origin, + cors_clear_config, + cors_deny, + cors_origin_1, + cors_origin_deny, + cors_permit_1, + cors_permit_12, + cors_permit_123_regex, ) -from stac_fastapi.api.middleware import MiddlewareConfig - from ..conftest import MockStarletteRequest STAC_CORE_ROUTES = [ @@ -39,7 +36,7 @@ def teardown_function(): - environ.pop(cors_config_location_key, None) + cors_clear_config() def test_post_search_content_type(app_client): @@ -306,85 +303,46 @@ def test_search_line_string_intersects( assert len(resp_json["features"]) == 1 -@pytest.mark.parametrize("app_client", [{"setup_func": cors_disable}], indirect=True) -def test_without_cors(app_client): - resp = app_client.get("/", headers={"Origin": cors_permit_origin}) +def test_with_default_cors_origin(app_client): + resp = app_client.get("/", headers={"Origin": cors_origin_1}) assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) + assert resp.headers["access-control-allow-origin"] == "*" -@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) -def test_with_match_cors(app_client): - resp = app_client.get("/", headers={"Origin": cors_permit_origin}) +@pytest.mark.parametrize("app_client", [{"setup_func": cors_permit_1}], indirect=True) +def test_with_match_cors_single(app_client): + resp = app_client.get("/", headers={"Origin": cors_origin_1}) assert resp.status_code == HTTPStatus.OK - assert resp.headers["access-control-allow-origin"] == cors_permit_origin + assert resp.headers["access-control-allow-origin"] == cors_origin_1 -@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) -def test_with_mismatch_cors(app_client): - resp = app_client.get("/", headers={"Origin": cors_deny_origin}) +@pytest.mark.parametrize("app_client", [{"setup_func": cors_permit_12}], indirect=True) +def test_with_match_cors_double(app_client): + resp = app_client.get("/", headers={"Origin": cors_origin_1}) assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) + assert resp.headers["access-control-allow-origin"] == cors_origin_1 -@pytest.mark.parametrize("app_client", [{"setup_func": cors_missing}], indirect=True) -def test_with_missing_config(app_client): - resp = app_client.get("/", headers={"Origin": cors_permit_origin}) +@pytest.mark.parametrize( + "app_client", [{"setup_func": cors_permit_123_regex}], indirect=True +) +def test_with_match_cors_all_regex_match(app_client): + resp = app_client.get("/", headers={"Origin": cors_origin_1}) assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) + assert resp.headers["access-control-allow-origin"] == cors_origin_1 @pytest.mark.parametrize( - "app_client", - [ - { - "setup_func": cors_enable, - "middleware_configs": [ - MiddlewareConfig( - CORSMiddleware, {"allow_origins": ["http://different.origin"]} - ) - ], - } - ], - indirect=True, + "app_client", [{"setup_func": cors_permit_123_regex}], indirect=True ) -def test_with_existing_cors(app_client): - resp = app_client.get("/", headers={"Origin": cors_permit_origin}) +def test_with_match_cors_all_regex_mismatch(app_client): + resp = app_client.get("/", headers={"Origin": cors_origin_deny}) assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) + assert "access-control-allow-origin" not in resp.headers + + +@pytest.mark.parametrize("app_client", [{"setup_func": cors_deny}], indirect=True) +def test_with_mismatch_cors_origin(app_client): + resp = app_client.get("/", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert "access-control-allow-origin" not in resp.headers diff --git a/stac_fastapi/sqlalchemy/tests/conftest.py b/stac_fastapi/sqlalchemy/tests/conftest.py index 2f1889327..1ef44f62b 100644 --- a/stac_fastapi/sqlalchemy/tests/conftest.py +++ b/stac_fastapi/sqlalchemy/tests/conftest.py @@ -1,12 +1,11 @@ import json import os -from typing import Callable, Dict, Optional +from typing import Callable, Dict import pytest from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi -from stac_fastapi.api.middleware import MiddlewareConfig from stac_fastapi.api.models import create_request_model from stac_fastapi.extensions.core import ( ContextExtension, @@ -105,9 +104,7 @@ def postgres_bulk_transactions(db_session): return BulkTransactionsClient(session=db_session) -def _api_client_provider( - db_session, middleware_configs: Optional[MiddlewareConfig] = [] -): +def _api_client_provider(db_session): settings = SqlalchemySettings() extensions = [ TransactionExtension( @@ -144,7 +141,6 @@ def _api_client_provider( extensions=extensions, search_get_request_model=get_request_model, search_post_request_model=post_request_model, - middlewares=middleware_configs, ) @@ -156,16 +152,12 @@ def api_client(db_session): @pytest.fixture def app_client(db_session, load_test_data, postgres_transactions, request): # support custom behaviours driven by fixture caller - middleware_configs = [] if hasattr(request, "param"): setup_func = request.param.get("setup_func") if setup_func is not None: setup_func() - middleware_configs = request.param.get("middleware_configs", []) coll = load_test_data("test_collection.json") postgres_transactions.create_collection(coll, request=MockStarletteRequest) - with TestClient( - _api_client_provider(db_session, middleware_configs=middleware_configs).app - ) as test_app: + with TestClient(_api_client_provider(db_session).app) as test_app: yield test_app From ee6934750cf9d1a9e4a48ad9320fbe88cc0a8ac1 Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Tue, 1 Feb 2022 11:32:20 -0800 Subject: [PATCH 07/25] feature/1 updated documentation --- CHANGES.md | 2 +- docs/tips-and-tricks.md | 23 +++-------------------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 1f3a4fa3d..d78ff3ef8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,7 +4,7 @@ ### Added -* Added ability to configure CORS middleware via JSON configuration file and environment variable, rather than having to modify code. +* Added ability to configure CORS middleware via environment variables ([#341](https://github.com/stac-utils/stac-fastapi/pull/341)) ### Changed diff --git a/docs/tips-and-tricks.md b/docs/tips-and-tricks.md index 997edfd0e..294274c77 100644 --- a/docs/tips-and-tricks.md +++ b/docs/tips-and-tricks.md @@ -5,30 +5,13 @@ This page contains a few 'tips and tricks' for getting stac-fastapi working in v CORS (Cross-Origin Resource Sharing) support may be required to use stac-fastapi in certain situations. For example, if you are running [stac-browser](https://github.com/radiantearth/stac-browser) to browse the STAC catalog created by stac-fastapi, then you will need to enable CORS support. -To do this, create a JSON configuration file whose schema matches the options described in the [FastAPI docs](https://fastapi.tiangolo.com/tutorial/cors/), e.g. +To do this, configure environment variables for the configuration options described in [FastAPI docs](https://fastapi.tiangolo.com/tutorial/cors/), using a `CORS_` prefix and upper-case, e.g. ``` -{ - "allow_origins": ["*"], - "allow_methods": ["*"] -} +CORS_ALLOW_ORIGINS=http://domain.one|http://domain.two +CORS_ALLOW_METHODS="*" ``` -Deploy this file to a location accessible by stac-fastapi, e.g. in Dockerfile - -``` -RUN mkdir /config -COPY cors.json /config/cors.json -``` - -Set an environment variable `CORS_CONFIG_LOCATION` pointing to this file, e.g. in Dockerfile - -``` -ENV CORS_CONFIG_LOCATION=/config/cors.json -``` - -If needed, you can edit the `allow_origins` parameter to only allow CORS requests from specific origins. - ## Enable the Context extension The Context STAC extension provides information on the number of items matched and returned from a STAC search. This is required by various other STAC-related tools, such as the pystac command-line client. To enable the extension, edit `stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py` (or the equivalent in the `pgstac` folder) and add the following import: From 939ae130da32837af88293eabaf5d4919d6d376b Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Tue, 1 Feb 2022 11:33:14 -0800 Subject: [PATCH 08/25] feature/1 updated documentation --- docs/tips-and-tricks.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/tips-and-tricks.md b/docs/tips-and-tricks.md index 294274c77..4e763d107 100644 --- a/docs/tips-and-tricks.md +++ b/docs/tips-and-tricks.md @@ -9,7 +9,6 @@ To do this, configure environment variables for the configuration options descri ``` CORS_ALLOW_ORIGINS=http://domain.one|http://domain.two -CORS_ALLOW_METHODS="*" ``` ## Enable the Context extension From 1007d52ee4e4d8ae968c3752cfd658dc102fca5e Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Wed, 2 Feb 2022 09:54:18 -0800 Subject: [PATCH 09/25] feature/1 follow pydantic configuration standard --- docs/tips-and-tricks.md | 9 ++- stac_fastapi/api/stac_fastapi/api/__init__.py | 5 ++ stac_fastapi/api/stac_fastapi/api/config.py | 55 ++++--------------- .../api/stac_fastapi/api/middleware.py | 30 ++++------ stac_fastapi/pgstac/tests/api/cors_support.py | 45 ++++++++++----- .../sqlalchemy/tests/api/cors_support.py | 45 ++++++++++----- 6 files changed, 92 insertions(+), 97 deletions(-) diff --git a/docs/tips-and-tricks.md b/docs/tips-and-tricks.md index 4e763d107..3d41bf1c3 100644 --- a/docs/tips-and-tricks.md +++ b/docs/tips-and-tricks.md @@ -5,10 +5,13 @@ This page contains a few 'tips and tricks' for getting stac-fastapi working in v CORS (Cross-Origin Resource Sharing) support may be required to use stac-fastapi in certain situations. For example, if you are running [stac-browser](https://github.com/radiantearth/stac-browser) to browse the STAC catalog created by stac-fastapi, then you will need to enable CORS support. -To do this, configure environment variables for the configuration options described in [FastAPI docs](https://fastapi.tiangolo.com/tutorial/cors/), using a `CORS_` prefix and upper-case, e.g. - +To do this, configure environment variables for the configuration options described in [FastAPI docs](https://fastapi.tiangolo.com/tutorial/cors/) using a `cors_` prefix e.g. +``` +cors_allow_credentials=true [or 1] +``` +Sequences, such as `allow_origins`, should be in JSON format e.g. ``` -CORS_ALLOW_ORIGINS=http://domain.one|http://domain.two +cors_allow_origins='["http://domain.one", "http://domain.two"]' ``` ## Enable the Context extension diff --git a/stac_fastapi/api/stac_fastapi/api/__init__.py b/stac_fastapi/api/stac_fastapi/api/__init__.py index df6f6249b..b616b5927 100644 --- a/stac_fastapi/api/stac_fastapi/api/__init__.py +++ b/stac_fastapi/api/stac_fastapi/api/__init__.py @@ -1 +1,6 @@ """api submodule.""" +from typing import Final + +from stac_fastapi.api.config import Settings + +settings: Final = Settings() diff --git a/stac_fastapi/api/stac_fastapi/api/config.py b/stac_fastapi/api/stac_fastapi/api/config.py index 9f1b8c90a..933023d37 100644 --- a/stac_fastapi/api/stac_fastapi/api/config.py +++ b/stac_fastapi/api/stac_fastapi/api/config.py @@ -1,10 +1,10 @@ """Application settings.""" import enum -import re from logging import getLogger -from os import environ from typing import Final, Sequence +from pydantic import BaseSettings, Field + logger: Final = getLogger(__file__) @@ -30,46 +30,13 @@ class AddOns(enum.Enum): bulk_transaction = "bulk-transaction" -def env_to_sequence( - env_var: str, default: Sequence[str], sequence_separator: str = "|" -) -> Sequence[str]: - """Retrieve a sequence of values from an env var string, or default if missing.""" - if env_var in environ: - if re.search(re.escape(sequence_separator), environ[env_var]): - return tuple( - [part for part in environ[env_var].split(sequence_separator) if part] - ) - else: - return (environ[env_var],) - else: - return default - - -def env_to_str(env_var: str, default: str) -> str: - """Retrieve a string from an env var, or default if missing.""" - if env_var in environ: - return environ[env_var] - else: - return default - - -def env_to_bool(env_var: str, default: bool) -> bool: - """Retrieve a bool from an env var, or default if missing.""" - if env_var in environ: - if re.match("^(true|1)$", environ[env_var], re.IGNORECASE): - return True - if re.match("^(false|0)$", environ[env_var], re.IGNORECASE): - return False - logger.warning(f"{env_var} set but not a valid bool") - return default - +class Settings(BaseSettings): + """API settings.""" -def env_to_int(env_var: str, default: int) -> int: - """Retrieve an int from an env var, or default if missing.""" - if env_var in environ: - value = environ[env_var].strip() - if value.isdigit(): - return int(value) - else: - logger.warning(f"{env_var} set but not a valid int") - return default + allow_origins: Sequence[str] = Field(("*",), env="cors_allow_origins") + allow_methods: Sequence[str] = Field(("*",), env="cors_allow_methods") + allow_headers: Sequence[str] = Field(("*",), env="cors_allow_headers") + allow_credentials: bool = Field(False, env="cors_allow_credentials") + allow_origin_regex: str = Field(None, env="cors_allow_origin_regex") + expose_headers: Sequence[str] = Field(("*",), env="cors_expose_headers") + max_age: int = Field(600, env="cors_max_age") diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index d1974ff7c..ffcbcde23 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -10,7 +10,7 @@ from starlette.routing import Match from starlette.types import ASGIApp -from stac_fastapi.api.config import env_to_bool, env_to_int, env_to_sequence, env_to_str +from stac_fastapi.api import settings logger: Final = getLogger(__file__) @@ -39,7 +39,7 @@ async def _middleware(request: Request, call_next): class CORSMiddleware(cors.CORSMiddleware): - """Starlette CORS Middleware with default.""" + """Starlette CORS Middleware with configuration.""" def __init__( self, @@ -54,41 +54,31 @@ def __init__( ) -> None: """Create CORSMiddleware Object.""" allow_origins = ( - env_to_sequence("CORS_ALLOW_ORIGINS", ("*",)) - if allow_origins is None - else allow_origins + settings.allow_origins if allow_origins is None else allow_origins ) allow_methods = ( - env_to_sequence("CORS_ALLOW_METHODS", ("*",)) - if allow_methods is None - else allow_methods + settings.allow_methods if allow_methods is None else allow_methods ) allow_headers = ( - env_to_sequence("CORS_ALLOW_HEADERS", ("*",)) - if allow_headers is None - else allow_headers + settings.allow_headers if allow_headers is None else allow_headers ) allow_credentials = ( - env_to_bool("CORS_ALLOW_CREDENTIALS", False) + settings.allow_credentials if allow_credentials is None else allow_credentials ) allow_origin_regex = ( - env_to_str("CORS_ALLOW_ORIGIN_REGEX", None) + settings.allow_origin_regex if allow_origin_regex is None else allow_origin_regex ) if allow_origin_regex is not None: - logger.info( - "CORS_ALLOW_ORIGIN_REGEX present and will override CORS_ALLOW_ORIGINS" - ) + logger.info("allow_origin_regex present and will override allow_origins") allow_origins = "" expose_headers = ( - env_to_sequence("CORS_EXPOSE_HEADERS", ("*",)) - if expose_headers is None - else expose_headers + settings.expose_headers if expose_headers is None else expose_headers ) - max_age = env_to_int("CORS_MAX_AGE", 600) if max_age is None else max_age + max_age = settings.max_age if max_age is None else max_age logger.debug( f""" CORS configuration diff --git a/stac_fastapi/pgstac/tests/api/cors_support.py b/stac_fastapi/pgstac/tests/api/cors_support.py index b0471c957..bf6996ed0 100644 --- a/stac_fastapi/pgstac/tests/api/cors_support.py +++ b/stac_fastapi/pgstac/tests/api/cors_support.py @@ -1,6 +1,10 @@ -from os import environ +from copy import deepcopy +from json import dumps from typing import Final +from stac_fastapi.api import settings + +settings_fallback = deepcopy(settings) cors_origin_1: Final = "http://permit.one" cors_origin_2: Final = "http://permit.two" cors_origin_3: Final = "http://permit.three" @@ -8,38 +12,49 @@ def cors_permit_1(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_1 + settings.allow_origins = dumps((cors_origin_1,)) def cors_permit_2(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_2 + settings.allow_origins = dumps((cors_origin_2,)) def cors_permit_3(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_3 + settings.allow_origins = dumps((cors_origin_3,)) def cors_permit_12(): - environ["CORS_ALLOW_ORIGINS"] = f"{cors_origin_1}|{cors_origin_2}" + settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) def cors_permit_123_regex(): - environ["CORS_ALLOW_ORIGIN_REGEX"] = "http\\://permit\\..+" + settings.allow_origin_regex = "http\\://permit\\..+" def cors_deny(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_deny + settings.allow_origins = dumps((cors_origin_deny,)) def cors_disable_get(): - environ["CORS_ALLOW_METHODS"] = "HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH" + settings.allow_methods = dumps( + ( + "HEAD", + "POST", + "PUT", + "DELETE", + "CONNECT", + "OPTIONS", + "TRACE", + "PATCH", + ) + ) def cors_clear_config(): - environ.pop("CORS_ALLOW_ORIGINS", None) - environ.pop("CORS_ALLOW_METHODS", None) - environ.pop("CORS_ALLOW_HEADERS", None) - environ.pop("CORS_ALLOW_CREDENTIALS", None) - environ.pop("CORS_ALLOW_ORIGIN_REGEX", None) - environ.pop("CORS_EXPOSE_HEADERS", None) - environ.pop("CORS_MAX_AGE", None) + settings.allow_origins = settings_fallback.allow_origins + settings.allow_methods = settings_fallback.allow_methods + settings.allow_headers = settings_fallback.allow_headers + settings.allow_credentials = settings_fallback.allow_credentials + settings.allow_origin_regex = settings_fallback.allow_origin_regex + settings.expose_headers = settings_fallback.expose_headers + settings.max_age = settings_fallback.max_age diff --git a/stac_fastapi/sqlalchemy/tests/api/cors_support.py b/stac_fastapi/sqlalchemy/tests/api/cors_support.py index b0471c957..bf6996ed0 100644 --- a/stac_fastapi/sqlalchemy/tests/api/cors_support.py +++ b/stac_fastapi/sqlalchemy/tests/api/cors_support.py @@ -1,6 +1,10 @@ -from os import environ +from copy import deepcopy +from json import dumps from typing import Final +from stac_fastapi.api import settings + +settings_fallback = deepcopy(settings) cors_origin_1: Final = "http://permit.one" cors_origin_2: Final = "http://permit.two" cors_origin_3: Final = "http://permit.three" @@ -8,38 +12,49 @@ def cors_permit_1(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_1 + settings.allow_origins = dumps((cors_origin_1,)) def cors_permit_2(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_2 + settings.allow_origins = dumps((cors_origin_2,)) def cors_permit_3(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_3 + settings.allow_origins = dumps((cors_origin_3,)) def cors_permit_12(): - environ["CORS_ALLOW_ORIGINS"] = f"{cors_origin_1}|{cors_origin_2}" + settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) def cors_permit_123_regex(): - environ["CORS_ALLOW_ORIGIN_REGEX"] = "http\\://permit\\..+" + settings.allow_origin_regex = "http\\://permit\\..+" def cors_deny(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_deny + settings.allow_origins = dumps((cors_origin_deny,)) def cors_disable_get(): - environ["CORS_ALLOW_METHODS"] = "HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH" + settings.allow_methods = dumps( + ( + "HEAD", + "POST", + "PUT", + "DELETE", + "CONNECT", + "OPTIONS", + "TRACE", + "PATCH", + ) + ) def cors_clear_config(): - environ.pop("CORS_ALLOW_ORIGINS", None) - environ.pop("CORS_ALLOW_METHODS", None) - environ.pop("CORS_ALLOW_HEADERS", None) - environ.pop("CORS_ALLOW_CREDENTIALS", None) - environ.pop("CORS_ALLOW_ORIGIN_REGEX", None) - environ.pop("CORS_EXPOSE_HEADERS", None) - environ.pop("CORS_MAX_AGE", None) + settings.allow_origins = settings_fallback.allow_origins + settings.allow_methods = settings_fallback.allow_methods + settings.allow_headers = settings_fallback.allow_headers + settings.allow_credentials = settings_fallback.allow_credentials + settings.allow_origin_regex = settings_fallback.allow_origin_regex + settings.expose_headers = settings_fallback.expose_headers + settings.max_age = settings_fallback.max_age From 9aeb28b7b36b8d7cd0e6a11238afeee963043306 Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Wed, 2 Feb 2022 14:02:57 -0800 Subject: [PATCH 10/25] feature/1 fix docs build --- stac_fastapi/api/stac_fastapi/api/__init__.py | 5 ----- stac_fastapi/api/stac_fastapi/api/config.py | 3 +++ stac_fastapi/api/stac_fastapi/api/middleware.py | 2 +- stac_fastapi/pgstac/tests/api/cors_support.py | 2 +- stac_fastapi/sqlalchemy/tests/api/cors_support.py | 2 +- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/stac_fastapi/api/stac_fastapi/api/__init__.py b/stac_fastapi/api/stac_fastapi/api/__init__.py index b616b5927..df6f6249b 100644 --- a/stac_fastapi/api/stac_fastapi/api/__init__.py +++ b/stac_fastapi/api/stac_fastapi/api/__init__.py @@ -1,6 +1 @@ """api submodule.""" -from typing import Final - -from stac_fastapi.api.config import Settings - -settings: Final = Settings() diff --git a/stac_fastapi/api/stac_fastapi/api/config.py b/stac_fastapi/api/stac_fastapi/api/config.py index 933023d37..96114b1f4 100644 --- a/stac_fastapi/api/stac_fastapi/api/config.py +++ b/stac_fastapi/api/stac_fastapi/api/config.py @@ -40,3 +40,6 @@ class Settings(BaseSettings): allow_origin_regex: str = Field(None, env="cors_allow_origin_regex") expose_headers: Sequence[str] = Field(("*",), env="cors_expose_headers") max_age: int = Field(600, env="cors_max_age") + + +settings: Final = Settings() diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index ffcbcde23..8dd8be2a8 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -10,7 +10,7 @@ from starlette.routing import Match from starlette.types import ASGIApp -from stac_fastapi.api import settings +from stac_fastapi.api.config import settings logger: Final = getLogger(__file__) diff --git a/stac_fastapi/pgstac/tests/api/cors_support.py b/stac_fastapi/pgstac/tests/api/cors_support.py index bf6996ed0..9f9c303e6 100644 --- a/stac_fastapi/pgstac/tests/api/cors_support.py +++ b/stac_fastapi/pgstac/tests/api/cors_support.py @@ -2,7 +2,7 @@ from json import dumps from typing import Final -from stac_fastapi.api import settings +from stac_fastapi.api.config import settings settings_fallback = deepcopy(settings) cors_origin_1: Final = "http://permit.one" diff --git a/stac_fastapi/sqlalchemy/tests/api/cors_support.py b/stac_fastapi/sqlalchemy/tests/api/cors_support.py index bf6996ed0..9f9c303e6 100644 --- a/stac_fastapi/sqlalchemy/tests/api/cors_support.py +++ b/stac_fastapi/sqlalchemy/tests/api/cors_support.py @@ -2,7 +2,7 @@ from json import dumps from typing import Final -from stac_fastapi.api import settings +from stac_fastapi.api.config import settings settings_fallback = deepcopy(settings) cors_origin_1: Final = "http://permit.one" From 3ff4d107c4626153b4a58db4badf3a24af9804ae Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Fri, 18 Feb 2022 12:13:23 -0800 Subject: [PATCH 11/25] feature/1 add CORS tests to api tests --- Makefile | 6 +- docker-compose.yml | 13 +++ stac_fastapi/api/tests/__init__.py | 0 stac_fastapi/api/tests/cors_support.py | 60 ++++++++++ stac_fastapi/api/tests/test_cors.py | 76 +++++++++++++ .../api/tests/test_route_dependencies.py | 106 ++++++++++++++++++ stac_fastapi/api/tests/util.py | 37 ++++++ 7 files changed, 297 insertions(+), 1 deletion(-) create mode 100644 stac_fastapi/api/tests/__init__.py create mode 100644 stac_fastapi/api/tests/cors_support.py create mode 100644 stac_fastapi/api/tests/test_cors.py create mode 100644 stac_fastapi/api/tests/test_route_dependencies.py create mode 100644 stac_fastapi/api/tests/util.py diff --git a/Makefile b/Makefile index 837db31da..eb90ae81e 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,10 @@ test-sqlalchemy: run-joplin-sqlalchemy test-pgstac: $(run_pgstac) /bin/bash -c 'export && ./scripts/wait-for-it.sh database:5432 && cd /app/stac_fastapi/pgstac/tests/ && pytest' +.PHONY: test-api +test-api: + docker-compose run api-tester + .PHONY: run-database run-database: docker-compose run --rm database @@ -55,7 +59,7 @@ run-joplin-pgstac: docker-compose run --rm loadjoplin-pgstac .PHONY: test -test: test-sqlalchemy test-pgstac +test: test-sqlalchemy test-pgstac test-api .PHONY: pybase-install pybase-install: diff --git a/docker-compose.yml b/docker-compose.yml index f5be8ce3f..c134767d8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -121,6 +121,19 @@ services: - database - app-pgstac + api-tester: + image: stac-utils/stac-fastapi + build: + context: . + dockerfile: Dockerfile + profiles: + # prevent tester from starting with `docker-compose up` + - api-test + working_dir: /app/stac_fastapi/api + volumes: + - ./:/app + command: pytest -svvv + networks: default: name: stac-fastapi-network diff --git a/stac_fastapi/api/tests/__init__.py b/stac_fastapi/api/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/stac_fastapi/api/tests/cors_support.py b/stac_fastapi/api/tests/cors_support.py new file mode 100644 index 000000000..9f9c303e6 --- /dev/null +++ b/stac_fastapi/api/tests/cors_support.py @@ -0,0 +1,60 @@ +from copy import deepcopy +from json import dumps +from typing import Final + +from stac_fastapi.api.config import settings + +settings_fallback = deepcopy(settings) +cors_origin_1: Final = "http://permit.one" +cors_origin_2: Final = "http://permit.two" +cors_origin_3: Final = "http://permit.three" +cors_origin_deny: Final = "http://deny.me" + + +def cors_permit_1(): + settings.allow_origins = dumps((cors_origin_1,)) + + +def cors_permit_2(): + settings.allow_origins = dumps((cors_origin_2,)) + + +def cors_permit_3(): + settings.allow_origins = dumps((cors_origin_3,)) + + +def cors_permit_12(): + settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) + + +def cors_permit_123_regex(): + settings.allow_origin_regex = "http\\://permit\\..+" + + +def cors_deny(): + settings.allow_origins = dumps((cors_origin_deny,)) + + +def cors_disable_get(): + settings.allow_methods = dumps( + ( + "HEAD", + "POST", + "PUT", + "DELETE", + "CONNECT", + "OPTIONS", + "TRACE", + "PATCH", + ) + ) + + +def cors_clear_config(): + settings.allow_origins = settings_fallback.allow_origins + settings.allow_methods = settings_fallback.allow_methods + settings.allow_headers = settings_fallback.allow_headers + settings.allow_credentials = settings_fallback.allow_credentials + settings.allow_origin_regex = settings_fallback.allow_origin_regex + settings.expose_headers = settings_fallback.expose_headers + settings.max_age = settings_fallback.max_age diff --git a/stac_fastapi/api/tests/test_cors.py b/stac_fastapi/api/tests/test_cors.py new file mode 100644 index 000000000..6dc7a2b0e --- /dev/null +++ b/stac_fastapi/api/tests/test_cors.py @@ -0,0 +1,76 @@ +from http import HTTPStatus + +from starlette.testclient import TestClient +from tests.cors_support import ( + cors_clear_config, + cors_deny, + cors_origin_1, + cors_origin_deny, + cors_permit_1, + cors_permit_12, + cors_permit_123_regex, +) +from tests.util import build_api + +from stac_fastapi.extensions.core import TokenPaginationExtension + + +def teardown_function(): + cors_clear_config() + + +def _get_api(): + return build_api([TokenPaginationExtension()]) + + +def test_with_default_cors_origin(): + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == "*" + + +def test_with_match_cors_single(): + cors_permit_1() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_origin_1 + + +def test_with_match_cors_double(): + cors_permit_12() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_origin_1 + + +def test_with_match_cors_all_regex_match(): + cors_permit_123_regex() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_origin_1 + + +def test_with_match_cors_all_regex_mismatch(): + cors_permit_123_regex() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_deny}) + assert resp.status_code == HTTPStatus.OK + assert "access-control-allow-origin" not in resp.headers + + +def test_with_mismatch_cors_origin(): + cors_deny() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert "access-control-allow-origin" not in resp.headers diff --git a/stac_fastapi/api/tests/test_route_dependencies.py b/stac_fastapi/api/tests/test_route_dependencies.py new file mode 100644 index 000000000..ba2281773 --- /dev/null +++ b/stac_fastapi/api/tests/test_route_dependencies.py @@ -0,0 +1,106 @@ +from fastapi import Depends, HTTPException, security, status +from starlette.testclient import TestClient +from tests.util import build_api + +from stac_fastapi.extensions.core import TokenPaginationExtension, TransactionExtension +from stac_fastapi.types import config, core + + +class TestRouteDependencies: + @staticmethod + def _get_extensions(): + return [ + TransactionExtension( + client=DummyTransactionsClient(), settings=config.ApiSettings() + ), + TokenPaginationExtension(), + ] + + @staticmethod + def _assert_dependency_applied(api, routes): + with TestClient(api.app) as client: + for route in routes: + response = getattr(client, route["method"].lower())(route["path"]) + assert ( + response.status_code == 401 + ), "Unauthenticated requests should be rejected" + assert response.json() == {"detail": "Not authenticated"} + + make_request = getattr(client, route["method"].lower()) + path = route["path"].format( + collectionId="test_collection", itemId="test_item" + ) + response = make_request( + path, + auth=("bob", "dobbs"), + data='{"dummy": "payload"}', + headers={"content-type": "application/json"}, + ) + assert ( + response.status_code == 200 + ), "Authenticated requests should be accepted" + assert response.json() == "dummy response" + + def test_build_api_with_route_dependencies(self): + routes = [ + {"path": "/collections", "method": "POST"}, + {"path": "/collections", "method": "PUT"}, + {"path": "/collections/{collectionId}", "method": "DELETE"}, + {"path": "/collections/{collectionId}/items", "method": "POST"}, + {"path": "/collections/{collectionId}/items", "method": "PUT"}, + {"path": "/collections/{collectionId}/items/{itemId}", "method": "DELETE"}, + ] + dependencies = [Depends(must_be_bob)] + api = build_api( + TestRouteDependencies._get_extensions(), + route_dependencies=[(routes, dependencies)], + ) + self._assert_dependency_applied(api, routes) + + def test_add_route_dependencies_after_building_api(self): + routes = [ + {"path": "/collections", "method": "POST"}, + {"path": "/collections", "method": "PUT"}, + {"path": "/collections/{collectionId}", "method": "DELETE"}, + {"path": "/collections/{collectionId}/items", "method": "POST"}, + {"path": "/collections/{collectionId}/items", "method": "PUT"}, + {"path": "/collections/{collectionId}/items/{itemId}", "method": "DELETE"}, + ] + api = build_api(TestRouteDependencies._get_extensions()) + api.add_route_dependencies(scopes=routes, dependencies=[Depends(must_be_bob)]) + self._assert_dependency_applied(api, routes) + + +class DummyTransactionsClient(core.BaseTransactionsClient): + """Defines a pattern for implementing the STAC transaction extension.""" + + def create_item(self, *args, **kwargs): + return "dummy response" + + def update_item(self, *args, **kwargs): + return "dummy response" + + def delete_item(self, *args, **kwargs): + return "dummy response" + + def create_collection(self, *args, **kwargs): + return "dummy response" + + def update_collection(self, *args, **kwargs): + return "dummy response" + + def delete_collection(self, *args, **kwargs): + return "dummy response" + + +def must_be_bob( + credentials: security.HTTPBasicCredentials = Depends(security.HTTPBasic()), +): + if credentials.username == "bob": + return True + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="You're not Bob", + headers={"WWW-Authenticate": "Basic"}, + ) diff --git a/stac_fastapi/api/tests/util.py b/stac_fastapi/api/tests/util.py new file mode 100644 index 000000000..aa4aea972 --- /dev/null +++ b/stac_fastapi/api/tests/util.py @@ -0,0 +1,37 @@ +from typing import List, Type + +from stac_fastapi.api.app import StacApi +from stac_fastapi.types import config, core +from stac_fastapi.types.extension import ApiExtension + + +class DummyCoreClient(core.BaseCoreClient): + def all_collections(self, *args, **kwargs): + ... + + def get_collection(self, *args, **kwargs): + ... + + def get_item(self, *args, **kwargs): + ... + + def get_search(self, *args, **kwargs): + ... + + def post_search(self, *args, **kwargs): + ... + + def item_collection(self, *args, **kwargs): + ... + + +def build_api(extensions: List[Type[ApiExtension]] = [], **overrides): + settings = config.ApiSettings() + return StacApi( + **{ + "settings": settings, + "client": DummyCoreClient(), + "extensions": extensions, + **overrides, + } + ) From 6f997097fe05c557dc1334b4846e487114d9ff69 Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Fri, 18 Feb 2022 12:36:30 -0800 Subject: [PATCH 12/25] feature/1 removed unnecessary tests --- stac_fastapi/pgstac/tests/api/cors_support.py | 60 ----------------- stac_fastapi/pgstac/tests/api/test_api.py | 65 ------------------- .../sqlalchemy/tests/api/cors_support.py | 60 ----------------- stac_fastapi/sqlalchemy/tests/api/test_api.py | 61 ----------------- 4 files changed, 246 deletions(-) delete mode 100644 stac_fastapi/pgstac/tests/api/cors_support.py delete mode 100644 stac_fastapi/sqlalchemy/tests/api/cors_support.py diff --git a/stac_fastapi/pgstac/tests/api/cors_support.py b/stac_fastapi/pgstac/tests/api/cors_support.py deleted file mode 100644 index 9f9c303e6..000000000 --- a/stac_fastapi/pgstac/tests/api/cors_support.py +++ /dev/null @@ -1,60 +0,0 @@ -from copy import deepcopy -from json import dumps -from typing import Final - -from stac_fastapi.api.config import settings - -settings_fallback = deepcopy(settings) -cors_origin_1: Final = "http://permit.one" -cors_origin_2: Final = "http://permit.two" -cors_origin_3: Final = "http://permit.three" -cors_origin_deny: Final = "http://deny.me" - - -def cors_permit_1(): - settings.allow_origins = dumps((cors_origin_1,)) - - -def cors_permit_2(): - settings.allow_origins = dumps((cors_origin_2,)) - - -def cors_permit_3(): - settings.allow_origins = dumps((cors_origin_3,)) - - -def cors_permit_12(): - settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) - - -def cors_permit_123_regex(): - settings.allow_origin_regex = "http\\://permit\\..+" - - -def cors_deny(): - settings.allow_origins = dumps((cors_origin_deny,)) - - -def cors_disable_get(): - settings.allow_methods = dumps( - ( - "HEAD", - "POST", - "PUT", - "DELETE", - "CONNECT", - "OPTIONS", - "TRACE", - "PATCH", - ) - ) - - -def cors_clear_config(): - settings.allow_origins = settings_fallback.allow_origins - settings.allow_methods = settings_fallback.allow_methods - settings.allow_headers = settings_fallback.allow_headers - settings.allow_credentials = settings_fallback.allow_credentials - settings.allow_origin_regex = settings_fallback.allow_origin_regex - settings.expose_headers = settings_fallback.expose_headers - settings.max_age = settings_fallback.max_age diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index ebc432907..9ab4c2c07 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -1,16 +1,6 @@ from datetime import datetime, timedelta -from http import HTTPStatus import pytest -from tests.api.cors_support import ( - cors_clear_config, - cors_deny, - cors_origin_1, - cors_origin_deny, - cors_permit_1, - cors_permit_12, - cors_permit_123_regex, -) STAC_CORE_ROUTES = [ "GET /", @@ -33,10 +23,6 @@ ] -def teardown_function(): - cors_clear_config() - - @pytest.mark.asyncio async def test_post_search_content_type(app_client): params = {"limit": 1} @@ -316,54 +302,3 @@ async def test_search_line_string_intersects( assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 1 - - -@pytest.mark.asyncio -async def test_with_default_cors_origin(app_client): - resp = await app_client.get("/", headers={"Origin": cors_origin_1}) - assert resp.status_code == HTTPStatus.OK - assert resp.headers["access-control-allow-origin"] == "*" - - -@pytest.mark.asyncio -@pytest.mark.parametrize("app_client", [{"setup_func": cors_permit_1}], indirect=True) -async def test_with_match_cors_single(app_client): - resp = await app_client.get("/", headers={"Origin": cors_origin_1}) - assert resp.status_code == HTTPStatus.OK - assert resp.headers["access-control-allow-origin"] == cors_origin_1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("app_client", [{"setup_func": cors_permit_12}], indirect=True) -async def test_with_match_cors_double(app_client): - resp = await app_client.get("/", headers={"Origin": cors_origin_1}) - assert resp.status_code == HTTPStatus.OK - assert resp.headers["access-control-allow-origin"] == cors_origin_1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "app_client", [{"setup_func": cors_permit_123_regex}], indirect=True -) -async def test_with_match_cors_all_regex_match(app_client): - resp = await app_client.get("/", headers={"Origin": cors_origin_1}) - assert resp.status_code == HTTPStatus.OK - assert resp.headers["access-control-allow-origin"] == cors_origin_1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "app_client", [{"setup_func": cors_permit_123_regex}], indirect=True -) -async def test_with_match_cors_all_regex_mismatch(app_client): - resp = await app_client.get("/", headers={"Origin": cors_origin_deny}) - assert resp.status_code == HTTPStatus.OK - assert "access-control-allow-origin" not in resp.headers - - -@pytest.mark.asyncio -@pytest.mark.parametrize("app_client", [{"setup_func": cors_deny}], indirect=True) -async def test_with_mismatch_cors_origin(app_client): - resp = await app_client.get("/", headers={"Origin": cors_origin_1}) - assert resp.status_code == HTTPStatus.OK - assert "access-control-allow-origin" not in resp.headers diff --git a/stac_fastapi/sqlalchemy/tests/api/cors_support.py b/stac_fastapi/sqlalchemy/tests/api/cors_support.py deleted file mode 100644 index 9f9c303e6..000000000 --- a/stac_fastapi/sqlalchemy/tests/api/cors_support.py +++ /dev/null @@ -1,60 +0,0 @@ -from copy import deepcopy -from json import dumps -from typing import Final - -from stac_fastapi.api.config import settings - -settings_fallback = deepcopy(settings) -cors_origin_1: Final = "http://permit.one" -cors_origin_2: Final = "http://permit.two" -cors_origin_3: Final = "http://permit.three" -cors_origin_deny: Final = "http://deny.me" - - -def cors_permit_1(): - settings.allow_origins = dumps((cors_origin_1,)) - - -def cors_permit_2(): - settings.allow_origins = dumps((cors_origin_2,)) - - -def cors_permit_3(): - settings.allow_origins = dumps((cors_origin_3,)) - - -def cors_permit_12(): - settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) - - -def cors_permit_123_regex(): - settings.allow_origin_regex = "http\\://permit\\..+" - - -def cors_deny(): - settings.allow_origins = dumps((cors_origin_deny,)) - - -def cors_disable_get(): - settings.allow_methods = dumps( - ( - "HEAD", - "POST", - "PUT", - "DELETE", - "CONNECT", - "OPTIONS", - "TRACE", - "PATCH", - ) - ) - - -def cors_clear_config(): - settings.allow_origins = settings_fallback.allow_origins - settings.allow_methods = settings_fallback.allow_methods - settings.allow_headers = settings_fallback.allow_headers - settings.allow_credentials = settings_fallback.allow_credentials - settings.allow_origin_regex = settings_fallback.allow_origin_regex - settings.expose_headers = settings_fallback.expose_headers - settings.max_age = settings_fallback.max_age diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index 42e1fd3f8..4a7dff776 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -1,16 +1,4 @@ from datetime import datetime, timedelta -from http import HTTPStatus - -import pytest -from tests.api.cors_support import ( - cors_clear_config, - cors_deny, - cors_origin_1, - cors_origin_deny, - cors_permit_1, - cors_permit_12, - cors_permit_123_regex, -) from ..conftest import MockStarletteRequest @@ -35,10 +23,6 @@ ] -def teardown_function(): - cors_clear_config() - - def test_post_search_content_type(app_client): params = {"limit": 1} resp = app_client.post("search", json=params) @@ -321,48 +305,3 @@ def test_app_fields_extension_return_all_properties( assert feature["properties"][expected_prop][0:19] == expected_value[0:19] else: assert feature["properties"][expected_prop] == expected_value - - -def test_with_default_cors_origin(app_client): - resp = app_client.get("/", headers={"Origin": cors_origin_1}) - assert resp.status_code == HTTPStatus.OK - assert resp.headers["access-control-allow-origin"] == "*" - - -@pytest.mark.parametrize("app_client", [{"setup_func": cors_permit_1}], indirect=True) -def test_with_match_cors_single(app_client): - resp = app_client.get("/", headers={"Origin": cors_origin_1}) - assert resp.status_code == HTTPStatus.OK - assert resp.headers["access-control-allow-origin"] == cors_origin_1 - - -@pytest.mark.parametrize("app_client", [{"setup_func": cors_permit_12}], indirect=True) -def test_with_match_cors_double(app_client): - resp = app_client.get("/", headers={"Origin": cors_origin_1}) - assert resp.status_code == HTTPStatus.OK - assert resp.headers["access-control-allow-origin"] == cors_origin_1 - - -@pytest.mark.parametrize( - "app_client", [{"setup_func": cors_permit_123_regex}], indirect=True -) -def test_with_match_cors_all_regex_match(app_client): - resp = app_client.get("/", headers={"Origin": cors_origin_1}) - assert resp.status_code == HTTPStatus.OK - assert resp.headers["access-control-allow-origin"] == cors_origin_1 - - -@pytest.mark.parametrize( - "app_client", [{"setup_func": cors_permit_123_regex}], indirect=True -) -def test_with_match_cors_all_regex_mismatch(app_client): - resp = app_client.get("/", headers={"Origin": cors_origin_deny}) - assert resp.status_code == HTTPStatus.OK - assert "access-control-allow-origin" not in resp.headers - - -@pytest.mark.parametrize("app_client", [{"setup_func": cors_deny}], indirect=True) -def test_with_mismatch_cors_origin(app_client): - resp = app_client.get("/", headers={"Origin": cors_origin_1}) - assert resp.status_code == HTTPStatus.OK - assert "access-control-allow-origin" not in resp.headers From 2569aeeee1c8ee8bfb276e755e150512765982fc Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Tue, 25 Jan 2022 10:47:14 -0800 Subject: [PATCH 13/25] feature/1 added runtime configuration of CORS --- .gitignore | 5 +- stac_fastapi/api/stac_fastapi/api/app.py | 1 + .../api/stac_fastapi/api/middleware.py | 55 +++++++++++++- stac_fastapi/pgstac/tests/api/cors_support.py | 23 ++++++ stac_fastapi/pgstac/tests/api/test_api.py | 71 +++++++++++++++++++ stac_fastapi/pgstac/tests/conftest.py | 24 +++---- .../sqlalchemy/tests/api/cors_support.py | 23 ++++++ stac_fastapi/sqlalchemy/tests/api/test_api.py | 69 ++++++++++++++++++ stac_fastapi/sqlalchemy/tests/conftest.py | 15 ++-- 9 files changed, 268 insertions(+), 18 deletions(-) create mode 100644 stac_fastapi/pgstac/tests/api/cors_support.py create mode 100644 stac_fastapi/sqlalchemy/tests/api/cors_support.py diff --git a/.gitignore b/.gitignore index 407d0fe9a..1cc0fbf51 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,7 @@ docs/api/* .envrc # Virtualenv -venv \ No newline at end of file +venv + +# IDE +.vscode \ No newline at end of file diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index a9f8a5542..46ceba109 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -14,6 +14,7 @@ from starlette.responses import JSONResponse, Response from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers +from stac_fastapi.api.middleware import MiddlewareConfig, append_runtime_middlewares from stac_fastapi.api.models import ( APIRequest, CollectionUri, diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index acb00915b..e7ed896b0 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -1,12 +1,18 @@ """api middleware.""" -from typing import Callable +from json import loads +from logging import getLogger +from os import environ, path +from typing import Any, Callable, Dict, Final, List, Optional from fastapi import APIRouter, FastAPI +from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.routing import Match +logger: Final = getLogger(__file__) + def router_middleware(app: FastAPI, router: APIRouter): """Add middleware to a specific router, assumes no router prefix.""" @@ -29,3 +35,50 @@ async def _middleware(request: Request, call_next): return func return deco + + +class MiddlewareConfig: + """Represents a middleware class plus any configuration detail.""" + + def __init__(self, middleware: Any, config: Optional[Dict[str, Any]] = None): + """Defaults config to empty dictionary if not provided.""" + self.middleware = middleware + self.config = {} if config is None else config + + +def append_runtime_middlewares( + middlewares: List[MiddlewareConfig], +) -> List[MiddlewareConfig]: + """Add any middlewares specified via environment variable and configure if appropriate.""" + extended_middlewares = middlewares.copy() + has_cors_middleware = ( + len( + [ + entry + for entry in middlewares + if isinstance(entry.middleware, CORSMiddleware) + ] + ) + > 0 + ) + if not has_cors_middleware: + cors_config_location_key: Final = "CORS_CONFIG_LOCATION" + if cors_config_location_key in environ: + cors_config_path = environ[cors_config_location_key] + logger.info(f"looking for CORS config file at {cors_config_path}") + if path.exists(cors_config_path): + try: + with open(cors_config_path, "r") as cors_config_file: + cors_config = loads("".join(cors_config_file.readlines())) + extended_middlewares.append( + MiddlewareConfig(CORSMiddleware, cors_config) + ) + logger.debug(f"loaded CORS config {cors_config}") + except ValueError as e: + logger.error(f"error parsing JSON at {cors_config_path}: {e}") + except OSError as e: + logger.error(f"error reading {cors_config_path}: {e}") + else: + logger.warning(f"CORS config not found at {cors_config_path}") + + return extended_middlewares diff --git a/stac_fastapi/pgstac/tests/api/cors_support.py b/stac_fastapi/pgstac/tests/api/cors_support.py new file mode 100644 index 000000000..9f891a8d6 --- /dev/null +++ b/stac_fastapi/pgstac/tests/api/cors_support.py @@ -0,0 +1,23 @@ +from json import dumps +from os import environ, fdopen, path, sep +from tempfile import mkstemp +from typing import Final + +cors_config_location_key: Final = "CORS_CONFIG_LOCATION" +cors_permit_origin: Final = "http://cors.pass" +cors_deny_origin: Final = "http://cors.fail" + + +def cors_enable(): + tmp_file, tmp_filename = mkstemp() + with fdopen(tmp_file, "w") as f: + f.write(dumps({"allow_origins": [cors_permit_origin]})) + environ[cors_config_location_key] = tmp_filename + + +def cors_disable() -> None: + environ.pop(cors_config_location_key, None) + + +def cors_missing(): + environ[cors_config_location_key] = path.join(path.abspath(sep), "missing.file") diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index f4d783b11..2bfe40b8d 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -1,4 +1,16 @@ from datetime import datetime, timedelta +from http import HTTPStatus +from os import environ + +import pytest +from tests.api.cors_support import ( + cors_config_location_key, + cors_deny_origin, + cors_disable, + cors_enable, + cors_missing, + cors_permit_origin, +) STAC_CORE_ROUTES = [ "GET /", @@ -21,6 +33,10 @@ ] +def teardown_function(): + environ.pop(cors_config_location_key, None) + + async def test_post_search_content_type(app_client): params = {"limit": 1} resp = await app_client.post("search", json=params) @@ -281,3 +297,58 @@ async def test_search_line_string_intersects( assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 1 + + +@pytest.mark.parametrize("app_client", [{"setup_func": cors_disable}], indirect=True) +async def test_without_cors(app_client): + resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) + + +@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) +async def test_with_match_cors(app_client): + resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_permit_origin + + +@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) +async def test_with_mismatch_cors(app_client): + resp = await app_client.get("/", headers={"Origin": cors_deny_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) + + +@pytest.mark.parametrize("app_client", [{"setup_func": cors_missing}], indirect=True) +async def test_with_missing_config(app_client): + resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) diff --git a/stac_fastapi/pgstac/tests/conftest.py b/stac_fastapi/pgstac/tests/conftest.py index 170877a7d..65b626ca2 100644 --- a/stac_fastapi/pgstac/tests/conftest.py +++ b/stac_fastapi/pgstac/tests/conftest.py @@ -1,7 +1,6 @@ import asyncio import json import os -import time from typing import Callable, Dict import asyncpg @@ -92,8 +91,7 @@ async def pgstac(pg): await conn.close() -@pytest.fixture(scope="session") -def api_client(pg): +def _api_client_provider(): print("creating client with settings") extensions = [ @@ -119,21 +117,23 @@ def api_client(pg): @pytest.fixture(scope="session") -async def app(api_client): - time.time() - app = api_client.app - await connect_to_db(app) - - yield app - - await close_db_connection(app) +def api_client(pg): + return _api_client_provider() @pytest.fixture(scope="session") -async def app_client(app): +async def app_client(pg, request): + setup_func = request.param.get("setup_func") if hasattr(request, "param") else None + if setup_func is not None: + setup_func() + app = _api_client_provider().app async with AsyncClient(app=app, base_url="http://test") as c: + await connect_to_db(app) + yield c + await close_db_connection(app) + @pytest.fixture def load_test_data() -> Callable[[str], Dict]: diff --git a/stac_fastapi/sqlalchemy/tests/api/cors_support.py b/stac_fastapi/sqlalchemy/tests/api/cors_support.py new file mode 100644 index 000000000..9f891a8d6 --- /dev/null +++ b/stac_fastapi/sqlalchemy/tests/api/cors_support.py @@ -0,0 +1,23 @@ +from json import dumps +from os import environ, fdopen, path, sep +from tempfile import mkstemp +from typing import Final + +cors_config_location_key: Final = "CORS_CONFIG_LOCATION" +cors_permit_origin: Final = "http://cors.pass" +cors_deny_origin: Final = "http://cors.fail" + + +def cors_enable(): + tmp_file, tmp_filename = mkstemp() + with fdopen(tmp_file, "w") as f: + f.write(dumps({"allow_origins": [cors_permit_origin]})) + environ[cors_config_location_key] = tmp_filename + + +def cors_disable() -> None: + environ.pop(cors_config_location_key, None) + + +def cors_missing(): + environ[cors_config_location_key] = path.join(path.abspath(sep), "missing.file") diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index 0abd7cb00..e05e36b15 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -1,4 +1,16 @@ from datetime import datetime, timedelta +from http import HTTPStatus +from os import environ + +import pytest +from tests.api.cors_support import ( + cors_config_location_key, + cors_deny_origin, + cors_disable, + cors_enable, + cors_missing, + cors_permit_origin, +) from ..conftest import MockStarletteRequest @@ -23,6 +35,10 @@ ] +def teardown_function(): + environ.pop(cors_config_location_key, None) + + def test_post_search_content_type(app_client): params = {"limit": 1} resp = app_client.post("search", json=params) @@ -304,3 +320,56 @@ def test_app_fields_extension_return_all_properties( assert feature["properties"][expected_prop][0:19] == expected_value[0:19] else: assert feature["properties"][expected_prop] == expected_value +@pytest.mark.parametrize("app_client", [{"setup_func": cors_disable}], indirect=True) +def test_without_cors(app_client): + resp = app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) + + +@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) +def test_with_match_cors(app_client): + resp = app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_permit_origin + + +@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) +def test_with_mismatch_cors(app_client): + resp = app_client.get("/", headers={"Origin": cors_deny_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) + + +@pytest.mark.parametrize("app_client", [{"setup_func": cors_missing}], indirect=True) +def test_with_missing_config(app_client): + resp = app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) diff --git a/stac_fastapi/sqlalchemy/tests/conftest.py b/stac_fastapi/sqlalchemy/tests/conftest.py index 7abd9150f..4012eeb5a 100644 --- a/stac_fastapi/sqlalchemy/tests/conftest.py +++ b/stac_fastapi/sqlalchemy/tests/conftest.py @@ -104,8 +104,7 @@ def postgres_bulk_transactions(db_session): return BulkTransactionsClient(session=db_session) -@pytest.fixture -def api_client(db_session): +def _api_client_provider(db_session): settings = SqlalchemySettings() extensions = [ TransactionExtension( @@ -146,9 +145,17 @@ def api_client(db_session): @pytest.fixture -def app_client(api_client, load_test_data, postgres_transactions): +def api_client(db_session): + return _api_client_provider(db_session) + + +@pytest.fixture +def app_client(db_session, load_test_data, postgres_transactions, request): + setup_func = request.param.get("setup_func") if hasattr(request, "param") else None + if setup_func is not None: + setup_func() coll = load_test_data("test_collection.json") postgres_transactions.create_collection(coll, request=MockStarletteRequest) - with TestClient(api_client.app) as test_app: + with TestClient(_api_client_provider(db_session).app) as test_app: yield test_app From 41629331075939f9c11bd2af32ac70f4e26883f0 Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Wed, 26 Jan 2022 13:44:00 -0800 Subject: [PATCH 14/25] feature/1 updated documentation --- CHANGES.md | 1 + docs/tips-and-tricks.md | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 2355d3534..a76ffe705 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -7,6 +7,7 @@ * Add hook to allow adding dependencies to routes. ([#295](https://github.com/stac-utils/stac-fastapi/pull/295)) * Ability to POST an ItemCollection to the collections/{collectionId}/items route. ([#367](https://github.com/stac-utils/stac-fastapi/pull/367)) * Add STAC API - Collections conformance class. ([383](https://github.com/stac-utils/stac-fastapi/pull/383)) +* Added ability to configure CORS middleware via JSON configuration file and environment variable, rather than having to modify code. ([341](https://github.com/stac-utils/stac-fastapi/pull/341)) ### Changed diff --git a/docs/tips-and-tricks.md b/docs/tips-and-tricks.md index 3d4c9ac0f..997edfd0e 100644 --- a/docs/tips-and-tricks.md +++ b/docs/tips-and-tricks.md @@ -5,16 +5,26 @@ This page contains a few 'tips and tricks' for getting stac-fastapi working in v CORS (Cross-Origin Resource Sharing) support may be required to use stac-fastapi in certain situations. For example, if you are running [stac-browser](https://github.com/radiantearth/stac-browser) to browse the STAC catalog created by stac-fastapi, then you will need to enable CORS support. -To do this, edit `stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py` (or the equivalent in the `pgstac` folder) and add the following import: +To do this, create a JSON configuration file whose schema matches the options described in the [FastAPI docs](https://fastapi.tiangolo.com/tutorial/cors/), e.g. ``` -from fastapi.middleware.cors import CORSMiddleware +{ + "allow_origins": ["*"], + "allow_methods": ["*"] +} ``` -and then edit the `api = StacApi(...` call to add the following parameter: +Deploy this file to a location accessible by stac-fastapi, e.g. in Dockerfile ``` -middlewares=[lambda app: CORSMiddleware(app, allow_origins=["*"])] +RUN mkdir /config +COPY cors.json /config/cors.json +``` + +Set an environment variable `CORS_CONFIG_LOCATION` pointing to this file, e.g. in Dockerfile + +``` +ENV CORS_CONFIG_LOCATION=/config/cors.json ``` If needed, you can edit the `allow_origins` parameter to only allow CORS requests from specific origins. From 8736d1556077d8b12ff0c5999698921cd544f0ee Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Thu, 27 Jan 2022 09:08:59 -0800 Subject: [PATCH 15/25] feature/1 PR feedback and additional test --- .../api/stac_fastapi/api/middleware.py | 37 ++++++++++++------- stac_fastapi/pgstac/tests/api/test_api.py | 33 +++++++++++++++++ stac_fastapi/pgstac/tests/conftest.py | 20 ++++++---- stac_fastapi/sqlalchemy/tests/api/test_api.py | 32 ++++++++++++++++ stac_fastapi/sqlalchemy/tests/conftest.py | 22 ++++++++--- 5 files changed, 116 insertions(+), 28 deletions(-) diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index e7ed896b0..f1b292690 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -47,38 +47,47 @@ def __init__(self, middleware: Any, config: Optional[Dict[str, Any]] = None): def append_runtime_middlewares( - middlewares: List[MiddlewareConfig], + existing_middlewares: List[MiddlewareConfig], ) -> List[MiddlewareConfig]: """Add any middlewares specified via environment variable and configure if appropriate.""" - extended_middlewares = middlewares.copy() + return existing_middlewares + [ + addition + for addition in [_append_cors_middleware(existing_middlewares)] + if addition is not None + ] + + +def _append_cors_middleware( + existing_middlewares: List[MiddlewareConfig], +) -> Optional[MiddlewareConfig]: has_cors_middleware = ( len( [ entry - for entry in middlewares - if isinstance(entry.middleware, CORSMiddleware) + for entry in existing_middlewares + if entry.middleware == CORSMiddleware ] ) > 0 ) - if not has_cors_middleware: - cors_config_location_key: Final = "CORS_CONFIG_LOCATION" - if cors_config_location_key in environ: - cors_config_path = environ[cors_config_location_key] + cors_config_location_key: Final = "CORS_CONFIG_LOCATION" + if cors_config_location_key in environ: + cors_config_path = environ[cors_config_location_key] + if has_cors_middleware: + logger.warning( + f"CORSMiddleware already configured; ignoring config at {cors_config_path}" + ) + else: logger.info(f"looking for CORS config file at {cors_config_path}") if path.exists(cors_config_path): try: with open(cors_config_path, "r") as cors_config_file: - cors_config = loads("".join(cors_config_file.readlines())) - extended_middlewares.append( - MiddlewareConfig(CORSMiddleware, cors_config) - ) + cors_config = loads(cors_config_file.read()) logger.debug(f"loaded CORS config {cors_config}") + return MiddlewareConfig(CORSMiddleware, cors_config) except ValueError as e: logger.error(f"error parsing JSON at {cors_config_path}: {e}") except OSError as e: logger.error(f"error reading {cors_config_path}: {e}") else: logger.warning(f"CORS config not found at {cors_config_path}") - - return extended_middlewares diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index 2bfe40b8d..7b52d16a6 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -3,6 +3,7 @@ from os import environ import pytest +from fastapi.middleware.cors import CORSMiddleware from tests.api.cors_support import ( cors_config_location_key, cors_deny_origin, @@ -12,6 +13,8 @@ cors_permit_origin, ) +from stac_fastapi.api.middleware import MiddlewareConfig + STAC_CORE_ROUTES = [ "GET /", "GET /collections", @@ -352,3 +355,33 @@ async def test_with_missing_config(app_client): ) == 0 ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "app_client", + [ + { + "setup_func": cors_enable, + "middleware_configs": [ + MiddlewareConfig( + CORSMiddleware, {"allow_origins": ["http://different.origin"]} + ) + ], + } + ], + indirect=True, +) +async def test_with_existing_cors(app_client): + resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) diff --git a/stac_fastapi/pgstac/tests/conftest.py b/stac_fastapi/pgstac/tests/conftest.py index 65b626ca2..2adcf3279 100644 --- a/stac_fastapi/pgstac/tests/conftest.py +++ b/stac_fastapi/pgstac/tests/conftest.py @@ -1,7 +1,7 @@ import asyncio import json import os -from typing import Callable, Dict +from typing import Callable, Dict, Optional import asyncpg import pytest @@ -11,6 +11,7 @@ from stac_pydantic import Collection, Item from stac_fastapi.api.app import StacApi +from stac_fastapi.api.middleware import MiddlewareConfig from stac_fastapi.api.models import create_get_request_model, create_post_request_model from stac_fastapi.extensions.core import ( FieldsExtension, @@ -91,7 +92,7 @@ async def pgstac(pg): await conn.close() -def _api_client_provider(): +def _api_client_provider(middleware_configs: Optional[MiddlewareConfig] = []): print("creating client with settings") extensions = [ @@ -111,6 +112,7 @@ def _api_client_provider(): search_get_request_model=create_get_request_model(extensions), search_post_request_model=post_request_model, response_class=ORJSONResponse, + middlewares=middleware_configs, ) return api @@ -123,15 +125,17 @@ def api_client(pg): @pytest.fixture(scope="session") async def app_client(pg, request): - setup_func = request.param.get("setup_func") if hasattr(request, "param") else None - if setup_func is not None: - setup_func() - app = _api_client_provider().app + # support custom behaviours driven by fixture caller + middleware_configs = [] + if hasattr(request, "param"): + setup_func = request.param.get("setup_func") + if setup_func is not None: + setup_func() + middleware_configs = request.param.get("middleware_configs", []) + app = _api_client_provider(middleware_configs=middleware_configs).app async with AsyncClient(app=app, base_url="http://test") as c: await connect_to_db(app) - yield c - await close_db_connection(app) diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index e05e36b15..ad066634b 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -3,6 +3,7 @@ from os import environ import pytest +from fastapi.middleware.cors import CORSMiddleware from tests.api.cors_support import ( cors_config_location_key, cors_deny_origin, @@ -12,6 +13,8 @@ cors_permit_origin, ) +from stac_fastapi.api.middleware import MiddlewareConfig + from ..conftest import MockStarletteRequest STAC_CORE_ROUTES = [ @@ -373,3 +376,32 @@ def test_with_missing_config(app_client): ) == 0 ) + + +@pytest.mark.parametrize( + "app_client", + [ + { + "setup_func": cors_enable, + "middleware_configs": [ + MiddlewareConfig( + CORSMiddleware, {"allow_origins": ["http://different.origin"]} + ) + ], + } + ], + indirect=True, +) +async def test_with_existing_cors(app_client): + resp = app_client.get("/", headers={"Origin": cors_permit_origin}) + assert resp.status_code == HTTPStatus.OK + assert ( + len( + [ + header + for header in resp.headers + if header.startswith("access-control-allow-") + ] + ) + == 0 + ) diff --git a/stac_fastapi/sqlalchemy/tests/conftest.py b/stac_fastapi/sqlalchemy/tests/conftest.py index 4012eeb5a..2f1889327 100644 --- a/stac_fastapi/sqlalchemy/tests/conftest.py +++ b/stac_fastapi/sqlalchemy/tests/conftest.py @@ -1,11 +1,12 @@ import json import os -from typing import Callable, Dict +from typing import Callable, Dict, Optional import pytest from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi +from stac_fastapi.api.middleware import MiddlewareConfig from stac_fastapi.api.models import create_request_model from stac_fastapi.extensions.core import ( ContextExtension, @@ -104,7 +105,9 @@ def postgres_bulk_transactions(db_session): return BulkTransactionsClient(session=db_session) -def _api_client_provider(db_session): +def _api_client_provider( + db_session, middleware_configs: Optional[MiddlewareConfig] = [] +): settings = SqlalchemySettings() extensions = [ TransactionExtension( @@ -141,6 +144,7 @@ def _api_client_provider(db_session): extensions=extensions, search_get_request_model=get_request_model, search_post_request_model=post_request_model, + middlewares=middleware_configs, ) @@ -151,11 +155,17 @@ def api_client(db_session): @pytest.fixture def app_client(db_session, load_test_data, postgres_transactions, request): - setup_func = request.param.get("setup_func") if hasattr(request, "param") else None - if setup_func is not None: - setup_func() + # support custom behaviours driven by fixture caller + middleware_configs = [] + if hasattr(request, "param"): + setup_func = request.param.get("setup_func") + if setup_func is not None: + setup_func() + middleware_configs = request.param.get("middleware_configs", []) coll = load_test_data("test_collection.json") postgres_transactions.create_collection(coll, request=MockStarletteRequest) - with TestClient(_api_client_provider(db_session).app) as test_app: + with TestClient( + _api_client_provider(db_session, middleware_configs=middleware_configs).app + ) as test_app: yield test_app From 1cc85068ecbb8ebc497f74d462a8362934047309 Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Fri, 28 Jan 2022 16:12:13 -0800 Subject: [PATCH 16/25] feature/1 removed unwanted async on test --- stac_fastapi/sqlalchemy/tests/api/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index ad066634b..1e1736b03 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -392,7 +392,7 @@ def test_with_missing_config(app_client): ], indirect=True, ) -async def test_with_existing_cors(app_client): +def test_with_existing_cors(app_client): resp = app_client.get("/", headers={"Origin": cors_permit_origin}) assert resp.status_code == HTTPStatus.OK assert ( From f0e69b543ed947fcb14bbd96b9fce1342258c723 Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Tue, 1 Feb 2022 11:27:39 -0800 Subject: [PATCH 17/25] feature/1 updated with PR feedback from stac-fastapi --- stac_fastapi/api/stac_fastapi/api/app.py | 6 +- stac_fastapi/api/stac_fastapi/api/config.py | 51 +++++++ .../api/stac_fastapi/api/middleware.py | 130 ++++++++++-------- stac_fastapi/pgstac/tests/api/cors_support.py | 52 +++++-- stac_fastapi/pgstac/tests/api/test_api.py | 104 ++------------ stac_fastapi/pgstac/tests/conftest.py | 10 +- .../sqlalchemy/tests/api/cors_support.py | 52 +++++-- stac_fastapi/sqlalchemy/tests/api/test_api.py | 103 ++------------ stac_fastapi/sqlalchemy/tests/conftest.py | 14 +- 9 files changed, 227 insertions(+), 295 deletions(-) diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 46ceba109..f1ab10e29 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -14,7 +14,7 @@ from starlette.responses import JSONResponse, Response from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers -from stac_fastapi.api.middleware import MiddlewareConfig, append_runtime_middlewares +from stac_fastapi.api.middleware import CORSMiddleware from stac_fastapi.api.models import ( APIRequest, CollectionUri, @@ -92,7 +92,9 @@ class StacApi: ) pagination_extension = attr.ib(default=TokenPaginationExtension) response_class: Type[Response] = attr.ib(default=JSONResponse) - middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware])) + middlewares: List = attr.ib( + default=attr.Factory(lambda: [BrotliMiddleware, CORSMiddleware]) + ) route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[]) def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]: diff --git a/stac_fastapi/api/stac_fastapi/api/config.py b/stac_fastapi/api/stac_fastapi/api/config.py index 3a423e45d..9f1b8c90a 100644 --- a/stac_fastapi/api/stac_fastapi/api/config.py +++ b/stac_fastapi/api/stac_fastapi/api/config.py @@ -1,5 +1,11 @@ """Application settings.""" import enum +import re +from logging import getLogger +from os import environ +from typing import Final, Sequence + +logger: Final = getLogger(__file__) # TODO: Move to stac-pydantic @@ -22,3 +28,48 @@ class AddOns(enum.Enum): """Enumeration of available third party add ons.""" bulk_transaction = "bulk-transaction" + + +def env_to_sequence( + env_var: str, default: Sequence[str], sequence_separator: str = "|" +) -> Sequence[str]: + """Retrieve a sequence of values from an env var string, or default if missing.""" + if env_var in environ: + if re.search(re.escape(sequence_separator), environ[env_var]): + return tuple( + [part for part in environ[env_var].split(sequence_separator) if part] + ) + else: + return (environ[env_var],) + else: + return default + + +def env_to_str(env_var: str, default: str) -> str: + """Retrieve a string from an env var, or default if missing.""" + if env_var in environ: + return environ[env_var] + else: + return default + + +def env_to_bool(env_var: str, default: bool) -> bool: + """Retrieve a bool from an env var, or default if missing.""" + if env_var in environ: + if re.match("^(true|1)$", environ[env_var], re.IGNORECASE): + return True + if re.match("^(false|0)$", environ[env_var], re.IGNORECASE): + return False + logger.warning(f"{env_var} set but not a valid bool") + return default + + +def env_to_int(env_var: str, default: int) -> int: + """Retrieve an int from an env var, or default if missing.""" + if env_var in environ: + value = environ[env_var].strip() + if value.isdigit(): + return int(value) + else: + logger.warning(f"{env_var} set but not a valid int") + return default diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index f1b292690..d1974ff7c 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -1,15 +1,16 @@ """api middleware.""" -from json import loads from logging import getLogger -from os import environ, path -from typing import Any, Callable, Dict, Final, List, Optional +from typing import Callable, Final, Optional, Sequence from fastapi import APIRouter, FastAPI -from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware import cors from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.routing import Match +from starlette.types import ASGIApp + +from stac_fastapi.api.config import env_to_bool, env_to_int, env_to_sequence, env_to_str logger: Final = getLogger(__file__) @@ -37,57 +38,76 @@ async def _middleware(request: Request, call_next): return deco -class MiddlewareConfig: - """Represents a middleware class plus any configuration detail.""" - - def __init__(self, middleware: Any, config: Optional[Dict[str, Any]] = None): - """Defaults config to empty dictionary if not provided.""" - self.middleware = middleware - self.config = {} if config is None else config - +class CORSMiddleware(cors.CORSMiddleware): + """Starlette CORS Middleware with default.""" -def append_runtime_middlewares( - existing_middlewares: List[MiddlewareConfig], -) -> List[MiddlewareConfig]: - """Add any middlewares specified via environment variable and configure if appropriate.""" - return existing_middlewares + [ - addition - for addition in [_append_cors_middleware(existing_middlewares)] - if addition is not None - ] - - -def _append_cors_middleware( - existing_middlewares: List[MiddlewareConfig], -) -> Optional[MiddlewareConfig]: - has_cors_middleware = ( - len( - [ - entry - for entry in existing_middlewares - if entry.middleware == CORSMiddleware - ] + def __init__( + self, + app: ASGIApp, + allow_origins: Optional[Sequence[str]] = None, + allow_methods: Optional[Sequence[str]] = None, + allow_headers: Optional[Sequence[str]] = None, + allow_credentials: Optional[bool] = None, + allow_origin_regex: Optional[str] = None, + expose_headers: Optional[Sequence[str]] = None, + max_age: Optional[int] = None, + ) -> None: + """Create CORSMiddleware Object.""" + allow_origins = ( + env_to_sequence("CORS_ALLOW_ORIGINS", ("*",)) + if allow_origins is None + else allow_origins + ) + allow_methods = ( + env_to_sequence("CORS_ALLOW_METHODS", ("*",)) + if allow_methods is None + else allow_methods ) - > 0 - ) - cors_config_location_key: Final = "CORS_CONFIG_LOCATION" - if cors_config_location_key in environ: - cors_config_path = environ[cors_config_location_key] - if has_cors_middleware: - logger.warning( - f"CORSMiddleware already configured; ignoring config at {cors_config_path}" + allow_headers = ( + env_to_sequence("CORS_ALLOW_HEADERS", ("*",)) + if allow_headers is None + else allow_headers + ) + allow_credentials = ( + env_to_bool("CORS_ALLOW_CREDENTIALS", False) + if allow_credentials is None + else allow_credentials + ) + allow_origin_regex = ( + env_to_str("CORS_ALLOW_ORIGIN_REGEX", None) + if allow_origin_regex is None + else allow_origin_regex + ) + if allow_origin_regex is not None: + logger.info( + "CORS_ALLOW_ORIGIN_REGEX present and will override CORS_ALLOW_ORIGINS" ) - else: - logger.info(f"looking for CORS config file at {cors_config_path}") - if path.exists(cors_config_path): - try: - with open(cors_config_path, "r") as cors_config_file: - cors_config = loads(cors_config_file.read()) - logger.debug(f"loaded CORS config {cors_config}") - return MiddlewareConfig(CORSMiddleware, cors_config) - except ValueError as e: - logger.error(f"error parsing JSON at {cors_config_path}: {e}") - except OSError as e: - logger.error(f"error reading {cors_config_path}: {e}") - else: - logger.warning(f"CORS config not found at {cors_config_path}") + allow_origins = "" + expose_headers = ( + env_to_sequence("CORS_EXPOSE_HEADERS", ("*",)) + if expose_headers is None + else expose_headers + ) + max_age = env_to_int("CORS_MAX_AGE", 600) if max_age is None else max_age + logger.debug( + f""" + CORS configuration + allow_origins: {allow_origins} + allow_methods: {allow_methods} + allow_headers: {allow_headers} + allow_credentials: {allow_credentials} + allow_origin_regex: {allow_origin_regex} + expose_headers: {expose_headers} + max_age: {max_age} + """ + ) + super().__init__( + app, + allow_origins=allow_origins, + allow_methods=allow_methods, + allow_headers=allow_headers, + allow_credentials=allow_credentials, + allow_origin_regex=allow_origin_regex, + expose_headers=expose_headers, + max_age=max_age, + ) diff --git a/stac_fastapi/pgstac/tests/api/cors_support.py b/stac_fastapi/pgstac/tests/api/cors_support.py index 9f891a8d6..b0471c957 100644 --- a/stac_fastapi/pgstac/tests/api/cors_support.py +++ b/stac_fastapi/pgstac/tests/api/cors_support.py @@ -1,23 +1,45 @@ -from json import dumps -from os import environ, fdopen, path, sep -from tempfile import mkstemp +from os import environ from typing import Final -cors_config_location_key: Final = "CORS_CONFIG_LOCATION" -cors_permit_origin: Final = "http://cors.pass" -cors_deny_origin: Final = "http://cors.fail" +cors_origin_1: Final = "http://permit.one" +cors_origin_2: Final = "http://permit.two" +cors_origin_3: Final = "http://permit.three" +cors_origin_deny: Final = "http://deny.me" -def cors_enable(): - tmp_file, tmp_filename = mkstemp() - with fdopen(tmp_file, "w") as f: - f.write(dumps({"allow_origins": [cors_permit_origin]})) - environ[cors_config_location_key] = tmp_filename +def cors_permit_1(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_1 -def cors_disable() -> None: - environ.pop(cors_config_location_key, None) +def cors_permit_2(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_2 -def cors_missing(): - environ[cors_config_location_key] = path.join(path.abspath(sep), "missing.file") +def cors_permit_3(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_3 + + +def cors_permit_12(): + environ["CORS_ALLOW_ORIGINS"] = f"{cors_origin_1}|{cors_origin_2}" + + +def cors_permit_123_regex(): + environ["CORS_ALLOW_ORIGIN_REGEX"] = "http\\://permit\\..+" + + +def cors_deny(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_deny + + +def cors_disable_get(): + environ["CORS_ALLOW_METHODS"] = "HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH" + + +def cors_clear_config(): + environ.pop("CORS_ALLOW_ORIGINS", None) + environ.pop("CORS_ALLOW_METHODS", None) + environ.pop("CORS_ALLOW_HEADERS", None) + environ.pop("CORS_ALLOW_CREDENTIALS", None) + environ.pop("CORS_ALLOW_ORIGIN_REGEX", None) + environ.pop("CORS_EXPOSE_HEADERS", None) + environ.pop("CORS_MAX_AGE", None) diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index 7b52d16a6..7503bd0e0 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -1,20 +1,17 @@ from datetime import datetime, timedelta from http import HTTPStatus -from os import environ import pytest -from fastapi.middleware.cors import CORSMiddleware from tests.api.cors_support import ( - cors_config_location_key, - cors_deny_origin, - cors_disable, - cors_enable, - cors_missing, - cors_permit_origin, + cors_clear_config, + cors_deny, + cors_origin_1, + cors_origin_deny, + cors_permit_1, + cors_permit_12, + cors_permit_123_regex, ) -from stac_fastapi.api.middleware import MiddlewareConfig - STAC_CORE_ROUTES = [ "GET /", "GET /collections", @@ -37,7 +34,7 @@ def teardown_function(): - environ.pop(cors_config_location_key, None) + cors_clear_config() async def test_post_search_content_type(app_client): @@ -300,88 +297,3 @@ async def test_search_line_string_intersects( assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 1 - - -@pytest.mark.parametrize("app_client", [{"setup_func": cors_disable}], indirect=True) -async def test_without_cors(app_client): - resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) - assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) - - -@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) -async def test_with_match_cors(app_client): - resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) - assert resp.status_code == HTTPStatus.OK - assert resp.headers["access-control-allow-origin"] == cors_permit_origin - - -@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) -async def test_with_mismatch_cors(app_client): - resp = await app_client.get("/", headers={"Origin": cors_deny_origin}) - assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) - - -@pytest.mark.parametrize("app_client", [{"setup_func": cors_missing}], indirect=True) -async def test_with_missing_config(app_client): - resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) - assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "app_client", - [ - { - "setup_func": cors_enable, - "middleware_configs": [ - MiddlewareConfig( - CORSMiddleware, {"allow_origins": ["http://different.origin"]} - ) - ], - } - ], - indirect=True, -) -async def test_with_existing_cors(app_client): - resp = await app_client.get("/", headers={"Origin": cors_permit_origin}) - assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) diff --git a/stac_fastapi/pgstac/tests/conftest.py b/stac_fastapi/pgstac/tests/conftest.py index 2adcf3279..29e7ff990 100644 --- a/stac_fastapi/pgstac/tests/conftest.py +++ b/stac_fastapi/pgstac/tests/conftest.py @@ -1,7 +1,7 @@ import asyncio import json import os -from typing import Callable, Dict, Optional +from typing import Callable, Dict import asyncpg import pytest @@ -11,7 +11,6 @@ from stac_pydantic import Collection, Item from stac_fastapi.api.app import StacApi -from stac_fastapi.api.middleware import MiddlewareConfig from stac_fastapi.api.models import create_get_request_model, create_post_request_model from stac_fastapi.extensions.core import ( FieldsExtension, @@ -92,7 +91,7 @@ async def pgstac(pg): await conn.close() -def _api_client_provider(middleware_configs: Optional[MiddlewareConfig] = []): +def _api_client_provider(): print("creating client with settings") extensions = [ @@ -112,7 +111,6 @@ def _api_client_provider(middleware_configs: Optional[MiddlewareConfig] = []): search_get_request_model=create_get_request_model(extensions), search_post_request_model=post_request_model, response_class=ORJSONResponse, - middlewares=middleware_configs, ) return api @@ -126,13 +124,11 @@ def api_client(pg): @pytest.fixture(scope="session") async def app_client(pg, request): # support custom behaviours driven by fixture caller - middleware_configs = [] if hasattr(request, "param"): setup_func = request.param.get("setup_func") if setup_func is not None: setup_func() - middleware_configs = request.param.get("middleware_configs", []) - app = _api_client_provider(middleware_configs=middleware_configs).app + app = _api_client_provider().app async with AsyncClient(app=app, base_url="http://test") as c: await connect_to_db(app) yield c diff --git a/stac_fastapi/sqlalchemy/tests/api/cors_support.py b/stac_fastapi/sqlalchemy/tests/api/cors_support.py index 9f891a8d6..b0471c957 100644 --- a/stac_fastapi/sqlalchemy/tests/api/cors_support.py +++ b/stac_fastapi/sqlalchemy/tests/api/cors_support.py @@ -1,23 +1,45 @@ -from json import dumps -from os import environ, fdopen, path, sep -from tempfile import mkstemp +from os import environ from typing import Final -cors_config_location_key: Final = "CORS_CONFIG_LOCATION" -cors_permit_origin: Final = "http://cors.pass" -cors_deny_origin: Final = "http://cors.fail" +cors_origin_1: Final = "http://permit.one" +cors_origin_2: Final = "http://permit.two" +cors_origin_3: Final = "http://permit.three" +cors_origin_deny: Final = "http://deny.me" -def cors_enable(): - tmp_file, tmp_filename = mkstemp() - with fdopen(tmp_file, "w") as f: - f.write(dumps({"allow_origins": [cors_permit_origin]})) - environ[cors_config_location_key] = tmp_filename +def cors_permit_1(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_1 -def cors_disable() -> None: - environ.pop(cors_config_location_key, None) +def cors_permit_2(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_2 -def cors_missing(): - environ[cors_config_location_key] = path.join(path.abspath(sep), "missing.file") +def cors_permit_3(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_3 + + +def cors_permit_12(): + environ["CORS_ALLOW_ORIGINS"] = f"{cors_origin_1}|{cors_origin_2}" + + +def cors_permit_123_regex(): + environ["CORS_ALLOW_ORIGIN_REGEX"] = "http\\://permit\\..+" + + +def cors_deny(): + environ["CORS_ALLOW_ORIGINS"] = cors_origin_deny + + +def cors_disable_get(): + environ["CORS_ALLOW_METHODS"] = "HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH" + + +def cors_clear_config(): + environ.pop("CORS_ALLOW_ORIGINS", None) + environ.pop("CORS_ALLOW_METHODS", None) + environ.pop("CORS_ALLOW_HEADERS", None) + environ.pop("CORS_ALLOW_CREDENTIALS", None) + environ.pop("CORS_ALLOW_ORIGIN_REGEX", None) + environ.pop("CORS_EXPOSE_HEADERS", None) + environ.pop("CORS_MAX_AGE", None) diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index 1e1736b03..6dc8511bf 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -1,20 +1,17 @@ from datetime import datetime, timedelta from http import HTTPStatus -from os import environ import pytest -from fastapi.middleware.cors import CORSMiddleware from tests.api.cors_support import ( - cors_config_location_key, - cors_deny_origin, - cors_disable, - cors_enable, - cors_missing, - cors_permit_origin, + cors_clear_config, + cors_deny, + cors_origin_1, + cors_origin_deny, + cors_permit_1, + cors_permit_12, + cors_permit_123_regex, ) -from stac_fastapi.api.middleware import MiddlewareConfig - from ..conftest import MockStarletteRequest STAC_CORE_ROUTES = [ @@ -39,7 +36,7 @@ def teardown_function(): - environ.pop(cors_config_location_key, None) + cors_clear_config() def test_post_search_content_type(app_client): @@ -322,86 +319,4 @@ def test_app_fields_extension_return_all_properties( if expected_prop in ("datetime", "created", "updated"): assert feature["properties"][expected_prop][0:19] == expected_value[0:19] else: - assert feature["properties"][expected_prop] == expected_value -@pytest.mark.parametrize("app_client", [{"setup_func": cors_disable}], indirect=True) -def test_without_cors(app_client): - resp = app_client.get("/", headers={"Origin": cors_permit_origin}) - assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) - - -@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) -def test_with_match_cors(app_client): - resp = app_client.get("/", headers={"Origin": cors_permit_origin}) - assert resp.status_code == HTTPStatus.OK - assert resp.headers["access-control-allow-origin"] == cors_permit_origin - - -@pytest.mark.parametrize("app_client", [{"setup_func": cors_enable}], indirect=True) -def test_with_mismatch_cors(app_client): - resp = app_client.get("/", headers={"Origin": cors_deny_origin}) - assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) - - -@pytest.mark.parametrize("app_client", [{"setup_func": cors_missing}], indirect=True) -def test_with_missing_config(app_client): - resp = app_client.get("/", headers={"Origin": cors_permit_origin}) - assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) - - -@pytest.mark.parametrize( - "app_client", - [ - { - "setup_func": cors_enable, - "middleware_configs": [ - MiddlewareConfig( - CORSMiddleware, {"allow_origins": ["http://different.origin"]} - ) - ], - } - ], - indirect=True, -) -def test_with_existing_cors(app_client): - resp = app_client.get("/", headers={"Origin": cors_permit_origin}) - assert resp.status_code == HTTPStatus.OK - assert ( - len( - [ - header - for header in resp.headers - if header.startswith("access-control-allow-") - ] - ) - == 0 - ) + assert feature["properties"][expected_prop] == expected_value \ No newline at end of file diff --git a/stac_fastapi/sqlalchemy/tests/conftest.py b/stac_fastapi/sqlalchemy/tests/conftest.py index 2f1889327..1ef44f62b 100644 --- a/stac_fastapi/sqlalchemy/tests/conftest.py +++ b/stac_fastapi/sqlalchemy/tests/conftest.py @@ -1,12 +1,11 @@ import json import os -from typing import Callable, Dict, Optional +from typing import Callable, Dict import pytest from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi -from stac_fastapi.api.middleware import MiddlewareConfig from stac_fastapi.api.models import create_request_model from stac_fastapi.extensions.core import ( ContextExtension, @@ -105,9 +104,7 @@ def postgres_bulk_transactions(db_session): return BulkTransactionsClient(session=db_session) -def _api_client_provider( - db_session, middleware_configs: Optional[MiddlewareConfig] = [] -): +def _api_client_provider(db_session): settings = SqlalchemySettings() extensions = [ TransactionExtension( @@ -144,7 +141,6 @@ def _api_client_provider( extensions=extensions, search_get_request_model=get_request_model, search_post_request_model=post_request_model, - middlewares=middleware_configs, ) @@ -156,16 +152,12 @@ def api_client(db_session): @pytest.fixture def app_client(db_session, load_test_data, postgres_transactions, request): # support custom behaviours driven by fixture caller - middleware_configs = [] if hasattr(request, "param"): setup_func = request.param.get("setup_func") if setup_func is not None: setup_func() - middleware_configs = request.param.get("middleware_configs", []) coll = load_test_data("test_collection.json") postgres_transactions.create_collection(coll, request=MockStarletteRequest) - with TestClient( - _api_client_provider(db_session, middleware_configs=middleware_configs).app - ) as test_app: + with TestClient(_api_client_provider(db_session).app) as test_app: yield test_app From 76f467c39e8e6522137f294689365b9326eb49cd Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Tue, 1 Feb 2022 11:32:20 -0800 Subject: [PATCH 18/25] feature/1 updated documentation --- CHANGES.md | 1 + docs/tips-and-tricks.md | 23 +++-------------------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index a76ffe705..f530b67d7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,6 +8,7 @@ * Ability to POST an ItemCollection to the collections/{collectionId}/items route. ([#367](https://github.com/stac-utils/stac-fastapi/pull/367)) * Add STAC API - Collections conformance class. ([383](https://github.com/stac-utils/stac-fastapi/pull/383)) * Added ability to configure CORS middleware via JSON configuration file and environment variable, rather than having to modify code. ([341](https://github.com/stac-utils/stac-fastapi/pull/341)) +* Added ability to configure CORS middleware via environment variables ([#341](https://github.com/stac-utils/stac-fastapi/pull/341)) ### Changed diff --git a/docs/tips-and-tricks.md b/docs/tips-and-tricks.md index 997edfd0e..294274c77 100644 --- a/docs/tips-and-tricks.md +++ b/docs/tips-and-tricks.md @@ -5,30 +5,13 @@ This page contains a few 'tips and tricks' for getting stac-fastapi working in v CORS (Cross-Origin Resource Sharing) support may be required to use stac-fastapi in certain situations. For example, if you are running [stac-browser](https://github.com/radiantearth/stac-browser) to browse the STAC catalog created by stac-fastapi, then you will need to enable CORS support. -To do this, create a JSON configuration file whose schema matches the options described in the [FastAPI docs](https://fastapi.tiangolo.com/tutorial/cors/), e.g. +To do this, configure environment variables for the configuration options described in [FastAPI docs](https://fastapi.tiangolo.com/tutorial/cors/), using a `CORS_` prefix and upper-case, e.g. ``` -{ - "allow_origins": ["*"], - "allow_methods": ["*"] -} +CORS_ALLOW_ORIGINS=http://domain.one|http://domain.two +CORS_ALLOW_METHODS="*" ``` -Deploy this file to a location accessible by stac-fastapi, e.g. in Dockerfile - -``` -RUN mkdir /config -COPY cors.json /config/cors.json -``` - -Set an environment variable `CORS_CONFIG_LOCATION` pointing to this file, e.g. in Dockerfile - -``` -ENV CORS_CONFIG_LOCATION=/config/cors.json -``` - -If needed, you can edit the `allow_origins` parameter to only allow CORS requests from specific origins. - ## Enable the Context extension The Context STAC extension provides information on the number of items matched and returned from a STAC search. This is required by various other STAC-related tools, such as the pystac command-line client. To enable the extension, edit `stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py` (or the equivalent in the `pgstac` folder) and add the following import: From 0dc532fcf3d21f1668e744920fb8a6ae6488aa38 Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Tue, 1 Feb 2022 11:33:14 -0800 Subject: [PATCH 19/25] feature/1 updated documentation --- docs/tips-and-tricks.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/tips-and-tricks.md b/docs/tips-and-tricks.md index 294274c77..4e763d107 100644 --- a/docs/tips-and-tricks.md +++ b/docs/tips-and-tricks.md @@ -9,7 +9,6 @@ To do this, configure environment variables for the configuration options descri ``` CORS_ALLOW_ORIGINS=http://domain.one|http://domain.two -CORS_ALLOW_METHODS="*" ``` ## Enable the Context extension From 205eb6f86770ef2ac8915af395671ff4655d9120 Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Wed, 2 Feb 2022 09:54:18 -0800 Subject: [PATCH 20/25] feature/1 follow pydantic configuration standard --- docs/tips-and-tricks.md | 9 ++- stac_fastapi/api/stac_fastapi/api/__init__.py | 5 ++ stac_fastapi/api/stac_fastapi/api/config.py | 55 ++++--------------- .../api/stac_fastapi/api/middleware.py | 30 ++++------ stac_fastapi/pgstac/tests/api/cors_support.py | 45 ++++++++++----- .../sqlalchemy/tests/api/cors_support.py | 45 ++++++++++----- 6 files changed, 92 insertions(+), 97 deletions(-) diff --git a/docs/tips-and-tricks.md b/docs/tips-and-tricks.md index 4e763d107..3d41bf1c3 100644 --- a/docs/tips-and-tricks.md +++ b/docs/tips-and-tricks.md @@ -5,10 +5,13 @@ This page contains a few 'tips and tricks' for getting stac-fastapi working in v CORS (Cross-Origin Resource Sharing) support may be required to use stac-fastapi in certain situations. For example, if you are running [stac-browser](https://github.com/radiantearth/stac-browser) to browse the STAC catalog created by stac-fastapi, then you will need to enable CORS support. -To do this, configure environment variables for the configuration options described in [FastAPI docs](https://fastapi.tiangolo.com/tutorial/cors/), using a `CORS_` prefix and upper-case, e.g. - +To do this, configure environment variables for the configuration options described in [FastAPI docs](https://fastapi.tiangolo.com/tutorial/cors/) using a `cors_` prefix e.g. +``` +cors_allow_credentials=true [or 1] +``` +Sequences, such as `allow_origins`, should be in JSON format e.g. ``` -CORS_ALLOW_ORIGINS=http://domain.one|http://domain.two +cors_allow_origins='["http://domain.one", "http://domain.two"]' ``` ## Enable the Context extension diff --git a/stac_fastapi/api/stac_fastapi/api/__init__.py b/stac_fastapi/api/stac_fastapi/api/__init__.py index df6f6249b..b616b5927 100644 --- a/stac_fastapi/api/stac_fastapi/api/__init__.py +++ b/stac_fastapi/api/stac_fastapi/api/__init__.py @@ -1 +1,6 @@ """api submodule.""" +from typing import Final + +from stac_fastapi.api.config import Settings + +settings: Final = Settings() diff --git a/stac_fastapi/api/stac_fastapi/api/config.py b/stac_fastapi/api/stac_fastapi/api/config.py index 9f1b8c90a..933023d37 100644 --- a/stac_fastapi/api/stac_fastapi/api/config.py +++ b/stac_fastapi/api/stac_fastapi/api/config.py @@ -1,10 +1,10 @@ """Application settings.""" import enum -import re from logging import getLogger -from os import environ from typing import Final, Sequence +from pydantic import BaseSettings, Field + logger: Final = getLogger(__file__) @@ -30,46 +30,13 @@ class AddOns(enum.Enum): bulk_transaction = "bulk-transaction" -def env_to_sequence( - env_var: str, default: Sequence[str], sequence_separator: str = "|" -) -> Sequence[str]: - """Retrieve a sequence of values from an env var string, or default if missing.""" - if env_var in environ: - if re.search(re.escape(sequence_separator), environ[env_var]): - return tuple( - [part for part in environ[env_var].split(sequence_separator) if part] - ) - else: - return (environ[env_var],) - else: - return default - - -def env_to_str(env_var: str, default: str) -> str: - """Retrieve a string from an env var, or default if missing.""" - if env_var in environ: - return environ[env_var] - else: - return default - - -def env_to_bool(env_var: str, default: bool) -> bool: - """Retrieve a bool from an env var, or default if missing.""" - if env_var in environ: - if re.match("^(true|1)$", environ[env_var], re.IGNORECASE): - return True - if re.match("^(false|0)$", environ[env_var], re.IGNORECASE): - return False - logger.warning(f"{env_var} set but not a valid bool") - return default - +class Settings(BaseSettings): + """API settings.""" -def env_to_int(env_var: str, default: int) -> int: - """Retrieve an int from an env var, or default if missing.""" - if env_var in environ: - value = environ[env_var].strip() - if value.isdigit(): - return int(value) - else: - logger.warning(f"{env_var} set but not a valid int") - return default + allow_origins: Sequence[str] = Field(("*",), env="cors_allow_origins") + allow_methods: Sequence[str] = Field(("*",), env="cors_allow_methods") + allow_headers: Sequence[str] = Field(("*",), env="cors_allow_headers") + allow_credentials: bool = Field(False, env="cors_allow_credentials") + allow_origin_regex: str = Field(None, env="cors_allow_origin_regex") + expose_headers: Sequence[str] = Field(("*",), env="cors_expose_headers") + max_age: int = Field(600, env="cors_max_age") diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index d1974ff7c..ffcbcde23 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -10,7 +10,7 @@ from starlette.routing import Match from starlette.types import ASGIApp -from stac_fastapi.api.config import env_to_bool, env_to_int, env_to_sequence, env_to_str +from stac_fastapi.api import settings logger: Final = getLogger(__file__) @@ -39,7 +39,7 @@ async def _middleware(request: Request, call_next): class CORSMiddleware(cors.CORSMiddleware): - """Starlette CORS Middleware with default.""" + """Starlette CORS Middleware with configuration.""" def __init__( self, @@ -54,41 +54,31 @@ def __init__( ) -> None: """Create CORSMiddleware Object.""" allow_origins = ( - env_to_sequence("CORS_ALLOW_ORIGINS", ("*",)) - if allow_origins is None - else allow_origins + settings.allow_origins if allow_origins is None else allow_origins ) allow_methods = ( - env_to_sequence("CORS_ALLOW_METHODS", ("*",)) - if allow_methods is None - else allow_methods + settings.allow_methods if allow_methods is None else allow_methods ) allow_headers = ( - env_to_sequence("CORS_ALLOW_HEADERS", ("*",)) - if allow_headers is None - else allow_headers + settings.allow_headers if allow_headers is None else allow_headers ) allow_credentials = ( - env_to_bool("CORS_ALLOW_CREDENTIALS", False) + settings.allow_credentials if allow_credentials is None else allow_credentials ) allow_origin_regex = ( - env_to_str("CORS_ALLOW_ORIGIN_REGEX", None) + settings.allow_origin_regex if allow_origin_regex is None else allow_origin_regex ) if allow_origin_regex is not None: - logger.info( - "CORS_ALLOW_ORIGIN_REGEX present and will override CORS_ALLOW_ORIGINS" - ) + logger.info("allow_origin_regex present and will override allow_origins") allow_origins = "" expose_headers = ( - env_to_sequence("CORS_EXPOSE_HEADERS", ("*",)) - if expose_headers is None - else expose_headers + settings.expose_headers if expose_headers is None else expose_headers ) - max_age = env_to_int("CORS_MAX_AGE", 600) if max_age is None else max_age + max_age = settings.max_age if max_age is None else max_age logger.debug( f""" CORS configuration diff --git a/stac_fastapi/pgstac/tests/api/cors_support.py b/stac_fastapi/pgstac/tests/api/cors_support.py index b0471c957..bf6996ed0 100644 --- a/stac_fastapi/pgstac/tests/api/cors_support.py +++ b/stac_fastapi/pgstac/tests/api/cors_support.py @@ -1,6 +1,10 @@ -from os import environ +from copy import deepcopy +from json import dumps from typing import Final +from stac_fastapi.api import settings + +settings_fallback = deepcopy(settings) cors_origin_1: Final = "http://permit.one" cors_origin_2: Final = "http://permit.two" cors_origin_3: Final = "http://permit.three" @@ -8,38 +12,49 @@ def cors_permit_1(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_1 + settings.allow_origins = dumps((cors_origin_1,)) def cors_permit_2(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_2 + settings.allow_origins = dumps((cors_origin_2,)) def cors_permit_3(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_3 + settings.allow_origins = dumps((cors_origin_3,)) def cors_permit_12(): - environ["CORS_ALLOW_ORIGINS"] = f"{cors_origin_1}|{cors_origin_2}" + settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) def cors_permit_123_regex(): - environ["CORS_ALLOW_ORIGIN_REGEX"] = "http\\://permit\\..+" + settings.allow_origin_regex = "http\\://permit\\..+" def cors_deny(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_deny + settings.allow_origins = dumps((cors_origin_deny,)) def cors_disable_get(): - environ["CORS_ALLOW_METHODS"] = "HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH" + settings.allow_methods = dumps( + ( + "HEAD", + "POST", + "PUT", + "DELETE", + "CONNECT", + "OPTIONS", + "TRACE", + "PATCH", + ) + ) def cors_clear_config(): - environ.pop("CORS_ALLOW_ORIGINS", None) - environ.pop("CORS_ALLOW_METHODS", None) - environ.pop("CORS_ALLOW_HEADERS", None) - environ.pop("CORS_ALLOW_CREDENTIALS", None) - environ.pop("CORS_ALLOW_ORIGIN_REGEX", None) - environ.pop("CORS_EXPOSE_HEADERS", None) - environ.pop("CORS_MAX_AGE", None) + settings.allow_origins = settings_fallback.allow_origins + settings.allow_methods = settings_fallback.allow_methods + settings.allow_headers = settings_fallback.allow_headers + settings.allow_credentials = settings_fallback.allow_credentials + settings.allow_origin_regex = settings_fallback.allow_origin_regex + settings.expose_headers = settings_fallback.expose_headers + settings.max_age = settings_fallback.max_age diff --git a/stac_fastapi/sqlalchemy/tests/api/cors_support.py b/stac_fastapi/sqlalchemy/tests/api/cors_support.py index b0471c957..bf6996ed0 100644 --- a/stac_fastapi/sqlalchemy/tests/api/cors_support.py +++ b/stac_fastapi/sqlalchemy/tests/api/cors_support.py @@ -1,6 +1,10 @@ -from os import environ +from copy import deepcopy +from json import dumps from typing import Final +from stac_fastapi.api import settings + +settings_fallback = deepcopy(settings) cors_origin_1: Final = "http://permit.one" cors_origin_2: Final = "http://permit.two" cors_origin_3: Final = "http://permit.three" @@ -8,38 +12,49 @@ def cors_permit_1(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_1 + settings.allow_origins = dumps((cors_origin_1,)) def cors_permit_2(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_2 + settings.allow_origins = dumps((cors_origin_2,)) def cors_permit_3(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_3 + settings.allow_origins = dumps((cors_origin_3,)) def cors_permit_12(): - environ["CORS_ALLOW_ORIGINS"] = f"{cors_origin_1}|{cors_origin_2}" + settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) def cors_permit_123_regex(): - environ["CORS_ALLOW_ORIGIN_REGEX"] = "http\\://permit\\..+" + settings.allow_origin_regex = "http\\://permit\\..+" def cors_deny(): - environ["CORS_ALLOW_ORIGINS"] = cors_origin_deny + settings.allow_origins = dumps((cors_origin_deny,)) def cors_disable_get(): - environ["CORS_ALLOW_METHODS"] = "HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH" + settings.allow_methods = dumps( + ( + "HEAD", + "POST", + "PUT", + "DELETE", + "CONNECT", + "OPTIONS", + "TRACE", + "PATCH", + ) + ) def cors_clear_config(): - environ.pop("CORS_ALLOW_ORIGINS", None) - environ.pop("CORS_ALLOW_METHODS", None) - environ.pop("CORS_ALLOW_HEADERS", None) - environ.pop("CORS_ALLOW_CREDENTIALS", None) - environ.pop("CORS_ALLOW_ORIGIN_REGEX", None) - environ.pop("CORS_EXPOSE_HEADERS", None) - environ.pop("CORS_MAX_AGE", None) + settings.allow_origins = settings_fallback.allow_origins + settings.allow_methods = settings_fallback.allow_methods + settings.allow_headers = settings_fallback.allow_headers + settings.allow_credentials = settings_fallback.allow_credentials + settings.allow_origin_regex = settings_fallback.allow_origin_regex + settings.expose_headers = settings_fallback.expose_headers + settings.max_age = settings_fallback.max_age From bfffdddd6688979295d95e0f24ce062e63713fc6 Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Wed, 2 Feb 2022 14:02:57 -0800 Subject: [PATCH 21/25] feature/1 fix docs build --- stac_fastapi/api/stac_fastapi/api/__init__.py | 5 ----- stac_fastapi/api/stac_fastapi/api/config.py | 3 +++ stac_fastapi/api/stac_fastapi/api/middleware.py | 2 +- stac_fastapi/pgstac/tests/api/cors_support.py | 2 +- stac_fastapi/sqlalchemy/tests/api/cors_support.py | 2 +- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/stac_fastapi/api/stac_fastapi/api/__init__.py b/stac_fastapi/api/stac_fastapi/api/__init__.py index b616b5927..df6f6249b 100644 --- a/stac_fastapi/api/stac_fastapi/api/__init__.py +++ b/stac_fastapi/api/stac_fastapi/api/__init__.py @@ -1,6 +1 @@ """api submodule.""" -from typing import Final - -from stac_fastapi.api.config import Settings - -settings: Final = Settings() diff --git a/stac_fastapi/api/stac_fastapi/api/config.py b/stac_fastapi/api/stac_fastapi/api/config.py index 933023d37..96114b1f4 100644 --- a/stac_fastapi/api/stac_fastapi/api/config.py +++ b/stac_fastapi/api/stac_fastapi/api/config.py @@ -40,3 +40,6 @@ class Settings(BaseSettings): allow_origin_regex: str = Field(None, env="cors_allow_origin_regex") expose_headers: Sequence[str] = Field(("*",), env="cors_expose_headers") max_age: int = Field(600, env="cors_max_age") + + +settings: Final = Settings() diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index ffcbcde23..8dd8be2a8 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -10,7 +10,7 @@ from starlette.routing import Match from starlette.types import ASGIApp -from stac_fastapi.api import settings +from stac_fastapi.api.config import settings logger: Final = getLogger(__file__) diff --git a/stac_fastapi/pgstac/tests/api/cors_support.py b/stac_fastapi/pgstac/tests/api/cors_support.py index bf6996ed0..9f9c303e6 100644 --- a/stac_fastapi/pgstac/tests/api/cors_support.py +++ b/stac_fastapi/pgstac/tests/api/cors_support.py @@ -2,7 +2,7 @@ from json import dumps from typing import Final -from stac_fastapi.api import settings +from stac_fastapi.api.config import settings settings_fallback = deepcopy(settings) cors_origin_1: Final = "http://permit.one" diff --git a/stac_fastapi/sqlalchemy/tests/api/cors_support.py b/stac_fastapi/sqlalchemy/tests/api/cors_support.py index bf6996ed0..9f9c303e6 100644 --- a/stac_fastapi/sqlalchemy/tests/api/cors_support.py +++ b/stac_fastapi/sqlalchemy/tests/api/cors_support.py @@ -2,7 +2,7 @@ from json import dumps from typing import Final -from stac_fastapi.api import settings +from stac_fastapi.api.config import settings settings_fallback = deepcopy(settings) cors_origin_1: Final = "http://permit.one" From a0fa5fc7b84c71a7cc11651e3864acdb5eeebc1f Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Fri, 18 Feb 2022 12:13:23 -0800 Subject: [PATCH 22/25] feature/1 add CORS tests to api tests --- Makefile | 6 +- docker-compose.yml | 13 +++ stac_fastapi/api/tests/__init__.py | 0 stac_fastapi/api/tests/cors_support.py | 60 ++++++++++ stac_fastapi/api/tests/test_cors.py | 76 +++++++++++++ .../api/tests/test_route_dependencies.py | 106 ++++++++++++++++++ stac_fastapi/api/tests/util.py | 37 ++++++ 7 files changed, 297 insertions(+), 1 deletion(-) create mode 100644 stac_fastapi/api/tests/__init__.py create mode 100644 stac_fastapi/api/tests/cors_support.py create mode 100644 stac_fastapi/api/tests/test_cors.py create mode 100644 stac_fastapi/api/tests/test_route_dependencies.py create mode 100644 stac_fastapi/api/tests/util.py diff --git a/Makefile b/Makefile index fe2b6fe32..d38197a3b 100644 --- a/Makefile +++ b/Makefile @@ -46,6 +46,10 @@ test-sqlalchemy: run-joplin-sqlalchemy test-pgstac: $(run_pgstac) /bin/bash -c 'export && ./scripts/wait-for-it.sh database:5432 && cd /app/stac_fastapi/pgstac/tests/ && pytest -vvv' +.PHONY: test-api +test-api: + docker-compose run api-tester + .PHONY: run-database run-database: docker-compose run --rm database @@ -59,7 +63,7 @@ run-joplin-pgstac: docker-compose run --rm loadjoplin-pgstac .PHONY: test -test: test-sqlalchemy test-pgstac +test: test-sqlalchemy test-pgstac test-api .PHONY: pybase-install pybase-install: diff --git a/docker-compose.yml b/docker-compose.yml index 996bb6593..bc8feba62 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -118,6 +118,19 @@ services: - database - app-pgstac + api-tester: + image: stac-utils/stac-fastapi + build: + context: . + dockerfile: Dockerfile + profiles: + # prevent tester from starting with `docker-compose up` + - api-test + working_dir: /app/stac_fastapi/api + volumes: + - ./:/app + command: pytest -svvv + networks: default: name: stac-fastapi-network diff --git a/stac_fastapi/api/tests/__init__.py b/stac_fastapi/api/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/stac_fastapi/api/tests/cors_support.py b/stac_fastapi/api/tests/cors_support.py new file mode 100644 index 000000000..9f9c303e6 --- /dev/null +++ b/stac_fastapi/api/tests/cors_support.py @@ -0,0 +1,60 @@ +from copy import deepcopy +from json import dumps +from typing import Final + +from stac_fastapi.api.config import settings + +settings_fallback = deepcopy(settings) +cors_origin_1: Final = "http://permit.one" +cors_origin_2: Final = "http://permit.two" +cors_origin_3: Final = "http://permit.three" +cors_origin_deny: Final = "http://deny.me" + + +def cors_permit_1(): + settings.allow_origins = dumps((cors_origin_1,)) + + +def cors_permit_2(): + settings.allow_origins = dumps((cors_origin_2,)) + + +def cors_permit_3(): + settings.allow_origins = dumps((cors_origin_3,)) + + +def cors_permit_12(): + settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) + + +def cors_permit_123_regex(): + settings.allow_origin_regex = "http\\://permit\\..+" + + +def cors_deny(): + settings.allow_origins = dumps((cors_origin_deny,)) + + +def cors_disable_get(): + settings.allow_methods = dumps( + ( + "HEAD", + "POST", + "PUT", + "DELETE", + "CONNECT", + "OPTIONS", + "TRACE", + "PATCH", + ) + ) + + +def cors_clear_config(): + settings.allow_origins = settings_fallback.allow_origins + settings.allow_methods = settings_fallback.allow_methods + settings.allow_headers = settings_fallback.allow_headers + settings.allow_credentials = settings_fallback.allow_credentials + settings.allow_origin_regex = settings_fallback.allow_origin_regex + settings.expose_headers = settings_fallback.expose_headers + settings.max_age = settings_fallback.max_age diff --git a/stac_fastapi/api/tests/test_cors.py b/stac_fastapi/api/tests/test_cors.py new file mode 100644 index 000000000..6dc7a2b0e --- /dev/null +++ b/stac_fastapi/api/tests/test_cors.py @@ -0,0 +1,76 @@ +from http import HTTPStatus + +from starlette.testclient import TestClient +from tests.cors_support import ( + cors_clear_config, + cors_deny, + cors_origin_1, + cors_origin_deny, + cors_permit_1, + cors_permit_12, + cors_permit_123_regex, +) +from tests.util import build_api + +from stac_fastapi.extensions.core import TokenPaginationExtension + + +def teardown_function(): + cors_clear_config() + + +def _get_api(): + return build_api([TokenPaginationExtension()]) + + +def test_with_default_cors_origin(): + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == "*" + + +def test_with_match_cors_single(): + cors_permit_1() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_origin_1 + + +def test_with_match_cors_double(): + cors_permit_12() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_origin_1 + + +def test_with_match_cors_all_regex_match(): + cors_permit_123_regex() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_origin_1 + + +def test_with_match_cors_all_regex_mismatch(): + cors_permit_123_regex() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_deny}) + assert resp.status_code == HTTPStatus.OK + assert "access-control-allow-origin" not in resp.headers + + +def test_with_mismatch_cors_origin(): + cors_deny() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert "access-control-allow-origin" not in resp.headers diff --git a/stac_fastapi/api/tests/test_route_dependencies.py b/stac_fastapi/api/tests/test_route_dependencies.py new file mode 100644 index 000000000..ba2281773 --- /dev/null +++ b/stac_fastapi/api/tests/test_route_dependencies.py @@ -0,0 +1,106 @@ +from fastapi import Depends, HTTPException, security, status +from starlette.testclient import TestClient +from tests.util import build_api + +from stac_fastapi.extensions.core import TokenPaginationExtension, TransactionExtension +from stac_fastapi.types import config, core + + +class TestRouteDependencies: + @staticmethod + def _get_extensions(): + return [ + TransactionExtension( + client=DummyTransactionsClient(), settings=config.ApiSettings() + ), + TokenPaginationExtension(), + ] + + @staticmethod + def _assert_dependency_applied(api, routes): + with TestClient(api.app) as client: + for route in routes: + response = getattr(client, route["method"].lower())(route["path"]) + assert ( + response.status_code == 401 + ), "Unauthenticated requests should be rejected" + assert response.json() == {"detail": "Not authenticated"} + + make_request = getattr(client, route["method"].lower()) + path = route["path"].format( + collectionId="test_collection", itemId="test_item" + ) + response = make_request( + path, + auth=("bob", "dobbs"), + data='{"dummy": "payload"}', + headers={"content-type": "application/json"}, + ) + assert ( + response.status_code == 200 + ), "Authenticated requests should be accepted" + assert response.json() == "dummy response" + + def test_build_api_with_route_dependencies(self): + routes = [ + {"path": "/collections", "method": "POST"}, + {"path": "/collections", "method": "PUT"}, + {"path": "/collections/{collectionId}", "method": "DELETE"}, + {"path": "/collections/{collectionId}/items", "method": "POST"}, + {"path": "/collections/{collectionId}/items", "method": "PUT"}, + {"path": "/collections/{collectionId}/items/{itemId}", "method": "DELETE"}, + ] + dependencies = [Depends(must_be_bob)] + api = build_api( + TestRouteDependencies._get_extensions(), + route_dependencies=[(routes, dependencies)], + ) + self._assert_dependency_applied(api, routes) + + def test_add_route_dependencies_after_building_api(self): + routes = [ + {"path": "/collections", "method": "POST"}, + {"path": "/collections", "method": "PUT"}, + {"path": "/collections/{collectionId}", "method": "DELETE"}, + {"path": "/collections/{collectionId}/items", "method": "POST"}, + {"path": "/collections/{collectionId}/items", "method": "PUT"}, + {"path": "/collections/{collectionId}/items/{itemId}", "method": "DELETE"}, + ] + api = build_api(TestRouteDependencies._get_extensions()) + api.add_route_dependencies(scopes=routes, dependencies=[Depends(must_be_bob)]) + self._assert_dependency_applied(api, routes) + + +class DummyTransactionsClient(core.BaseTransactionsClient): + """Defines a pattern for implementing the STAC transaction extension.""" + + def create_item(self, *args, **kwargs): + return "dummy response" + + def update_item(self, *args, **kwargs): + return "dummy response" + + def delete_item(self, *args, **kwargs): + return "dummy response" + + def create_collection(self, *args, **kwargs): + return "dummy response" + + def update_collection(self, *args, **kwargs): + return "dummy response" + + def delete_collection(self, *args, **kwargs): + return "dummy response" + + +def must_be_bob( + credentials: security.HTTPBasicCredentials = Depends(security.HTTPBasic()), +): + if credentials.username == "bob": + return True + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="You're not Bob", + headers={"WWW-Authenticate": "Basic"}, + ) diff --git a/stac_fastapi/api/tests/util.py b/stac_fastapi/api/tests/util.py new file mode 100644 index 000000000..aa4aea972 --- /dev/null +++ b/stac_fastapi/api/tests/util.py @@ -0,0 +1,37 @@ +from typing import List, Type + +from stac_fastapi.api.app import StacApi +from stac_fastapi.types import config, core +from stac_fastapi.types.extension import ApiExtension + + +class DummyCoreClient(core.BaseCoreClient): + def all_collections(self, *args, **kwargs): + ... + + def get_collection(self, *args, **kwargs): + ... + + def get_item(self, *args, **kwargs): + ... + + def get_search(self, *args, **kwargs): + ... + + def post_search(self, *args, **kwargs): + ... + + def item_collection(self, *args, **kwargs): + ... + + +def build_api(extensions: List[Type[ApiExtension]] = [], **overrides): + settings = config.ApiSettings() + return StacApi( + **{ + "settings": settings, + "client": DummyCoreClient(), + "extensions": extensions, + **overrides, + } + ) From 99fdd858db17873a555176becdc9939281d6678f Mon Sep 17 00:00:00 2001 From: captaincoordinates Date: Fri, 18 Feb 2022 12:36:30 -0800 Subject: [PATCH 23/25] feature/1 removed unnecessary tests --- stac_fastapi/pgstac/tests/api/cors_support.py | 60 ------------------- stac_fastapi/pgstac/tests/api/test_api.py | 10 ---- .../sqlalchemy/tests/api/cors_support.py | 60 ------------------- stac_fastapi/sqlalchemy/tests/api/test_api.py | 18 +----- 4 files changed, 1 insertion(+), 147 deletions(-) delete mode 100644 stac_fastapi/pgstac/tests/api/cors_support.py delete mode 100644 stac_fastapi/sqlalchemy/tests/api/cors_support.py diff --git a/stac_fastapi/pgstac/tests/api/cors_support.py b/stac_fastapi/pgstac/tests/api/cors_support.py deleted file mode 100644 index 9f9c303e6..000000000 --- a/stac_fastapi/pgstac/tests/api/cors_support.py +++ /dev/null @@ -1,60 +0,0 @@ -from copy import deepcopy -from json import dumps -from typing import Final - -from stac_fastapi.api.config import settings - -settings_fallback = deepcopy(settings) -cors_origin_1: Final = "http://permit.one" -cors_origin_2: Final = "http://permit.two" -cors_origin_3: Final = "http://permit.three" -cors_origin_deny: Final = "http://deny.me" - - -def cors_permit_1(): - settings.allow_origins = dumps((cors_origin_1,)) - - -def cors_permit_2(): - settings.allow_origins = dumps((cors_origin_2,)) - - -def cors_permit_3(): - settings.allow_origins = dumps((cors_origin_3,)) - - -def cors_permit_12(): - settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) - - -def cors_permit_123_regex(): - settings.allow_origin_regex = "http\\://permit\\..+" - - -def cors_deny(): - settings.allow_origins = dumps((cors_origin_deny,)) - - -def cors_disable_get(): - settings.allow_methods = dumps( - ( - "HEAD", - "POST", - "PUT", - "DELETE", - "CONNECT", - "OPTIONS", - "TRACE", - "PATCH", - ) - ) - - -def cors_clear_config(): - settings.allow_origins = settings_fallback.allow_origins - settings.allow_methods = settings_fallback.allow_methods - settings.allow_headers = settings_fallback.allow_headers - settings.allow_credentials = settings_fallback.allow_credentials - settings.allow_origin_regex = settings_fallback.allow_origin_regex - settings.expose_headers = settings_fallback.expose_headers - settings.max_age = settings_fallback.max_age diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index 7503bd0e0..17dbba850 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -1,16 +1,6 @@ from datetime import datetime, timedelta -from http import HTTPStatus import pytest -from tests.api.cors_support import ( - cors_clear_config, - cors_deny, - cors_origin_1, - cors_origin_deny, - cors_permit_1, - cors_permit_12, - cors_permit_123_regex, -) STAC_CORE_ROUTES = [ "GET /", diff --git a/stac_fastapi/sqlalchemy/tests/api/cors_support.py b/stac_fastapi/sqlalchemy/tests/api/cors_support.py deleted file mode 100644 index 9f9c303e6..000000000 --- a/stac_fastapi/sqlalchemy/tests/api/cors_support.py +++ /dev/null @@ -1,60 +0,0 @@ -from copy import deepcopy -from json import dumps -from typing import Final - -from stac_fastapi.api.config import settings - -settings_fallback = deepcopy(settings) -cors_origin_1: Final = "http://permit.one" -cors_origin_2: Final = "http://permit.two" -cors_origin_3: Final = "http://permit.three" -cors_origin_deny: Final = "http://deny.me" - - -def cors_permit_1(): - settings.allow_origins = dumps((cors_origin_1,)) - - -def cors_permit_2(): - settings.allow_origins = dumps((cors_origin_2,)) - - -def cors_permit_3(): - settings.allow_origins = dumps((cors_origin_3,)) - - -def cors_permit_12(): - settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) - - -def cors_permit_123_regex(): - settings.allow_origin_regex = "http\\://permit\\..+" - - -def cors_deny(): - settings.allow_origins = dumps((cors_origin_deny,)) - - -def cors_disable_get(): - settings.allow_methods = dumps( - ( - "HEAD", - "POST", - "PUT", - "DELETE", - "CONNECT", - "OPTIONS", - "TRACE", - "PATCH", - ) - ) - - -def cors_clear_config(): - settings.allow_origins = settings_fallback.allow_origins - settings.allow_methods = settings_fallback.allow_methods - settings.allow_headers = settings_fallback.allow_headers - settings.allow_credentials = settings_fallback.allow_credentials - settings.allow_origin_regex = settings_fallback.allow_origin_regex - settings.expose_headers = settings_fallback.expose_headers - settings.max_age = settings_fallback.max_age diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index 6dc8511bf..0abd7cb00 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -1,16 +1,4 @@ from datetime import datetime, timedelta -from http import HTTPStatus - -import pytest -from tests.api.cors_support import ( - cors_clear_config, - cors_deny, - cors_origin_1, - cors_origin_deny, - cors_permit_1, - cors_permit_12, - cors_permit_123_regex, -) from ..conftest import MockStarletteRequest @@ -35,10 +23,6 @@ ] -def teardown_function(): - cors_clear_config() - - def test_post_search_content_type(app_client): params = {"limit": 1} resp = app_client.post("search", json=params) @@ -319,4 +303,4 @@ def test_app_fields_extension_return_all_properties( if expected_prop in ("datetime", "created", "updated"): assert feature["properties"][expected_prop][0:19] == expected_value[0:19] else: - assert feature["properties"][expected_prop] == expected_value \ No newline at end of file + assert feature["properties"][expected_prop] == expected_value From 8699c3427c43e217fc5fbf60a486ebd624ff00c5 Mon Sep 17 00:00:00 2001 From: Nathan Zimmerman Date: Tue, 19 Apr 2022 13:32:43 -0500 Subject: [PATCH 24/25] Fix intermittent error while loading test data --- Makefile | 10 +++---- stac_fastapi/pgstac/tests/api/test_api.py | 6 ----- stac_fastapi/pgstac/tests/conftest.py | 33 +++++++++++++++-------- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/Makefile b/Makefile index d38197a3b..26174d4ab 100644 --- a/Makefile +++ b/Makefile @@ -23,19 +23,19 @@ docker-run-all: docker-compose up .PHONY: docker-run-sqlalchemy -docker-run-sqlalchemy: image +docker-run-sqlalchemy: image run-joplin-sqlalchemy $(run_sqlalchemy) .PHONY: docker-run-pgstac -docker-run-pgstac: image +docker-run-pgstac: image run-joplin-pgstac $(run_pgstac) .PHONY: docker-shell-sqlalchemy -docker-shell-sqlalchemy: +docker-shell-sqlalchemy: run-joplin-sqlalchemy $(run_sqlalchemy) /bin/bash .PHONY: docker-shell-pgstac -docker-shell-pgstac: +docker-shell-pgstac: run-joplin-pgstac $(run_pgstac) /bin/bash .PHONY: test-sqlalchemy @@ -43,7 +43,7 @@ test-sqlalchemy: run-joplin-sqlalchemy $(run_sqlalchemy) /bin/bash -c 'export && ./scripts/wait-for-it.sh database:5432 && cd /app/stac_fastapi/sqlalchemy/tests/ && pytest -vvv' .PHONY: test-pgstac -test-pgstac: +test-pgstac: run-joplin-pgstac $(run_pgstac) /bin/bash -c 'export && ./scripts/wait-for-it.sh database:5432 && cd /app/stac_fastapi/pgstac/tests/ && pytest -vvv' .PHONY: test-api diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index 17dbba850..f4d783b11 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -1,7 +1,5 @@ from datetime import datetime, timedelta -import pytest - STAC_CORE_ROUTES = [ "GET /", "GET /collections", @@ -23,10 +21,6 @@ ] -def teardown_function(): - cors_clear_config() - - async def test_post_search_content_type(app_client): params = {"limit": 1} resp = await app_client.post("search", json=params) diff --git a/stac_fastapi/pgstac/tests/conftest.py b/stac_fastapi/pgstac/tests/conftest.py index 29e7ff990..8e9000310 100644 --- a/stac_fastapi/pgstac/tests/conftest.py +++ b/stac_fastapi/pgstac/tests/conftest.py @@ -25,6 +25,7 @@ from stac_fastapi.pgstac.extensions import QueryExtension from stac_fastapi.pgstac.transactions import TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch +from stac_fastapi.types.errors import ConflictError DATA_DIR = os.path.join(os.path.dirname(__file__), "data") @@ -147,20 +148,30 @@ def load_file(filename: str) -> Dict: @pytest.fixture async def load_test_collection(app_client, load_test_data): data = load_test_data("test_collection.json") - resp = await app_client.post( - "/collections", - json=data, - ) - assert resp.status_code == 200 + try: + resp = await app_client.post( + "/collections", + json=data, + ) + assert resp.status_code == 200 + except ConflictError: + resp = await app_client.get(f"/collections/{data['id']}") + assert resp.status_code == 200 return Collection.parse_obj(resp.json()) @pytest.fixture -async def load_test_item(app_client, load_test_data, load_test_collection): +async def load_test_item(app_client, load_test_data): data = load_test_data("test_item.json") - resp = await app_client.post( - "/collections/{coll.id}/items", - json=data, - ) - assert resp.status_code == 200 + try: + resp = await app_client.post( + "/collections/{coll.id}/items", + json=data, + ) + assert resp.status_code == 200 + except ConflictError: + resp = await app_client.get( + f"/collections/{data['collection']}/items/{data['id']}" + ) + assert resp.status_code == 200 return Item.parse_obj(resp.json()) From b0e740ea32456ed84e57ab8dd0cd507150155c24 Mon Sep 17 00:00:00 2001 From: Nathan Zimmerman Date: Fri, 13 May 2022 16:31:46 -0500 Subject: [PATCH 25/25] Rename settings (#7) --- stac_fastapi/api/stac_fastapi/api/config.py | 4 +-- .../api/stac_fastapi/api/middleware.py | 24 +++++++++----- stac_fastapi/api/tests/cors_support.py | 32 +++++++++---------- 3 files changed, 34 insertions(+), 26 deletions(-) diff --git a/stac_fastapi/api/stac_fastapi/api/config.py b/stac_fastapi/api/stac_fastapi/api/config.py index 96114b1f4..988f246ba 100644 --- a/stac_fastapi/api/stac_fastapi/api/config.py +++ b/stac_fastapi/api/stac_fastapi/api/config.py @@ -30,7 +30,7 @@ class AddOns(enum.Enum): bulk_transaction = "bulk-transaction" -class Settings(BaseSettings): +class FastApiAppSettings(BaseSettings): """API settings.""" allow_origins: Sequence[str] = Field(("*",), env="cors_allow_origins") @@ -42,4 +42,4 @@ class Settings(BaseSettings): max_age: int = Field(600, env="cors_max_age") -settings: Final = Settings() +fastapi_app_settings: Final = FastApiAppSettings() diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index 8dd8be2a8..4858aeb35 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -10,7 +10,7 @@ from starlette.routing import Match from starlette.types import ASGIApp -from stac_fastapi.api.config import settings +from stac_fastapi.api.config import fastapi_app_settings logger: Final = getLogger(__file__) @@ -54,21 +54,27 @@ def __init__( ) -> None: """Create CORSMiddleware Object.""" allow_origins = ( - settings.allow_origins if allow_origins is None else allow_origins + fastapi_app_settings.allow_origins + if allow_origins is None + else allow_origins ) allow_methods = ( - settings.allow_methods if allow_methods is None else allow_methods + fastapi_app_settings.allow_methods + if allow_methods is None + else allow_methods ) allow_headers = ( - settings.allow_headers if allow_headers is None else allow_headers + fastapi_app_settings.allow_headers + if allow_headers is None + else allow_headers ) allow_credentials = ( - settings.allow_credentials + fastapi_app_settings.allow_credentials if allow_credentials is None else allow_credentials ) allow_origin_regex = ( - settings.allow_origin_regex + fastapi_app_settings.allow_origin_regex if allow_origin_regex is None else allow_origin_regex ) @@ -76,9 +82,11 @@ def __init__( logger.info("allow_origin_regex present and will override allow_origins") allow_origins = "" expose_headers = ( - settings.expose_headers if expose_headers is None else expose_headers + fastapi_app_settings.expose_headers + if expose_headers is None + else expose_headers ) - max_age = settings.max_age if max_age is None else max_age + max_age = fastapi_app_settings.max_age if max_age is None else max_age logger.debug( f""" CORS configuration diff --git a/stac_fastapi/api/tests/cors_support.py b/stac_fastapi/api/tests/cors_support.py index 9f9c303e6..15f1b7375 100644 --- a/stac_fastapi/api/tests/cors_support.py +++ b/stac_fastapi/api/tests/cors_support.py @@ -2,9 +2,9 @@ from json import dumps from typing import Final -from stac_fastapi.api.config import settings +from stac_fastapi.api.config import fastapi_app_settings -settings_fallback = deepcopy(settings) +settings_fallback = deepcopy(fastapi_app_settings) cors_origin_1: Final = "http://permit.one" cors_origin_2: Final = "http://permit.two" cors_origin_3: Final = "http://permit.three" @@ -12,31 +12,31 @@ def cors_permit_1(): - settings.allow_origins = dumps((cors_origin_1,)) + fastapi_app_settings.allow_origins = dumps((cors_origin_1,)) def cors_permit_2(): - settings.allow_origins = dumps((cors_origin_2,)) + fastapi_app_settings.allow_origins = dumps((cors_origin_2,)) def cors_permit_3(): - settings.allow_origins = dumps((cors_origin_3,)) + fastapi_app_settings.allow_origins = dumps((cors_origin_3,)) def cors_permit_12(): - settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) + fastapi_app_settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) def cors_permit_123_regex(): - settings.allow_origin_regex = "http\\://permit\\..+" + fastapi_app_settings.allow_origin_regex = "http\\://permit\\..+" def cors_deny(): - settings.allow_origins = dumps((cors_origin_deny,)) + fastapi_app_settings.allow_origins = dumps((cors_origin_deny,)) def cors_disable_get(): - settings.allow_methods = dumps( + fastapi_app_settings.allow_methods = dumps( ( "HEAD", "POST", @@ -51,10 +51,10 @@ def cors_disable_get(): def cors_clear_config(): - settings.allow_origins = settings_fallback.allow_origins - settings.allow_methods = settings_fallback.allow_methods - settings.allow_headers = settings_fallback.allow_headers - settings.allow_credentials = settings_fallback.allow_credentials - settings.allow_origin_regex = settings_fallback.allow_origin_regex - settings.expose_headers = settings_fallback.expose_headers - settings.max_age = settings_fallback.max_age + fastapi_app_settings.allow_origins = settings_fallback.allow_origins + fastapi_app_settings.allow_methods = settings_fallback.allow_methods + fastapi_app_settings.allow_headers = settings_fallback.allow_headers + fastapi_app_settings.allow_credentials = settings_fallback.allow_credentials + fastapi_app_settings.allow_origin_regex = settings_fallback.allow_origin_regex + fastapi_app_settings.expose_headers = settings_fallback.expose_headers + fastapi_app_settings.max_age = settings_fallback.max_age