diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index 66685d15..52445dfd 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -21,7 +21,7 @@ EnforceAuthMiddleware, OpenApiMiddleware, ) -from .utils.lifespan import check_server_health +from .utils.lifespan import check_conformance, check_server_health logger = logging.getLogger(__name__) @@ -40,9 +40,22 @@ async def lifespan(app: FastAPI): # Wait for upstream servers to become available if settings.wait_for_upstream: + logger.info("Running upstream server health checks...") for url in [settings.upstream_url, settings.oidc_discovery_internal_url]: await check_server_health(url=url) + # Log all middleware connected to the app + logger.debug( + "Connected middleware:\n%s", + "\n".join([f" - {m.cls.__name__}" for m in app.user_middleware]), + ) + + if settings.check_conformance: + await check_conformance( + app.user_middleware, + str(settings.upstream_url), + ) + yield app = FastAPI( @@ -88,19 +101,19 @@ async def lifespan(app: FastAPI): ) app.add_middleware( - EnforceAuthMiddleware, - public_endpoints=settings.public_endpoints, - private_endpoints=settings.private_endpoints, - default_public=settings.default_public, - oidc_config_url=settings.oidc_discovery_internal_url, + CompressionMiddleware, ) app.add_middleware( - CompressionMiddleware, + AddProcessTimeHeaderMiddleware, ) app.add_middleware( - AddProcessTimeHeaderMiddleware, + EnforceAuthMiddleware, + public_endpoints=settings.public_endpoints, + private_endpoints=settings.private_endpoints, + default_public=settings.default_public, + oidc_config_url=settings.oidc_discovery_internal_url, ) return app diff --git a/src/stac_auth_proxy/config.py b/src/stac_auth_proxy/config.py index 3748bcfd..73bcdbe8 100644 --- a/src/stac_auth_proxy/config.py +++ b/src/stac_auth_proxy/config.py @@ -39,6 +39,7 @@ class Settings(BaseSettings): oidc_discovery_internal_url: HttpUrl wait_for_upstream: bool = True + check_conformance: bool = True # Endpoints healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz") diff --git a/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py b/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py index 6d10b1f1..9499f5ff 100644 --- a/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py +++ b/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py @@ -13,10 +13,18 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send from ..utils import filters +from ..utils.middleware import required_conformance logger = getLogger(__name__) +@required_conformance( + r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", + r"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter", + r"https://api.stacspec.org/v1\.\d+\.\d+(?:-[\w\.]+)?/item-search#filter", +) @dataclass(frozen=True) class ApplyCql2FilterMiddleware: """Middleware to apply the Cql2Filter to the request.""" diff --git a/src/stac_auth_proxy/utils/lifespan.py b/src/stac_auth_proxy/utils/lifespan.py index 0cc427d0..f9432b80 100644 --- a/src/stac_auth_proxy/utils/lifespan.py +++ b/src/stac_auth_proxy/utils/lifespan.py @@ -2,9 +2,11 @@ import asyncio import logging +import re import httpx from pydantic import HttpUrl +from starlette.middleware import Middleware logger = logging.getLogger(__name__) @@ -21,14 +23,16 @@ async def check_server_health( if isinstance(url, HttpUrl): url = str(url) - async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client: + async with httpx.AsyncClient( + base_url=url, timeout=timeout, follow_redirects=True + ) as client: for attempt in range(max_retries): try: - response = await client.get(url) + response = await client.get("/") response.raise_for_status() logger.info(f"Upstream API {url!r} is healthy") return - except Exception as e: + except httpx.ConnectError as e: logger.warning(f"Upstream health check for {url!r} failed: {e}") retry_in = min(retry_delay * (2**attempt), retry_delay_max) logger.warning( @@ -40,3 +44,50 @@ async def check_server_health( raise RuntimeError( f"Upstream API {url!r} failed to respond after {max_retries} attempts" ) + + +async def check_conformance( + middleware_classes: list[Middleware], + api_url: str, + attr_name: str = "__required_conformances__", + endpoint: str = "/conformance", +): + """Check if the upstream API supports a given conformance class.""" + required_conformances: dict[str, list[str]] = {} + for middleware in middleware_classes: + + for conformance in getattr(middleware.cls, attr_name, []): + required_conformances.setdefault(conformance, []).append( + middleware.cls.__name__ + ) + + async with httpx.AsyncClient(base_url=api_url) as client: + response = await client.get(endpoint) + response.raise_for_status() + api_conforms_to = response.json().get("conformsTo", []) + + missing = [ + req_conformance + for req_conformance in required_conformances.keys() + if not any( + re.match(req_conformance, conformance) for conformance in api_conforms_to + ) + ] + + def conformance_str(conformance: str) -> str: + return f" - {conformance} [{','.join(required_conformances[conformance])}]" + + if missing: + missing_str = [conformance_str(c) for c in missing] + raise RuntimeError( + "\n".join( + [ + "Upstream catalog is missing the following conformance classes:", + *missing_str, + ] + ) + ) + logger.debug( + "Upstream catalog conforms to the following required conformance classes: \n%s", + "\n".join([conformance_str(c) for c in required_conformances]), + ) diff --git a/src/stac_auth_proxy/utils/middleware.py b/src/stac_auth_proxy/utils/middleware.py index e67d9493..814b9da5 100644 --- a/src/stac_auth_proxy/utils/middleware.py +++ b/src/stac_auth_proxy/utils/middleware.py @@ -99,3 +99,16 @@ async def transform_response(message: Message) -> None: ) return await self.app(scope, receive, transform_response) + + +def required_conformance( + *conformances: str, + attr_name: str = "__required_conformances__", +): + """Register required conformance classes with a middleware class.""" + + def decorator(middleware): + setattr(middleware, attr_name, list(conformances)) + return middleware + + return decorator diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py new file mode 100644 index 00000000..9f0a1126 --- /dev/null +++ b/tests/test_lifespan.py @@ -0,0 +1,84 @@ +"""Tests for lifespan module.""" + +from dataclasses import dataclass +from unittest.mock import patch + +import pytest +from starlette.middleware import Middleware +from starlette.types import ASGIApp + +from stac_auth_proxy.utils.lifespan import check_conformance, check_server_health +from stac_auth_proxy.utils.middleware import required_conformance + + +@required_conformance("http://example.com/conformance") +@dataclass +class TestMiddleware: + """Test middleware with required conformance.""" + + app: ASGIApp + + +async def test_check_server_health_success(source_api_server): + """Test successful health check.""" + await check_server_health(source_api_server) + + +async def test_check_server_health_failure(): + """Test health check failure.""" + with patch("asyncio.sleep") as mock_sleep: + with pytest.raises(RuntimeError) as exc_info: + await check_server_health("http://localhost:9999") + assert "failed to respond after" in str(exc_info.value) + # Verify sleep was called with exponential backoff + assert mock_sleep.call_count > 0 + # First call should be with base delay + # NOTE: When testing individually, the mock_sleep strangely has a first call of + # 0 seconds (possibly by httpx), however when running all tests, this does not + # occur. So, we have to check for 1.0 in the first two calls. + assert 1.0 in [mock_sleep.call_args_list[i][0][0] for i in range(2)] + # Last call should be with max delay + assert mock_sleep.call_args_list[-1][0][0] == 5.0 + + +async def test_check_conformance_success(source_api_server, source_api_responses): + """Test successful conformance check.""" + middleware = [Middleware(TestMiddleware)] + await check_conformance(middleware, source_api_server) + + +async def test_check_conformance_failure(source_api_server, source_api_responses): + """Test conformance check failure.""" + # Override the conformance response to not include required conformance + source_api_responses["/conformance"]["GET"] = {"conformsTo": []} + + middleware = [Middleware(TestMiddleware)] + with pytest.raises(RuntimeError) as exc_info: + await check_conformance(middleware, source_api_server) + assert "missing the following conformance classes" in str(exc_info.value) + + +async def test_check_conformance_multiple_middleware(source_api_server): + """Test conformance check with multiple middleware.""" + + @required_conformance("http://example.com/conformance") + class TestMiddleware2: + def __init__(self, app): + self.app = app + + middleware = [ + Middleware(TestMiddleware), + Middleware(TestMiddleware2), + ] + await check_conformance(middleware, source_api_server) + + +async def test_check_conformance_no_required(source_api_server): + """Test conformance check with middleware that has no required conformances.""" + + class NoConformanceMiddleware: + def __init__(self, app): + self.app = app + + middleware = [Middleware(NoConformanceMiddleware)] + await check_conformance(middleware, source_api_server)