Skip to content

Commit 9b2f382

Browse files
committed
Add middleware for processing links
1 parent 240828b commit 9b2f382

File tree

5 files changed

+78
-34
lines changed

5 files changed

+78
-34
lines changed

src/stac_auth_proxy/app.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
AddProcessTimeHeaderMiddleware,
1919
ApplyCql2FilterMiddleware,
2020
AuthenticationExtensionMiddleware,
21-
BasePathMiddleware,
2221
BuildCql2FilterMiddleware,
2322
EnforceAuthMiddleware,
2423
OpenApiMiddleware,
24+
ProcessLinksMiddleware,
25+
RemoveRootPathMiddleware,
2526
)
2627
from .utils.lifespan import check_conformance, check_server_health
2728

@@ -133,9 +134,16 @@ async def lifespan(app: FastAPI):
133134
oidc_config_url=settings.oidc_discovery_internal_url,
134135
)
135136

137+
if settings.root_path or settings.upstream_url.path != "/":
138+
app.add_middleware(
139+
ProcessLinksMiddleware,
140+
upstream_url=str(settings.upstream_url),
141+
base_path=settings.root_path,
142+
)
143+
136144
if settings.root_path:
137145
app.add_middleware(
138-
BasePathMiddleware,
146+
RemoveRootPathMiddleware,
139147
base_path=settings.root_path,
140148
)
141149

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
"""Middleware to remove BASE_PATH from incoming requests and update links in responses."""
1+
"""Middleware to remove the application root path from incoming requests and update links in responses."""
22

33
import logging
44
import re
55
from dataclasses import dataclass
6-
from typing import Any
6+
from typing import Any, Optional
77
from urllib.parse import urlparse, urlunparse
88

99
from starlette.datastructures import Headers
1010
from starlette.requests import Request
11-
from starlette.types import ASGIApp, Receive, Scope, Send
11+
from starlette.types import ASGIApp, Scope
1212

1313
from ..utils.middleware import JsonResponseMiddleware
1414
from ..utils.stac import get_links
@@ -17,35 +17,18 @@
1717

1818

1919
@dataclass
20-
class BasePathMiddleware(JsonResponseMiddleware):
20+
class ProcessLinksMiddleware(JsonResponseMiddleware):
2121
"""
22-
Middleware to handle the base path of the request and update links in responses.
23-
24-
IMPORTANT: This middleware must be the first middleware in the chain (ie last in the
25-
order of declaration) so that it trims the base_path from the request path before
26-
other middleware review the request.
22+
Middleware to update links in responses, removing the upstream_url path and adding
23+
the root_path if it exists.
2724
"""
2825

2926
app: ASGIApp
30-
base_path: str
31-
transform_links: bool = True
27+
upstream_url: str
28+
root_path: Optional[str] = None
3229

3330
json_content_type_expr: str = r"application/(geo\+)?json"
3431

35-
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
36-
"""Remove BASE_PATH from the request path if it exists."""
37-
if scope["type"] != "http":
38-
return await self.app(scope, receive, send)
39-
40-
path = scope["path"]
41-
42-
# Remove base_path if it exists at the start of the path
43-
if path.startswith(self.base_path):
44-
scope["raw_path"] = scope["path"].encode()
45-
scope["path"] = path[len(self.base_path) :] or "/"
46-
47-
return await super().__call__(scope, receive, send)
48-
4932
def should_transform_response(self, request: Request, scope: Scope) -> bool:
5033
"""Only transform responses with JSON content type."""
5134
return bool(
@@ -69,11 +52,22 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
6952
if parsed_link.netloc != request.headers.get("host"):
7053
continue
7154

72-
parsed_link = parsed_link._replace(
73-
path=f"{self.base_path}{parsed_link.path}"
74-
)
55+
# Remove the upstream_url path from the link if it exists
56+
if urlparse(self.upstream_url).path != "/":
57+
parsed_link = parsed_link._replace(
58+
path=parsed_link.path[len(urlparse(self.upstream_url).path) :]
59+
)
60+
61+
# Add the root_path to the link if it exists
62+
if self.root_path:
63+
parsed_link = parsed_link._replace(
64+
path=f"{self.root_path}{parsed_link.path}"
65+
)
66+
7567
link["href"] = urlunparse(parsed_link)
7668
except Exception as e:
77-
logger.warning("Failed to parse link href %s: %s", href, str(e))
69+
logger.error(
70+
"Failed to parse link href %r, (ignoring): %s", href, str(e)
71+
)
7872

7973
return data
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Middleware to remove BASE_PATH from incoming requests and update links in responses."""
2+
3+
import logging
4+
from dataclasses import dataclass
5+
6+
from starlette.types import ASGIApp, Receive, Scope, Send
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
@dataclass
12+
class RemoveRootPathMiddleware:
13+
"""
14+
Middleware to remove the base path of the request before it is sent to the upstream
15+
server.
16+
17+
IMPORTANT: This middleware must be place early in the middleware chain (ie late in the
18+
order of declaration) so that it trims the base_path from the request path before any
19+
middleware that may need to use the request path (e.g. EnforceAuthMiddleware).
20+
"""
21+
22+
app: ASGIApp
23+
base_path: str
24+
transform_links: bool = True
25+
26+
json_content_type_expr: str = r"application/(geo\+)?json"
27+
28+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
29+
"""Remove BASE_PATH from the request path if it exists."""
30+
if scope["type"] != "http":
31+
return await self.app(scope, receive, send)
32+
33+
path = scope["path"]
34+
35+
# Remove base_path if it exists at the start of the path
36+
if path.startswith(self.base_path):
37+
scope["raw_path"] = scope["path"].encode()
38+
scope["path"] = path[len(self.base_path) :] or "/"
39+
40+
return await self.app(scope, receive, send)

src/stac_auth_proxy/middleware/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
from .AddProcessTimeHeaderMiddleware import AddProcessTimeHeaderMiddleware
44
from .ApplyCql2FilterMiddleware import ApplyCql2FilterMiddleware
55
from .AuthenticationExtensionMiddleware import AuthenticationExtensionMiddleware
6-
from .BasePathMiddleware import BasePathMiddleware
76
from .BuildCql2FilterMiddleware import BuildCql2FilterMiddleware
87
from .EnforceAuthMiddleware import EnforceAuthMiddleware
8+
from .ProcessLinksMiddleware import ProcessLinksMiddleware
9+
from .RemoveRootPathMiddleware import RemoveRootPathMiddleware
910
from .UpdateOpenApiMiddleware import OpenApiMiddleware
1011

1112
__all__ = [
1213
"AddProcessTimeHeaderMiddleware",
1314
"ApplyCql2FilterMiddleware",
1415
"AuthenticationExtensionMiddleware",
15-
"BasePathMiddleware",
1616
"BuildCql2FilterMiddleware",
1717
"EnforceAuthMiddleware",
18+
"ProcessLinksMiddleware",
19+
"RemoveRootPathMiddleware",
1820
"OpenApiMiddleware",
1921
]

src/stac_auth_proxy/utils/stac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from itertools import chain
44

55

6-
def get_links(data: dict) -> chain[dict]: # type: ignore
6+
def get_links(data: dict) -> chain[dict]:
77
"""Get all links from a STAC response."""
88
return chain(
99
# Item/Collection

0 commit comments

Comments
 (0)