From 410cf9a5aff75deef60b78484a659b5142b6aca6 Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Thu, 27 Mar 2025 22:34:51 -0700 Subject: [PATCH 1/8] feat: add middleware conformance checks --- src/stac_auth_proxy/app.py | 15 ++++- src/stac_auth_proxy/config.py | 1 + .../middleware/ApplyCql2FilterMiddleware.py | 8 +++ src/stac_auth_proxy/utils/lifespan.py | 57 +++++++++++++++++++ src/stac_auth_proxy/utils/middleware.py | 10 ++++ 5 files changed, 90 insertions(+), 1 deletion(-) diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index 66685d15..1b602458 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -21,7 +21,11 @@ EnforceAuthMiddleware, OpenApiMiddleware, ) -from .utils.lifespan import check_server_health +from .utils.lifespan import ( + check_conformance, + check_server_health, + log_middleware_classes, +) logger = logging.getLogger(__name__) @@ -40,9 +44,18 @@ async def lifespan(app: FastAPI): # Wait for upstream servers to become available if settings.wait_for_upstream: + logger.info("Running upstream server health checks...") for url in [settings.upstream_url, settings.oidc_discovery_internal_url]: await check_server_health(url=url) + # Log all middleware connected to the app + await log_middleware_classes(app.user_middleware) + if settings.check_conformance: + await check_conformance( + app.user_middleware, + str(settings.upstream_url), + ) + yield app = FastAPI( diff --git a/src/stac_auth_proxy/config.py b/src/stac_auth_proxy/config.py index 3748bcfd..73bcdbe8 100644 --- a/src/stac_auth_proxy/config.py +++ b/src/stac_auth_proxy/config.py @@ -39,6 +39,7 @@ class Settings(BaseSettings): oidc_discovery_internal_url: HttpUrl wait_for_upstream: bool = True + check_conformance: bool = True # Endpoints healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz") diff --git a/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py b/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py index 6d10b1f1..9499f5ff 100644 --- a/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py +++ b/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py @@ -13,10 +13,18 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send from ..utils import filters +from ..utils.middleware import required_conformance logger = getLogger(__name__) +@required_conformance( + r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", + r"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter", + r"https://api.stacspec.org/v1\.\d+\.\d+(?:-[\w\.]+)?/item-search#filter", +) @dataclass(frozen=True) class ApplyCql2FilterMiddleware: """Middleware to apply the Cql2Filter to the request.""" diff --git a/src/stac_auth_proxy/utils/lifespan.py b/src/stac_auth_proxy/utils/lifespan.py index 0cc427d0..b95a0044 100644 --- a/src/stac_auth_proxy/utils/lifespan.py +++ b/src/stac_auth_proxy/utils/lifespan.py @@ -2,9 +2,11 @@ import asyncio import logging +import re import httpx from pydantic import HttpUrl +from starlette.middleware import Middleware logger = logging.getLogger(__name__) @@ -40,3 +42,58 @@ async def check_server_health( raise RuntimeError( f"Upstream API {url!r} failed to respond after {max_retries} attempts" ) + + +async def log_middleware_classes(middleware_classes: list[Middleware]): + """Log the middleware classes connected to the application.""" + logger.debug( + "Connected middleware:\n%s", + "\n".join( + [f"- {middleware.cls.__name__}" for middleware in middleware_classes] + ), + ) + + +async def check_conformance( + middleware_classes: list[Middleware], + api_url: str, + attr_name: str = "__required_conformances__", +): + """Check if the upstream API supports a given conformance class.""" + required_conformances: dict[str, list[str]] = {} + for middleware in middleware_classes: + + for conformance in getattr(middleware.cls, attr_name, []): + required_conformances.setdefault(conformance, []).append( + middleware.cls.__name__ + ) + + async with httpx.AsyncClient() as client: + response = await client.get(api_url) + response.raise_for_status() + api_conforms_to = response.json().get("conformsTo", []) + missing = [ + req_conformance + for req_conformance in required_conformances.keys() + if not any( + re.match(req_conformance, conformance) for conformance in api_conforms_to + ) + ] + + def print_conformance(conformance): + return f" - {conformance} [{','.join(required_conformances[conformance])}]" + + if missing: + missing_str = [print_conformance(c) for c in missing] + raise RuntimeError( + "\n".join( + [ + "Upstream catalog is missing the following conformance classes:", + *missing_str, + ] + ) + ) + logger.debug( + "Upstream catalog conforms to the following required conformance classes: \n%s", + "\n".join([print_conformance(c) for c in required_conformances]), + ) diff --git a/src/stac_auth_proxy/utils/middleware.py b/src/stac_auth_proxy/utils/middleware.py index e67d9493..edb9ee9a 100644 --- a/src/stac_auth_proxy/utils/middleware.py +++ b/src/stac_auth_proxy/utils/middleware.py @@ -99,3 +99,13 @@ async def transform_response(message: Message) -> None: ) return await self.app(scope, receive, transform_response) + + +def required_conformance(*conformances: str): + """Register required conformance classes with a middleware class.""" + + def decorator(func): + func.__required_conformances__ = list(conformances) + return func + + return decorator From c5faa4abff0ea45668e7da842105cd008814a6a8 Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Fri, 28 Mar 2025 13:09:30 -0700 Subject: [PATCH 2/8] add tests --- tests/test_lifespan.py | 84 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 tests/test_lifespan.py diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py new file mode 100644 index 00000000..9f0a1126 --- /dev/null +++ b/tests/test_lifespan.py @@ -0,0 +1,84 @@ +"""Tests for lifespan module.""" + +from dataclasses import dataclass +from unittest.mock import patch + +import pytest +from starlette.middleware import Middleware +from starlette.types import ASGIApp + +from stac_auth_proxy.utils.lifespan import check_conformance, check_server_health +from stac_auth_proxy.utils.middleware import required_conformance + + +@required_conformance("http://example.com/conformance") +@dataclass +class TestMiddleware: + """Test middleware with required conformance.""" + + app: ASGIApp + + +async def test_check_server_health_success(source_api_server): + """Test successful health check.""" + await check_server_health(source_api_server) + + +async def test_check_server_health_failure(): + """Test health check failure.""" + with patch("asyncio.sleep") as mock_sleep: + with pytest.raises(RuntimeError) as exc_info: + await check_server_health("http://localhost:9999") + assert "failed to respond after" in str(exc_info.value) + # Verify sleep was called with exponential backoff + assert mock_sleep.call_count > 0 + # First call should be with base delay + # NOTE: When testing individually, the mock_sleep strangely has a first call of + # 0 seconds (possibly by httpx), however when running all tests, this does not + # occur. So, we have to check for 1.0 in the first two calls. + assert 1.0 in [mock_sleep.call_args_list[i][0][0] for i in range(2)] + # Last call should be with max delay + assert mock_sleep.call_args_list[-1][0][0] == 5.0 + + +async def test_check_conformance_success(source_api_server, source_api_responses): + """Test successful conformance check.""" + middleware = [Middleware(TestMiddleware)] + await check_conformance(middleware, source_api_server) + + +async def test_check_conformance_failure(source_api_server, source_api_responses): + """Test conformance check failure.""" + # Override the conformance response to not include required conformance + source_api_responses["/conformance"]["GET"] = {"conformsTo": []} + + middleware = [Middleware(TestMiddleware)] + with pytest.raises(RuntimeError) as exc_info: + await check_conformance(middleware, source_api_server) + assert "missing the following conformance classes" in str(exc_info.value) + + +async def test_check_conformance_multiple_middleware(source_api_server): + """Test conformance check with multiple middleware.""" + + @required_conformance("http://example.com/conformance") + class TestMiddleware2: + def __init__(self, app): + self.app = app + + middleware = [ + Middleware(TestMiddleware), + Middleware(TestMiddleware2), + ] + await check_conformance(middleware, source_api_server) + + +async def test_check_conformance_no_required(source_api_server): + """Test conformance check with middleware that has no required conformances.""" + + class NoConformanceMiddleware: + def __init__(self, app): + self.app = app + + middleware = [Middleware(NoConformanceMiddleware)] + await check_conformance(middleware, source_api_server) From 927f415069c4b8dd1c227ac18f7470212303b1df Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Fri, 28 Mar 2025 13:09:42 -0700 Subject: [PATCH 3/8] cleanup --- src/stac_auth_proxy/utils/lifespan.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/stac_auth_proxy/utils/lifespan.py b/src/stac_auth_proxy/utils/lifespan.py index b95a0044..32fca68d 100644 --- a/src/stac_auth_proxy/utils/lifespan.py +++ b/src/stac_auth_proxy/utils/lifespan.py @@ -30,7 +30,7 @@ async def check_server_health( response.raise_for_status() logger.info(f"Upstream API {url!r} is healthy") return - except Exception as e: + except httpx.ConnectError as e: logger.warning(f"Upstream health check for {url!r} failed: {e}") retry_in = min(retry_delay * (2**attempt), retry_delay_max) logger.warning( @@ -69,7 +69,7 @@ async def check_conformance( ) async with httpx.AsyncClient() as client: - response = await client.get(api_url) + response = await client.get(f"{api_url}/conformance") response.raise_for_status() api_conforms_to = response.json().get("conformsTo", []) missing = [ @@ -80,11 +80,11 @@ async def check_conformance( ) ] - def print_conformance(conformance): + def conformance_str(conformance: str) -> str: return f" - {conformance} [{','.join(required_conformances[conformance])}]" if missing: - missing_str = [print_conformance(c) for c in missing] + missing_str = [conformance_str(c) for c in missing] raise RuntimeError( "\n".join( [ @@ -95,5 +95,5 @@ def print_conformance(conformance): ) logger.debug( "Upstream catalog conforms to the following required conformance classes: \n%s", - "\n".join([print_conformance(c) for c in required_conformances]), + "\n".join([conformance_str(c) for c in required_conformances]), ) From 6cc26faa82e780e98cc961b6934e5350e4fc69ca Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Fri, 28 Mar 2025 13:10:25 -0700 Subject: [PATCH 4/8] refactor: auth middleware runs first --- src/stac_auth_proxy/app.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index 66685d15..1842612a 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -88,19 +88,19 @@ async def lifespan(app: FastAPI): ) app.add_middleware( - EnforceAuthMiddleware, - public_endpoints=settings.public_endpoints, - private_endpoints=settings.private_endpoints, - default_public=settings.default_public, - oidc_config_url=settings.oidc_discovery_internal_url, + CompressionMiddleware, ) app.add_middleware( - CompressionMiddleware, + AddProcessTimeHeaderMiddleware, ) app.add_middleware( - AddProcessTimeHeaderMiddleware, + EnforceAuthMiddleware, + public_endpoints=settings.public_endpoints, + private_endpoints=settings.private_endpoints, + default_public=settings.default_public, + oidc_config_url=settings.oidc_discovery_internal_url, ) return app From 5d36dd8c707a391255b4715d42252249ca3b6bbb Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Fri, 28 Mar 2025 15:15:07 -0700 Subject: [PATCH 5/8] slight refactor --- src/stac_auth_proxy/utils/lifespan.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/stac_auth_proxy/utils/lifespan.py b/src/stac_auth_proxy/utils/lifespan.py index 32fca68d..19659b59 100644 --- a/src/stac_auth_proxy/utils/lifespan.py +++ b/src/stac_auth_proxy/utils/lifespan.py @@ -58,6 +58,7 @@ async def check_conformance( middleware_classes: list[Middleware], api_url: str, attr_name: str = "__required_conformances__", + endpoint: str = "/conformance", ): """Check if the upstream API supports a given conformance class.""" required_conformances: dict[str, list[str]] = {} @@ -68,10 +69,11 @@ async def check_conformance( middleware.cls.__name__ ) - async with httpx.AsyncClient() as client: - response = await client.get(f"{api_url}/conformance") + async with httpx.AsyncClient(base_url=api_url) as client: + response = await client.get(endpoint) response.raise_for_status() api_conforms_to = response.json().get("conformsTo", []) + missing = [ req_conformance for req_conformance in required_conformances.keys() From cc598e7f294845767221195f3fba181a43d38d79 Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Fri, 28 Mar 2025 15:18:08 -0700 Subject: [PATCH 6/8] refactor --- src/stac_auth_proxy/utils/middleware.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/stac_auth_proxy/utils/middleware.py b/src/stac_auth_proxy/utils/middleware.py index edb9ee9a..814b9da5 100644 --- a/src/stac_auth_proxy/utils/middleware.py +++ b/src/stac_auth_proxy/utils/middleware.py @@ -101,11 +101,14 @@ async def transform_response(message: Message) -> None: return await self.app(scope, receive, transform_response) -def required_conformance(*conformances: str): +def required_conformance( + *conformances: str, + attr_name: str = "__required_conformances__", +): """Register required conformance classes with a middleware class.""" - def decorator(func): - func.__required_conformances__ = list(conformances) - return func + def decorator(middleware): + setattr(middleware, attr_name, list(conformances)) + return middleware return decorator From 3c12cda561feea8e90f97a1f116bd14f0823b45e Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Fri, 28 Mar 2025 15:21:46 -0700 Subject: [PATCH 7/8] refactor --- src/stac_auth_proxy/app.py | 12 ++++++------ src/stac_auth_proxy/utils/lifespan.py | 10 ---------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index 1b602458..28c8745c 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -21,11 +21,7 @@ EnforceAuthMiddleware, OpenApiMiddleware, ) -from .utils.lifespan import ( - check_conformance, - check_server_health, - log_middleware_classes, -) +from .utils.lifespan import check_conformance, check_server_health logger = logging.getLogger(__name__) @@ -49,7 +45,11 @@ async def lifespan(app: FastAPI): await check_server_health(url=url) # Log all middleware connected to the app - await log_middleware_classes(app.user_middleware) + logger.debug( + "Connected middleware:\n%s", + "\n".join([f" - {m.cls.__name__}" for m in app.user_middleware]), + ) + if settings.check_conformance: await check_conformance( app.user_middleware, diff --git a/src/stac_auth_proxy/utils/lifespan.py b/src/stac_auth_proxy/utils/lifespan.py index 19659b59..cb83838e 100644 --- a/src/stac_auth_proxy/utils/lifespan.py +++ b/src/stac_auth_proxy/utils/lifespan.py @@ -44,16 +44,6 @@ async def check_server_health( ) -async def log_middleware_classes(middleware_classes: list[Middleware]): - """Log the middleware classes connected to the application.""" - logger.debug( - "Connected middleware:\n%s", - "\n".join( - [f"- {middleware.cls.__name__}" for middleware in middleware_classes] - ), - ) - - async def check_conformance( middleware_classes: list[Middleware], api_url: str, From 2ef6cb3ab36db926c71e5fa684e986649235b31f Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Fri, 28 Mar 2025 16:08:50 -0700 Subject: [PATCH 8/8] refactor --- src/stac_auth_proxy/utils/lifespan.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/stac_auth_proxy/utils/lifespan.py b/src/stac_auth_proxy/utils/lifespan.py index cb83838e..f9432b80 100644 --- a/src/stac_auth_proxy/utils/lifespan.py +++ b/src/stac_auth_proxy/utils/lifespan.py @@ -23,10 +23,12 @@ async def check_server_health( if isinstance(url, HttpUrl): url = str(url) - async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client: + async with httpx.AsyncClient( + base_url=url, timeout=timeout, follow_redirects=True + ) as client: for attempt in range(max_retries): try: - response = await client.get(url) + response = await client.get("/") response.raise_for_status() logger.info(f"Upstream API {url!r} is healthy") return