diff --git a/README.md b/README.md index 5d2ec4a0..96f0f085 100644 --- a/README.md +++ b/README.md @@ -9,11 +9,12 @@ STAC Auth Proxy is a proxy API that mediates between the client and your interna ## ✨Features✨ -- πŸ” Authentication: Selectively apply [OpenID Connect (OIDC)](https://openid.net/developers/how-connect-works/) auth*n token validation & optional scope requirements to some or all endpoints & methods -- πŸ›‚ Content Filtering: Apply CQL2 filters to client requests, utilizing the [Filter Extension](https://github.com/stac-api-extensions/filter?tab=readme-ov-file) to filter API content based on user context -- 🧩 Authentication Extension: Integrate the [Authentication Extension](https://github.com/stac-extensions/authentication) into API responses -- πŸ“˜ OpenAPI Augmentation: Update API's [OpenAPI document](https://swagger.io/specification/) with security requirements, keeping auto-generated docs/UIs accurate (e.g. [Swagger UI](https://swagger.io/tools/swagger-ui/)) -- πŸ—œοΈ Response compression: Compress API responses via [`starlette-cramjam`](https://github.com/developmentseed/starlette-cramjam/) +- **πŸ” Authentication:** Apply [OpenID Connect (OIDC)](https://openid.net/developers/how-connect-works/) token validation and optional scope checks to specified endpoints and methods +- **πŸ›‚ Content Filtering:** Use CQL2 filters via the [Filter Extension](https://github.com/stac-api-extensions/filter?tab=readme-ov-file) to tailor API responses based on user context +- **🀝 External Policy Integration:** Integrate with externalsystems (e.g. [Open Policy Agent (OPA)](https://www.openpolicyagent.org/)) to generate CQL2 filters dynamically from policy decisions +- **🧩 Authentication Extension:** Add the [Authentication Extension](https://github.com/stac-extensions/authentication) to API responses to expose auth-related metadata +- **πŸ“˜ OpenAPI Augmentation:** Enhance the [OpenAPI spec](https://swagger.io/specification/) with security details to keep auto-generated docs and UIs (e.g., [Swagger UI](https://swagger.io/tools/swagger-ui/)) accurate +- **πŸ—œοΈ Response Compression:** Optimize response sizes using [`starlette-cramjam`](https://github.com/developmentseed/starlette-cramjam/) ## Usage @@ -185,9 +186,6 @@ The system supports generating CQL2 filters based on request context to provide > [!IMPORTANT] > The upstream STAC API must support the [STAC API Filter Extension](https://github.com/stac-api-extensions/filter/blob/main/README.md), including the [Features Filter](http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter) conformance class on to the Features resource (`/collections/{cid}/items`)[^37]. -> [!TIP] -> Integration with external authorization systems (e.g. [Open Policy Agent](https://www.openpolicyagent.org/)) can be achieved by specifying an `ITEMS_FILTER` that points to a class/function that, once initialized, returns a [`cql2.Expr` object](https://developmentseed.org/cql2-rs/latest/python/#cql2.Expr) when called with the request context. - #### Filters If enabled, filters are intended to be applied to the following endpoints: @@ -270,6 +268,108 @@ sequenceDiagram STAC API->>Client: Response ``` +#### Authoring Filter Generators + +The `ITEMS_FILTER_CLS` configuration option can be used to specify a class that will be used to generate a CQL2 filter for the request. The class must define a `__call__` method that accepts a single argument: a dictionary containing the request context; and returns a valid `cql2-text` expression (as a `str`) or `cql2-json` expression (as a `dict`). + +> [!TIP] +> An example integration can be found in [`examples/custom-integration`](https://github.com/developmentseed/stac-auth-proxy/blob/main/examples/custom-integration). + +##### Basic Filter Generator + +```py +import dataclasses +from typing import Any + +from cql2 import Expr + + +@dataclasses.dataclass +class ExampleFilter: + async def __call__(self, context: dict[str, Any]) -> str: + return "true" +``` + +> [!TIP] +> Despite being referred to as a _class_, a filter generator could be written as a function. +> +>
+> +> Example +> +> ```py +> from typing import Any +> +> from cql2 import Expr +> +> +> def example_filter(): +> def example_filter(context: dict[str, Any]) -> str | dict[str, Any]: +> return Expr("true") +> return example_filter +> ``` +> +>
+ +##### Complex Filter Generator + +An example of a more complex filter generator where the filter is generated based on the response of an external API: + +```py +import dataclasses +from typing import Any + +from httpx import AsyncClient +from stac_auth_proxy.utils.cache import MemoryCache + + +@dataclasses.dataclass +class ApprovedCollectionsFilter: + api_url: str + kind: Literal["item", "collection"] = "item" + client: AsyncClient = dataclasses.field(init=False) + cache: MemoryCache = dataclasses.field(init=False) + + def __post_init__(self): + # We keep the client in the class instance to avoid creating a new client for + # each request, taking advantage of the client's connection pooling. + self.client = AsyncClient(base_url=self.api_url) + self.cache = MemoryCache(ttl=30) + + async def __call__(self, context: dict[str, Any]) -> dict[str, Any]: + token = context["req"]["headers"].get("authorization") + + try: + # Check cache for a previously generated filter + approved_collections = self.cache[token] + except KeyError: + # Lookup approved collections from an external API + approved_collections = await self.lookup(token) + self.cache[token] = approved_collections + + # Build CQL2 filter + return { + "op": "a_containedby", + "args": [ + {"property": "collection" if self.kind == "item" else "id"}, + approved_collections + ], + } + + async def lookup(self, token: Optional[str]) -> list[str]: + # Lookup approved collections from an external API + headers = {"Authorization": f"Bearer {token}"} if token else {} + response = await self.client.get( + f"/get-approved-collections", + headers=headers, + ) + response.raise_for_status() + return response.json()["collections"] +``` + +> [!TIP] +> Filter generation runs for every relevant request. Consider memoizing external API calls to improve performance. + [^21]: https://github.com/developmentseed/stac-auth-proxy/issues/21 [^22]: https://github.com/developmentseed/stac-auth-proxy/issues/22 [^23]: https://github.com/developmentseed/stac-auth-proxy/issues/23 diff --git a/examples/custom-integration/.python-version b/examples/custom-integration/.python-version new file mode 100644 index 00000000..24ee5b1b --- /dev/null +++ b/examples/custom-integration/.python-version @@ -0,0 +1 @@ +3.13 diff --git a/examples/custom-integration/Dockerfile b/examples/custom-integration/Dockerfile new file mode 100644 index 00000000..8960de8a --- /dev/null +++ b/examples/custom-integration/Dockerfile @@ -0,0 +1,6 @@ +ARG STAC_AUTH_PROXY_VERSION +FROM ghcr.io/developmentseed/stac-auth-proxy:${STAC_AUTH_PROXY_VERSION} + +ADD . /opt/stac-auth-proxy-integration + +RUN pip install /opt/stac-auth-proxy-integration diff --git a/examples/custom-integration/README.md b/examples/custom-integration/README.md new file mode 100644 index 00000000..bf1ef848 --- /dev/null +++ b/examples/custom-integration/README.md @@ -0,0 +1,11 @@ +# Custom Integration Example + +This example demonstrates how to integrate with a custom filter generator. + +## Running the Example + +From the root directory, run: + +```sh +docker compose -f docker-compose.yaml -f examples/custom-integration/docker-compose.yaml up +``` diff --git a/examples/custom-integration/docker-compose.yaml b/examples/custom-integration/docker-compose.yaml new file mode 100644 index 00000000..0af7c118 --- /dev/null +++ b/examples/custom-integration/docker-compose.yaml @@ -0,0 +1,12 @@ +# This compose file is intended to be run alongside the `docker-compose.yaml` file in the +# root directory. + +services: + proxy: + build: + context: examples/custom-integration + args: + STAC_AUTH_PROXY_VERSION: 0.1.2 + environment: + ITEMS_FILTER_CLS: custom_integration:cql2_builder + ITEMS_FILTER_KWARGS: '{"admin_user": "user123"}' diff --git a/examples/custom-integration/pyproject.toml b/examples/custom-integration/pyproject.toml new file mode 100644 index 00000000..3d452ce3 --- /dev/null +++ b/examples/custom-integration/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "custom_integration" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.9" +dependencies = [] diff --git a/examples/custom-integration/src/custom_integration.py b/examples/custom-integration/src/custom_integration.py new file mode 100644 index 00000000..75280513 --- /dev/null +++ b/examples/custom-integration/src/custom_integration.py @@ -0,0 +1,29 @@ +""" +A custom integration example. + +In this example, we're intentionally using a functional pattern but you could also use a +class like we do in the integrations found in stac_auth_proxy.filters. +""" + +from typing import Any + + +def cql2_builder(admin_user: str): + """CQL2 builder integration filter.""" + # NOTE: This is where you would set up things like connection pools. + # NOTE: args/kwargs are passed in via environment variables. + + async def custom_integration_filter(ctx: dict[str, Any]) -> str: + """ + Generate CQL2 expressions based on the request context. + + Returns a CQL2 expression, either as a string (cql2-text) or as a dict (cql2-json). + """ + # NOTE: This is where you would perform a lookup from a database, API, etc. + # NOTE: ctx is the request context, which includes the payload, headers, etc. + + if ctx["payload"] and ctx["payload"]["sub"] == admin_user: + return "1=1" + return "private = true" + + return custom_integration_filter diff --git a/examples/opa/README.md b/examples/opa/README.md new file mode 100644 index 00000000..4abf6db9 --- /dev/null +++ b/examples/opa/README.md @@ -0,0 +1,27 @@ +# Open Policy Agent (OPA) Integration + +This example demonstrates how to integrate with an Open Policy Agent (OPA) to authorize requests to a STAC API. + +## Running the Example + +From the root directory, run: + +```sh +docker compose -f docker-compose.yaml -f examples/opa/docker-compose.yaml up +``` + +## Testing OPA + +```sh +β–Ά curl -X POST "http://localhost:8181/v1/data/stac/cql2" \ + -H "Content-Type: application/json" \ + -d '{"input":{"payload": null}}' +{"result":"private = true"} +``` + +```sh +β–Ά curl -X POST "http://localhost:8181/v1/data/stac/cql2" \ + -H "Content-Type: application/json" \ + -d '{"input":{"payload": {"sub": "user1"}}}' +{"result":"1=1"} +``` diff --git a/examples/opa/docker-compose.yaml b/examples/opa/docker-compose.yaml new file mode 100644 index 00000000..28a9d82b --- /dev/null +++ b/examples/opa/docker-compose.yaml @@ -0,0 +1,13 @@ +services: + proxy: + environment: + ITEMS_FILTER_CLS: stac_auth_proxy.filters:Opa + ITEMS_FILTER_ARGS: '["http://opa:8181", "stac/cql2"]' + + opa: + image: openpolicyagent/opa:latest + command: "run --server --addr=:8181 --watch /policies" + ports: + - "8181:8181" + volumes: + - ./examples/opa/policies:/policies diff --git a/examples/opa/policies/stac/policy.rego b/examples/opa/policies/stac/policy.rego new file mode 100644 index 00000000..e3c25235 --- /dev/null +++ b/examples/opa/policies/stac/policy.rego @@ -0,0 +1,7 @@ +package stac + +default cql2 := "\"naip:year\" = 2021" + +cql2 := "1=1" if { + input.payload.sub != null +} diff --git a/src/stac_auth_proxy/filters/__init__.py b/src/stac_auth_proxy/filters/__init__.py index 5f2833cf..fc1d1cad 100644 --- a/src/stac_auth_proxy/filters/__init__.py +++ b/src/stac_auth_proxy/filters/__init__.py @@ -1,5 +1,9 @@ """CQL2 filter generators.""" +from .opa import Opa from .template import Template -__all__ = ["Template"] +__all__ = [ + "Opa", + "Template", +] diff --git a/src/stac_auth_proxy/filters/opa.py b/src/stac_auth_proxy/filters/opa.py new file mode 100644 index 00000000..9fa05ca1 --- /dev/null +++ b/src/stac_auth_proxy/filters/opa.py @@ -0,0 +1,44 @@ +"""Integration with Open Policy Agent (OPA) to generate CQL2 filters for requests to a STAC API.""" + +from dataclasses import dataclass, field +from typing import Any + +import httpx + +from ..utils.cache import MemoryCache, get_value_by_path + + +@dataclass +class Opa: + """Call Open Policy Agent (OPA) to generate CQL2 filters from request context.""" + + host: str + decision: str + + client: httpx.AsyncClient = field(init=False) + cache: MemoryCache = field(init=False) + cache_key: str = "req.headers.authorization" + cache_ttl: float = 5.0 + + def __post_init__(self): + """Initialize the client.""" + self.client = httpx.AsyncClient(base_url=self.host) + self.cache = MemoryCache(ttl=self.cache_ttl) + + async def __call__(self, context: dict[str, Any]) -> str: + """Generate a CQL2 filter for the request.""" + token = get_value_by_path(context, self.cache_key) + try: + expr_str = self.cache[token] + except KeyError: + expr_str = await self._fetch(context) + self.cache[token] = expr_str + return expr_str + + async def _fetch(self, context: dict[str, Any]) -> str: + """Fetch the CQL2 filter from OPA.""" + response = await self.client.post( + f"/v1/data/{self.decision}", + json={"input": context}, + ) + return response.raise_for_status().json()["result"] diff --git a/src/stac_auth_proxy/utils/cache.py b/src/stac_auth_proxy/utils/cache.py new file mode 100644 index 00000000..420a7819 --- /dev/null +++ b/src/stac_auth_proxy/utils/cache.py @@ -0,0 +1,95 @@ +"""Cache utilities.""" + +from dataclasses import dataclass, field +from time import time +from typing import Any + +from stac_auth_proxy.utils.filters import logger + + +@dataclass +class MemoryCache: + """Cache results of a method call for a given key.""" + + ttl: float = 5.0 + cache: dict[tuple[Any], tuple[Any, float]] = field(default_factory=dict) + _last_pruned: float = field(default_factory=time) + + class Expired(Exception): + """Exception raised when a cache entry has expired.""" + + def __getitem__(self, key: Any) -> Any: + """Get a value from the cache if it is not expired.""" + if key not in self.cache: + msg = f"{self._key_str(key)!r} not in cache." + logger.debug(msg) + raise KeyError(msg) + + result, timestamp = self.cache[key] + if (time() - timestamp) > self.ttl: + msg = f"{self._key_str(key)} in cache, but expired." + del self.cache[key] + logger.debug(msg) + raise self.Expired(f"{key} expired") + + logger.debug(f"{self._key_str(key)} in cache, returning cached result.") + return result + + def __setitem__(self, key: Any, value: Any): + """Set a value in the cache.""" + self.cache[key] = (value, time()) + self._prune() + + def __contains__(self, key: Any) -> bool: + """Check if a key is in the cache and is not expired.""" + try: + self[key] + return True + except (KeyError, self.Expired): + return False + + def get(self, key: Any) -> Any: + """Get a value from the cache.""" + try: + return self[key] + except (KeyError, self.Expired): + return None + + def _prune(self): + """Prune the cache of expired items.""" + if time() - self._last_pruned < self.ttl: + return + self.cache = { + k: (v, time_entered) + for k, (v, time_entered) in self.cache.items() + if time_entered > (time() - self.ttl) + } + self._last_pruned = time() + + @staticmethod + def _key_str(key: Any) -> str: + """Get a string representation of a key.""" + return key if len(str(key)) < 10 else f"{str(key)[:9]}..." + + +def get_value_by_path(obj: dict, path: str, default: Any = None) -> Any: + """ + Get a value from a dictionary using dot notation. + + Args: + obj: The dictionary to search in + path: The dot notation path (e.g. "payload.sub") + default: Default value to return if path doesn't exist + + Returns + ------- + The value at the specified path or default if path doesn't exist + """ + try: + for key in path.split("."): + if obj is None: + return default + obj = obj.get(key, default) + return obj + except (AttributeError, KeyError, TypeError): + return default diff --git a/src/stac_auth_proxy/utils/filters.py b/src/stac_auth_proxy/utils/filters.py index f3aadce9..7d450648 100644 --- a/src/stac_auth_proxy/utils/filters.py +++ b/src/stac_auth_proxy/utils/filters.py @@ -1,11 +1,14 @@ """Utility functions.""" import json +import logging from typing import Optional from urllib.parse import parse_qs from cql2 import Expr +logger = logging.getLogger(__name__) + def append_qs_filter(qs: str, filter: Expr, filter_lang: Optional[str] = None) -> bytes: """Insert a filter expression into a query string. If a filter already exists, combine them.""" diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 00000000..604c52e5 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,112 @@ +"""Tests for cache utilities.""" + +from unittest.mock import patch + +import pytest + +from stac_auth_proxy.utils.cache import MemoryCache, get_value_by_path + + +def test_memory_cache_basic_operations(): + """Test basic cache operations.""" + cache = MemoryCache(ttl=5.0) # 5 second TTL + key = "test_key" + value = "test_value" + + # Test setting and getting a value + cache[key] = value + assert cache[key] == value + assert key in cache + + # Test getting non-existent key + with pytest.raises(KeyError): + _ = cache["non_existent"] + + # Test get() method + assert cache.get(key) == value + assert cache.get("non_existent") is None + + +def test_memory_cache_expiration(): + """Test cache expiration.""" + cache = MemoryCache(ttl=5.0) # 5 second TTL + key = "test_key" + value = "test_value" + + # Set initial time + with patch("stac_auth_proxy.utils.cache.time") as mock_time: + mock_time.return_value = 1000.0 + cache[key] = value + assert cache[key] == value + + # Advance time past TTL + mock_time.return_value = 1006.0 # 6 seconds later + + # Test expired key + with pytest.raises(MemoryCache.Expired): + cache[key] + + # Test contains after expiration + assert key not in cache + + +def test_memory_cache_pruning(): + """Test cache pruning.""" + cache = MemoryCache(ttl=5.0) # 5 second TTL + key1 = "key1" + key2 = "key2" + value = "test_value" + + with patch("stac_auth_proxy.utils.cache.time") as mock_time: + # Set initial time + mock_time.return_value = 1000.0 + cache[key1] = value + cache[key2] = value + + # Advance time past TTL + mock_time.return_value = 1006.0 # 6 seconds later + + # Force pruning by adding a new item + cache["key3"] = value + + # Check that expired items were pruned + assert key1 not in cache + assert key2 not in cache + assert "key3" in cache + + +def test_memory_cache_key_str(): + """Test key string representation.""" + cache = MemoryCache() + + # Test short key + short_key = "123" + assert cache._key_str(short_key) == short_key + + # Test long key + long_key = "1234567890" + assert cache._key_str(long_key) == "123456789..." + + +@pytest.mark.parametrize( + "obj, path, default, expected", + [ + # Basic path + ({"a": {"b": 1}}, "a.b", None, 1), + # Nested path + ({"a": {"b": {"c": 2}}}, "a.b.c", None, 2), + # Non-existent path + ({"a": {"b": 1}}, "a.c", None, None), + # Default value + ({"a": {"b": 1}}, "a.c", "default", "default"), + # None in path + ({"a": None}, "a.b", None, None), + # Empty path + ({"a": 1}, "", None, None), + # Complex object + ({"a": {"b": [1, 2, 3]}}, "a.b", None, [1, 2, 3]), + ], +) +def test_get_value_by_path(obj, path, default, expected): + """Test getting values by path.""" + assert get_value_by_path(obj, path, default) == expected diff --git a/tests/test_filters_opa.py b/tests/test_filters_opa.py new file mode 100644 index 00000000..1d682fd0 --- /dev/null +++ b/tests/test_filters_opa.py @@ -0,0 +1,95 @@ +"""Test OPA filter integration.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from httpx import AsyncClient, Response + +from stac_auth_proxy.filters.opa import Opa + + +@pytest.fixture +def opa_filter_generator(): + """Create an OPA instance for testing.""" + return Opa(host="http://localhost:8181", decision="stac/filter") + + +@pytest.fixture +def mock_opa_response(): + """Create a mock httpx Response.""" + response = MagicMock(spec=Response) + response.json.return_value = {"result": "collection = 'test'"} + response.raise_for_status.return_value = response + return response + + +@pytest.mark.asyncio +async def test_opa_initialization(opa_filter_generator): + """Test OPA initialization.""" + assert opa_filter_generator.host == "http://localhost:8181" + assert opa_filter_generator.decision == "stac/filter" + assert opa_filter_generator.cache_key == "req.headers.authorization" + assert opa_filter_generator.cache_ttl == 5.0 + assert isinstance(opa_filter_generator.client, AsyncClient) + assert opa_filter_generator.cache is not None + + +@pytest.mark.asyncio +async def test_opa_cache_hit(opa_filter_generator, mock_opa_response): + """Test OPA cache hit behavior.""" + context = {"req": {"headers": {"authorization": "test-token"}}} + + # Mock the OPA response + with patch.object( + opa_filter_generator.client, "post", new_callable=AsyncMock + ) as mock_post: + mock_post.return_value = mock_opa_response + + # First call should hit OPA + result = await opa_filter_generator(context) + assert result == "collection = 'test'" + assert mock_post.call_count == 1 + + # Second call should use cache + result = await opa_filter_generator(context) + assert result == "collection = 'test'" + assert mock_post.call_count == 1 # Still 1, no new call made + + +@pytest.mark.asyncio +async def test_opa_cache_miss(opa_filter_generator, mock_opa_response): + """Test OPA cache miss behavior.""" + context = {"req": {"headers": {"authorization": "test-token"}}} + + with patch.object( + opa_filter_generator.client, "post", new_callable=AsyncMock + ) as mock_post: + mock_post.return_value = mock_opa_response + + # First call with token1 + result = await opa_filter_generator(context) + assert result == "collection = 'test'" + assert mock_post.call_count == 1 + + # Call with different token should miss cache + context["req"]["headers"]["authorization"] = "different-token" + result = await opa_filter_generator(context) + assert result == "collection = 'test'" + assert mock_post.call_count == 2 # New call made + + +@pytest.mark.asyncio +async def test_opa_error_handling(opa_filter_generator): + """Test OPA error handling.""" + context = {"req": {"headers": {"authorization": "test-token"}}} + + with patch.object( + opa_filter_generator.client, "post", new_callable=AsyncMock + ) as mock_post: + # Create a mock response that raises an exception on raise_for_status + error_response = MagicMock(spec=Response) + error_response.raise_for_status.side_effect = Exception("Internal server error") + mock_post.return_value = error_response + + with pytest.raises(Exception): + await opa_filter_generator(context)