Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
BuildCql2FilterMiddleware,
EnforceAuthMiddleware,
OpenApiMiddleware,
ProcessLinksMiddleware,
RemoveRootPathMiddleware,
)
from .utils.lifespan import check_conformance, check_server_health

Expand Down Expand Up @@ -67,11 +69,15 @@ async def lifespan(app: FastAPI):
app = FastAPI(
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
lifespan=lifespan,
root_path=settings.root_path,
)
if app.root_path:
logger.debug("Mounted app at %s", app.root_path)

#
# Handlers (place catch-all proxy handler last)
#

if settings.healthz_prefix:
app.include_router(
HealthzHandler(upstream_url=str(settings.upstream_url)).router,
Expand All @@ -90,6 +96,7 @@ async def lifespan(app: FastAPI):
#
# Middleware (order is important, last added = first to run)
#

if settings.enable_authentication_extension:
app.add_middleware(
AuthenticationExtensionMiddleware,
Expand All @@ -106,6 +113,7 @@ async def lifespan(app: FastAPI):
public_endpoints=settings.public_endpoints,
private_endpoints=settings.private_endpoints,
default_public=settings.default_public,
root_path=settings.root_path,
auth_scheme_name=settings.openapi_auth_scheme_name,
auth_scheme_override=settings.openapi_auth_scheme_override,
)
Expand All @@ -119,11 +127,6 @@ async def lifespan(app: FastAPI):
items_filter=settings.items_filter(),
)

if settings.enable_compression:
app.add_middleware(
CompressionMiddleware,
)

app.add_middleware(
AddProcessTimeHeaderMiddleware,
)
Expand All @@ -136,4 +139,22 @@ async def lifespan(app: FastAPI):
oidc_config_url=settings.oidc_discovery_internal_url,
)

if settings.root_path or settings.upstream_url.path != "/":
app.add_middleware(
ProcessLinksMiddleware,
upstream_url=str(settings.upstream_url),
root_path=settings.root_path,
)

if settings.root_path:
app.add_middleware(
RemoveRootPathMiddleware,
root_path=settings.root_path,
)

if settings.enable_compression:
app.add_middleware(
CompressionMiddleware,
)

return app
1 change: 1 addition & 0 deletions src/stac_auth_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Settings(BaseSettings):
oidc_discovery_url: HttpUrl
oidc_discovery_internal_url: HttpUrl

root_path: str = ""
override_host: bool = True
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
wait_for_upstream: bool = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import re
from dataclasses import dataclass, field
from itertools import chain
from typing import Any
from urllib.parse import urlparse

Expand All @@ -14,6 +13,7 @@
from ..config import EndpointMethods
from ..utils.middleware import JsonResponseMiddleware
from ..utils.requests import find_match
from ..utils.stac import get_links

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,18 +101,7 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
# auth:refs
# ---
# Annotate links with "auth:refs": [auth_scheme]
links = chain(
# Item/Collection
data.get("links", []),
# Collections/Items/Search
(
link
for prop in ["features", "collections"]
for object_with_links in data.get(prop, [])
for link in object_with_links.get("links", [])
),
)
for link in links:
for link in get_links(data):
if "href" not in link:
logger.warning("Link %s has no href", link)
continue
Expand Down
73 changes: 73 additions & 0 deletions src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Middleware to remove the application root path from incoming requests and update links in responses."""

import logging
import re
from dataclasses import dataclass
from typing import Any, Optional
from urllib.parse import urlparse, urlunparse

from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.types import ASGIApp, Scope

from ..utils.middleware import JsonResponseMiddleware
from ..utils.stac import get_links

logger = logging.getLogger(__name__)


@dataclass
class ProcessLinksMiddleware(JsonResponseMiddleware):
"""
Middleware to update links in responses, removing the upstream_url path and adding
the root_path if it exists.
"""

app: ASGIApp
upstream_url: str
root_path: Optional[str] = None

json_content_type_expr: str = r"application/(geo\+)?json"

def should_transform_response(self, request: Request, scope: Scope) -> bool:
"""Only transform responses with JSON content type."""
return bool(
re.match(
self.json_content_type_expr,
Headers(scope=scope).get("content-type", ""),
)
)

def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
"""Update links in the response to include root_path."""
for link in get_links(data):
href = link.get("href")
if not href:
continue

try:
parsed_link = urlparse(href)

# Ignore links that are not for this proxy
if parsed_link.netloc != request.headers.get("host"):
continue

# Remove the upstream_url path from the link if it exists
if urlparse(self.upstream_url).path != "/":
parsed_link = parsed_link._replace(
path=parsed_link.path[len(urlparse(self.upstream_url).path) :]
)

# Add the root_path to the link if it exists
if self.root_path:
parsed_link = parsed_link._replace(
path=f"{self.root_path}{parsed_link.path}"
)

link["href"] = urlunparse(parsed_link)
except Exception as e:
logger.error(
"Failed to parse link href %r, (ignoring): %s", href, str(e)
)

return data
45 changes: 45 additions & 0 deletions src/stac_auth_proxy/middleware/RemoveRootPathMiddleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Middleware to remove ROOT_PATH from incoming requests and update links in responses."""

