diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index a5c0aa95..7af5f7a7 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -20,6 +20,7 @@ Cql2ApplyFilterBodyMiddleware, Cql2ApplyFilterQueryStringMiddleware, Cql2BuildFilterMiddleware, + Cql2RewriteLinksFilterMiddleware, Cql2ValidateResponseBodyMiddleware, EnforceAuthMiddleware, OpenApiMiddleware, @@ -110,6 +111,7 @@ def configure_app( app.add_middleware(Cql2ValidateResponseBodyMiddleware) app.add_middleware(Cql2ApplyFilterBodyMiddleware) app.add_middleware(Cql2ApplyFilterQueryStringMiddleware) + app.add_middleware(Cql2RewriteLinksFilterMiddleware) app.add_middleware( Cql2BuildFilterMiddleware, items_filter=settings.items_filter() if settings.items_filter else None, diff --git a/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py b/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py new file mode 100644 index 00000000..1909bfd4 --- /dev/null +++ b/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py @@ -0,0 +1,108 @@ +"""Middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state.""" + +import json +from dataclasses import dataclass +from logging import getLogger +from typing import Optional +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse + +from cql2 import Expr +from starlette.requests import Request +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +logger = getLogger(__name__) + + +@dataclass(frozen=True) +class Cql2RewriteLinksFilterMiddleware: + """ASGI middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state.""" + + app: ASGIApp + state_key: str = "cql2_filter" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Replace 'filter' in .links of the JSON response to state before we had applied the filter.""" + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope) + original_filter = request.query_params.get("filter") + cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) + if cql2_filter is None: + # No filter set, just pass through + return await self.app(scope, receive, send) + + # Intercept the response + response_start = None + body_chunks = [] + more_body = True + + async def send_wrapper(message: Message): + nonlocal response_start, body_chunks, more_body + if message["type"] == "http.response.start": + response_start = message + elif message["type"] == "http.response.body": + body_chunks.append(message.get("body", b"")) + more_body = message.get("more_body", False) + if not more_body: + await self._process_and_send_response( + response_start, body_chunks, send, original_filter + ) + else: + await send(message) + + await self.app(scope, receive, send_wrapper) + + async def _process_and_send_response( + self, + response_start: Message, + body_chunks: list[bytes], + send: Send, + original_filter: Optional[str], + ): + body = b"".join(body_chunks) + try: + data = json.loads(body) + except Exception: + await send(response_start) + await send({"type": "http.response.body", "body": body, "more_body": False}) + return + + cql2_filter = Expr(original_filter) if original_filter else None + links = data.get("links") + if isinstance(links, list): + for link in links: + # Handle filter in query string + if "href" in link: + url = urlparse(link["href"]) + qs = parse_qs(url.query) + if "filter" in qs: + if cql2_filter: + qs["filter"] = [cql2_filter.to_text()] + else: + qs.pop("filter", None) + qs.pop("filter-lang", None) + new_query = urlencode(qs, doseq=True) + link["href"] = urlunparse(url._replace(query=new_query)) + + # Handle filter in body (for POST links) + if "body" in link and isinstance(link["body"], dict): + if "filter" in link["body"]: + if cql2_filter: + link["body"]["filter"] = cql2_filter.to_json() + else: + link["body"].pop("filter", None) + link["body"].pop("filter-lang", None) + + # Send the modified response + new_body = json.dumps(data).encode("utf-8") + + # Patch content-length + headers = [ + (k, v) for k, v in response_start["headers"] if k != b"content-length" + ] + headers.append((b"content-length", str(len(new_body)).encode("latin1"))) + response_start = dict(response_start) + response_start["headers"] = headers + await send(response_start) + await send({"type": "http.response.body", "body": new_body, "more_body": False}) diff --git a/src/stac_auth_proxy/middleware/__init__.py b/src/stac_auth_proxy/middleware/__init__.py index bc1ae4f2..c5dc005b 100644 --- a/src/stac_auth_proxy/middleware/__init__.py +++ b/src/stac_auth_proxy/middleware/__init__.py @@ -5,6 +5,7 @@ from .Cql2ApplyFilterBodyMiddleware import Cql2ApplyFilterBodyMiddleware from .Cql2ApplyFilterQueryStringMiddleware import Cql2ApplyFilterQueryStringMiddleware from .Cql2BuildFilterMiddleware import Cql2BuildFilterMiddleware +from .Cql2RewriteLinksFilterMiddleware import Cql2RewriteLinksFilterMiddleware from .Cql2ValidateResponseBodyMiddleware import Cql2ValidateResponseBodyMiddleware from .EnforceAuthMiddleware import EnforceAuthMiddleware from .ProcessLinksMiddleware import ProcessLinksMiddleware @@ -17,6 +18,7 @@ "Cql2ApplyFilterBodyMiddleware", "Cql2ApplyFilterQueryStringMiddleware", "Cql2BuildFilterMiddleware", + "Cql2RewriteLinksFilterMiddleware", "Cql2ValidateResponseBodyMiddleware", "EnforceAuthMiddleware", "OpenApiMiddleware", diff --git a/tests/test_cql2_rewrite_links_filter_middleware.py b/tests/test_cql2_rewrite_links_filter_middleware.py new file mode 100644 index 00000000..d5b07cdb --- /dev/null +++ b/tests/test_cql2_rewrite_links_filter_middleware.py @@ -0,0 +1,337 @@ +"""Test Cql2RewriteLinksFilterMiddleware.""" + +import re + +import pytest +from cql2 import Expr +from fastapi import FastAPI, Request, Response +from starlette.testclient import TestClient + +from stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware import ( + Cql2RewriteLinksFilterMiddleware, +) + + +def test_non_json_response(): + """Test middleware behavior with non-JSON responses.""" + app = FastAPI() + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + + @app.get("/plain") + async def plain(): + return Response(content="not json", media_type="text/plain") + + client = TestClient(app) + response = client.get("/plain") + assert response.status_code == 200 + assert response.text == "not json" + + +class TestEdgeCases: + """Test middleware behavior with edge cases.""" + + def test_no_links_in_response(self): + """Test middleware behavior when response has no links.""" + app = FastAPI() + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + + @app.get("/test") + async def test_endpoint(request: Request): + return {"data": "no links here"} + + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 200 + data = response.json() + assert data == {"data": "no links here"} + + def test_malformed_json_response(self): + """Test middleware behavior with malformed JSON response.""" + app = FastAPI() + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + + @app.get("/test") + async def test_endpoint(request: Request): + return Response(content="invalid json", media_type="application/json") + + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 200 + assert response.text == "invalid json" + + def test_links_not_list(self): + """Test middleware behavior when links is not a list.""" + app = FastAPI() + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + + @app.get("/test") + async def test_endpoint(request: Request): + return {"links": "not a list"} + + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 200 + data = response.json() + assert data == {"links": "not a list"} + + +class TestMiddlewareStackSimulation: + """Test middleware behavior by simulating the full middleware stack.""" + + @pytest.mark.parametrize( + "system_filter,user_filter,state_key", + [ + # Test 1: Basic system filter removal + ( + "private = false", + "cloud_coverage < 50", + "cql2_filter", + ), + # Test 2: Different system filter + ( + "collection = 'landsat'", + "datetime > '2023-01-01'", + "cql2_filter", + ), + # Test 3: Custom state key + ( + "access_level = 'public'", + "quality > 0.8", + "custom_filter", + ), + # Test 4: Complex system filter + ( + "(private = false) and (status = 'active')", + "cloud_coverage < 30", + "cql2_filter", + ), + # Test 5: No user filter provided + ( + "private = false", + None, + "cql2_filter", + ), + # Test 6: No user filter with different system filter + ( + "collection = 'landsat'", + None, + "cql2_filter", + ), + ], + ) + def test_middleware_removes_system_filter_from_query_string_links( + self, + system_filter, + user_filter, + state_key, + ): + """Test that middleware removes system-applied filter from query string links.""" + app = FastAPI() + + # Add a middleware that simulates Cql2BuildFilterMiddleware + class MockBuildFilterMiddleware: + def __init__(self, app, state_key="cql2_filter"): + self.app = app + self.state_key = state_key + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + request = Request(scope) + setattr(request.state, self.state_key, Expr(system_filter)) + await self.app(scope, receive, send) + + app.add_middleware(Cql2RewriteLinksFilterMiddleware, state_key=state_key) + app.add_middleware(MockBuildFilterMiddleware, state_key=state_key) + + @app.get("/test") + async def test_endpoint(request: Request): + # Automatically join system and user filters using CQL2 operators + system_expr = getattr(request.state, state_key, None) + user_filter_param = request.query_params.get("filter") + + # Build combined expression using CQL2 operators + combined_expr = None + if system_expr and user_filter_param: + # Both system and user filters exist - join them with & + user_expr = Expr(user_filter_param) + combined_expr = system_expr + user_expr + elif system_expr: + # Only system filter exists + combined_expr = system_expr + elif user_filter_param: + # Only user filter exists + combined_expr = Expr(user_filter_param) + + filter_param = f"filter={combined_expr.to_text()}" if combined_expr else "" + separator = "&" if filter_param else "" + + return { + "links": [ + { + "rel": "self", + "href": f"http://example.com/search?{filter_param}{separator}other=param", + } + ] + } + + # Build the request URL + if user_filter: + url = f"/test?filter={user_filter}" + else: + url = "/test" + + client = TestClient(app) + response = client.get(url) + assert response.status_code == 200 + data = response.json() + + # System filter should be removed, leaving only user filter + href = data["links"][0]["href"] + + if user_filter: + # When user filter exists, it should be present in the result + assert "filter=" in href + # Check that key terms from user filter are present + user_expr = Expr(user_filter) + user_text = user_expr.to_text() + # Extract meaningful terms (skip operators and literals) + user_terms = [ + term + for term in re.findall(r"\b\w+\b", user_text) + if term not in ["and", "or", "not", "true", "false"] + ] + for term in user_terms: + assert term in href + else: + # When no user filter, the filter parameter should be completely removed + assert "filter=" not in href + + # The system filter should NOT be in the result + system_expr = Expr(system_filter) + system_text = system_expr.to_text() + # Extract meaningful terms from system filter + system_terms = [ + term + for term in re.findall(r"\b\w+\b", system_text) + if term not in ["and", "or", "not", "true", "false"] + ] + for term in system_terms: + assert term not in href + + # Other parameters should remain + assert "other=param" in href + + @pytest.mark.parametrize( + "system_filter,user_filter,expected_filter,state_key", + [ + # Test 1: Basic request body filter removal + ( + "private = false", + "cloud_coverage < 50", + {"op": "<", "args": [{"property": "cloud_coverage"}, 50]}, + "cql2_filter", + ), + # Test 2: Different system filter in body + ( + "collection = 'landsat'", + "datetime > '2023-01-01'", + {"op": ">", "args": [{"property": "datetime"}, "2023-01-01"]}, + "cql2_filter", + ), + # Test 3: Custom state key + ( + "access_level = 'public'", + "quality > 0.8", + {"op": ">", "args": [{"property": "quality"}, 0.8]}, + "custom_filter", + ), + # Test 4: No user filter provided + ( + "private = false", + None, + None, # Should be completely removed + "cql2_filter", + ), + ], + ) + def test_middleware_removes_system_filter_from_request_body_links( + self, system_filter, user_filter, expected_filter, state_key + ): + """Test that middleware removes system filter from request body links.""" + app = FastAPI() + + class MockBuildFilterMiddleware: + def __init__(self, app, state_key="cql2_filter"): + self.app = app + self.state_key = state_key + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + request = Request(scope) + setattr(request.state, self.state_key, Expr(system_filter)) + await self.app(scope, receive, send) + + app.add_middleware(Cql2RewriteLinksFilterMiddleware, state_key=state_key) + app.add_middleware(MockBuildFilterMiddleware, state_key=state_key) + + @app.get("/test") + async def test_endpoint(request: Request): + # Automatically create combined filter for request body using CQL2 operators + system_expr = getattr(request.state, state_key, None) + user_filter_param = request.query_params.get("filter") + + # Build combined expression using CQL2 operators + combined_expr = None + if system_expr and user_filter_param: + # Both system and user filters exist - join them with & + user_expr = Expr(user_filter_param) + combined_expr = system_expr + user_expr + elif system_expr: + # Only system filter exists + combined_expr = system_expr + elif user_filter_param: + # Only user filter exists + combined_expr = Expr(user_filter_param) + + body_data = { + "other_data": "preserved", + } + + if combined_expr: + body_data["filter"] = combined_expr.to_json() + body_data["filter-lang"] = "cql2-json" + + return { + "links": [ + { + "rel": "post", + "body": body_data, + } + ] + } + + # Build the request URL + if user_filter: + url = f"/test?filter={user_filter}" + else: + url = "/test" + + client = TestClient(app) + response = client.get(url) + assert response.status_code == 200 + data = response.json() + + body = data["links"][0]["body"] + + if expected_filter: + # System filter should be removed from request body, leaving only user filter + assert body["filter"] == expected_filter + # filter-lang should remain since there's still a filter + assert body["filter-lang"] == "cql2-json" + else: + # When no user filter, the filter should be completely removed + assert "filter" not in body + assert "filter-lang" not in body + + # Other data should be preserved + assert body["other_data"] == "preserved"