diff --git a/.gitignore b/.gitignore index 407d0fe9a..1cc0fbf51 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,7 @@ docs/api/* .envrc # Virtualenv -venv \ No newline at end of file +venv + +# IDE +.vscode \ No newline at end of file diff --git a/CHANGES.md b/CHANGES.md index 2355d3534..aee0b028a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,9 +4,12 @@ ### Added +* Added ability to configure CORS middleware via environment variables ([#341](https://github.com/stac-utils/stac-fastapi/pull/341)) * Add hook to allow adding dependencies to routes. ([#295](https://github.com/stac-utils/stac-fastapi/pull/295)) * Ability to POST an ItemCollection to the collections/{collectionId}/items route. ([#367](https://github.com/stac-utils/stac-fastapi/pull/367)) * Add STAC API - Collections conformance class. ([383](https://github.com/stac-utils/stac-fastapi/pull/383)) +* Added ability to configure CORS middleware via JSON configuration file and environment variable, rather than having to modify code. ([341](https://github.com/stac-utils/stac-fastapi/pull/341)) +* Added ability to configure CORS middleware via environment variables ([#341](https://github.com/stac-utils/stac-fastapi/pull/341)) ### Changed diff --git a/Makefile b/Makefile index fe2b6fe32..26174d4ab 100644 --- a/Makefile +++ b/Makefile @@ -23,19 +23,19 @@ docker-run-all: docker-compose up .PHONY: docker-run-sqlalchemy -docker-run-sqlalchemy: image +docker-run-sqlalchemy: image run-joplin-sqlalchemy $(run_sqlalchemy) .PHONY: docker-run-pgstac -docker-run-pgstac: image +docker-run-pgstac: image run-joplin-pgstac $(run_pgstac) .PHONY: docker-shell-sqlalchemy -docker-shell-sqlalchemy: +docker-shell-sqlalchemy: run-joplin-sqlalchemy $(run_sqlalchemy) /bin/bash .PHONY: docker-shell-pgstac -docker-shell-pgstac: +docker-shell-pgstac: run-joplin-pgstac $(run_pgstac) /bin/bash .PHONY: test-sqlalchemy @@ -43,9 +43,13 @@ test-sqlalchemy: run-joplin-sqlalchemy $(run_sqlalchemy) /bin/bash -c 'export && ./scripts/wait-for-it.sh database:5432 && cd /app/stac_fastapi/sqlalchemy/tests/ && pytest -vvv' .PHONY: test-pgstac -test-pgstac: +test-pgstac: run-joplin-pgstac $(run_pgstac) /bin/bash -c 'export && ./scripts/wait-for-it.sh database:5432 && cd /app/stac_fastapi/pgstac/tests/ && pytest -vvv' +.PHONY: test-api +test-api: + docker-compose run api-tester + .PHONY: run-database run-database: docker-compose run --rm database @@ -59,7 +63,7 @@ run-joplin-pgstac: docker-compose run --rm loadjoplin-pgstac .PHONY: test -test: test-sqlalchemy test-pgstac +test: test-sqlalchemy test-pgstac test-api .PHONY: pybase-install pybase-install: diff --git a/docker-compose.yml b/docker-compose.yml index 996bb6593..bc8feba62 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -118,6 +118,19 @@ services: - database - app-pgstac + api-tester: + image: stac-utils/stac-fastapi + build: + context: . + dockerfile: Dockerfile + profiles: + # prevent tester from starting with `docker-compose up` + - api-test + working_dir: /app/stac_fastapi/api + volumes: + - ./:/app + command: pytest -svvv + networks: default: name: stac-fastapi-network diff --git a/docs/tips-and-tricks.md b/docs/tips-and-tricks.md index 3d4c9ac0f..3d41bf1c3 100644 --- a/docs/tips-and-tricks.md +++ b/docs/tips-and-tricks.md @@ -5,20 +5,15 @@ This page contains a few 'tips and tricks' for getting stac-fastapi working in v CORS (Cross-Origin Resource Sharing) support may be required to use stac-fastapi in certain situations. For example, if you are running [stac-browser](https://github.com/radiantearth/stac-browser) to browse the STAC catalog created by stac-fastapi, then you will need to enable CORS support. -To do this, edit `stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py` (or the equivalent in the `pgstac` folder) and add the following import: - +To do this, configure environment variables for the configuration options described in [FastAPI docs](https://fastapi.tiangolo.com/tutorial/cors/) using a `cors_` prefix e.g. ``` -from fastapi.middleware.cors import CORSMiddleware +cors_allow_credentials=true [or 1] ``` - -and then edit the `api = StacApi(...` call to add the following parameter: - +Sequences, such as `allow_origins`, should be in JSON format e.g. ``` -middlewares=[lambda app: CORSMiddleware(app, allow_origins=["*"])] +cors_allow_origins='["http://domain.one", "http://domain.two"]' ``` -If needed, you can edit the `allow_origins` parameter to only allow CORS requests from specific origins. - ## Enable the Context extension The Context STAC extension provides information on the number of items matched and returned from a STAC search. This is required by various other STAC-related tools, such as the pystac command-line client. To enable the extension, edit `stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/app.py` (or the equivalent in the `pgstac` folder) and add the following import: diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index a9f8a5542..f1ab10e29 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -14,6 +14,7 @@ from starlette.responses import JSONResponse, Response from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers +from stac_fastapi.api.middleware import CORSMiddleware from stac_fastapi.api.models import ( APIRequest, CollectionUri, @@ -91,7 +92,9 @@ class StacApi: ) pagination_extension = attr.ib(default=TokenPaginationExtension) response_class: Type[Response] = attr.ib(default=JSONResponse) - middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware])) + middlewares: List = attr.ib( + default=attr.Factory(lambda: [BrotliMiddleware, CORSMiddleware]) + ) route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[]) def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]: diff --git a/stac_fastapi/api/stac_fastapi/api/config.py b/stac_fastapi/api/stac_fastapi/api/config.py index 3a423e45d..988f246ba 100644 --- a/stac_fastapi/api/stac_fastapi/api/config.py +++ b/stac_fastapi/api/stac_fastapi/api/config.py @@ -1,5 +1,11 @@ """Application settings.""" import enum +from logging import getLogger +from typing import Final, Sequence + +from pydantic import BaseSettings, Field + +logger: Final = getLogger(__file__) # TODO: Move to stac-pydantic @@ -22,3 +28,18 @@ class AddOns(enum.Enum): """Enumeration of available third party add ons.""" bulk_transaction = "bulk-transaction" + + +class FastApiAppSettings(BaseSettings): + """API settings.""" + + allow_origins: Sequence[str] = Field(("*",), env="cors_allow_origins") + allow_methods: Sequence[str] = Field(("*",), env="cors_allow_methods") + allow_headers: Sequence[str] = Field(("*",), env="cors_allow_headers") + allow_credentials: bool = Field(False, env="cors_allow_credentials") + allow_origin_regex: str = Field(None, env="cors_allow_origin_regex") + expose_headers: Sequence[str] = Field(("*",), env="cors_expose_headers") + max_age: int = Field(600, env="cors_max_age") + + +fastapi_app_settings: Final = FastApiAppSettings() diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index acb00915b..4858aeb35 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -1,11 +1,18 @@ """api middleware.""" -from typing import Callable +from logging import getLogger +from typing import Callable, Final, Optional, Sequence from fastapi import APIRouter, FastAPI +from fastapi.middleware import cors from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.routing import Match +from starlette.types import ASGIApp + +from stac_fastapi.api.config import fastapi_app_settings + +logger: Final = getLogger(__file__) def router_middleware(app: FastAPI, router: APIRouter): @@ -29,3 +36,76 @@ async def _middleware(request: Request, call_next): return func return deco + + +class CORSMiddleware(cors.CORSMiddleware): + """Starlette CORS Middleware with configuration.""" + + def __init__( + self, + app: ASGIApp, + allow_origins: Optional[Sequence[str]] = None, + allow_methods: Optional[Sequence[str]] = None, + allow_headers: Optional[Sequence[str]] = None, + allow_credentials: Optional[bool] = None, + allow_origin_regex: Optional[str] = None, + expose_headers: Optional[Sequence[str]] = None, + max_age: Optional[int] = None, + ) -> None: + """Create CORSMiddleware Object.""" + allow_origins = ( + fastapi_app_settings.allow_origins + if allow_origins is None + else allow_origins + ) + allow_methods = ( + fastapi_app_settings.allow_methods + if allow_methods is None + else allow_methods + ) + allow_headers = ( + fastapi_app_settings.allow_headers + if allow_headers is None + else allow_headers + ) + allow_credentials = ( + fastapi_app_settings.allow_credentials + if allow_credentials is None + else allow_credentials + ) + allow_origin_regex = ( + fastapi_app_settings.allow_origin_regex + if allow_origin_regex is None + else allow_origin_regex + ) + if allow_origin_regex is not None: + logger.info("allow_origin_regex present and will override allow_origins") + allow_origins = "" + expose_headers = ( + fastapi_app_settings.expose_headers + if expose_headers is None + else expose_headers + ) + max_age = fastapi_app_settings.max_age if max_age is None else max_age + logger.debug( + f""" + CORS configuration + allow_origins: {allow_origins} + allow_methods: {allow_methods} + allow_headers: {allow_headers} + allow_credentials: {allow_credentials} + allow_origin_regex: {allow_origin_regex} + expose_headers: {expose_headers} + max_age: {max_age} + """ + ) + super().__init__( + app, + allow_origins=allow_origins, + allow_methods=allow_methods, + allow_headers=allow_headers, + allow_credentials=allow_credentials, + allow_origin_regex=allow_origin_regex, + expose_headers=expose_headers, + max_age=max_age, + ) diff --git a/stac_fastapi/api/tests/__init__.py b/stac_fastapi/api/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/stac_fastapi/api/tests/cors_support.py b/stac_fastapi/api/tests/cors_support.py new file mode 100644 index 000000000..15f1b7375 --- /dev/null +++ b/stac_fastapi/api/tests/cors_support.py @@ -0,0 +1,60 @@ +from copy import deepcopy +from json import dumps +from typing import Final + +from stac_fastapi.api.config import fastapi_app_settings + +settings_fallback = deepcopy(fastapi_app_settings) +cors_origin_1: Final = "http://permit.one" +cors_origin_2: Final = "http://permit.two" +cors_origin_3: Final = "http://permit.three" +cors_origin_deny: Final = "http://deny.me" + + +def cors_permit_1(): + fastapi_app_settings.allow_origins = dumps((cors_origin_1,)) + + +def cors_permit_2(): + fastapi_app_settings.allow_origins = dumps((cors_origin_2,)) + + +def cors_permit_3(): + fastapi_app_settings.allow_origins = dumps((cors_origin_3,)) + + +def cors_permit_12(): + fastapi_app_settings.allow_origins = dumps((cors_origin_1, cors_origin_2)) + + +def cors_permit_123_regex(): + fastapi_app_settings.allow_origin_regex = "http\\://permit\\..+" + + +def cors_deny(): + fastapi_app_settings.allow_origins = dumps((cors_origin_deny,)) + + +def cors_disable_get(): + fastapi_app_settings.allow_methods = dumps( + ( + "HEAD", + "POST", + "PUT", + "DELETE", + "CONNECT", + "OPTIONS", + "TRACE", + "PATCH", + ) + ) + + +def cors_clear_config(): + fastapi_app_settings.allow_origins = settings_fallback.allow_origins + fastapi_app_settings.allow_methods = settings_fallback.allow_methods + fastapi_app_settings.allow_headers = settings_fallback.allow_headers + fastapi_app_settings.allow_credentials = settings_fallback.allow_credentials + fastapi_app_settings.allow_origin_regex = settings_fallback.allow_origin_regex + fastapi_app_settings.expose_headers = settings_fallback.expose_headers + fastapi_app_settings.max_age = settings_fallback.max_age diff --git a/stac_fastapi/api/tests/test_cors.py b/stac_fastapi/api/tests/test_cors.py new file mode 100644 index 000000000..6dc7a2b0e --- /dev/null +++ b/stac_fastapi/api/tests/test_cors.py @@ -0,0 +1,76 @@ +from http import HTTPStatus + +from starlette.testclient import TestClient +from tests.cors_support import ( + cors_clear_config, + cors_deny, + cors_origin_1, + cors_origin_deny, + cors_permit_1, + cors_permit_12, + cors_permit_123_regex, +) +from tests.util import build_api + +from stac_fastapi.extensions.core import TokenPaginationExtension + + +def teardown_function(): + cors_clear_config() + + +def _get_api(): + return build_api([TokenPaginationExtension()]) + + +def test_with_default_cors_origin(): + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == "*" + + +def test_with_match_cors_single(): + cors_permit_1() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_origin_1 + + +def test_with_match_cors_double(): + cors_permit_12() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_origin_1 + + +def test_with_match_cors_all_regex_match(): + cors_permit_123_regex() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert resp.headers["access-control-allow-origin"] == cors_origin_1 + + +def test_with_match_cors_all_regex_mismatch(): + cors_permit_123_regex() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_deny}) + assert resp.status_code == HTTPStatus.OK + assert "access-control-allow-origin" not in resp.headers + + +def test_with_mismatch_cors_origin(): + cors_deny() + api = _get_api() + with TestClient(api.app) as client: + resp = client.get("/conformance", headers={"Origin": cors_origin_1}) + assert resp.status_code == HTTPStatus.OK + assert "access-control-allow-origin" not in resp.headers diff --git a/stac_fastapi/api/tests/test_route_dependencies.py b/stac_fastapi/api/tests/test_route_dependencies.py new file mode 100644 index 000000000..ba2281773 --- /dev/null +++ b/stac_fastapi/api/tests/test_route_dependencies.py @@ -0,0 +1,106 @@ +from fastapi import Depends, HTTPException, security, status +from starlette.testclient import TestClient +from tests.util import build_api + +from stac_fastapi.extensions.core import TokenPaginationExtension, TransactionExtension +from stac_fastapi.types import config, core + + +class TestRouteDependencies: + @staticmethod + def _get_extensions(): + return [ + TransactionExtension( + client=DummyTransactionsClient(), settings=config.ApiSettings() + ), + TokenPaginationExtension(), + ] + + @staticmethod + def _assert_dependency_applied(api, routes): + with TestClient(api.app) as client: + for route in routes: + response = getattr(client, route["method"].lower())(route["path"]) + assert ( + response.status_code == 401 + ), "Unauthenticated requests should be rejected" + assert response.json() == {"detail": "Not authenticated"} + + make_request = getattr(client, route["method"].lower()) + path = route["path"].format( + collectionId="test_collection", itemId="test_item" + ) + response = make_request( + path, + auth=("bob", "dobbs"), + data='{"dummy": "payload"}', + headers={"content-type": "application/json"}, + ) + assert ( + response.status_code == 200 + ), "Authenticated requests should be accepted" + assert response.json() == "dummy response" + + def test_build_api_with_route_dependencies(self): + routes = [ + {"path": "/collections", "method": "POST"}, + {"path": "/collections", "method": "PUT"}, + {"path": "/collections/{collectionId}", "method": "DELETE"}, + {"path": "/collections/{collectionId}/items", "method": "POST"}, + {"path": "/collections/{collectionId}/items", "method": "PUT"}, + {"path": "/collections/{collectionId}/items/{itemId}", "method": "DELETE"}, + ] + dependencies = [Depends(must_be_bob)] + api = build_api( + TestRouteDependencies._get_extensions(), + route_dependencies=[(routes, dependencies)], + ) + self._assert_dependency_applied(api, routes) + + def test_add_route_dependencies_after_building_api(self): + routes = [ + {"path": "/collections", "method": "POST"}, + {"path": "/collections", "method": "PUT"}, + {"path": "/collections/{collectionId}", "method": "DELETE"}, + {"path": "/collections/{collectionId}/items", "method": "POST"}, + {"path": "/collections/{collectionId}/items", "method": "PUT"}, + {"path": "/collections/{collectionId}/items/{itemId}", "method": "DELETE"}, + ] + api = build_api(TestRouteDependencies._get_extensions()) + api.add_route_dependencies(scopes=routes, dependencies=[Depends(must_be_bob)]) + self._assert_dependency_applied(api, routes) + + +class DummyTransactionsClient(core.BaseTransactionsClient): + """Defines a pattern for implementing the STAC transaction extension.""" + + def create_item(self, *args, **kwargs): + return "dummy response" + + def update_item(self, *args, **kwargs): + return "dummy response" + + def delete_item(self, *args, **kwargs): + return "dummy response" + + def create_collection(self, *args, **kwargs): + return "dummy response" + + def update_collection(self, *args, **kwargs): + return "dummy response" + + def delete_collection(self, *args, **kwargs): + return "dummy response" + + +def must_be_bob( + credentials: security.HTTPBasicCredentials = Depends(security.HTTPBasic()), +): + if credentials.username == "bob": + return True + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="You're not Bob", + headers={"WWW-Authenticate": "Basic"}, + ) diff --git a/stac_fastapi/api/tests/util.py b/stac_fastapi/api/tests/util.py new file mode 100644 index 000000000..aa4aea972 --- /dev/null +++ b/stac_fastapi/api/tests/util.py @@ -0,0 +1,37 @@ +from typing import List, Type + +from stac_fastapi.api.app import StacApi +from stac_fastapi.types import config, core +from stac_fastapi.types.extension import ApiExtension + + +class DummyCoreClient(core.BaseCoreClient): + def all_collections(self, *args, **kwargs): + ... + + def get_collection(self, *args, **kwargs): + ... + + def get_item(self, *args, **kwargs): + ... + + def get_search(self, *args, **kwargs): + ... + + def post_search(self, *args, **kwargs): + ... + + def item_collection(self, *args, **kwargs): + ... + + +def build_api(extensions: List[Type[ApiExtension]] = [], **overrides): + settings = config.ApiSettings() + return StacApi( + **{ + "settings": settings, + "client": DummyCoreClient(), + "extensions": extensions, + **overrides, + } + ) diff --git a/stac_fastapi/pgstac/tests/conftest.py b/stac_fastapi/pgstac/tests/conftest.py index 170877a7d..8e9000310 100644 --- a/stac_fastapi/pgstac/tests/conftest.py +++ b/stac_fastapi/pgstac/tests/conftest.py @@ -1,7 +1,6 @@ import asyncio import json import os -import time from typing import Callable, Dict import asyncpg @@ -26,6 +25,7 @@ from stac_fastapi.pgstac.extensions import QueryExtension from stac_fastapi.pgstac.transactions import TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch +from stac_fastapi.types.errors import ConflictError DATA_DIR = os.path.join(os.path.dirname(__file__), "data") @@ -92,8 +92,7 @@ async def pgstac(pg): await conn.close() -@pytest.fixture(scope="session") -def api_client(pg): +def _api_client_provider(): print("creating client with settings") extensions = [ @@ -119,20 +118,22 @@ def api_client(pg): @pytest.fixture(scope="session") -async def app(api_client): - time.time() - app = api_client.app - await connect_to_db(app) - - yield app - - await close_db_connection(app) +def api_client(pg): + return _api_client_provider() @pytest.fixture(scope="session") -async def app_client(app): +async def app_client(pg, request): + # support custom behaviours driven by fixture caller + if hasattr(request, "param"): + setup_func = request.param.get("setup_func") + if setup_func is not None: + setup_func() + app = _api_client_provider().app async with AsyncClient(app=app, base_url="http://test") as c: + await connect_to_db(app) yield c + await close_db_connection(app) @pytest.fixture @@ -147,20 +148,30 @@ def load_file(filename: str) -> Dict: @pytest.fixture async def load_test_collection(app_client, load_test_data): data = load_test_data("test_collection.json") - resp = await app_client.post( - "/collections", - json=data, - ) - assert resp.status_code == 200 + try: + resp = await app_client.post( + "/collections", + json=data, + ) + assert resp.status_code == 200 + except ConflictError: + resp = await app_client.get(f"/collections/{data['id']}") + assert resp.status_code == 200 return Collection.parse_obj(resp.json()) @pytest.fixture -async def load_test_item(app_client, load_test_data, load_test_collection): +async def load_test_item(app_client, load_test_data): data = load_test_data("test_item.json") - resp = await app_client.post( - "/collections/{coll.id}/items", - json=data, - ) - assert resp.status_code == 200 + try: + resp = await app_client.post( + "/collections/{coll.id}/items", + json=data, + ) + assert resp.status_code == 200 + except ConflictError: + resp = await app_client.get( + f"/collections/{data['collection']}/items/{data['id']}" + ) + assert resp.status_code == 200 return Item.parse_obj(resp.json()) diff --git a/stac_fastapi/sqlalchemy/tests/conftest.py b/stac_fastapi/sqlalchemy/tests/conftest.py index 7abd9150f..1ef44f62b 100644 --- a/stac_fastapi/sqlalchemy/tests/conftest.py +++ b/stac_fastapi/sqlalchemy/tests/conftest.py @@ -104,8 +104,7 @@ def postgres_bulk_transactions(db_session): return BulkTransactionsClient(session=db_session) -@pytest.fixture -def api_client(db_session): +def _api_client_provider(db_session): settings = SqlalchemySettings() extensions = [ TransactionExtension( @@ -146,9 +145,19 @@ def api_client(db_session): @pytest.fixture -def app_client(api_client, load_test_data, postgres_transactions): +def api_client(db_session): + return _api_client_provider(db_session) + + +@pytest.fixture +def app_client(db_session, load_test_data, postgres_transactions, request): + # support custom behaviours driven by fixture caller + if hasattr(request, "param"): + setup_func = request.param.get("setup_func") + if setup_func is not None: + setup_func() coll = load_test_data("test_collection.json") postgres_transactions.create_collection(coll, request=MockStarletteRequest) - with TestClient(api_client.app) as test_app: + with TestClient(_api_client_provider(db_session).app) as test_app: yield test_app