From 7a4e9843d6a7790e68117429a31378864da09e65 Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Mon, 21 Jul 2025 22:49:16 -0700 Subject: [PATCH 1/6] fix: handle empty search body --- src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py b/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py index aa4f8d58..871fe6e8 100644 --- a/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py +++ b/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py @@ -89,7 +89,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # Modify body try: - body = json.loads(body) + body = json.loads(body) if body else {} except json.JSONDecodeError as e: logger.warning("Failed to parse request body as JSON") # TODO: Return a 400 error From 7e4729ab43194232d89c5870f6b1f68abbe60401 Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Mon, 21 Jul 2025 23:27:19 -0700 Subject: [PATCH 2/6] refactor: breakup CQL2 middleware --- src/stac_auth_proxy/app.py | 14 +- .../middleware/ApplyCql2FilterMiddleware.py | 202 ------------------ .../Cql2ApplyFilterBodyMiddleware.py | 98 +++++++++ .../Cql2ApplyFilterQueryStringMiddleware.py | 56 +++++ ...leware.py => Cql2BuildFilterMiddleware.py} | 2 +- .../Cql2ValidateResponseBodyMiddleware.py | 133 ++++++++++++ src/stac_auth_proxy/middleware/__init__.py | 14 +- 7 files changed, 305 insertions(+), 214 deletions(-) delete mode 100644 src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py create mode 100644 src/stac_auth_proxy/middleware/Cql2ApplyFilterBodyMiddleware.py create mode 100644 src/stac_auth_proxy/middleware/Cql2ApplyFilterQueryStringMiddleware.py rename src/stac_auth_proxy/middleware/{BuildCql2FilterMiddleware.py => Cql2BuildFilterMiddleware.py} (99%) create mode 100644 src/stac_auth_proxy/middleware/Cql2ValidateResponseBodyMiddleware.py diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index 6329ab1d..eb62edab 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -16,9 +16,11 @@ from .handlers import HealthzHandler, ReverseProxyHandler, SwaggerUI from .middleware import ( AddProcessTimeHeaderMiddleware, - ApplyCql2FilterMiddleware, AuthenticationExtensionMiddleware, - BuildCql2FilterMiddleware, + Cql2ApplyFilterBodyMiddleware, + Cql2ApplyFilterQueryStringMiddleware, + Cql2BuildFilterMiddleware, + Cql2ValidateResponseBodyMiddleware, EnforceAuthMiddleware, OpenApiMiddleware, ProcessLinksMiddleware, @@ -132,11 +134,11 @@ async def lifespan(app: FastAPI): ) if settings.items_filter or settings.collections_filter: + app.add_middleware(Cql2ValidateResponseBodyMiddleware) + app.add_middleware(Cql2ApplyFilterBodyMiddleware) + app.add_middleware(Cql2ApplyFilterQueryStringMiddleware) app.add_middleware( - ApplyCql2FilterMiddleware, - ) - app.add_middleware( - BuildCql2FilterMiddleware, + Cql2BuildFilterMiddleware, items_filter=settings.items_filter() if settings.items_filter else None, collections_filter=( settings.collections_filter() if settings.collections_filter else None diff --git a/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py b/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py deleted file mode 100644 index 871fe6e8..00000000 --- a/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Middleware to apply CQL2 filters.""" - -import json -import re -from dataclasses import dataclass -from logging import getLogger -from typing import Optional - -from cql2 import Expr -from starlette.datastructures import MutableHeaders -from starlette.requests import Request -from starlette.types import ASGIApp, Message, Receive, Scope, Send - -from ..utils import filters -from ..utils.middleware import required_conformance - -logger = getLogger(__name__) - - -@required_conformance( - r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", - r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", - r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", -) -@dataclass(frozen=True) -class ApplyCql2FilterMiddleware: - """Middleware to apply the Cql2Filter to the request.""" - - app: ASGIApp - state_key: str = "cql2_filter" - - single_record_endpoints = [ - r"^/collections/([^/]+)/items/([^/]+)$", - r"^/collections/([^/]+)$", - ] - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Add the Cql2Filter to the request.""" - if scope["type"] != "http": - return await self.app(scope, receive, send) - - request = Request(scope) - - cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) - - if not cql2_filter: - return await self.app(scope, receive, send) - - # Handle POST, PUT, PATCH - if request.method in ["POST", "PUT", "PATCH"]: - req_body_handler = Cql2RequestBodyAugmentor( - app=self.app, - cql2_filter=cql2_filter, - ) - return await req_body_handler(scope, receive, send) - - # Handle single record requests (ie non-filterable endpoints) - if any( - re.match(expr, request.url.path) for expr in self.single_record_endpoints - ): - res_body_validator = Cql2ResponseBodyValidator( - app=self.app, - cql2_filter=cql2_filter, - ) - return await res_body_validator(scope, send, receive) - - scope["query_string"] = filters.append_qs_filter(request.url.query, cql2_filter) - return await self.app(scope, receive, send) - - -@dataclass(frozen=True) -class Cql2RequestBodyAugmentor: - """Handler to augment the request body with a CQL2 filter.""" - - app: ASGIApp - cql2_filter: Expr - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Augment the request body with a CQL2 filter.""" - body = b"" - more_body = True - - # Read the body - while more_body: - message = await receive() - if message["type"] == "http.request": - body += message.get("body", b"") - more_body = message.get("more_body", False) - - # Modify body - try: - body = json.loads(body) if body else {} - except json.JSONDecodeError as e: - logger.warning("Failed to parse request body as JSON") - # TODO: Return a 400 error - raise e - - # Augment the body - assert isinstance(body, dict), "Request body must be a JSON object" - new_body = json.dumps( - filters.append_body_filter(body, self.cql2_filter) - ).encode("utf-8") - - # Patch content-length in the headers - headers = dict(scope["headers"]) - headers[b"content-length"] = str(len(new_body)).encode("latin1") - scope["headers"] = list(headers.items()) - - async def new_receive(): - return { - "type": "http.request", - "body": new_body, - "more_body": False, - } - - await self.app(scope, new_receive, send) - - -@dataclass -class Cql2ResponseBodyValidator: - """Handler to validate response body with CQL2.""" - - app: ASGIApp - cql2_filter: Expr - - async def __call__(self, scope: Scope, send: Send, receive: Receive) -> None: - """Process a response message and apply filtering if needed.""" - if scope["type"] != "http": - return await self.app(scope, send, receive) - - body = b"" - initial_message: Optional[Message] = None - - async def _send_error_response(status: int, code: str, message: str) -> None: - """Send an error response with the given status and message.""" - assert initial_message, "Initial message not set" - response_dict = { - "code": code, - "description": message, - } - response_bytes = json.dumps(response_dict).encode("utf-8") - headers = MutableHeaders(scope=initial_message) - headers["content-length"] = str(len(response_bytes)) - initial_message["status"] = status - await send(initial_message) - await send( - { - "type": "http.response.body", - "body": response_bytes, - "more_body": False, - } - ) - - async def buffered_send(message: Message) -> None: - """Process a response message and apply filtering if needed.""" - nonlocal body - nonlocal initial_message - initial_message = initial_message or message - # NOTE: to avoid data-leak, we process 404s so their responses are the same as rejected 200s - should_process = initial_message["status"] in [200, 404] - - if not should_process: - return await send(message) - - if message["type"] == "http.response.start": - # Hold off on sending response headers until we've validated the response body - return - - body += message["body"] - if message.get("more_body"): - return - - try: - body_json = json.loads(body) - except json.JSONDecodeError: - msg = "Failed to parse response body as JSON" - logger.warning(msg) - await _send_error_response(status=502, code="ParseError", message=msg) - return - - try: - cql2_matches = self.cql2_filter.matches(body_json) - except Exception as e: - cql2_matches = False - logger.warning("Failed to apply filter: %s", e) - - if cql2_matches: - logger.debug("Response matches filter, returning record") - await send(initial_message) - return await send( - { - "type": "http.response.body", - "body": json.dumps(body_json).encode("utf-8"), - "more_body": False, - } - ) - logger.debug("Response did not match filter, returning 404") - return await _send_error_response( - status=404, code="NotFoundError", message="Record not found." - ) - - return await self.app(scope, receive, buffered_send) diff --git a/src/stac_auth_proxy/middleware/Cql2ApplyFilterBodyMiddleware.py b/src/stac_auth_proxy/middleware/Cql2ApplyFilterBodyMiddleware.py new file mode 100644 index 00000000..b1d46d42 --- /dev/null +++ b/src/stac_auth_proxy/middleware/Cql2ApplyFilterBodyMiddleware.py @@ -0,0 +1,98 @@ +"""Middleware to augment the request body with a CQL2 filter for POST/PUT/PATCH requests.""" + +import json +from dataclasses import dataclass +from logging import getLogger +from typing import Optional + +from cql2 import Expr +from starlette.requests import Request +from starlette.types import ASGIApp, Receive, Scope, Send + +from ..utils import filters +from ..utils.middleware import required_conformance + +logger = getLogger(__name__) + + +@required_conformance( + r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", +) +@dataclass(frozen=True) +class Cql2ApplyFilterBodyMiddleware: + """Middleware to augment the request body with a CQL2 filter for POST/PUT/PATCH requests.""" + + app: ASGIApp + state_key: str = "cql2_filter" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Apply the CQL2 filter to the request body.""" + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope) + cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) + if not cql2_filter: + return await self.app(scope, receive, send) + + if request.method not in ["POST", "PUT", "PATCH"]: + return await self.app(scope, receive, send) + + body = b"" + more_body = True + while more_body: + message = await receive() + if message["type"] == "http.request": + body += message.get("body", b"") + more_body = message.get("more_body", False) + + try: + body_json = json.loads(body) if body else {} + except json.JSONDecodeError: + logger.warning("Failed to parse request body as JSON") + from starlette.responses import JSONResponse + + response = JSONResponse( + { + "code": "ParseError", + "description": "Request body must be valid JSON.", + }, + status_code=400, + ) + await response(scope, receive, send) + return + + if not isinstance(body_json, dict): + logger.warning("Request body must be a JSON object") + from starlette.responses import JSONResponse + + response = JSONResponse( + { + "code": "TypeError", + "description": "Request body must be a JSON object.", + }, + status_code=400, + ) + await response(scope, receive, send) + return + + new_body = json.dumps( + filters.append_body_filter(body_json, cql2_filter) + ).encode("utf-8") + + # Patch content-length in the headers + headers = dict(scope["headers"]) + headers[b"content-length"] = str(len(new_body)).encode("latin1") + scope = dict(scope) + scope["headers"] = list(headers.items()) + + async def new_receive(): + return { + "type": "http.request", + "body": new_body, + "more_body": False, + } + + await self.app(scope, new_receive, send) diff --git a/src/stac_auth_proxy/middleware/Cql2ApplyFilterQueryStringMiddleware.py b/src/stac_auth_proxy/middleware/Cql2ApplyFilterQueryStringMiddleware.py new file mode 100644 index 00000000..539731e0 --- /dev/null +++ b/src/stac_auth_proxy/middleware/Cql2ApplyFilterQueryStringMiddleware.py @@ -0,0 +1,56 @@ +"""Middleware to inject CQL2 filters into the query string for GET/list endpoints.""" + +import re +from dataclasses import dataclass +from logging import getLogger +from typing import Optional + +from cql2 import Expr +from starlette.requests import Request +from starlette.types import ASGIApp, Receive, Scope, Send + +from ..utils import filters +from ..utils.middleware import required_conformance + +logger = getLogger(__name__) + + +@required_conformance( + r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", +) +@dataclass(frozen=True) +class Cql2ApplyFilterQueryStringMiddleware: + """Middleware to inject CQL2 filters into the query string for GET/list endpoints.""" + + app: ASGIApp + state_key: str = "cql2_filter" + + single_record_endpoints = [ + r"^/collections/([^/]+)/items/([^/]+)$", + r"^/collections/([^/]+)$", + ] + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Apply the CQL2 filter to the query string.""" + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope) + cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) + if not cql2_filter: + return await self.app(scope, receive, send) + + # Only handle GET requests that are not single-record endpoints + if request.method != "GET": + return await self.app(scope, receive, send) + if any( + re.match(expr, request.url.path) for expr in self.single_record_endpoints + ): + return await self.app(scope, receive, send) + + # Inject filter into query string + scope = dict(scope) + scope["query_string"] = filters.append_qs_filter(request.url.query, cql2_filter) + return await self.app(scope, receive, send) diff --git a/src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py b/src/stac_auth_proxy/middleware/Cql2BuildFilterMiddleware.py similarity index 99% rename from src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py rename to src/stac_auth_proxy/middleware/Cql2BuildFilterMiddleware.py index cfa153d6..03083e6c 100644 --- a/src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py +++ b/src/stac_auth_proxy/middleware/Cql2BuildFilterMiddleware.py @@ -22,7 +22,7 @@ "http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", ) @dataclass(frozen=True) -class BuildCql2FilterMiddleware: +class Cql2BuildFilterMiddleware: """Middleware to build the Cql2Filter.""" app: ASGIApp diff --git a/src/stac_auth_proxy/middleware/Cql2ValidateResponseBodyMiddleware.py b/src/stac_auth_proxy/middleware/Cql2ValidateResponseBodyMiddleware.py new file mode 100644 index 00000000..c55a9a09 --- /dev/null +++ b/src/stac_auth_proxy/middleware/Cql2ValidateResponseBodyMiddleware.py @@ -0,0 +1,133 @@ +"""Middleware to validate the response body with a CQL2 filter for single-record endpoints.""" + +import json +import re +from dataclasses import dataclass +from logging import getLogger +from typing import Optional + +from cql2 import Expr +from starlette.requests import Request +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +from ..utils.middleware import required_conformance + +logger = getLogger(__name__) + + +@required_conformance( + r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", +) +@dataclass +class Cql2ValidateResponseBodyMiddleware: + """ASGI middleware to validate the response body with a CQL2 filter for single-record endpoints.""" + + app: ASGIApp + state_key: str = "cql2_filter" + + single_record_endpoints = [ + r"^/collections/([^/]+)/items/([^/]+)$", + r"^/collections/([^/]+)$", + ] + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Validate the response body with a CQL2 filter for single-record endpoints.""" + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope) + cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) + if not cql2_filter: + return await self.app(scope, receive, send) + + if not any( + re.match(expr, request.url.path) for expr in self.single_record_endpoints + ): + return await self.app(scope, receive, send) + + # Intercept the response + response_start = None + body_chunks = [] + more_body = True + + async def send_wrapper(message: Message): + nonlocal response_start, body_chunks, more_body + if message["type"] == "http.response.start": + response_start = message + elif message["type"] == "http.response.body": + body_chunks.append(message.get("body", b"")) + more_body = message.get("more_body", False) + if not more_body: + await self._process_and_send_response( + response_start, body_chunks, send, cql2_filter + ) + else: + await send(message) + + await self.app(scope, receive, send_wrapper) + + async def _process_and_send_response( + self, response_start, body_chunks, send, cql2_filter + ): + body = b"".join(body_chunks) + try: + body_json = json.loads(body) + except json.JSONDecodeError: + logger.warning("Failed to parse response body as JSON") + await self._send_json_response( + send, + status=502, + content={ + "code": "ParseError", + "description": "Failed to parse response body as JSON", + }, + ) + return + + try: + cql2_matches = cql2_filter.matches(body_json) + except Exception as e: + cql2_matches = False + logger.warning("Failed to apply filter: %s", e) + + if cql2_matches: + logger.debug("Response matches filter, returning record") + # Send the original response start + await send(response_start) + # Send the filtered body + await send( + { + "type": "http.response.body", + "body": json.dumps(body_json).encode("utf-8"), + "more_body": False, + } + ) + else: + logger.debug("Response did not match filter, returning 404") + await self._send_json_response( + send, + status=404, + content={"code": "NotFoundError", "description": "Record not found."}, + ) + + async def _send_json_response(self, send, status, content): + response_bytes = json.dumps(content).encode("utf-8") + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(response_bytes)).encode("latin1")), + ], + } + ) + await send( + { + "type": "http.response.body", + "body": response_bytes, + "more_body": False, + } + ) diff --git a/src/stac_auth_proxy/middleware/__init__.py b/src/stac_auth_proxy/middleware/__init__.py index 6ad1875d..bc1ae4f2 100644 --- a/src/stac_auth_proxy/middleware/__init__.py +++ b/src/stac_auth_proxy/middleware/__init__.py @@ -1,9 +1,11 @@ """Custom middleware.""" from .AddProcessTimeHeaderMiddleware import AddProcessTimeHeaderMiddleware -from .ApplyCql2FilterMiddleware import ApplyCql2FilterMiddleware from .AuthenticationExtensionMiddleware import AuthenticationExtensionMiddleware -from .BuildCql2FilterMiddleware import BuildCql2FilterMiddleware +from .Cql2ApplyFilterBodyMiddleware import Cql2ApplyFilterBodyMiddleware +from .Cql2ApplyFilterQueryStringMiddleware import Cql2ApplyFilterQueryStringMiddleware +from .Cql2BuildFilterMiddleware import Cql2BuildFilterMiddleware +from .Cql2ValidateResponseBodyMiddleware import Cql2ValidateResponseBodyMiddleware from .EnforceAuthMiddleware import EnforceAuthMiddleware from .ProcessLinksMiddleware import ProcessLinksMiddleware from .RemoveRootPathMiddleware import RemoveRootPathMiddleware @@ -11,11 +13,13 @@ __all__ = [ "AddProcessTimeHeaderMiddleware", - "ApplyCql2FilterMiddleware", "AuthenticationExtensionMiddleware", - "BuildCql2FilterMiddleware", + "Cql2ApplyFilterBodyMiddleware", + "Cql2ApplyFilterQueryStringMiddleware", + "Cql2BuildFilterMiddleware", + "Cql2ValidateResponseBodyMiddleware", "EnforceAuthMiddleware", + "OpenApiMiddleware", "ProcessLinksMiddleware", "RemoveRootPathMiddleware", - "OpenApiMiddleware", ] From 729104e4ed6ffc3b28f6830544d9d47add76f10e Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Tue, 22 Jul 2025 08:22:11 -0700 Subject: [PATCH 3/6] feat: remove added filters from response links #65 --- src/stac_auth_proxy/app.py | 2 + .../Cql2RewriteLinksFilterMiddleware.py | 108 ++++++++++++++++++ src/stac_auth_proxy/middleware/__init__.py | 2 + 3 files changed, 112 insertions(+) create mode 100644 src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index eb62edab..313dc512 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -20,6 +20,7 @@ Cql2ApplyFilterBodyMiddleware, Cql2ApplyFilterQueryStringMiddleware, Cql2BuildFilterMiddleware, + Cql2RewriteLinksFilterMiddleware, Cql2ValidateResponseBodyMiddleware, EnforceAuthMiddleware, OpenApiMiddleware, @@ -137,6 +138,7 @@ async def lifespan(app: FastAPI): app.add_middleware(Cql2ValidateResponseBodyMiddleware) app.add_middleware(Cql2ApplyFilterBodyMiddleware) app.add_middleware(Cql2ApplyFilterQueryStringMiddleware) + app.add_middleware(Cql2RewriteLinksFilterMiddleware) app.add_middleware( Cql2BuildFilterMiddleware, items_filter=settings.items_filter() if settings.items_filter else None, diff --git a/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py b/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py new file mode 100644 index 00000000..1909bfd4 --- /dev/null +++ b/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py @@ -0,0 +1,108 @@ +"""Middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state.""" + +import json +from dataclasses import dataclass +from logging import getLogger +from typing import Optional +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse + +from cql2 import Expr +from starlette.requests import Request +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +logger = getLogger(__name__) + + +@dataclass(frozen=True) +class Cql2RewriteLinksFilterMiddleware: + """ASGI middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state.""" + + app: ASGIApp + state_key: str = "cql2_filter" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Replace 'filter' in .links of the JSON response to state before we had applied the filter.""" + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope) + original_filter = request.query_params.get("filter") + cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) + if cql2_filter is None: + # No filter set, just pass through + return await self.app(scope, receive, send) + + # Intercept the response + response_start = None + body_chunks = [] + more_body = True + + async def send_wrapper(message: Message): + nonlocal response_start, body_chunks, more_body + if message["type"] == "http.response.start": + response_start = message + elif message["type"] == "http.response.body": + body_chunks.append(message.get("body", b"")) + more_body = message.get("more_body", False) + if not more_body: + await self._process_and_send_response( + response_start, body_chunks, send, original_filter + ) + else: + await send(message) + + await self.app(scope, receive, send_wrapper) + + async def _process_and_send_response( + self, + response_start: Message, + body_chunks: list[bytes], + send: Send, + original_filter: Optional[str], + ): + body = b"".join(body_chunks) + try: + data = json.loads(body) + except Exception: + await send(response_start) + await send({"type": "http.response.body", "body": body, "more_body": False}) + return + + cql2_filter = Expr(original_filter) if original_filter else None + links = data.get("links") + if isinstance(links, list): + for link in links: + # Handle filter in query string + if "href" in link: + url = urlparse(link["href"]) + qs = parse_qs(url.query) + if "filter" in qs: + if cql2_filter: + qs["filter"] = [cql2_filter.to_text()] + else: + qs.pop("filter", None) + qs.pop("filter-lang", None) + new_query = urlencode(qs, doseq=True) + link["href"] = urlunparse(url._replace(query=new_query)) + + # Handle filter in body (for POST links) + if "body" in link and isinstance(link["body"], dict): + if "filter" in link["body"]: + if cql2_filter: + link["body"]["filter"] = cql2_filter.to_json() + else: + link["body"].pop("filter", None) + link["body"].pop("filter-lang", None) + + # Send the modified response + new_body = json.dumps(data).encode("utf-8") + + # Patch content-length + headers = [ + (k, v) for k, v in response_start["headers"] if k != b"content-length" + ] + headers.append((b"content-length", str(len(new_body)).encode("latin1"))) + response_start = dict(response_start) + response_start["headers"] = headers + await send(response_start) + await send({"type": "http.response.body", "body": new_body, "more_body": False}) diff --git a/src/stac_auth_proxy/middleware/__init__.py b/src/stac_auth_proxy/middleware/__init__.py index bc1ae4f2..c5dc005b 100644 --- a/src/stac_auth_proxy/middleware/__init__.py +++ b/src/stac_auth_proxy/middleware/__init__.py @@ -5,6 +5,7 @@ from .Cql2ApplyFilterBodyMiddleware import Cql2ApplyFilterBodyMiddleware from .Cql2ApplyFilterQueryStringMiddleware import Cql2ApplyFilterQueryStringMiddleware from .Cql2BuildFilterMiddleware import Cql2BuildFilterMiddleware +from .Cql2RewriteLinksFilterMiddleware import Cql2RewriteLinksFilterMiddleware from .Cql2ValidateResponseBodyMiddleware import Cql2ValidateResponseBodyMiddleware from .EnforceAuthMiddleware import EnforceAuthMiddleware from .ProcessLinksMiddleware import ProcessLinksMiddleware @@ -17,6 +18,7 @@ "Cql2ApplyFilterBodyMiddleware", "Cql2ApplyFilterQueryStringMiddleware", "Cql2BuildFilterMiddleware", + "Cql2RewriteLinksFilterMiddleware", "Cql2ValidateResponseBodyMiddleware", "EnforceAuthMiddleware", "OpenApiMiddleware", From 771b02b9293b3130e2af46262eac4f531c515e0a Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Thu, 24 Jul 2025 11:34:28 -0700 Subject: [PATCH 4/6] in progress --- ...st_cql2_rewrite_links_filter_middleware.py | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 tests/test_cql2_rewrite_links_filter_middleware.py diff --git a/tests/test_cql2_rewrite_links_filter_middleware.py b/tests/test_cql2_rewrite_links_filter_middleware.py new file mode 100644 index 00000000..d80e2383 --- /dev/null +++ b/tests/test_cql2_rewrite_links_filter_middleware.py @@ -0,0 +1,110 @@ +from unittest.mock import patch, MagicMock + +import pytest +from fastapi import FastAPI, Request, Response +from starlette.testclient import TestClient + +from stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware import ( + Cql2RewriteLinksFilterMiddleware, +) + + +@pytest.fixture +def app_with_middleware(): + app = FastAPI() + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + + @app.get("/test") + async def test_endpoint(request: Request): + # Simulate a response with links containing a filter in the query and body + return { + "links": [ + { + "rel": "self", + "href": "http://example.com/search?filter=foo&filter-lang=cql2-text", + }, + { + "rel": "post", + "body": {"filter": "foo", "filter-lang": "cql2-json"}, + }, + ] + } + + return app + + +def test_rewrite_links_with_filter(app_with_middleware): + # Patch cql2.Expr to simulate to_text and to_json + with patch( + "stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr" + ) as MockExpr: + mock_expr = MagicMock() + mock_expr.to_text.return_value = "bar" + mock_expr.to_json.return_value = {"foo": "bar"} + MockExpr.return_value = mock_expr + + client = TestClient(app_with_middleware) + response = client.get("/test?filter=foo") + assert response.status_code == 200 + data = response.json() + # The filter in the href should be rewritten + assert any( + "filter=bar" in link["href"] for link in data["links"] if "href" in link + ) + # The filter in the body should be rewritten + assert any( + link.get("body", {}).get("filter") == {"foo": "bar"} + for link in data["links"] + ) + + +def test_remove_filter_from_links(app_with_middleware): + # Patch cql2.Expr to return None (no filter) + with patch( + "stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr" + ) as MockExpr: + MockExpr.return_value = None + client = TestClient(app_with_middleware) + response = client.get("/test") + assert response.status_code == 200 + data = response.json() + # The filter should be removed from href and body + for link in data["links"]: + if "href" in link: + assert "filter=" not in link["href"] + if "body" in link: + assert "filter" not in link["body"] + assert "filter-lang" not in link["body"] + + +def test_passthrough_when_no_filter_state(app_with_middleware): + # Simulate no filter in request.state + with patch( + "stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr" + ) as MockExpr: + MockExpr.return_value = None + client = TestClient(app_with_middleware) + response = client.get("/test") + assert response.status_code == 200 + data = response.json() + # Should be unchanged (filter removed) + for link in data["links"]: + if "href" in link: + assert "filter=" not in link["href"] + if "body" in link: + assert "filter" not in link["body"] + assert "filter-lang" not in link["body"] + + +def test_non_json_response(app_with_middleware): + # Add a route that returns plain text + app = app_with_middleware + + @app.get("/plain") + async def plain(): + return Response(content="not json", media_type="text/plain") + + client = TestClient(app) + response = client.get("/plain") + assert response.status_code == 200 + assert response.text == "not json" From ca5eafabcaf6824db501aa22672f7aed6a92dfeb Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Thu, 24 Jul 2025 11:34:38 -0700 Subject: [PATCH 5/6] in progress --- tests/test_cql2_rewrite_links_filter_middleware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cql2_rewrite_links_filter_middleware.py b/tests/test_cql2_rewrite_links_filter_middleware.py index d80e2383..c81649ed 100644 --- a/tests/test_cql2_rewrite_links_filter_middleware.py +++ b/tests/test_cql2_rewrite_links_filter_middleware.py @@ -1,4 +1,4 @@ -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import pytest from fastapi import FastAPI, Request, Response From ea17bff056df3b9f073259152f93c30d6374ba8b Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Tue, 2 Sep 2025 20:44:14 -0700 Subject: [PATCH 6/6] chore: add tests --- ...st_cql2_rewrite_links_filter_middleware.py | 391 ++++++++++++++---- 1 file changed, 309 insertions(+), 82 deletions(-) diff --git a/tests/test_cql2_rewrite_links_filter_middleware.py b/tests/test_cql2_rewrite_links_filter_middleware.py index c81649ed..d5b07cdb 100644 --- a/tests/test_cql2_rewrite_links_filter_middleware.py +++ b/tests/test_cql2_rewrite_links_filter_middleware.py @@ -1,6 +1,9 @@ -from unittest.mock import MagicMock, patch +"""Test Cql2RewriteLinksFilterMiddleware.""" + +import re import pytest +from cql2 import Expr from fastapi import FastAPI, Request, Response from starlette.testclient import TestClient @@ -9,102 +12,326 @@ ) -@pytest.fixture -def app_with_middleware(): +def test_non_json_response(): + """Test middleware behavior with non-JSON responses.""" app = FastAPI() app.add_middleware(Cql2RewriteLinksFilterMiddleware) - @app.get("/test") - async def test_endpoint(request: Request): - # Simulate a response with links containing a filter in the query and body - return { - "links": [ - { - "rel": "self", - "href": "http://example.com/search?filter=foo&filter-lang=cql2-text", - }, - { - "rel": "post", - "body": {"filter": "foo", "filter-lang": "cql2-json"}, - }, - ] - } + @app.get("/plain") + async def plain(): + return Response(content="not json", media_type="text/plain") + + client = TestClient(app) + response = client.get("/plain") + assert response.status_code == 200 + assert response.text == "not json" - return app +class TestEdgeCases: + """Test middleware behavior with edge cases.""" -def test_rewrite_links_with_filter(app_with_middleware): - # Patch cql2.Expr to simulate to_text and to_json - with patch( - "stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr" - ) as MockExpr: - mock_expr = MagicMock() - mock_expr.to_text.return_value = "bar" - mock_expr.to_json.return_value = {"foo": "bar"} - MockExpr.return_value = mock_expr + def test_no_links_in_response(self): + """Test middleware behavior when response has no links.""" + app = FastAPI() + app.add_middleware(Cql2RewriteLinksFilterMiddleware) - client = TestClient(app_with_middleware) - response = client.get("/test?filter=foo") + @app.get("/test") + async def test_endpoint(request: Request): + return {"data": "no links here"} + + client = TestClient(app) + response = client.get("/test") assert response.status_code == 200 data = response.json() - # The filter in the href should be rewritten - assert any( - "filter=bar" in link["href"] for link in data["links"] if "href" in link - ) - # The filter in the body should be rewritten - assert any( - link.get("body", {}).get("filter") == {"foo": "bar"} - for link in data["links"] - ) - - -def test_remove_filter_from_links(app_with_middleware): - # Patch cql2.Expr to return None (no filter) - with patch( - "stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr" - ) as MockExpr: - MockExpr.return_value = None - client = TestClient(app_with_middleware) + assert data == {"data": "no links here"} + + def test_malformed_json_response(self): + """Test middleware behavior with malformed JSON response.""" + app = FastAPI() + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + + @app.get("/test") + async def test_endpoint(request: Request): + return Response(content="invalid json", media_type="application/json") + + client = TestClient(app) response = client.get("/test") assert response.status_code == 200 - data = response.json() - # The filter should be removed from href and body - for link in data["links"]: - if "href" in link: - assert "filter=" not in link["href"] - if "body" in link: - assert "filter" not in link["body"] - assert "filter-lang" not in link["body"] - - -def test_passthrough_when_no_filter_state(app_with_middleware): - # Simulate no filter in request.state - with patch( - "stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr" - ) as MockExpr: - MockExpr.return_value = None - client = TestClient(app_with_middleware) + assert response.text == "invalid json" + + def test_links_not_list(self): + """Test middleware behavior when links is not a list.""" + app = FastAPI() + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + + @app.get("/test") + async def test_endpoint(request: Request): + return {"links": "not a list"} + + client = TestClient(app) response = client.get("/test") assert response.status_code == 200 data = response.json() - # Should be unchanged (filter removed) - for link in data["links"]: - if "href" in link: - assert "filter=" not in link["href"] - if "body" in link: - assert "filter" not in link["body"] - assert "filter-lang" not in link["body"] + assert data == {"links": "not a list"} -def test_non_json_response(app_with_middleware): - # Add a route that returns plain text - app = app_with_middleware +class TestMiddlewareStackSimulation: + """Test middleware behavior by simulating the full middleware stack.""" - @app.get("/plain") - async def plain(): - return Response(content="not json", media_type="text/plain") + @pytest.mark.parametrize( + "system_filter,user_filter,state_key", + [ + # Test 1: Basic system filter removal + ( + "private = false", + "cloud_coverage < 50", + "cql2_filter", + ), + # Test 2: Different system filter + ( + "collection = 'landsat'", + "datetime > '2023-01-01'", + "cql2_filter", + ), + # Test 3: Custom state key + ( + "access_level = 'public'", + "quality > 0.8", + "custom_filter", + ), + # Test 4: Complex system filter + ( + "(private = false) and (status = 'active')", + "cloud_coverage < 30", + "cql2_filter", + ), + # Test 5: No user filter provided + ( + "private = false", + None, + "cql2_filter", + ), + # Test 6: No user filter with different system filter + ( + "collection = 'landsat'", + None, + "cql2_filter", + ), + ], + ) + def test_middleware_removes_system_filter_from_query_string_links( + self, + system_filter, + user_filter, + state_key, + ): + """Test that middleware removes system-applied filter from query string links.""" + app = FastAPI() - client = TestClient(app) - response = client.get("/plain") - assert response.status_code == 200 - assert response.text == "not json" + # Add a middleware that simulates Cql2BuildFilterMiddleware + class MockBuildFilterMiddleware: + def __init__(self, app, state_key="cql2_filter"): + self.app = app + self.state_key = state_key + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + request = Request(scope) + setattr(request.state, self.state_key, Expr(system_filter)) + await self.app(scope, receive, send) + + app.add_middleware(Cql2RewriteLinksFilterMiddleware, state_key=state_key) + app.add_middleware(MockBuildFilterMiddleware, state_key=state_key) + + @app.get("/test") + async def test_endpoint(request: Request): + # Automatically join system and user filters using CQL2 operators + system_expr = getattr(request.state, state_key, None) + user_filter_param = request.query_params.get("filter") + + # Build combined expression using CQL2 operators + combined_expr = None + if system_expr and user_filter_param: + # Both system and user filters exist - join them with & + user_expr = Expr(user_filter_param) + combined_expr = system_expr + user_expr + elif system_expr: + # Only system filter exists + combined_expr = system_expr + elif user_filter_param: + # Only user filter exists + combined_expr = Expr(user_filter_param) + + filter_param = f"filter={combined_expr.to_text()}" if combined_expr else "" + separator = "&" if filter_param else "" + + return { + "links": [ + { + "rel": "self", + "href": f"http://example.com/search?{filter_param}{separator}other=param", + } + ] + } + + # Build the request URL + if user_filter: + url = f"/test?filter={user_filter}" + else: + url = "/test" + + client = TestClient(app) + response = client.get(url) + assert response.status_code == 200 + data = response.json() + + # System filter should be removed, leaving only user filter + href = data["links"][0]["href"] + + if user_filter: + # When user filter exists, it should be present in the result + assert "filter=" in href + # Check that key terms from user filter are present + user_expr = Expr(user_filter) + user_text = user_expr.to_text() + # Extract meaningful terms (skip operators and literals) + user_terms = [ + term + for term in re.findall(r"\b\w+\b", user_text) + if term not in ["and", "or", "not", "true", "false"] + ] + for term in user_terms: + assert term in href + else: + # When no user filter, the filter parameter should be completely removed + assert "filter=" not in href + + # The system filter should NOT be in the result + system_expr = Expr(system_filter) + system_text = system_expr.to_text() + # Extract meaningful terms from system filter + system_terms = [ + term + for term in re.findall(r"\b\w+\b", system_text) + if term not in ["and", "or", "not", "true", "false"] + ] + for term in system_terms: + assert term not in href + + # Other parameters should remain + assert "other=param" in href + + @pytest.mark.parametrize( + "system_filter,user_filter,expected_filter,state_key", + [ + # Test 1: Basic request body filter removal + ( + "private = false", + "cloud_coverage < 50", + {"op": "<", "args": [{"property": "cloud_coverage"}, 50]}, + "cql2_filter", + ), + # Test 2: Different system filter in body + ( + "collection = 'landsat'", + "datetime > '2023-01-01'", + {"op": ">", "args": [{"property": "datetime"}, "2023-01-01"]}, + "cql2_filter", + ), + # Test 3: Custom state key + ( + "access_level = 'public'", + "quality > 0.8", + {"op": ">", "args": [{"property": "quality"}, 0.8]}, + "custom_filter", + ), + # Test 4: No user filter provided + ( + "private = false", + None, + None, # Should be completely removed + "cql2_filter", + ), + ], + ) + def test_middleware_removes_system_filter_from_request_body_links( + self, system_filter, user_filter, expected_filter, state_key + ): + """Test that middleware removes system filter from request body links.""" + app = FastAPI() + + class MockBuildFilterMiddleware: + def __init__(self, app, state_key="cql2_filter"): + self.app = app + self.state_key = state_key + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + request = Request(scope) + setattr(request.state, self.state_key, Expr(system_filter)) + await self.app(scope, receive, send) + + app.add_middleware(Cql2RewriteLinksFilterMiddleware, state_key=state_key) + app.add_middleware(MockBuildFilterMiddleware, state_key=state_key) + + @app.get("/test") + async def test_endpoint(request: Request): + # Automatically create combined filter for request body using CQL2 operators + system_expr = getattr(request.state, state_key, None) + user_filter_param = request.query_params.get("filter") + + # Build combined expression using CQL2 operators + combined_expr = None + if system_expr and user_filter_param: + # Both system and user filters exist - join them with & + user_expr = Expr(user_filter_param) + combined_expr = system_expr + user_expr + elif system_expr: + # Only system filter exists + combined_expr = system_expr + elif user_filter_param: + # Only user filter exists + combined_expr = Expr(user_filter_param) + + body_data = { + "other_data": "preserved", + } + + if combined_expr: + body_data["filter"] = combined_expr.to_json() + body_data["filter-lang"] = "cql2-json" + + return { + "links": [ + { + "rel": "post", + "body": body_data, + } + ] + } + + # Build the request URL + if user_filter: + url = f"/test?filter={user_filter}" + else: + url = "/test" + + client = TestClient(app) + response = client.get(url) + assert response.status_code == 200 + data = response.json() + + body = data["links"][0]["body"] + + if expected_filter: + # System filter should be removed from request body, leaving only user filter + assert body["filter"] == expected_filter + # filter-lang should remain since there's still a filter + assert body["filter-lang"] == "cql2-json" + else: + # When no user filter, the filter should be completely removed + assert "filter" not in body + assert "filter-lang" not in body + + # Other data should be preserved + assert body["other_data"] == "preserved"