Skip to content

Commit 72b9107

Browse files
authored
fix: use NoDecode for cors origins (#313)
* fix: use NoDecode for cors origins Otherwise, the env var syntax needs to be an encoded json string, which is awkward. * chore: update changelog * feat: add json parsing, more values * chore: update changelog
1 parent 8874baa commit 72b9107

File tree

3 files changed

+76
-7
lines changed

3 files changed

+76
-7
lines changed

CHANGES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## [Unreleased]
44

5+
### Fixed
6+
7+
- Parsing of `CORS_ORIGINS`, `CORS_HEADERS`, and `CORS_METHODS` from environment variables ([#313](https://github.com/stac-utils/stac-fastapi-pgstac/pull/313))
8+
59
### Changed
610

711
- Docker container runs as non-root user

stac_fastapi/pgstac/config.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Postgres API configuration."""
22

3+
import json
34
import warnings
45
from typing import Annotated, Any, List, Optional, Sequence, Type
56
from urllib.parse import quote_plus as quote
67

78
from pydantic import BaseModel, BeforeValidator, Field, model_validator
8-
from pydantic_settings import BaseSettings, SettingsConfigDict
9+
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
910
from stac_fastapi.types.config import ApiSettings
1011
from typing_extensions import Self
1112

@@ -158,8 +159,12 @@ def connection_string(self):
158159

159160
def str_to_list(value: Any) -> Any:
160161
if isinstance(value, str):
161-
return [v.strip() for v in value.split(",")]
162-
return value
162+
if value.startswith("["):
163+
return json.loads(value)
164+
else:
165+
return [v.strip() for v in value.split(",")]
166+
else:
167+
return value
163168

164169

165170
class Settings(ApiSettings):
@@ -201,15 +206,17 @@ class Settings(ApiSettings):
201206
Implies that the `Transactions` extension is enabled.
202207
"""
203208

204-
cors_origins: Annotated[Sequence[str], BeforeValidator(str_to_list)] = ("*",)
209+
cors_origins: Annotated[Sequence[str], BeforeValidator(str_to_list), NoDecode] = (
210+
"*",
211+
)
205212
cors_origin_regex: Optional[str] = None
206-
cors_methods: Annotated[Sequence[str], BeforeValidator(str_to_list)] = (
213+
cors_methods: Annotated[Sequence[str], BeforeValidator(str_to_list), NoDecode] = (
207214
"GET",
208215
"POST",
209216
"OPTIONS",
210217
)
211218
cors_credentials: bool = False
212-
cors_headers: Annotated[Sequence[str], BeforeValidator(str_to_list)] = (
219+
cors_headers: Annotated[Sequence[str], BeforeValidator(str_to_list), NoDecode] = (
213220
"Content-Type",
214221
)
215222

tests/test_config.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
import pytest
66
from pydantic import ValidationError
7+
from pytest import MonkeyPatch
78

8-
from stac_fastapi.pgstac.config import PostgresSettings
9+
from stac_fastapi.pgstac.config import PostgresSettings, Settings
910

1011

1112
async def test_pg_settings_with_env(monkeypatch):
@@ -74,3 +75,60 @@ async def test_pg_settings_attributes(monkeypatch):
7475
postgres_dbname="pgstac",
7576
_env_file=None,
7677
)
78+
79+
80+
@pytest.mark.parametrize(
81+
"cors_origins",
82+
[
83+
"http://stac-fastapi-pgstac.test,http://stac-fastapi.test",
84+
'["http://stac-fastapi-pgstac.test","http://stac-fastapi.test"]',
85+
],
86+
)
87+
def test_cors_origins(monkeypatch: MonkeyPatch, cors_origins: str) -> None:
88+
monkeypatch.setenv(
89+
"CORS_ORIGINS",
90+
cors_origins,
91+
)
92+
settings = Settings()
93+
assert settings.cors_origins == [
94+
"http://stac-fastapi-pgstac.test",
95+
"http://stac-fastapi.test",
96+
]
97+
98+
99+
@pytest.mark.parametrize(
100+
"cors_methods",
101+
[
102+
"GET,POST",
103+
'["GET","POST"]',
104+
],
105+
)
106+
def test_cors_methods(monkeypatch: MonkeyPatch, cors_methods: str) -> None:
107+
monkeypatch.setenv(
108+
"CORS_METHODS",
109+
cors_methods,
110+
)
111+
settings = Settings()
112+
assert settings.cors_methods == [
113+
"GET",
114+
"POST",
115+
]
116+
117+
118+
@pytest.mark.parametrize(
119+
"cors_headers",
120+
[
121+
"Content-Type,X-Foo",
122+
'["Content-Type","X-Foo"]',
123+
],
124+
)
125+
def test_cors_headers(monkeypatch: MonkeyPatch, cors_headers: str) -> None:
126+
monkeypatch.setenv(
127+
"CORS_HEADERS",
128+
cors_headers,
129+
)
130+
settings = Settings()
131+
assert settings.cors_headers == [
132+
"Content-Type",
133+
"X-Foo",
134+
]

0 commit comments

Comments
 (0)