Skip to content

Commit 240828b

Browse files
committed
Process links
1 parent 0d2fdb9 commit 240828b

File tree

5 files changed

+79
-26
lines changed

5 files changed

+79
-26
lines changed

src/stac_auth_proxy/app.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
AddProcessTimeHeaderMiddleware,
1919
ApplyCql2FilterMiddleware,
2020
AuthenticationExtensionMiddleware,
21+
BasePathMiddleware,
2122
BuildCql2FilterMiddleware,
2223
EnforceAuthMiddleware,
2324
OpenApiMiddleware,
24-
RemoveRootPathMiddleware,
2525
)
2626
from .utils.lifespan import check_conformance, check_server_health
2727

@@ -121,11 +121,6 @@ async def lifespan(app: FastAPI):
121121
items_filter=settings.items_filter(),
122122
)
123123

124-
if settings.enable_compression:
125-
app.add_middleware(
126-
CompressionMiddleware,
127-
)
128-
129124
app.add_middleware(
130125
AddProcessTimeHeaderMiddleware,
131126
)
@@ -140,8 +135,13 @@ async def lifespan(app: FastAPI):
140135

141136
if settings.root_path:
142137
app.add_middleware(
143-
RemoveRootPathMiddleware,
138+
BasePathMiddleware,
144139
base_path=settings.root_path,
145140
)
146141

142+
if settings.enable_compression:
143+
app.add_middleware(
144+
CompressionMiddleware,
145+
)
146+
147147
return app

src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import re
55
from dataclasses import dataclass, field
6-
from itertools import chain
76
from typing import Any
87
from urllib.parse import urlparse
98

@@ -14,6 +13,7 @@
1413
from ..config import EndpointMethods
1514
from ..utils.middleware import JsonResponseMiddleware
1615
from ..utils.requests import find_match
16+
from ..utils.stac import get_links
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -101,18 +101,7 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
101101
# auth:refs
102102
# ---
103103
# Annotate links with "auth:refs": [auth_scheme]
104-
links = chain(
105-
# Item/Collection
106-
data.get("links", []),
107-
# Collections/Items/Search
108-
(
109-
link
110-
for prop in ["features", "collections"]
111-
for object_with_links in data.get(prop, [])
112-
for link in object_with_links.get("links", [])
113-
),
114-
)
115-
for link in links:
104+
for link in get_links(data):
116105
if "href" not in link:
117106
logger.warning("Link %s has no href", link)
118107
continue

src/stac_auth_proxy/middleware/BasePathMiddleware.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
1-
"""Middleware to remove BASE_PATH from incoming requests."""
1+
"""Middleware to remove BASE_PATH from incoming requests and update links in responses."""
22

3+
import logging
4+
import re
35
from dataclasses import dataclass
6+
from typing import Any
7+
from urllib.parse import urlparse, urlunparse
48

9+
from starlette.datastructures import Headers
10+
from starlette.requests import Request
511
from starlette.types import ASGIApp, Receive, Scope, Send
612

13+
from ..utils.middleware import JsonResponseMiddleware
14+
from ..utils.stac import get_links
15+
16+
logger = logging.getLogger(__name__)
17+
718

819
@dataclass
9-
class RemoveRootPathMiddleware:
20+
class BasePathMiddleware(JsonResponseMiddleware):
1021
"""
11-
Middleware to remove BASE_PATH from incoming requests.
22+
Middleware to handle the base path of the request and update links in responses.
1223
1324
IMPORTANT: This middleware must be the first middleware in the chain (ie last in the
1425
order of declaration) so that it trims the base_path from the request path before
@@ -17,6 +28,9 @@ class RemoveRootPathMiddleware:
1728

1829
app: ASGIApp
1930
base_path: str
31+
transform_links: bool = True
32+
33+
json_content_type_expr: str = r"application/(geo\+)?json"
2034

2135
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2236
"""Remove BASE_PATH from the request path if it exists."""
@@ -30,4 +44,36 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3044
scope["raw_path"] = scope["path"].encode()
3145
scope["path"] = path[len(self.base_path) :] or "/"
3246

33-
return await self.app(scope, receive, send)
47+
return await super().__call__(scope, receive, send)
48+
49+
def should_transform_response(self, request: Request, scope: Scope) -> bool:
50+
"""Only transform responses with JSON content type."""
51+
return bool(
52+
re.match(
53+
self.json_content_type_expr,
54+
Headers(scope=scope).get("content-type", ""),
55+
)
56+
)
57+
58+
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
59+
"""Update links in the response to include base_path."""
60+
for link in get_links(data):
61+
href = link.get("href")
62+
if not href:
63+
continue
64+
65+
try:
66+
parsed_link = urlparse(href)
67+
68+
# Ignore links that are not for this proxy
69+
if parsed_link.netloc != request.headers.get("host"):
70+
continue
71+
72+
parsed_link = parsed_link._replace(
73+
path=f"{self.base_path}{parsed_link.path}"
74+
)
75+
link["href"] = urlunparse(parsed_link)
76+
except Exception as e:
77+
logger.warning("Failed to parse link href %s: %s", href, str(e))
78+
79+
return data

src/stac_auth_proxy/middleware/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .AddProcessTimeHeaderMiddleware import AddProcessTimeHeaderMiddleware
44
from .ApplyCql2FilterMiddleware import ApplyCql2FilterMiddleware
55
from .AuthenticationExtensionMiddleware import AuthenticationExtensionMiddleware
6-
from .BasePathMiddleware import RemoveRootPathMiddleware
6+
from .BasePathMiddleware import BasePathMiddleware
77
from .BuildCql2FilterMiddleware import BuildCql2FilterMiddleware
88
from .EnforceAuthMiddleware import EnforceAuthMiddleware
99
from .UpdateOpenApiMiddleware import OpenApiMiddleware
@@ -12,7 +12,7 @@
1212
"AddProcessTimeHeaderMiddleware",
1313
"ApplyCql2FilterMiddleware",
1414
"AuthenticationExtensionMiddleware",
15-
"RemoveRootPathMiddleware",
15+
"BasePathMiddleware",
1616
"BuildCql2FilterMiddleware",
1717
"EnforceAuthMiddleware",
1818
"OpenApiMiddleware",

src/stac_auth_proxy/utils/stac.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""STAC-specific utilities."""
2+
3+
from itertools import chain
4+
5+
6+
def get_links(data: dict) -> chain[dict]: # type: ignore
7+
"""Get all links from a STAC response."""
8+
return chain(
9+
# Item/Collection
10+
data.get("links", []),
11+
# Collections/Items/Search
12+
(
13+
link
14+
for prop in ["features", "collections"]
15+
for object_with_links in data.get(prop, [])
16+
for link in object_with_links.get("links", [])
17+
),
18+
)

0 commit comments

Comments
 (0)