import logging
from dataclasses import dataclass

from starlette.responses import Response
from starlette.types import ASGIApp, Receive, Scope, Send

logger = logging.getLogger(__name__)


@dataclass
class RemoveRootPathMiddleware:
"""
Middleware to remove the root path of the request before it is sent to the upstream
server.

IMPORTANT: This middleware must be placed early in the middleware chain (ie late in
the order of declaration) so that it trims the root_path from the request path before
any middleware that may need to use the request path (e.g. EnforceAuthMiddleware).
"""

app: ASGIApp
root_path: str

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Remove ROOT_PATH from the request path if it exists."""
if scope["type"] != "http":
return await self.app(scope, receive, send)

# If root_path is set and path doesn't start with it, return 404
if self.root_path and not scope["path"].startswith(self.root_path):
response = Response("Not Found", status_code=404)
logger.error(
f"Root path {self.root_path!r} not found in path {scope['path']!r}"
)
await response(scope, receive, send)
return

# Remove root_path if it exists at the start of the path
if scope["path"].startswith(self.root_path):
scope["raw_path"] = scope["path"].encode()
scope["path"] = scope["path"][len(self.root_path) :] or "/"

return await self.app(scope, receive, send)
8 changes: 8 additions & 0 deletions src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class OpenApiMiddleware(JsonResponseMiddleware):
private_endpoints: EndpointMethods
public_endpoints: EndpointMethods
default_public: bool
root_path: str = ""
auth_scheme_name: str = "oidcAuth"
auth_scheme_override: Optional[dict] = None

Expand All @@ -46,12 +47,19 @@ def should_transform_response(self, request: Request, scope: Scope) -> bool:

def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
"""Augment the OpenAPI spec with auth information."""
# Add servers field with root path if root_path is set
if self.root_path:
data["servers"] = [{"url": self.root_path}]

# Add security scheme
components = data.setdefault("components", {})
securitySchemes = components.setdefault("securitySchemes", {})
securitySchemes[self.auth_scheme_name] = self.auth_scheme_override or {
"type": "openIdConnect",
"openIdConnectUrl": self.oidc_config_url,
}

# Add security to private endpoints
for path, method_config in data["paths"].items():
for method, config in method_config.items():
match = find_match(
Expand Down
4 changes: 4 additions & 0 deletions src/stac_auth_proxy/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from .AuthenticationExtensionMiddleware import AuthenticationExtensionMiddleware
from .BuildCql2FilterMiddleware import BuildCql2FilterMiddleware
from .EnforceAuthMiddleware import EnforceAuthMiddleware
from .ProcessLinksMiddleware import ProcessLinksMiddleware
from .RemoveRootPathMiddleware import RemoveRootPathMiddleware
from .UpdateOpenApiMiddleware import OpenApiMiddleware

__all__ = [
Expand All @@ -13,5 +15,7 @@
"AuthenticationExtensionMiddleware",
"BuildCql2FilterMiddleware",
"EnforceAuthMiddleware",
"ProcessLinksMiddleware",
"RemoveRootPathMiddleware",
"OpenApiMiddleware",
]
18 changes: 18 additions & 0 deletions src/stac_auth_proxy/utils/stac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""STAC-specific utilities."""

from itertools import chain


def get_links(data: dict) -> chain[dict]:
"""Get all links from a STAC response."""
return chain(
# Item/Collection
data.get("links", []),
# Collections/Items/Search
(
link
for prop in ["features", "collections"]
for object_with_links in data.get(prop, [])
for link in object_with_links.get("links", [])
),
)
30 changes: 30 additions & 0 deletions tests/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,33 @@ def test_auth_scheme_override(source_api: FastAPI, source_api_server: str):
security_schemes = openapi.get("components", {}).get("securitySchemes", {})
assert "oidcAuth" in security_schemes
assert security_schemes["oidcAuth"] == custom_scheme


def test_root_path_in_openapi_spec(source_api: FastAPI, source_api_server: str):
"""When root_path is set, the OpenAPI spec includes the root path in the servers field."""
root_path = "/api/v1"
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
root_path=root_path,
)
client = TestClient(app)
response = client.get(root_path + source_api.openapi_url)
assert response.status_code == 200
openapi = response.json()
assert "servers" in openapi
assert openapi["servers"] == [{"url": root_path}]


def test_no_root_path_in_openapi_spec(source_api: FastAPI, source_api_server: str):
"""When root_path is not set, the OpenAPI spec does not include a servers field."""
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
root_path="", # Empty string means no root path
)
client = TestClient(app)
response = client.get(source_api.openapi_url)
assert response.status_code == 200
openapi = response.json()
assert "servers" not in openapi
Loading