diff --git a/README.md b/README.md index a3909054..6c7c2304 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,10 @@ The application is configurable via environment variables. - **Type:** string - **Required:** No, defaults to `/healthz` - **Example:** `''` (disabled) + - **`OVERRIDE_HOST`**, override the host header for the upstream API + - **Type:** boolean + - **Required:** No, defaults to `true` + - **Example:** `false`, `1`, `True` - Authentication - **`OIDC_DISCOVERY_URL`**, OpenID Connect discovery document URL - **Type:** HTTP(S) URL diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index 03c26be2..1d02612c 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -80,7 +80,10 @@ async def lifespan(app: FastAPI): app.add_api_route( "/{path:path}", - ReverseProxyHandler(upstream=str(settings.upstream_url)).proxy_request, + ReverseProxyHandler( + upstream=str(settings.upstream_url), + override_host=settings.override_host, + ).proxy_request, methods=["GET", "POST", "PUT", "PATCH", "DELETE"], ) diff --git a/src/stac_auth_proxy/config.py b/src/stac_auth_proxy/config.py index aa4690f0..1515c0e1 100644 --- a/src/stac_auth_proxy/config.py +++ b/src/stac_auth_proxy/config.py @@ -38,6 +38,7 @@ class Settings(BaseSettings): oidc_discovery_url: HttpUrl oidc_discovery_internal_url: HttpUrl + override_host: bool = True healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz") wait_for_upstream: bool = True check_conformance: bool = True diff --git a/src/stac_auth_proxy/handlers/reverse_proxy.py b/src/stac_auth_proxy/handlers/reverse_proxy.py index e4c868b8..a1849171 100644 --- a/src/stac_auth_proxy/handlers/reverse_proxy.py +++ b/src/stac_auth_proxy/handlers/reverse_proxy.py @@ -20,6 +20,10 @@ class ReverseProxyHandler: client: httpx.AsyncClient = None timeout: httpx.Timeout = field(default_factory=lambda: httpx.Timeout(timeout=15.0)) + proxy_name: str = "stac-auth-proxy" + override_host: bool = True + legacy_forwarded_headers: bool = False + def __post_init__(self): """Initialize the HTTP client.""" self.client = self.client or httpx.AsyncClient( @@ -28,11 +32,34 @@ def __post_init__(self): http2=True, ) + def _prepare_headers(self, request: Request) -> MutableHeaders: + """Prepare headers for the proxied request.""" + headers = MutableHeaders(request.headers) + headers.setdefault("Via", f"1.1 {self.proxy_name}") + + proxy_client = request.client.host if request.client else "unknown" + proxy_proto = request.url.scheme + proxy_host = request.url.netloc + proxy_path = request.base_url.path + headers.setdefault( + "Forwarded", + f"for={proxy_client};host={proxy_host};proto={proxy_proto};path={proxy_path}", + ) + if self.legacy_forwarded_headers: + headers.setdefault("X-Forwarded-For", proxy_client) + headers.setdefault("X-Forwarded-Host", proxy_host) + headers.setdefault("X-Forwarded-Path", proxy_path) + headers.setdefault("X-Forwarded-Proto", proxy_proto) + + # Set host to the upstream host + if self.override_host: + headers["Host"] = self.client.base_url.netloc.decode("utf-8") + + return headers + async def proxy_request(self, request: Request) -> Response: """Proxy a request to the upstream STAC API.""" - headers = MutableHeaders(request.headers) - headers.setdefault("X-Forwarded-For", request.client.host) - headers.setdefault("X-Forwarded-Host", request.url.hostname) + headers = self._prepare_headers(request) # https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466 rp_req = self.client.build_request( diff --git a/tests/test_reverse_proxy.py b/tests/test_reverse_proxy.py new file mode 100644 index 00000000..7bbb7d9e --- /dev/null +++ b/tests/test_reverse_proxy.py @@ -0,0 +1,173 @@ +"""Tests for the reverse proxy handler's header functionality.""" + +import pytest +from fastapi import Request + +from stac_auth_proxy.handlers.reverse_proxy import ReverseProxyHandler + + +@pytest.fixture +def mock_request(): + """Create a mock FastAPI request.""" + scope = { + "type": "http", + "method": "GET", + "path": "/test", + "headers": [ + (b"host", b"localhost:8000"), + (b"user-agent", b"test-agent"), + (b"accept", b"application/json"), + ], + } + return Request(scope) + + +@pytest.fixture +def reverse_proxy_handler(): + """Create a reverse proxy handler instance.""" + return ReverseProxyHandler(upstream="http://upstream-api.com") + + +@pytest.mark.asyncio +async def test_basic_headers(mock_request, reverse_proxy_handler): + """Test that basic headers are properly set.""" + headers = reverse_proxy_handler._prepare_headers(mock_request) + + # Check standard headers + assert headers["Host"] == "upstream-api.com" + assert headers["User-Agent"] == "test-agent" + assert headers["Accept"] == "application/json" + + # Check modern forwarded header + assert "Forwarded" in headers + forwarded = headers["Forwarded"] + assert "for=unknown" in forwarded + assert "host=localhost:8000" in forwarded + assert "proto=http" in forwarded + assert "path=/" in forwarded + + # Check Via header + assert headers["Via"] == "1.1 stac-auth-proxy" + + # Legacy headers should not be present by default + assert "X-Forwarded-For" not in headers + assert "X-Forwarded-Host" not in headers + assert "X-Forwarded-Proto" not in headers + assert "X-Forwarded-Path" not in headers + + +@pytest.mark.asyncio +async def test_legacy_forwarded_headers(mock_request): + """Test that legacy X-Forwarded-* headers are set when enabled.""" + handler = ReverseProxyHandler( + upstream="http://upstream-api.com", legacy_forwarded_headers=True + ) + headers = handler._prepare_headers(mock_request) + + # Check legacy headers + assert headers["X-Forwarded-For"] == "unknown" + assert headers["X-Forwarded-Host"] == "localhost:8000" + assert headers["X-Forwarded-Proto"] == "http" + assert headers["X-Forwarded-Path"] == "/" + + # Modern Forwarded header should still be present + assert "Forwarded" in headers + + +@pytest.mark.asyncio +async def test_override_host_disabled(mock_request): + """Test that host override can be disabled.""" + handler = ReverseProxyHandler( + upstream="http://upstream-api.com", override_host=False + ) + headers = handler._prepare_headers(mock_request) + assert headers["Host"] == "localhost:8000" + + +@pytest.mark.asyncio +async def test_custom_proxy_name(mock_request): + """Test that custom proxy name is used in Via header.""" + handler = ReverseProxyHandler( + upstream="http://upstream-api.com", proxy_name="custom-proxy" + ) + headers = handler._prepare_headers(mock_request) + assert headers["Via"] == "1.1 custom-proxy" + + +@pytest.mark.asyncio +async def test_forwarded_headers_with_client(mock_request): + """Test forwarded headers when client information is available.""" + # Add client information to the request + mock_request.scope["client"] = ("192.168.1.1", 12345) + handler = ReverseProxyHandler(upstream="http://upstream-api.com") + headers = handler._prepare_headers(mock_request) + + # Check modern Forwarded header + forwarded = headers["Forwarded"] + assert "for=192.168.1.1" in forwarded + assert "host=localhost:8000" in forwarded + assert "proto=http" in forwarded + assert "path=/" in forwarded + + # Legacy headers should not be present by default + assert "X-Forwarded-For" not in headers + assert "X-Forwarded-Host" not in headers + assert "X-Forwarded-Proto" not in headers + assert "X-Forwarded-Path" not in headers + + +@pytest.mark.asyncio +async def test_legacy_forwarded_headers_with_client(mock_request): + """Test legacy forwarded headers when client information is available.""" + mock_request.scope["client"] = ("192.168.1.1", 12345) + handler = ReverseProxyHandler( + upstream="http://upstream-api.com", legacy_forwarded_headers=True + ) + headers = handler._prepare_headers(mock_request) + + # Check legacy headers + assert headers["X-Forwarded-For"] == "192.168.1.1" + assert headers["X-Forwarded-Host"] == "localhost:8000" + assert headers["X-Forwarded-Proto"] == "http" + assert headers["X-Forwarded-Path"] == "/" + + # Modern Forwarded header should still be present + assert "Forwarded" in headers + + +@pytest.mark.asyncio +async def test_https_proto(mock_request): + """Test that X-Forwarded-Proto is set correctly for HTTPS.""" + mock_request.scope["scheme"] = "https" + handler = ReverseProxyHandler(upstream="http://upstream-api.com") + headers = handler._prepare_headers(mock_request) + + # Check modern Forwarded header + assert "proto=https" in headers["Forwarded"] + + # Legacy headers should not be present by default + assert "X-Forwarded-Proto" not in headers + + +@pytest.mark.asyncio +async def test_https_proto_legacy(mock_request): + """Test that X-Forwarded-Proto is set correctly for HTTPS with legacy headers.""" + mock_request.scope["scheme"] = "https" + handler = ReverseProxyHandler( + upstream="http://upstream-api.com", legacy_forwarded_headers=True + ) + headers = handler._prepare_headers(mock_request) + assert headers["X-Forwarded-Proto"] == "https" + assert "proto=https" in headers["Forwarded"] + + +@pytest.mark.asyncio +async def test_non_standard_port(mock_request): + """Test handling of non-standard ports in host header.""" + mock_request.scope["headers"] = [ + (b"host", b"localhost:8080"), + (b"user-agent", b"test-agent"), + ] + handler = ReverseProxyHandler(upstream="http://upstream-api.com:8080") + headers = handler._prepare_headers(mock_request) + assert headers["Host"] == "upstream-api.com:8080"