Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def lifespan(app: FastAPI):
upstream=str(settings.upstream_url),
override_host=settings.override_host,
).proxy_request,
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
)

#
Expand Down
11 changes: 10 additions & 1 deletion src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return await self.app(scope, receive, send)

request = Request(scope)

# Skip authentication for OPTIONS requests, https://fetch.spec.whatwg.org/#cors-protocol-and-credentials
if request.method == "OPTIONS":
return await self.app(scope, receive, send)

match = find_match(
request.url.path,
request.method,
Expand Down Expand Up @@ -148,7 +153,11 @@ def validate_token(
# NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
audience=self.allowed_jwt_audiences,
)
except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e:
except (
jwt.exceptions.InvalidTokenError,
jwt.exceptions.DecodeError,
jwt.exceptions.PyJWKClientError,
) as e:
logger.error("InvalidTokenError: %r", e)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand Down
3 changes: 3 additions & 0 deletions src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
# Add security to private endpoints
for path, method_config in data["paths"].items():
for method, config in method_config.items():
if method == "options":
# OPTIONS requests are not authenticated, https://fetch.spec.whatwg.org/#cors-protocol-and-credentials
continue
match = find_match(
path,
method,
Expand Down
21 changes: 18 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,35 +87,50 @@ def source_api():

# Default responses for each endpoint
default_responses = {
"/": {"GET": {"id": "Response from GET@"}},
"/conformance": {"GET": {"conformsTo": ["http://example.com/conformance"]}},
"/queryables": {"GET": {"queryables": {}}},
"/": {
"GET": {"id": "Response from GET@"},
"OPTIONS": {"id": "Response from OPTIONS@"},
},
"/conformance": {
"GET": {"conformsTo": ["http://example.com/conformance"]},
"OPTIONS": {"id": "Response from OPTIONS@"},
},
"/queryables": {
"GET": {"queryables": {}},
"OPTIONS": {"id": "Response from OPTIONS@"},
},
"/search": {
"GET": {"type": "FeatureCollection", "features": []},
"POST": {"type": "FeatureCollection", "features": []},
"OPTIONS": {"id": "Response from OPTIONS@"},
},
"/collections": {
"GET": {"collections": []},
"POST": {"id": "Response from POST@"},
"OPTIONS": {"id": "Response from OPTIONS@"},
},
"/collections/{collection_id}": {
"GET": {"id": "Response from GET@"},
"PUT": {"id": "Response from PUT@"},
"PATCH": {"id": "Response from PATCH@"},
"DELETE": {"id": "Response from DELETE@"},
"OPTIONS": {"id": "Response from OPTIONS@"},
},
"/collections/{collection_id}/items": {
"GET": {"type": "FeatureCollection", "features": []},
"POST": {"id": "Response from POST@"},
"OPTIONS": {"id": "Response from OPTIONS@"},
},
"/collections/{collection_id}/items/{item_id}": {
"GET": {"id": "Response from GET@"},
"PUT": {"id": "Response from PUT@"},
"PATCH": {"id": "Response from PATCH@"},
"DELETE": {"id": "Response from DELETE@"},
"OPTIONS": {"id": "Response from OPTIONS@"},
},
"/collections/{collection_id}/bulk_items": {
"POST": {"id": "Response from POST@"},
"OPTIONS": {"id": "Response from OPTIONS@"},
},
}

Expand Down
175 changes: 175 additions & 0 deletions tests/test_authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,178 @@ def test_scopes(
)
expected_status_code = 200 if expected_permitted else 401
assert response.status_code == expected_status_code


@pytest.mark.parametrize(
"path,default_public,private_endpoints",
[
("/", False, {}),
("/collections", False, {}),
("/search", False, {}),
("/collections", True, {r"^/collections$": [("POST", "collection:create")]}),
("/search", True, {r"^/search$": [("POST", "search:write")]}),
(
"/collections/example-collection/items",
True,
{r"^/collections/.*/items$": [("POST", "item:create")]},
),
],
)
def test_options_bypass_auth(
path, default_public, private_endpoints, source_api_server
):
"""OPTIONS requests should bypass authentication regardless of endpoint configuration."""
test_app = app_factory(
upstream_url=source_api_server,
default_public=default_public,
private_endpoints=private_endpoints,
)
client = TestClient(test_app)
response = client.options(path)
assert response.status_code == 200, "OPTIONS request should bypass authentication"


@pytest.mark.parametrize(
"path,method,default_public,private_endpoints,expected_status",
[
# Test that non-OPTIONS requests still require auth when endpoints are private
("/collections", "GET", False, {}, 403),
("/collections", "POST", False, {}, 403),
("/search", "GET", False, {}, 403),
# Test that OPTIONS requests bypass auth even when endpoints are private
("/collections", "OPTIONS", False, {}, 200),
("/search", "OPTIONS", False, {}, 200),
# Test with specific private endpoint configurations
(
"/collections",
"POST",
True,
{r"^/collections$": [("POST", "collection:create")]},
403,
),
(
"/collections",
"OPTIONS",
True,
{r"^/collections$": [("POST", "collection:create")]},
200,
),
],
)
def test_options_vs_other_methods_auth_behavior(
path, method, default_public, private_endpoints, expected_status, source_api_server
):
"""Compare authentication behavior between OPTIONS and other HTTP methods."""
test_app = app_factory(
upstream_url=source_api_server,
default_public=default_public,
private_endpoints=private_endpoints,
)
client = TestClient(test_app)
response = client.request(method=method, url=path, headers={})
assert response.status_code == expected_status


@pytest.mark.parametrize(
"path,method,default_public,private_endpoints,expected_status",
[
# Test that requests with valid auth succeed
("/collections", "GET", False, {}, 200),
("/collections", "POST", False, {}, 200),
("/search", "GET", False, {}, 200),
("/collections", "OPTIONS", False, {}, 200),
("/search", "OPTIONS", False, {}, 200),
# Test with specific private endpoint configurations
(
"/collections",
"POST",
True,
{r"^/collections$": [("POST", "collection:create")]},
200,
),
(
"/collections",
"OPTIONS",
True,
{r"^/collections$": [("POST", "collection:create")]},
200,
),
],
)
def test_options_vs_other_methods_with_valid_auth(
path,
method,
default_public,
private_endpoints,
expected_status,
source_api_server,
token_builder,
):
"""Compare authentication behavior between OPTIONS and other HTTP methods with valid auth."""
test_app = app_factory(
upstream_url=source_api_server,
default_public=default_public,
private_endpoints=private_endpoints,
)
valid_auth_token = token_builder({"scope": "collection:create"})
client = TestClient(test_app)
response = client.request(
method=method,
url=path,
headers={"Authorization": f"Bearer {valid_auth_token}"},
)
assert response.status_code == expected_status


@pytest.mark.parametrize(
"invalid_token,expected_status",
[
("Bearer invalid-token", 401),
(
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
401,
),
("InvalidFormat", 401),
("Bearer", 401),
("", 403), # No auth header returns 403, not 401
],
)
def test_with_invalid_tokens_fails(invalid_token, expected_status, source_api_server):
"""GET requests should fail with invalid or malformed tokens."""
test_app = app_factory(
upstream_url=source_api_server,
default_public=False, # All endpoints private
private_endpoints={},
)
client = TestClient(test_app)
response = client.get("/collections", headers={"Authorization": invalid_token})
assert (
response.status_code == expected_status
), f"GET request should fail with token: {invalid_token}"

response = client.options("/collections", headers={"Authorization": invalid_token})
assert (
response.status_code == 200
), f"OPTIONS request should succeed with token: {invalid_token}"


def test_options_requests_with_cors_headers(source_api_server):
"""OPTIONS requests should work properly with CORS headers."""
test_app = app_factory(
upstream_url=source_api_server,
default_public=False, # All endpoints private
private_endpoints={},
)
client = TestClient(test_app)

# Test OPTIONS request with CORS headers
cors_headers = {
"Origin": "https://example.com",
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Content-Type,Authorization",
}

response = client.options("/collections", headers=cors_headers)
assert (
response.status_code == 200
), "OPTIONS request with CORS headers should succeed"
28 changes: 20 additions & 8 deletions tests/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_oidc_in_openapi_spec_public_endpoints(
source_api: FastAPI, source_api_server: str
):
"""When OpenAPI spec endpoint is set & endpoints are marked public, those endpoints are not marked private in the spec."""
public = {r"^/queryables$": ["GET"], r"^/api": ["GET"]}
public = {r"^/queryables$": ["GET"], r"^/api$": ["GET"]}
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
Expand All @@ -140,17 +140,29 @@ def test_oidc_in_openapi_spec_public_endpoints(

openapi = client.get(source_api.openapi_url).raise_for_status().json()

expected_auth = {"/queryables": ["GET"]}
expected_required_auth = {"/queryables": ["GET"]}
for path, method_config in openapi["paths"].items():
for method, config in method_config.items():
security = config.get("security")

if method == "options":
assert (
not security
), f"OPTIONS {path} requests should not require authentication"
continue

if security:
assert path not in expected_auth
else:
assert path in expected_auth
assert any(
method.casefold() == m.casefold() for m in expected_auth[path]
)
assert (
path not in expected_required_auth
), f"Path {path} should not require authentication"
continue

assert (
path in expected_required_auth
), f"Path {path} should require authentication"
assert any(
method.casefold() == m.casefold() for m in expected_required_auth[path]
)


def test_auth_scheme_name_override(source_api: FastAPI, source_api_server: str):
Expand Down
Loading