Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ STAC Auth Proxy is a proxy API that mediates between the client and your interna
## ✨Features✨

- **🔐 Authentication:** Apply [OpenID Connect (OIDC)](https://openid.net/developers/how-connect-works/) token validation and optional scope checks to specified endpoints and methods
- **🛂 Content Filtering:** Use CQL2 filters via the [Filter Extension](https://github.com/stac-api-extensions/filter?tab=readme-ov-file) to tailor API responses based on user context
- **🛂 Content Filtering:** Use CQL2 filters via the [Filter Extension](https://github.com/stac-api-extensions/filter?tab=readme-ov-file) to tailor API responses based on request context (e.g. user role)
- **🤝 External Policy Integration:** Integrate with external systems (e.g. [Open Policy Agent (OPA)](https://www.openpolicyagent.org/)) to generate CQL2 filters dynamically from policy decisions
- **🧩 Authentication Extension:** Add the [Authentication Extension](https://github.com/stac-extensions/authentication) to API responses to expose auth-related metadata
- **📘 OpenAPI Augmentation:** Enhance the [OpenAPI spec](https://swagger.io/specification/) with security details to keep auto-generated docs and UIs (e.g., [Swagger UI](https://swagger.io/tools/swagger-ui/)) accurate
Expand Down Expand Up @@ -158,6 +158,18 @@ The application is configurable via environment variables.
- **Type:** Dictionary of keyword arguments used to initialize the class
- **Required:** No, defaults to `{}`
- **Example:** `{"field_name": "properties.organization"}`
- **`COLLECTIONS_FILTER_CLS`**, CQL2 expression generator for collection-level filtering
- **Type:** JSON object with class configuration
- **Required:** No, defaults to `null` (disabled)
- **Example:** `stac_auth_proxy.filters:Opa`, `stac_auth_proxy.filters:Template`, `my_package:OrganizationFilter`
- **`COLLECTIONS_FILTER_ARGS`**, Positional arguments for CQL2 expression generator
- **Type:** List of positional arguments used to initialize the class
- **Required:** No, defaults to `[]`
- **Example:**: `["org1"]`
- **`COLLECTIONS_FILTER_KWARGS`**, Keyword arguments for CQL2 expression generator
- **Type:** Dictionary of keyword arguments used to initialize the class
- **Required:** No, defaults to `{}`
- **Example:** `{"field_name": "properties.organization"}`

### Tips

Expand Down Expand Up @@ -227,7 +239,7 @@ The system supports generating CQL2 filters based on request context to provide

#### Filters

If enabled, filters are intended to be applied to the following endpoints:
If enabled, filters are applied to the following endpoints:

- `GET /search`
- **Supported:** ✅
Expand All @@ -250,12 +262,12 @@ If enabled, filters are intended to be applied to the following endpoints:
- **Applied Filter:** `ITEMS_FILTER`
- **Strategy:** Validate response against CQL2 query.
- `GET /collections`
- **Supported:** ❌[^23]
- **Supported:**
- **Action:** Read Collection
- **Applied Filter:** `COLLECTIONS_FILTER`
- **Strategy:** Append query params with generated CQL2 query.
- `GET /collections/{collection_id}`
- **Supported:** ❌[^23]
- **Supported:**
- **Action:** Read Collection
- **Applied Filter:** `COLLECTIONS_FILTER`
- **Strategy:** Validate response against CQL2 query.
Expand Down Expand Up @@ -411,6 +423,5 @@ class ApprovedCollectionsFilter:

[^21]: https://github.com/developmentseed/stac-auth-proxy/issues/21
[^22]: https://github.com/developmentseed/stac-auth-proxy/issues/22
[^23]: https://github.com/developmentseed/stac-auth-proxy/issues/23
[^30]: https://github.com/developmentseed/stac-auth-proxy/issues/30
[^37]: https://github.com/developmentseed/stac-auth-proxy/issues/37
4 changes: 3 additions & 1 deletion examples/opa/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ services:
proxy:
environment:
ITEMS_FILTER_CLS: stac_auth_proxy.filters:Opa
ITEMS_FILTER_ARGS: '["http://opa:8181", "stac/cql2"]'
ITEMS_FILTER_ARGS: '["http://opa:8181", "stac/items_cql2"]'
COLLECTIONS_FILTER_CLS: stac_auth_proxy.filters:Opa
COLLECTIONS_FILTER_ARGS: '["http://opa:8181", "stac/collections_cql2"]'

opa:
image: openpolicyagent/opa:latest
Expand Down
10 changes: 8 additions & 2 deletions examples/opa/policies/stac/policy.rego
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package stac

default cql2 := "\"naip:year\" = 2021"
default items_cql2 := "\"naip:year\" = 2021"

cql2 := "1=1" if {
items_cql2 := "1=1" if {
input.payload.sub != null
}

default collections_cql2 := "id = 'naip'"

collections_cql2 := "1=1" if {
input.payload.sub != null
}
7 changes: 5 additions & 2 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,16 @@ async def lifespan(app: FastAPI):
auth_scheme_override=settings.openapi_auth_scheme_override,
)

if settings.items_filter:
if settings.items_filter or settings.collections_filter:
app.add_middleware(
ApplyCql2FilterMiddleware,
)
app.add_middleware(
BuildCql2FilterMiddleware,
items_filter=settings.items_filter(),
items_filter=settings.items_filter() if settings.items_filter else None,
collections_filter=(
settings.collections_filter() if settings.collections_filter else None
),
)

app.add_middleware(
Expand Down
1 change: 1 addition & 0 deletions src/stac_auth_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class Settings(BaseSettings):

# Filters
items_filter: Optional[ClassInput] = None
collections_filter: Optional[ClassInput] = None

model_config = SettingsConfigDict(
env_nested_delimiter="_",
Expand Down
56 changes: 39 additions & 17 deletions src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
r"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter",
r"https://api.stacspec.org/v1\.\d+\.\d+(?:-[\w\.]+)?/item-search#filter",
)
@dataclass(frozen=True)
class ApplyCql2FilterMiddleware:
Expand All @@ -31,6 +29,11 @@ class ApplyCql2FilterMiddleware:
app: ASGIApp
state_key: str = "cql2_filter"

single_record_endpoints = [
r"^/collections/([^/]+)/items/([^/]+)$",
r"^/collections/([^/]+)$",
]

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Add the Cql2Filter to the request."""
if scope["type"] != "http":
Expand All @@ -51,7 +54,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
)
return await req_body_handler(scope, receive, send)

if re.match(r"^/collections/([^/]+)/items/([^/]+)$", request.url.path):
# Handle single record requests (ie non-filterable endpoints)
if any(
re.match(expr, request.url.path) for expr in self.single_record_endpoints
):
res_body_validator = Cql2ResponseBodyValidator(
app=self.app,
cql2_filter=cql2_filter,
Expand Down Expand Up @@ -125,18 +131,22 @@ async def __call__(self, scope: Scope, send: Send, receive: Receive) -> None:
body = b""
initial_message: Optional[Message] = None

async def _send_error_response(status: int, message: str) -> None:
async def _send_error_response(status: int, code: str, message: str) -> None:
"""Send an error response with the given status and message."""
assert initial_message, "Initial message not set"
error_body = json.dumps({"message": message}).encode("utf-8")
response_dict = {
"code": code,
"description": message,
}
response_bytes = json.dumps(response_dict).encode("utf-8")
headers = MutableHeaders(scope=initial_message)
headers["content-length"] = str(len(error_body))
headers["content-length"] = str(len(response_bytes))
initial_message["status"] = status
await send(initial_message)
await send(
{
"type": "http.response.body",
"body": error_body,
"body": response_bytes,
"more_body": False,
}
)
Expand All @@ -145,28 +155,37 @@ async def buffered_send(message: Message) -> None:
"""Process a response message and apply filtering if needed."""
nonlocal body
nonlocal initial_message
initial_message = initial_message or message
# NOTE: to avoid data-leak, we process 404s so their responses are the same as rejected 200s
should_process = initial_message["status"] in [200, 404]

if not should_process:
return await send(message)

if message["type"] == "http.response.start":
initial_message = message
# Hold off on sending response headers until we've validated the response body
return

assert initial_message, "Initial message not set"

body += message["body"]
if message.get("more_body"):
return

try:
body_json = json.loads(body)
except json.JSONDecodeError:
logger.warning("Failed to parse response body as JSON")
await _send_error_response(502, "Not found")
msg = "Failed to parse response body as JSON"
logger.warning(msg)
await _send_error_response(status=502, code="ParseError", message=msg)
return

logger.debug(
"Applying %s filter to %s", self.cql2_filter.to_text(), body_json
)
if self.cql2_filter.matches(body_json):
try:
cql2_matches = self.cql2_filter.matches(body_json)
except Exception as e:
cql2_matches = False
logger.warning("Failed to apply filter: %s", e)

if cql2_matches:
logger.debug("Response matches filter, returning record")
await send(initial_message)
return await send(
{
Expand All @@ -175,6 +194,9 @@ async def buffered_send(message: Message) -> None:
"more_body": False,
}
)
return await _send_error_response(404, "Not found")
logger.debug("Response did not match filter, returning 404")
return await _send_error_response(
status=404, code="NotFoundError", message="Record not found."
)

return await self.app(scope, receive, buffered_send)
40 changes: 38 additions & 2 deletions src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
from starlette.types import ASGIApp, Receive, Scope, Send

from ..utils import requests
from ..utils.middleware import required_conformance

logger = logging.getLogger(__name__)


@required_conformance(
"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
)
@dataclass(frozen=True)
class BuildCql2FilterMiddleware:
"""Middleware to build the Cql2Filter."""
Expand All @@ -25,7 +31,37 @@ class BuildCql2FilterMiddleware:

# Filters
collections_filter: Optional[Callable] = None
collections_filter_path: str = r"^/collections(/[^/]+)?$"
items_filter: Optional[Callable] = None
items_filter_path: str = r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)"

def __post_init__(self):
"""Set required conformances based on the filter functions."""
required_conformances = set()
if self.collections_filter:
logger.debug("Appending required conformance for collections filter")
required_conformances.update(
[
"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/filter",
"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
r"https://api.stacspec.org/v1\.0\.0(?:-[\w\.]+)?/item-search#filter",
"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter",
]
)
if self.items_filter:
logger.debug("Appending required conformance for items filter")
required_conformances.update(
[
"https://api.stacspec.org/v1.0.0/core",
r"https://api.stacspec.org/v1\.0\.0(?:-[\w\.]+)?/collection-search#filter",
"http://www.opengis.net/spec/ogcapi-common-2/1.0/conf/simple-query",
]
)

# Must set required conformances on class
self.__class__.__required_conformances__ = required_conformances.union(
getattr(self.__class__, "__required_conformances__", [])
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Build the CQL2 filter, place on the request state."""
Expand Down Expand Up @@ -65,8 +101,8 @@ def _get_filter(
) -> Optional[Callable[..., Awaitable[str | dict[str, Any]]]]:
"""Get the CQL2 filter builder for the given path."""
endpoint_filters = [
(r"^/collections(/[^/]+)?$", self.collections_filter),
(r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)", self.items_filter),
(self.collections_filter_path, self.collections_filter),
(self.items_filter_path, self.items_filter),
]
for expr, builder in endpoint_filters:
if re.match(expr, path):
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def mock_env():
@pytest.fixture
async def mock_upstream() -> AsyncGenerator[MagicMock, None]:
"""Mock the HTTPX send method. Useful when we want to inspect the request is sent to upstream API."""
# NOTE: This fixture will interfere with the source_api_responses fixture

async def store_body(request, **kwargs):
"""Exhaust and store the request body."""
Expand Down
Loading
Loading