Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/user-guide/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ The application is configurable via environment variables.
**Required:** No, defaults to the value of `OIDC_DISCOVERY_URL`
**Example:** `http://auth/.well-known/openid-configuration`

### `ALLOWED_JWT_AUDIENCES`

: Unique identifier(s) of API resource server(s)

**Type:** string
**Required:** No
**Example:** `https://auth.example.audience.1.net,https://auth.example.audience.2.net`
**Note** A comma-separated list of the intended recipient(s) of the JWT. At least one audience value must match the `aud` (audience) claim present in the incoming JWT. If undefined, the API will not impose a check on the `aud` claim


### `DEFAULT_PUBLIC`

: Default access policy for endpoints
Expand Down
1 change: 1 addition & 0 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ async def lifespan(app: FastAPI):
private_endpoints=settings.private_endpoints,
default_public=settings.default_public,
oidc_discovery_url=settings.oidc_discovery_internal_url,
allowed_jwt_audiences=settings.allowed_jwt_audiences,
)

if settings.root_path or settings.upstream_url.path != "/":
Expand Down
17 changes: 16 additions & 1 deletion src/stac_auth_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib
from typing import Any, Literal, Optional, Sequence, TypeAlias, Union

from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic.networks import HttpUrl
from pydantic_settings import BaseSettings, SettingsConfigDict

Expand All @@ -16,6 +16,14 @@
_PREFIX_PATTERN = r"^/.*$"


def str2list(x: Optional[str] = None) -> Optional[Sequence[str]]:
"""Convert string to list based on , delimiter."""
if x:
return x.replace(" ", "").split(",")

return None


class _ClassInput(BaseModel):
"""Input model for dynamically loading a class or function."""

Expand All @@ -39,6 +47,7 @@ class Settings(BaseSettings):
upstream_url: HttpUrl
oidc_discovery_url: HttpUrl
oidc_discovery_internal_url: HttpUrl
allowed_jwt_audiences: Optional[Sequence[str]] = None

root_path: str = ""
override_host: bool = True
Expand Down Expand Up @@ -92,3 +101,9 @@ def _default_oidc_discovery_internal_url(cls, data: Any) -> Any:
if not data.get("oidc_discovery_internal_url"):
data["oidc_discovery_internal_url"] = data.get("oidc_discovery_url")
return data

@field_validator("allowed_jwt_audiences", mode="before")
@classmethod
def parse_audience(cls, v) -> Optional[Sequence[str]]:
"""Parse a comma separated string list of audiences into a list."""
return str2list(v)
7 changes: 7 additions & 0 deletions src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def validate_token(
# NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
audience=self.allowed_jwt_audiences,
)
except jwt.InvalidAudienceError as e:
logger.error("InvalidAudienceError: %r", e)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate Audience",
headers={"WWW-Authenticate": "Bearer"},
)
except (
jwt.exceptions.InvalidTokenError,
jwt.exceptions.DecodeError,
Expand Down
122 changes: 122 additions & 0 deletions tests/test_authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,125 @@ def test_options_requests_with_cors_headers(source_api_server):
assert (
response.status_code == 200
), "OPTIONS request with CORS headers should succeed"


@pytest.mark.parametrize(
"token_audiences,allowed_audiences,expected_status",
[
# Single audience scenarios
(["stac-api"], "stac-api", 200),
(["stac-api"], "different-api", 401),
(["stac-api"], "stac-api,other-api", 200),
# Multiple audiences in token
(["stac-api", "other-api"], "stac-api", 200),
(["stac-api", "other-api"], "other-api", 200),
(["stac-api", "other-api"], "different-api", 401),
(["stac-api", "other-api"], "stac-api, other-api,third-api", 200),
# No audience in token
(None, "stac-api", 401),
("", "stac-api", 401),
# Empty allowed audiences will regect tokens with an `aud` claim
("any-api", "", 401),
("any-api", None, 401),
# Backward compatibility - no audience configured
(None, None, 200),
("", None, 200),
],
)
def test_jwt_audience_validation(
source_api_server,
token_builder,
token_audiences,
allowed_audiences,
expected_status,
):
"""Test JWT audience validation with various configurations."""
# Build app with audience configuration
app_factory = AppFactory(
oidc_discovery_url="https://example-stac-api.com/.well-known/openid-configuration",
default_public=False,
allowed_jwt_audiences=allowed_audiences,
)
test_app = app_factory(upstream_url=source_api_server)

# Build token with audience claim
token_payload = {}
if token_audiences is not None:
token_payload["aud"] = token_audiences

valid_auth_token = token_builder(token_payload)

client = TestClient(test_app)
response = client.get(
"/collections",
headers={"Authorization": f"Bearer {valid_auth_token}"},
)
assert response.status_code == expected_status


@pytest.mark.parametrize(
"aud_value,scope,expected_status,description",
[
(["stac-api"], "openid", 401, "Valid audience but missing scope"),
(["stac-api"], "collection:create", 200, "Valid audience and valid scope"),
(["wrong-api"], "collection:create", 401, "Invalid audience but valid scope"),
],
)
def test_audience_validation_with_scopes(
source_api_server, token_builder, aud_value, scope, expected_status, description
):
"""Test that audience validation works alongside scope validation."""
app_factory = AppFactory(
oidc_discovery_url="https://example-stac-api.com/.well-known/openid-configuration",
default_public=False,
allowed_jwt_audiences="stac-api",
private_endpoints={r"^/collections$": [("POST", "collection:create")]},
)
test_app = app_factory(upstream_url=source_api_server)

client = TestClient(test_app)

token = token_builder({"aud": aud_value, "scope": scope})
response = client.post(
"/collections",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == expected_status


@pytest.mark.parametrize(
"allowed_audiences_config,test_audience,expected_status",
[
# Comma-separated string
("stac-api,other-api", "stac-api", 200),
("stac-api,other-api", "other-api", 200),
("stac-api,other-api", "unknown-api", 401),
# Comma-separated with spaces
("stac-api, other-api", "stac-api", 200),
("stac-api, other-api", "other-api", 200),
("stac-api, other-api", "unknown-api", 401),
],
)
def test_allowed_audiences_configuration_formats(
source_api_server,
token_builder,
allowed_audiences_config,
test_audience,
expected_status,
):
"""Test different configuration formats for ALLOWED_JWT_AUDIENCES."""
app_factory = AppFactory(
oidc_discovery_url="https://example-stac-api.com/.well-known/openid-configuration",
default_public=False,
allowed_jwt_audiences=allowed_audiences_config,
)
test_app = app_factory(upstream_url=source_api_server)

client = TestClient(test_app)

token = token_builder({"aud": [test_audience]})
response = client.get(
"/collections",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == expected_status
Loading