Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions {{project_name}}/app/config.py.jinja
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from functools import lru_cache
from pathlib import Path
from typing import Any

{%if project_type =="agent" %}
from pydantic import SecretStr, field_validator
{%else%}
from pydantic import SecretStr
{%endif%}
from pydantic import AnyHttpUrl, Field, SecretStr, ValidationInfo, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

from app.utils.config_utils import (
EncryptedField,
EnvironmentType,
FernetDecryptorField,
)


class Settings(BaseSettings):
model_config = SettingsConfigDict(
Expand All @@ -18,11 +21,18 @@ class Settings(BaseSettings):
# dotenv search .env when module is imported, without usecwd it starts from the file it was called
)

# CORE SETTINGS
fernet_decryptor: FernetDecryptorField = Field("MASTER_KEY")
debug: bool = False
environment: EnvironmentType = EnvironmentType.LOCAL

# API SETTINGS
api_name: str = f"{{project_name}} API"
api_v1: str = "/api/v1"
api_latest: str = api_v1
paging_limit: int = 100
cors_origins: list[AnyHttpUrl] = []
cors_allow_all: bool = False

{% if project_type in ["api-monolith", "api-microservice"] %}
# DATABASE SETTINGS
Expand Down Expand Up @@ -58,6 +68,25 @@ class Settings(BaseSettings):
return v
{% endif %}

@field_validator("cors_origins", mode="after")
@classmethod
def assemble_cors_origins(cls, v: str | list[str]) -> list[str] | str:
if isinstance(v, str) and not v.startswith("["):
return [i.strip() for i in v.split(",")]
if isinstance(v, (list, str)):
return v

# This should never be reached given the type annotation, but ensures type safety
raise ValueError(f"Unexpected type for cors_origins: {type(v)}")

@field_validator("*", mode="after")
@classmethod
def _decryptor(cls, v: Any, validation_info: ValidationInfo, *args, **kwargs) -> Any:
if isinstance(v, EncryptedField):
return v.get_decrypted_value(validation_info.data["fernet_decryptor"])
return v


{%if project_type == "mcp-server" %}
# MCP SETTINGS
mcp_server_name: str = f"MCP Server"
Expand Down
2 changes: 2 additions & 0 deletions {{project_name}}/app/main.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ from app.mcp import mcp_router
from app.database import engine
from app.integrations.sqladmin.views import add_admin_views
{% endif %}
from app.middlewares import add_cors_middleware

basicConfig(level=INFO, format="[%(asctime)s - %(name)s] (%(levelname)s) %(message)s")

Expand All @@ -35,6 +36,7 @@ admin = Admin(app=api, engine=engine)
add_admin_views(admin)
{% endif %}

add_cors_middleware(api)

@api.get("/")
async def root() -> dict[str, str]:
Expand Down
18 changes: 18 additions & 0 deletions {{project_name}}/app/middlewares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from app.config import settings


def add_cors_middleware(app: FastAPI) -> None:
cors_origins = [str(origin).rstrip("/") for origin in settings.cors_origins]
if settings.cors_allow_all:
cors_origins = ["*"]

app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
84 changes: 84 additions & 0 deletions {{project_name}}/app/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
from enum import Enum
from functools import wraps
from typing import Any, Callable, Generator, Protocol

from cryptography.fernet import Fernet
from pydantic import ValidationInfo

CallableGenerator = Generator[Callable[..., Any], None, None]


class EnvironmentType(str, Enum):
LOCAL = "local"
TEST = "test"
STAGING = "staging"
PRODUCTION = "production"


class Decryptor(Protocol):
def decrypt(self, value: bytes) -> bytes: ...


class FakeFernet:
def decrypt(self, value: bytes) -> bytes:
return value


class EncryptedField(str):
@classmethod
def __get_pydantic_json_schema__(cls, field_schema: dict[str, Any]) -> None:
field_schema.update(type="str", writeOnly=True)

@classmethod
def __get_validators__(cls) -> "CallableGenerator":
yield cls.validate

@classmethod
def validate(cls, value: str, _: ValidationInfo) -> "EncryptedField":
if isinstance(value, cls):
return value
return cls(value)

def __init__(self, value: str):
self._secret_value = "".join(value.splitlines()).strip().encode("utf-8")
self.decrypted = False

def get_decrypted_value(self, decryptor: Decryptor) -> str:
if not self.decrypted:
value = decryptor.decrypt(self._secret_value)
self._secret_value = value
self.decrypted = True
return self._secret_value.decode("utf-8")


class FernetDecryptorField(str):
def __get_pydantic_json_schema__(self, field_schema: dict[str, Any]) -> None:
field_schema.update(type="str", writeOnly=True)

@classmethod
def __get_validators__(cls) -> "CallableGenerator":
yield cls.validate

@classmethod
def validate(cls, value: str, _: ValidationInfo) -> Decryptor:
master_key = os.environ.get(value)
if not master_key:
return FakeFernet()
return Fernet(os.environ[value])


def set_env_from_settings(func: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorator to set environment variables from settings.
This decorator is useful for encrypted fields and providers that
require API keys to be available as environment variables.
"""

@wraps(func)
def wrapper(*args, **kwargs) -> Any:
settings = func(*args, **kwargs)
# os.environ["EXAMPLE_API_KEY"] = settings.EXAMPLE_API_KEY
return settings # noqa: RET504

return wrapper
7 changes: 7 additions & 0 deletions {{project_name}}/config/.env.example.jinja
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
#--- APP ---#
ENVIRONMENT="local"
DEBUG=True
CORS_ORIGINS=["http://localhost:8000", "https://localhost:8000", "http://localhost", "https://localhost"]
SERVER_HOST="https://new_project_name.dev"

#--- DB ---#
DB_HOST=localhost
DP_PORT=5432
DB_NAME={{project_name}}
Expand Down