diff --git a/src/stac_auth_proxy/handlers/reverse_proxy.py b/src/stac_auth_proxy/handlers/reverse_proxy.py index a1849171..98aa9adf 100644 --- a/src/stac_auth_proxy/handlers/reverse_proxy.py +++ b/src/stac_auth_proxy/handlers/reverse_proxy.py @@ -37,14 +37,18 @@ def _prepare_headers(self, request: Request) -> MutableHeaders: headers = MutableHeaders(request.headers) headers.setdefault("Via", f"1.1 {self.proxy_name}") - proxy_client = request.client.host if request.client else "unknown" - proxy_proto = request.url.scheme - proxy_host = request.url.netloc - proxy_path = request.base_url.path + proxy_client = headers.get( + "X-Forwarded-For", request.client.host if request.client else "unknown" + ) + proxy_proto = headers.get("X-Forwarded-Proto", request.url.scheme) + proxy_host = headers.get("X-Forwarded-Host", request.url.netloc) + proxy_path = headers.get("X-Forwarded-Path", request.base_url.path) headers.setdefault( "Forwarded", f"for={proxy_client};host={proxy_host};proto={proxy_proto};path={proxy_path}", ) + + # NOTE: This is useful if the upstream API does not support the Forwarded header if self.legacy_forwarded_headers: headers.setdefault("X-Forwarded-For", proxy_client) headers.setdefault("X-Forwarded-Host", proxy_host) diff --git a/tests/test_reverse_proxy.py b/tests/test_reverse_proxy.py index 7bbb7d9e..5bd70eb3 100644 --- a/tests/test_reverse_proxy.py +++ b/tests/test_reverse_proxy.py @@ -6,10 +6,9 @@ from stac_auth_proxy.handlers.reverse_proxy import ReverseProxyHandler -@pytest.fixture -def mock_request(): - """Create a mock FastAPI request.""" - scope = { +def create_request(scope_overrides=None, headers=None): + """Create a mock FastAPI request with custom scope and headers.""" + default_scope = { "type": "http", "method": "GET", "path": "/test", @@ -19,7 +18,20 @@ def mock_request(): (b"accept", b"application/json"), ], } - return Request(scope) + + if scope_overrides: + default_scope.update(scope_overrides) + + if headers: + default_scope["headers"] = headers + + return Request(default_scope) + + +@pytest.fixture +def mock_request(): + """Create a mock FastAPI request.""" + return create_request() @pytest.fixture @@ -28,15 +40,33 @@ def reverse_proxy_handler(): return ReverseProxyHandler(upstream="http://upstream-api.com") +@pytest.mark.parametrize( + "legacy_headers,override_host,proxy_name,expected_host,expected_via", + [ + (False, True, "stac-auth-proxy", "upstream-api.com", "1.1 stac-auth-proxy"), + (True, True, "stac-auth-proxy", "upstream-api.com", "1.1 stac-auth-proxy"), + (False, False, "stac-auth-proxy", "localhost:8000", "1.1 stac-auth-proxy"), + (False, True, "custom-proxy", "upstream-api.com", "1.1 custom-proxy"), + ], +) @pytest.mark.asyncio -async def test_basic_headers(mock_request, reverse_proxy_handler): - """Test that basic headers are properly set.""" - headers = reverse_proxy_handler._prepare_headers(mock_request) +async def test_basic_headers( + mock_request, legacy_headers, override_host, proxy_name, expected_host, expected_via +): + """Test basic header functionality with various configurations.""" + handler = ReverseProxyHandler( + upstream="http://upstream-api.com", + legacy_forwarded_headers=legacy_headers, + override_host=override_host, + proxy_name=proxy_name, + ) + headers = handler._prepare_headers(mock_request) # Check standard headers - assert headers["Host"] == "upstream-api.com" + assert headers["Host"] == expected_host assert headers["User-Agent"] == "test-agent" assert headers["Accept"] == "application/json" + assert headers["Via"] == expected_via # Check modern forwarded header assert "Forwarded" in headers @@ -46,60 +76,28 @@ async def test_basic_headers(mock_request, reverse_proxy_handler): assert "proto=http" in forwarded assert "path=/" in forwarded - # Check Via header - assert headers["Via"] == "1.1 stac-auth-proxy" - - # Legacy headers should not be present by default - assert "X-Forwarded-For" not in headers - assert "X-Forwarded-Host" not in headers - assert "X-Forwarded-Proto" not in headers - assert "X-Forwarded-Path" not in headers - - -@pytest.mark.asyncio -async def test_legacy_forwarded_headers(mock_request): - """Test that legacy X-Forwarded-* headers are set when enabled.""" - handler = ReverseProxyHandler( - upstream="http://upstream-api.com", legacy_forwarded_headers=True - ) - headers = handler._prepare_headers(mock_request) - - # Check legacy headers - assert headers["X-Forwarded-For"] == "unknown" - assert headers["X-Forwarded-Host"] == "localhost:8000" - assert headers["X-Forwarded-Proto"] == "http" - assert headers["X-Forwarded-Path"] == "/" - - # Modern Forwarded header should still be present - assert "Forwarded" in headers - - -@pytest.mark.asyncio -async def test_override_host_disabled(mock_request): - """Test that host override can be disabled.""" - handler = ReverseProxyHandler( - upstream="http://upstream-api.com", override_host=False - ) - headers = handler._prepare_headers(mock_request) - assert headers["Host"] == "localhost:8000" - - -@pytest.mark.asyncio -async def test_custom_proxy_name(mock_request): - """Test that custom proxy name is used in Via header.""" - handler = ReverseProxyHandler( - upstream="http://upstream-api.com", proxy_name="custom-proxy" - ) - headers = handler._prepare_headers(mock_request) - assert headers["Via"] == "1.1 custom-proxy" + # Check legacy headers based on configuration + if legacy_headers: + assert headers["X-Forwarded-For"] == "unknown" + assert headers["X-Forwarded-Host"] == "localhost:8000" + assert headers["X-Forwarded-Proto"] == "http" + assert headers["X-Forwarded-Path"] == "/" + else: + assert "X-Forwarded-For" not in headers + assert "X-Forwarded-Host" not in headers + assert "X-Forwarded-Proto" not in headers + assert "X-Forwarded-Path" not in headers +@pytest.mark.parametrize("legacy_headers", [False, True]) @pytest.mark.asyncio -async def test_forwarded_headers_with_client(mock_request): +async def test_forwarded_headers_with_client(mock_request, legacy_headers): """Test forwarded headers when client information is available.""" # Add client information to the request mock_request.scope["client"] = ("192.168.1.1", 12345) - handler = ReverseProxyHandler(upstream="http://upstream-api.com") + handler = ReverseProxyHandler( + upstream="http://upstream-api.com", legacy_forwarded_headers=legacy_headers + ) headers = handler._prepare_headers(mock_request) # Check modern Forwarded header @@ -109,56 +107,37 @@ async def test_forwarded_headers_with_client(mock_request): assert "proto=http" in forwarded assert "path=/" in forwarded - # Legacy headers should not be present by default - assert "X-Forwarded-For" not in headers - assert "X-Forwarded-Host" not in headers - assert "X-Forwarded-Proto" not in headers - assert "X-Forwarded-Path" not in headers + # Check legacy headers based on configuration + if legacy_headers: + assert headers["X-Forwarded-For"] == "192.168.1.1" + assert headers["X-Forwarded-Host"] == "localhost:8000" + assert headers["X-Forwarded-Proto"] == "http" + assert headers["X-Forwarded-Path"] == "/" + else: + assert "X-Forwarded-For" not in headers + assert "X-Forwarded-Host" not in headers + assert "X-Forwarded-Proto" not in headers + assert "X-Forwarded-Path" not in headers +@pytest.mark.parametrize("legacy_headers", [False, True]) @pytest.mark.asyncio -async def test_legacy_forwarded_headers_with_client(mock_request): - """Test legacy forwarded headers when client information is available.""" - mock_request.scope["client"] = ("192.168.1.1", 12345) +async def test_https_proto(mock_request, legacy_headers): + """Test that protocol is set correctly for HTTPS.""" + mock_request.scope["scheme"] = "https" handler = ReverseProxyHandler( - upstream="http://upstream-api.com", legacy_forwarded_headers=True + upstream="http://upstream-api.com", legacy_forwarded_headers=legacy_headers ) headers = handler._prepare_headers(mock_request) - # Check legacy headers - assert headers["X-Forwarded-For"] == "192.168.1.1" - assert headers["X-Forwarded-Host"] == "localhost:8000" - assert headers["X-Forwarded-Proto"] == "http" - assert headers["X-Forwarded-Path"] == "/" - - # Modern Forwarded header should still be present - assert "Forwarded" in headers - - -@pytest.mark.asyncio -async def test_https_proto(mock_request): - """Test that X-Forwarded-Proto is set correctly for HTTPS.""" - mock_request.scope["scheme"] = "https" - handler = ReverseProxyHandler(upstream="http://upstream-api.com") - headers = handler._prepare_headers(mock_request) - # Check modern Forwarded header assert "proto=https" in headers["Forwarded"] - # Legacy headers should not be present by default - assert "X-Forwarded-Proto" not in headers - - -@pytest.mark.asyncio -async def test_https_proto_legacy(mock_request): - """Test that X-Forwarded-Proto is set correctly for HTTPS with legacy headers.""" - mock_request.scope["scheme"] = "https" - handler = ReverseProxyHandler( - upstream="http://upstream-api.com", legacy_forwarded_headers=True - ) - headers = handler._prepare_headers(mock_request) - assert headers["X-Forwarded-Proto"] == "https" - assert "proto=https" in headers["Forwarded"] + # Check legacy headers based on configuration + if legacy_headers: + assert headers["X-Forwarded-Proto"] == "https" + else: + assert "X-Forwarded-Proto" not in headers @pytest.mark.asyncio @@ -171,3 +150,135 @@ async def test_non_standard_port(mock_request): handler = ReverseProxyHandler(upstream="http://upstream-api.com:8080") headers = handler._prepare_headers(mock_request) assert headers["Host"] == "upstream-api.com:8080" + + +@pytest.mark.parametrize("legacy_headers", [False, True]) +@pytest.mark.asyncio +async def test_nginx_proxy_headers_preserved(legacy_headers): + """Test that existing proxy headers from NGINX are preserved.""" + # Simulate a request that already has proxy headers set by NGINX + headers = [ + (b"host", b"localhost:8000"), + (b"user-agent", b"test-agent"), + (b"x-forwarded-for", b"203.0.113.1, 198.51.100.1"), + (b"x-forwarded-proto", b"https"), + (b"x-forwarded-host", b"api.example.com"), + (b"x-forwarded-path", b"/api/v1"), + ] + request = create_request(headers=headers) + handler = ReverseProxyHandler( + upstream="http://upstream-api.com", legacy_forwarded_headers=legacy_headers + ) + headers = handler._prepare_headers(request) + + # Check that the existing proxy headers are preserved in the Forwarded header + forwarded = headers["Forwarded"] + assert "for=203.0.113.1, 198.51.100.1" in forwarded + assert "host=api.example.com" in forwarded + assert "proto=https" in forwarded + assert "path=/api/v1" in forwarded + + # The original headers should still be present (they're preserved from the request) + assert headers["X-Forwarded-For"] == "203.0.113.1, 198.51.100.1" + assert headers["X-Forwarded-Host"] == "api.example.com" + assert headers["X-Forwarded-Proto"] == "https" + assert headers["X-Forwarded-Path"] == "/api/v1" + + +@pytest.mark.parametrize( + "scope_overrides,headers,expected_forwarded", + [ + pytest.param( + {}, + [ + (b"host", b"localhost:8000"), + (b"user-agent", b"test-agent"), + (b"x-forwarded-for", b"203.0.113.1"), + (b"x-forwarded-proto", b"https"), + # Missing X-Forwarded-Host and X-Forwarded-Path + ], + { + "for": "203.0.113.1", # From existing header + "host": "localhost:8000", # Fallback to request host + "proto": "https", # From existing header + "path": "/", # Fallback to request path + }, + id="partial_headers_fallback", + ), + pytest.param( + {"client": ("192.168.1.1", 12345)}, # This should be ignored + [ + (b"host", b"localhost:8000"), + (b"user-agent", b"test-agent"), + (b"x-forwarded-for", b"203.0.113.1, 198.51.100.1"), + ], + { + "for": "203.0.113.1, 198.51.100.1", # From existing header + "host": "localhost:8000", + "proto": "http", + "path": "/", + }, + id="client_info_precedence", + ), + pytest.param( + {"scheme": "https"}, # This should be ignored + [ + (b"host", b"localhost:8000"), + (b"user-agent", b"test-agent"), + (b"x-forwarded-proto", b"http"), # NGINX says it's HTTP + ], + { + "for": "unknown", + "host": "localhost:8000", + "proto": "http", # From existing header + "path": "/", + }, + id="scheme_precedence", + ), + pytest.param( + {"path": "/custom/path"}, + [ + (b"host", b"localhost:8000"), + (b"user-agent", b"test-agent"), + (b"x-forwarded-path", b"/api/v1/root"), # NGINX says different path + ], + { + "for": "unknown", + "host": "localhost:8000", + "proto": "http", + "path": "/api/v1/root", # From existing header + }, + id="path_precedence", + ), + pytest.param( + {}, + [ + (b"host", b"localhost:8000"), + (b"user-agent", b"test-agent"), + (b"X-Forwarded-For", b"203.0.113.1"), # Mixed case + (b"x-forwarded-proto", b"https"), # Lower case + (b"X-FORWARDED-HOST", b"api.example.com"), # Upper case + ], + { + "for": "203.0.113.1", + "host": "api.example.com", + "proto": "https", + "path": "/", + }, + id="case_insensitive", + ), + ], +) +@pytest.mark.asyncio +async def test_nginx_headers_behavior(scope_overrides, headers, expected_forwarded): + """Test various NGINX header behaviors and precedence rules.""" + request = create_request(scope_overrides=scope_overrides, headers=headers) + handler = ReverseProxyHandler(upstream="http://upstream-api.com") + result_headers = handler._prepare_headers(request) + + # Check that the Forwarded header contains expected values + forwarded = result_headers["Forwarded"] + for key, expected_value in expected_forwarded.items(): + assert ( + f"{key}={expected_value}" in forwarded + ), f"Expected {key}={expected_value} in {forwarded}"