diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index 1d02612c..107689fc 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -21,6 +21,8 @@ BuildCql2FilterMiddleware, EnforceAuthMiddleware, OpenApiMiddleware, + ProcessLinksMiddleware, + RemoveRootPathMiddleware, ) from .utils.lifespan import check_conformance, check_server_health @@ -67,11 +69,15 @@ async def lifespan(app: FastAPI): app = FastAPI( openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema lifespan=lifespan, + root_path=settings.root_path, ) + if app.root_path: + logger.debug("Mounted app at %s", app.root_path) # # Handlers (place catch-all proxy handler last) # + if settings.healthz_prefix: app.include_router( HealthzHandler(upstream_url=str(settings.upstream_url)).router, @@ -90,6 +96,7 @@ async def lifespan(app: FastAPI): # # Middleware (order is important, last added = first to run) # + if settings.enable_authentication_extension: app.add_middleware( AuthenticationExtensionMiddleware, @@ -106,6 +113,7 @@ async def lifespan(app: FastAPI): public_endpoints=settings.public_endpoints, private_endpoints=settings.private_endpoints, default_public=settings.default_public, + root_path=settings.root_path, auth_scheme_name=settings.openapi_auth_scheme_name, auth_scheme_override=settings.openapi_auth_scheme_override, ) @@ -119,11 +127,6 @@ async def lifespan(app: FastAPI): items_filter=settings.items_filter(), ) - if settings.enable_compression: - app.add_middleware( - CompressionMiddleware, - ) - app.add_middleware( AddProcessTimeHeaderMiddleware, ) @@ -136,4 +139,22 @@ async def lifespan(app: FastAPI): oidc_config_url=settings.oidc_discovery_internal_url, ) + if settings.root_path or settings.upstream_url.path != "/": + app.add_middleware( + ProcessLinksMiddleware, + upstream_url=str(settings.upstream_url), + root_path=settings.root_path, + ) + + if settings.root_path: + app.add_middleware( + RemoveRootPathMiddleware, + root_path=settings.root_path, + ) + + if settings.enable_compression: + app.add_middleware( + CompressionMiddleware, + ) + return app diff --git a/src/stac_auth_proxy/config.py b/src/stac_auth_proxy/config.py index 1515c0e1..6aea9eb1 100644 --- a/src/stac_auth_proxy/config.py +++ b/src/stac_auth_proxy/config.py @@ -38,6 +38,7 @@ class Settings(BaseSettings): oidc_discovery_url: HttpUrl oidc_discovery_internal_url: HttpUrl + root_path: str = "" override_host: bool = True healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz") wait_for_upstream: bool = True diff --git a/src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py b/src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py index dcbc5dd2..23a9df5d 100644 --- a/src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py +++ b/src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py @@ -3,7 +3,6 @@ import logging import re from dataclasses import dataclass, field -from itertools import chain from typing import Any from urllib.parse import urlparse @@ -14,6 +13,7 @@ from ..config import EndpointMethods from ..utils.middleware import JsonResponseMiddleware from ..utils.requests import find_match +from ..utils.stac import get_links logger = logging.getLogger(__name__) @@ -101,18 +101,7 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An # auth:refs # --- # Annotate links with "auth:refs": [auth_scheme] - links = chain( - # Item/Collection - data.get("links", []), - # Collections/Items/Search - ( - link - for prop in ["features", "collections"] - for object_with_links in data.get(prop, []) - for link in object_with_links.get("links", []) - ), - ) - for link in links: + for link in get_links(data): if "href" not in link: logger.warning("Link %s has no href", link) continue diff --git a/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py b/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py new file mode 100644 index 00000000..9712b045 --- /dev/null +++ b/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py @@ -0,0 +1,73 @@ +"""Middleware to remove the application root path from incoming requests and update links in responses.""" + +import logging +import re +from dataclasses import dataclass +from typing import Any, Optional +from urllib.parse import urlparse, urlunparse + +from starlette.datastructures import Headers +from starlette.requests import Request +from starlette.types import ASGIApp, Scope + +from ..utils.middleware import JsonResponseMiddleware +from ..utils.stac import get_links + +logger = logging.getLogger(__name__) + + +@dataclass +class ProcessLinksMiddleware(JsonResponseMiddleware): + """ + Middleware to update links in responses, removing the upstream_url path and adding + the root_path if it exists. + """ + + app: ASGIApp + upstream_url: str + root_path: Optional[str] = None + + json_content_type_expr: str = r"application/(geo\+)?json" + + def should_transform_response(self, request: Request, scope: Scope) -> bool: + """Only transform responses with JSON content type.""" + return bool( + re.match( + self.json_content_type_expr, + Headers(scope=scope).get("content-type", ""), + ) + ) + + def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]: + """Update links in the response to include root_path.""" + for link in get_links(data): + href = link.get("href") + if not href: + continue + + try: + parsed_link = urlparse(href) + + # Ignore links that are not for this proxy + if parsed_link.netloc != request.headers.get("host"): + continue + + # Remove the upstream_url path from the link if it exists + if urlparse(self.upstream_url).path != "/": + parsed_link = parsed_link._replace( + path=parsed_link.path[len(urlparse(self.upstream_url).path) :] + ) + + # Add the root_path to the link if it exists + if self.root_path: + parsed_link = parsed_link._replace( + path=f"{self.root_path}{parsed_link.path}" + ) + + link["href"] = urlunparse(parsed_link) + except Exception as e: + logger.error( + "Failed to parse link href %r, (ignoring): %s", href, str(e) + ) + + return data diff --git a/src/stac_auth_proxy/middleware/RemoveRootPathMiddleware.py b/src/stac_auth_proxy/middleware/RemoveRootPathMiddleware.py new file mode 100644 index 00000000..a2874ce3 --- /dev/null +++ b/src/stac_auth_proxy/middleware/RemoveRootPathMiddleware.py @@ -0,0 +1,45 @@ +"""Middleware to remove ROOT_PATH from incoming requests and update links in responses.""" + +import logging +from dataclasses import dataclass + +from starlette.responses import Response +from starlette.types import ASGIApp, Receive, Scope, Send + +logger = logging.getLogger(__name__) + + +@dataclass +class RemoveRootPathMiddleware: + """ + Middleware to remove the root path of the request before it is sent to the upstream + server. + + IMPORTANT: This middleware must be placed early in the middleware chain (ie late in + the order of declaration) so that it trims the root_path from the request path before + any middleware that may need to use the request path (e.g. EnforceAuthMiddleware). + """ + + app: ASGIApp + root_path: str + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Remove ROOT_PATH from the request path if it exists.""" + if scope["type"] != "http": + return await self.app(scope, receive, send) + + # If root_path is set and path doesn't start with it, return 404 + if self.root_path and not scope["path"].startswith(self.root_path): + response = Response("Not Found", status_code=404) + logger.error( + f"Root path {self.root_path!r} not found in path {scope['path']!r}" + ) + await response(scope, receive, send) + return + + # Remove root_path if it exists at the start of the path + if scope["path"].startswith(self.root_path): + scope["raw_path"] = scope["path"].encode() + scope["path"] = scope["path"][len(self.root_path) :] or "/" + + return await self.app(scope, receive, send) diff --git a/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py b/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py index 16c1240a..7f7e64d9 100644 --- a/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py +++ b/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py @@ -23,6 +23,7 @@ class OpenApiMiddleware(JsonResponseMiddleware): private_endpoints: EndpointMethods public_endpoints: EndpointMethods default_public: bool + root_path: str = "" auth_scheme_name: str = "oidcAuth" auth_scheme_override: Optional[dict] = None @@ -46,12 +47,19 @@ def should_transform_response(self, request: Request, scope: Scope) -> bool: def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]: """Augment the OpenAPI spec with auth information.""" + # Add servers field with root path if root_path is set + if self.root_path: + data["servers"] = [{"url": self.root_path}] + + # Add security scheme components = data.setdefault("components", {}) securitySchemes = components.setdefault("securitySchemes", {}) securitySchemes[self.auth_scheme_name] = self.auth_scheme_override or { "type": "openIdConnect", "openIdConnectUrl": self.oidc_config_url, } + + # Add security to private endpoints for path, method_config in data["paths"].items(): for method, config in method_config.items(): match = find_match( diff --git a/src/stac_auth_proxy/middleware/__init__.py b/src/stac_auth_proxy/middleware/__init__.py index 08aa3718..6ad1875d 100644 --- a/src/stac_auth_proxy/middleware/__init__.py +++ b/src/stac_auth_proxy/middleware/__init__.py @@ -5,6 +5,8 @@ from .AuthenticationExtensionMiddleware import AuthenticationExtensionMiddleware from .BuildCql2FilterMiddleware import BuildCql2FilterMiddleware from .EnforceAuthMiddleware import EnforceAuthMiddleware +from .ProcessLinksMiddleware import ProcessLinksMiddleware +from .RemoveRootPathMiddleware import RemoveRootPathMiddleware from .UpdateOpenApiMiddleware import OpenApiMiddleware __all__ = [ @@ -13,5 +15,7 @@ "AuthenticationExtensionMiddleware", "BuildCql2FilterMiddleware", "EnforceAuthMiddleware", + "ProcessLinksMiddleware", + "RemoveRootPathMiddleware", "OpenApiMiddleware", ] diff --git a/src/stac_auth_proxy/utils/stac.py b/src/stac_auth_proxy/utils/stac.py new file mode 100644 index 00000000..53cd97c1 --- /dev/null +++ b/src/stac_auth_proxy/utils/stac.py @@ -0,0 +1,18 @@ +"""STAC-specific utilities.""" + +from itertools import chain + + +def get_links(data: dict) -> chain[dict]: + """Get all links from a STAC response.""" + return chain( + # Item/Collection + data.get("links", []), + # Collections/Items/Search + ( + link + for prop in ["features", "collections"] + for object_with_links in data.get(prop, []) + for link in object_with_links.get("links", []) + ), + ) diff --git a/tests/test_openapi.py b/tests/test_openapi.py index e0ad990f..cef00d54 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -190,3 +190,33 @@ def test_auth_scheme_override(source_api: FastAPI, source_api_server: str): security_schemes = openapi.get("components", {}).get("securitySchemes", {}) assert "oidcAuth" in security_schemes assert security_schemes["oidcAuth"] == custom_scheme + + +def test_root_path_in_openapi_spec(source_api: FastAPI, source_api_server: str): + """When root_path is set, the OpenAPI spec includes the root path in the servers field.""" + root_path = "/api/v1" + app = app_factory( + upstream_url=source_api_server, + openapi_spec_endpoint=source_api.openapi_url, + root_path=root_path, + ) + client = TestClient(app) + response = client.get(root_path + source_api.openapi_url) + assert response.status_code == 200 + openapi = response.json() + assert "servers" in openapi + assert openapi["servers"] == [{"url": root_path}] + + +def test_no_root_path_in_openapi_spec(source_api: FastAPI, source_api_server: str): + """When root_path is not set, the OpenAPI spec does not include a servers field.""" + app = app_factory( + upstream_url=source_api_server, + openapi_spec_endpoint=source_api.openapi_url, + root_path="", # Empty string means no root path + ) + client = TestClient(app) + response = client.get(source_api.openapi_url) + assert response.status_code == 200 + openapi = response.json() + assert "servers" not in openapi diff --git a/tests/test_process_links.py b/tests/test_process_links.py new file mode 100644 index 00000000..ff382093 --- /dev/null +++ b/tests/test_process_links.py @@ -0,0 +1,188 @@ +"""Tests for ProcessLinksMiddleware.""" + +import pytest +from starlette.requests import Request + +from stac_auth_proxy.middleware.ProcessLinksMiddleware import ProcessLinksMiddleware + + +@pytest.fixture +def middleware(): + """Create a test instance of the middleware.""" + return ProcessLinksMiddleware( + app=None, # We don't need the actual app for these tests + upstream_url="http://upstream.example.com/api", + root_path="/proxy", + ) + + +@pytest.fixture +def request_scope(): + """Create a test request scope.""" + return { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"proxy.example.com"), + (b"content-type", b"application/json"), + ], + } + + +def test_should_transform_response_json(middleware, request_scope): + """Test that JSON responses are transformed.""" + request = Request(request_scope) + assert middleware.should_transform_response(request, request_scope) + + +def test_should_transform_response_geojson(middleware, request_scope): + """Test that GeoJSON responses are transformed.""" + request_scope["headers"] = [ + (b"host", b"proxy.example.com"), + (b"content-type", b"application/geo+json"), + ] + request = Request(request_scope) + assert middleware.should_transform_response(request, request_scope) + + +def test_should_transform_response_non_json(middleware, request_scope): + """Test that non-JSON responses are not transformed.""" + request_scope["headers"] = [ + (b"host", b"proxy.example.com"), + (b"content-type", b"text/plain"), + ] + request = Request(request_scope) + assert not middleware.should_transform_response(request, request_scope) + + +def test_transform_json_with_upstream_path(middleware, request_scope): + """Test transforming links with upstream URL path.""" + request = Request(request_scope) + + data = { + "links": [ + {"rel": "self", "href": "http://proxy.example.com/api/collections"}, + {"rel": "root", "href": "http://proxy.example.com/api"}, + ] + } + + transformed = middleware.transform_json(data, request) + + assert ( + transformed["links"][0]["href"] == "http://proxy.example.com/proxy/collections" + ) + assert transformed["links"][1]["href"] == "http://proxy.example.com/proxy" + + +def test_transform_json_without_upstream_path(middleware, request_scope): + """Test transforming links without upstream URL path.""" + middleware = ProcessLinksMiddleware( + app=None, upstream_url="http://upstream.example.com", root_path="/proxy" + ) + request = Request(request_scope) + + data = { + "links": [ + {"rel": "self", "href": "http://proxy.example.com/collections"}, + {"rel": "root", "href": "http://proxy.example.com/"}, + ] + } + + transformed = middleware.transform_json(data, request) + + assert ( + transformed["links"][0]["href"] == "http://proxy.example.com/proxy/collections" + ) + assert transformed["links"][1]["href"] == "http://proxy.example.com/proxy/" + + +def test_transform_json_without_root_path(middleware, request_scope): + """Test transforming links without root path.""" + middleware = ProcessLinksMiddleware( + app=None, upstream_url="http://upstream.example.com/api", root_path=None + ) + request = Request(request_scope) + + data = { + "links": [ + {"rel": "self", "href": "http://proxy.example.com/api/collections"}, + {"rel": "root", "href": "http://proxy.example.com/api"}, + ] + } + + transformed = middleware.transform_json(data, request) + + assert transformed["links"][0]["href"] == "http://proxy.example.com/collections" + assert transformed["links"][1]["href"] == "http://proxy.example.com" + + +def test_transform_json_different_host(middleware, request_scope): + """Test that links with different hostnames are not transformed.""" + request = Request(request_scope) + + data = { + "links": [ + {"rel": "self", "href": "http://other.example.com/api/collections"}, + {"rel": "root", "href": "http://other.example.com/api"}, + ] + } + + transformed = middleware.transform_json(data, request) + + assert transformed["links"][0]["href"] == "http://other.example.com/api/collections" + assert transformed["links"][1]["href"] == "http://other.example.com/api" + + +def test_transform_json_invalid_link(middleware, request_scope): + """Test that invalid links are handled gracefully.""" + request = Request(request_scope) + + data = { + "links": [ + {"rel": "self", "href": "not-a-url"}, + {"rel": "root", "href": "http://proxy.example.com/api"}, + ] + } + + transformed = middleware.transform_json(data, request) + + assert transformed["links"][0]["href"] == "not-a-url" + assert transformed["links"][1]["href"] == "http://proxy.example.com/proxy" + + +def test_transform_json_nested_links(middleware, request_scope): + """Test transforming links in nested STAC objects.""" + request = Request(request_scope) + + data = { + "links": [ + {"rel": "self", "href": "http://proxy.example.com/api"}, + ], + "collections": [ + { + "id": "test-collection", + "links": [ + { + "rel": "self", + "href": "http://proxy.example.com/api/collections/test-collection", + }, + { + "rel": "items", + "href": "http://proxy.example.com/api/collections/test-collection/items", + }, + ], + } + ], + } + + transformed = middleware.transform_json(data, request) + + assert transformed["links"][0]["href"] == "http://proxy.example.com/proxy" + assert ( + transformed["collections"][0]["links"][0]["href"] + == "http://proxy.example.com/proxy/collections/test-collection" + ) + assert ( + transformed["collections"][0]["links"][1]["href"] + == "http://proxy.example.com/proxy/collections/test-collection/items" + ) diff --git a/tests/test_remove_root_path.py b/tests/test_remove_root_path.py new file mode 100644 index 00000000..d99b10e0 --- /dev/null +++ b/tests/test_remove_root_path.py @@ -0,0 +1,93 @@ +"""Tests for RemoveRootPathMiddleware.""" + +import pytest +from fastapi import FastAPI +from starlette.testclient import TestClient +from starlette.types import Receive, Scope, Send + +from stac_auth_proxy.middleware.RemoveRootPathMiddleware import RemoveRootPathMiddleware + + +class MockASGIApp: + """Mock ASGI application for testing.""" + + def __init__(self): + """Initialize the mock app.""" + self.called = False + self.scope = None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Mock ASGI call.""" + self.called = True + self.scope = scope + + +@pytest.mark.asyncio +async def test_remove_root_path_middleware(): + """Test that root path is removed from request path.""" + mock_app = MockASGIApp() + middleware = RemoveRootPathMiddleware(mock_app, root_path="/api") + + # Test with root path + scope = { + "type": "http", + "path": "/api/test", + "raw_path": b"/api/test", + } + await middleware(scope, None, None) + assert mock_app.called + assert mock_app.scope["path"] == "/test" + assert mock_app.scope["raw_path"] == b"/api/test" + + +@pytest.mark.asyncio +async def test_remove_root_path_middleware_non_http(): + """Test that non-HTTP requests are passed through unchanged.""" + mock_app = MockASGIApp() + middleware = RemoveRootPathMiddleware(mock_app, root_path="/api") + + scope = { + "type": "websocket", + "path": "/api/test", + } + await middleware(scope, None, None) + assert mock_app.called + assert mock_app.scope["path"] == "/api/test" + + +@pytest.mark.asyncio +async def test_remove_root_path_middleware_empty_path(): + """Test that empty path after root path removal is set to '/'.""" + mock_app = MockASGIApp() + middleware = RemoveRootPathMiddleware(mock_app, root_path="/api") + + scope = { + "type": "http", + "path": "/api", + "raw_path": b"/api", + } + await middleware(scope, None, None) + assert mock_app.called + assert mock_app.scope["path"] == "/" + assert mock_app.scope["raw_path"] == b"/api" + + +def test_remove_root_path_middleware_integration(): + """Test middleware integration with FastAPI.""" + app = FastAPI() + app.add_middleware(RemoveRootPathMiddleware, root_path="/api") + + @app.get("/test") + async def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + + # Test with root path + response = client.get("/api/test") + assert response.status_code == 200 + assert response.json() == {"message": "test"} + + # Test without root path + response = client.get("/test") + assert response.status_code == 404 # Should not find the endpoint