From 676a2ebc6cb0233b084afa8fbdfd14e31fba6458 Mon Sep 17 00:00:00 2001 From: john0isaac Date: Fri, 28 Jun 2024 19:52:25 +0300 Subject: [PATCH 1/9] add pytest dependency psycopg2-binary is required by create_engine for postgresql database in setting up the tests --- .github/workflows/app-tests.yaml | 10 ++++++++-- pyproject.toml | 22 ++++++++++++++++------ requirements-dev.txt | 6 +++++- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/.github/workflows/app-tests.yaml b/.github/workflows/app-tests.yaml index af4a2d37..488e1f76 100755 --- a/.github/workflows/app-tests.yaml +++ b/.github/workflows/app-tests.yaml @@ -5,9 +5,13 @@ on: branches: [ main ] pull_request: branches: [ main ] + workflow_dispatch: + +permissions: + contents: read jobs: - test_package: + test-package: name: Test ${{ matrix.os }} Python ${{ matrix.python_version }} runs-on: ${{ matrix.os }} strategy: @@ -65,4 +69,6 @@ jobs: run: | cd ./src/frontend npm install - npm run build \ No newline at end of file + npm run build + - name: Run Pytest + run: python3 -m pytest diff --git a/pyproject.toml b/pyproject.toml index d6d7928b..a0f8ab4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,20 @@ [tool.ruff] line-length = 120 -target-version = "py311" +target-version = "py312" +lint.select = ["E", "F", "I", "UP"] +lint.ignore = ["D203"] +lint.isort.known-first-party = ["fastapi_app"] -[tool.ruff.lint] -select = ["E", "F", "I", "UP"] -ignore = ["D203"] +[tool.mypy] +check_untyped_defs = true +python_version = 3.12 +exclude = [".venv/*"] -[tool.ruff.lint.isort] -known-first-party = ["fastapi_app"] +[tool.pytest.ini_options] +addopts = "-ra --cov" +testpaths = ["tests"] +pythonpath = ['src'] +filterwarnings = ["ignore::DeprecationWarning"] + +[tool.coverage.report] +show_missing = true diff --git a/requirements-dev.txt b/requirements-dev.txt index d1c56a4b..1acba1ea 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,8 @@ ruff pre-commit pip-tools -pip-compile-cross-platform \ No newline at end of file +pip-compile-cross-platform +pytest +pytest-cov +pytest-asyncio +psycopg2-binary From 296db5c66cbc25840dfc0e2eae8ec8e52f98e963 Mon Sep 17 00:00:00 2001 From: john0isaac Date: Fri, 28 Jun 2024 19:53:33 +0300 Subject: [PATCH 2/9] restrict workflow run and cache dependencies for faster reruns There is no need to do a full installation to use ruff this will make the workflow faster --- .github/workflows/python-code-quality.yaml | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python-code-quality.yaml b/.github/workflows/python-code-quality.yaml index 56191b1a..fce62c6f 100644 --- a/.github/workflows/python-code-quality.yaml +++ b/.github/workflows/python-code-quality.yaml @@ -3,11 +3,21 @@ name: Python code quality on: push: branches: [ main ] + paths: + - '**.py' + pull_request: branches: [ main ] + paths: + - '**.py' + + workflow_dispatch: + +permissions: + contents: read jobs: - build: + checks-format-and-lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -15,11 +25,12 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.12" + cache: 'pip' - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install -r requirements-dev.txt + python3 -m pip install --upgrade pip + python3 -m pip install ruff - name: Lint with ruff run: ruff check . - name: Check formatting with ruff - run: ruff format --check . + run: ruff format . --check From d41a6dfcc02552bb47af80d4c5e6ebee68f6cb1a Mon Sep 17 00:00:00 2001 From: john0isaac Date: Fri, 28 Jun 2024 19:54:09 +0300 Subject: [PATCH 3/9] mypy fixes --- src/fastapi_app/embeddings.py | 2 +- src/fastapi_app/postgres_searcher.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fastapi_app/embeddings.py b/src/fastapi_app/embeddings.py index eb6942c4..769db29c 100644 --- a/src/fastapi_app/embeddings.py +++ b/src/fastapi_app/embeddings.py @@ -4,7 +4,7 @@ async def compute_text_embedding( - q: str, openai_client, embed_model: str, embed_deployment: str = None, embedding_dimensions: int = 1536 + q: str, openai_client, embed_model: str, embed_deployment: str | None = None, embedding_dimensions: int = 1536 ): SUPPORTED_DIMENSIONS_MODEL = { "text-embedding-ada-002": False, diff --git a/src/fastapi_app/postgres_searcher.py b/src/fastapi_app/postgres_searcher.py index 5d13650e..ab3197d2 100644 --- a/src/fastapi_app/postgres_searcher.py +++ b/src/fastapi_app/postgres_searcher.py @@ -103,7 +103,7 @@ async def search( async def search_and_embed( self, - query_text: str, + query_text: str | None = None, top: int = 5, enable_vector_search: bool = False, enable_text_search: bool = False, From 44446bbf41233e46a5f9610290c2a682c7626da1 Mon Sep 17 00:00:00 2001 From: john0isaac Date: Fri, 28 Jun 2024 19:55:28 +0300 Subject: [PATCH 4/9] add dataclasses for response for better typing and easier usage this also improves the swagger api examples as it's typed now --- src/fastapi_app/api_models.py | 15 ++++++++++++++ src/fastapi_app/api_routes.py | 6 ++++-- src/fastapi_app/rag_advanced.py | 36 ++++++++++++++++----------------- src/fastapi_app/rag_simple.py | 29 +++++++++++++------------- 4 files changed, 51 insertions(+), 35 deletions(-) diff --git a/src/fastapi_app/api_models.py b/src/fastapi_app/api_models.py index 0945cb10..1c7a477f 100644 --- a/src/fastapi_app/api_models.py +++ b/src/fastapi_app/api_models.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Any from pydantic import BaseModel @@ -17,3 +18,17 @@ class ThoughtStep(BaseModel): title: str description: Any props: dict = {} + + +@dataclass +class RAGContext: + data_points: dict[int, dict[str, Any]] + thoughts: list[ThoughtStep] + followup_questions: list[str] | None = None + + +@dataclass +class RetrievalResponse: + message: Message + context: RAGContext + session_state: Any | None = None diff --git a/src/fastapi_app/api_routes.py b/src/fastapi_app/api_routes.py index 8c03dda7..c41ea010 100644 --- a/src/fastapi_app/api_routes.py +++ b/src/fastapi_app/api_routes.py @@ -9,6 +9,8 @@ from fastapi_app.rag_advanced import AdvancedRAGChat from fastapi_app.rag_simple import SimpleRAGChat +from .api_models import RetrievalResponse + router = fastapi.APIRouter() @@ -52,7 +54,7 @@ async def search_handler(query: str, top: int = 5, enable_vector_search: bool = return [item.to_dict() for item in results] -@router.post("/chat") +@router.post("/chat", response_model=RetrievalResponse) async def chat_handler(chat_request: ChatRequest): messages = [message.model_dump() for message in chat_request.messages] overrides = chat_request.context.get("overrides", {}) @@ -79,5 +81,5 @@ async def chat_handler(chat_request: ChatRequest): chat_deployment=global_storage.openai_chat_deployment, ) - response = await ragchat.run(messages, overrides=overrides) + response: RetrievalResponse = await ragchat.run(messages, overrides=overrides) return response diff --git a/src/fastapi_app/rag_advanced.py b/src/fastapi_app/rag_advanced.py index d603d997..1952688e 100644 --- a/src/fastapi_app/rag_advanced.py +++ b/src/fastapi_app/rag_advanced.py @@ -5,12 +5,10 @@ ) from openai import AsyncOpenAI -from openai.types.chat import ( - ChatCompletion, -) +from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from .api_models import ThoughtStep +from .api_models import RAGContext, RetrievalResponse, ThoughtStep from .postgres_searcher import PostgresSearcher from .query_rewriter import build_search_function, extract_search_arguments @@ -35,7 +33,7 @@ def __init__( async def run( self, messages: list[dict], overrides: dict[str, Any] = {} - ) -> dict[str, Any] | AsyncGenerator[dict[str, Any], None]: + ) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]: text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] top = overrides.get("top", 3) @@ -45,7 +43,7 @@ async def run( # Generate an optimized keyword search query based on the chat history and the last question query_response_token_limit = 500 - query_messages = build_messages( + query_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, system_prompt=self.query_prompt_template, new_user_content=original_user_query, @@ -55,7 +53,7 @@ async def run( ) chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create( - messages=query_messages, # type: ignore + messages=query_messages, # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, temperature=0.0, # Minimize creativity for search query generation @@ -81,7 +79,7 @@ async def run( # Generate a contextual and content specific answer using the search results and chat history response_token_limit = 1024 - messages = build_messages( + contextual_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, new_user_content=original_user_query + "\n\nSources:\n" + content, @@ -90,21 +88,21 @@ async def run( fallback_to_default=True, ) - chat_completion_response = await self.openai_chat_client.chat.completions.create( + chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, - messages=messages, + messages=contextual_messages, temperature=overrides.get("temperature", 0.3), max_tokens=response_token_limit, n=1, stream=False, ) - first_choice = chat_completion_response.model_dump()["choices"][0] - return { - "message": first_choice["message"], - "context": { - "data_points": {item.id: item.to_dict() for item in results}, - "thoughts": [ + first_choice = chat_completion_response.choices[0] + return RetrievalResponse( + message=first_choice.message, + context=RAGContext( + data_points={item.id: item.to_dict() for item in results}, + thoughts=[ ThoughtStep( title="Prompt to generate search arguments", description=[str(message) for message in query_messages], @@ -130,7 +128,7 @@ async def run( ), ThoughtStep( title="Prompt to generate answer", - description=[str(message) for message in messages], + description=[str(message) for message in contextual_messages], props=( {"model": self.chat_model, "deployment": self.chat_deployment} if self.chat_deployment @@ -138,5 +136,5 @@ async def run( ), ), ], - }, - } + ), + ) diff --git a/src/fastapi_app/rag_simple.py b/src/fastapi_app/rag_simple.py index bf1613e2..e599c234 100644 --- a/src/fastapi_app/rag_simple.py +++ b/src/fastapi_app/rag_simple.py @@ -5,9 +5,10 @@ ) from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from .api_models import ThoughtStep +from .api_models import RAGContext, RetrievalResponse, ThoughtStep from .postgres_searcher import PostgresSearcher @@ -30,7 +31,7 @@ def __init__( async def run( self, messages: list[dict], overrides: dict[str, Any] = {} - ) -> dict[str, Any] | AsyncGenerator[dict[str, Any], None]: + ) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]: text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] top = overrides.get("top", 3) @@ -48,7 +49,7 @@ async def run( # Generate a contextual and content specific answer using the search results and chat history response_token_limit = 1024 - messages = build_messages( + contextual_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, new_user_content=original_user_query + "\n\nSources:\n" + content, @@ -57,21 +58,21 @@ async def run( fallback_to_default=True, ) - chat_completion_response = await self.openai_chat_client.chat.completions.create( + chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, - messages=messages, + messages=contextual_messages, temperature=overrides.get("temperature", 0.3), max_tokens=response_token_limit, n=1, stream=False, ) - first_choice = chat_completion_response.model_dump()["choices"][0] - return { - "message": first_choice["message"], - "context": { - "data_points": {item.id: item.to_dict() for item in results}, - "thoughts": [ + first_choice = chat_completion_response.choices[0] + return RetrievalResponse( + message=first_choice.message, + context=RAGContext( + data_points={item.id: item.to_dict() for item in results}, + thoughts=[ ThoughtStep( title="Search query for database", description=original_user_query if text_search else None, @@ -87,7 +88,7 @@ async def run( ), ThoughtStep( title="Prompt to generate answer", - description=[str(message) for message in messages], + description=[str(message) for message in contextual_messages], props=( {"model": self.chat_model, "deployment": self.chat_deployment} if self.chat_deployment @@ -95,5 +96,5 @@ async def run( ), ), ], - }, - } + ), + ) From 8c58aac00ddb6210577ae95cc78f6567172d804b Mon Sep 17 00:00:00 2001 From: john0isaac Date: Fri, 28 Jun 2024 19:56:08 +0300 Subject: [PATCH 5/9] setup test client for database and without database with test coverage 100% for frontend routes --- tests/conftest.py | 67 ++++++++++++++++++++++++++++++++++++++ tests/test_endpoints.py | 71 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/test_endpoints.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..21df4428 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,67 @@ +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +from fastapi_app import create_app +from fastapi_app.postgres_models import Base + +POSTGRESQL_DATABASE_URL = "postgresql://admin:postgres@localhost:5432/postgres" + + +# Create a SQLAlchemy engine +engine = create_engine( + POSTGRESQL_DATABASE_URL, + poolclass=StaticPool, +) + +# Create a sessionmaker to manage sessions +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +@pytest.fixture(scope="session") +def setup_database(): + """Create tables in the database for all tests.""" + try: + Base.metadata.create_all(bind=engine) + yield + Base.metadata.drop_all(bind=engine) + except Exception as e: + pytest.skip(f"Unable to connect to the database: {e}") + + +@pytest.fixture(scope="function") +def db_session(setup_database): + """Create a new database session with a rollback at the end of the test.""" + connection = engine.connect() + transaction = connection.begin() + session = TestingSessionLocal(bind=connection) + yield session + session.close() + transaction.rollback() + connection.close() + + +@pytest.fixture(scope="function") +def test_db_client(db_session): + """Create a test client that uses the override_get_db fixture to return a session.""" + + def override_db_session(): + try: + yield db_session + finally: + db_session.close() + + app = create_app() + app.router.lifespan = override_db_session + with TestClient(app) as test_client: + yield test_client + + +@pytest.fixture(scope="function") +def test_client(): + """Create a test client.""" + app = create_app() + with TestClient(app) as test_client: + yield test_client diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py new file mode 100644 index 00000000..ac9e4cc9 --- /dev/null +++ b/tests/test_endpoints.py @@ -0,0 +1,71 @@ +import os + +import pytest + + +@pytest.mark.asyncio +async def test_index(test_client): + """test the index route""" + response = test_client.get("/") + + html_index_file_path = "src/static/index.html" + with open(html_index_file_path, "rb") as f: + html_index_file = f.read() + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/html; charset=utf-8" + assert response.headers["Content-Length"] == str(len(html_index_file)) + assert html_index_file == response.content + + +@pytest.mark.asyncio +async def test_favicon(test_client): + """test the favicon route""" + response = test_client.get("/favicon.ico") + + favicon_file_path = "src/static/favicon.ico" + with open(favicon_file_path, "rb") as f: + favicon_file = f.read() + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "image/vnd.microsoft.icon" + assert response.headers["Content-Length"] == str(len(favicon_file)) + assert favicon_file == response.content + + +@pytest.mark.asyncio +async def test_assets_non_existent_404(test_client): + """test the assets route with a non-existent file""" + response = test_client.get("/assets/manifest.json") + + assert response.status_code == 404 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["Content-Length"] == "22" + assert b'{"detail":"Not Found"}' in response.content + + +@pytest.mark.asyncio +async def test_assets(test_client): + """test the assets route with an existing file""" + assets_dir_path = "src/static/assets" + assets_file_path = os.listdir(assets_dir_path)[0] + + with open(os.path.join(assets_dir_path, assets_file_path), "rb") as f: + assets_file = f.read() + + response = test_client.get(f"/assets/{assets_file_path}") + + assert response.status_code == 200 + assert response.headers["Content-Length"] == str(len(assets_file)) + assert assets_file == response.content + + +@pytest.mark.asyncio +async def test_chat_non_json_415(test_client): + """test the chat route with a non-json request""" + response = test_client.post("/chat") + + assert response.status_code == 422 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["Content-Length"] == "82" + assert b'{"detail":[{"type":"missing"' in response.content From a0522e6ccd25aea2ce72c02cc6bcdd01b880085b Mon Sep 17 00:00:00 2001 From: john0isaac Date: Fri, 28 Jun 2024 20:06:00 +0300 Subject: [PATCH 6/9] fix tests for windows and macos --- tests/test_endpoints.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index ac9e4cc9..388776bb 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -28,7 +28,6 @@ async def test_favicon(test_client): favicon_file = f.read() assert response.status_code == 200 - assert response.headers["Content-Type"] == "image/vnd.microsoft.icon" assert response.headers["Content-Length"] == str(len(favicon_file)) assert favicon_file == response.content From 06c424b6409adbdacf302366816b29c4e511b99c Mon Sep 17 00:00:00 2001 From: john0isaac Date: Sat, 29 Jun 2024 15:20:01 +0300 Subject: [PATCH 7/9] use basemodel instead of dataclass to match the other models and for validation --- src/fastapi_app/api_models.py | 7 ++----- src/fastapi_app/rag_advanced.py | 4 ++-- src/fastapi_app/rag_simple.py | 4 ++-- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/fastapi_app/api_models.py b/src/fastapi_app/api_models.py index 1c7a477f..f96bd84f 100644 --- a/src/fastapi_app/api_models.py +++ b/src/fastapi_app/api_models.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any from pydantic import BaseModel @@ -20,15 +19,13 @@ class ThoughtStep(BaseModel): props: dict = {} -@dataclass -class RAGContext: +class RAGContext(BaseModel): data_points: dict[int, dict[str, Any]] thoughts: list[ThoughtStep] followup_questions: list[str] | None = None -@dataclass -class RetrievalResponse: +class RetrievalResponse(BaseModel): message: Message context: RAGContext session_state: Any | None = None diff --git a/src/fastapi_app/rag_advanced.py b/src/fastapi_app/rag_advanced.py index 1952688e..85c720d5 100644 --- a/src/fastapi_app/rag_advanced.py +++ b/src/fastapi_app/rag_advanced.py @@ -8,7 +8,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from .api_models import RAGContext, RetrievalResponse, ThoughtStep +from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep from .postgres_searcher import PostgresSearcher from .query_rewriter import build_search_function, extract_search_arguments @@ -99,7 +99,7 @@ async def run( ) first_choice = chat_completion_response.choices[0] return RetrievalResponse( - message=first_choice.message, + message=Message(content=first_choice.message.content, role=first_choice.message.role), context=RAGContext( data_points={item.id: item.to_dict() for item in results}, thoughts=[ diff --git a/src/fastapi_app/rag_simple.py b/src/fastapi_app/rag_simple.py index e599c234..c3bea611 100644 --- a/src/fastapi_app/rag_simple.py +++ b/src/fastapi_app/rag_simple.py @@ -8,7 +8,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from .api_models import RAGContext, RetrievalResponse, ThoughtStep +from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep from .postgres_searcher import PostgresSearcher @@ -69,7 +69,7 @@ async def run( ) first_choice = chat_completion_response.choices[0] return RetrievalResponse( - message=first_choice.message, + message=Message(content=first_choice.message.content, role=first_choice.message.role), context=RAGContext( data_points={item.id: item.to_dict() for item in results}, thoughts=[ From cf3bc786d81671fef4119383bf4e35ce4fa09819 Mon Sep 17 00:00:00 2001 From: john0isaac Date: Sat, 29 Jun 2024 15:20:26 +0300 Subject: [PATCH 8/9] fix scopes and add app fixture --- tests/conftest.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 21df4428..f20af11f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,12 @@ def setup_database(): pytest.skip(f"Unable to connect to the database: {e}") +@pytest.fixture(scope="session") +def app(): + """Create a FastAPI app.""" + return create_app() + + @pytest.fixture(scope="function") def db_session(setup_database): """Create a new database session with a rollback at the end of the test.""" @@ -44,7 +50,7 @@ def db_session(setup_database): @pytest.fixture(scope="function") -def test_db_client(db_session): +def test_db_client(app, db_session): """Create a test client that uses the override_get_db fixture to return a session.""" def override_db_session(): @@ -53,15 +59,13 @@ def override_db_session(): finally: db_session.close() - app = create_app() app.router.lifespan = override_db_session with TestClient(app) as test_client: yield test_client -@pytest.fixture(scope="function") -def test_client(): +@pytest.fixture(scope="session") +def test_client(app): """Create a test client.""" - app = create_app() with TestClient(app) as test_client: yield test_client From 4d9a53613114d1286ed292e7153617d9a3c6de59 Mon Sep 17 00:00:00 2001 From: john0isaac Date: Mon, 1 Jul 2024 09:38:35 +0000 Subject: [PATCH 9/9] fix tests to use existing setup --- requirements-dev.txt | 1 - src/fastapi_app/__init__.py | 5 +- tests/conftest.py | 92 +++++++++++++++++-------------------- 3 files changed, 46 insertions(+), 52 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1acba1ea..4c00a59c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,4 +6,3 @@ pip-compile-cross-platform pytest pytest-cov pytest-asyncio -psycopg2-binary diff --git a/src/fastapi_app/__init__.py b/src/fastapi_app/__init__.py index de1d0fc8..ce7c780b 100644 --- a/src/fastapi_app/__init__.py +++ b/src/fastapi_app/__init__.py @@ -52,11 +52,12 @@ async def lifespan(app: FastAPI): await engine.dispose() -def create_app(): +def create_app(is_testing: bool = False): env = Env() if not os.getenv("RUNNING_IN_PRODUCTION"): - env.read_env(".env") + if not is_testing: + env.read_env(".env") logging.basicConfig(level=logging.INFO) else: logging.basicConfig(level=logging.WARNING) diff --git a/tests/conftest.py b/tests/conftest.py index f20af11f..47af097a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,71 +1,65 @@ +import os +from pathlib import Path +from unittest.mock import patch + import pytest from fastapi.testclient import TestClient -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import StaticPool +from sqlalchemy.ext.asyncio import async_sessionmaker from fastapi_app import create_app -from fastapi_app.postgres_models import Base - -POSTGRESQL_DATABASE_URL = "postgresql://admin:postgres@localhost:5432/postgres" - - -# Create a SQLAlchemy engine -engine = create_engine( - POSTGRESQL_DATABASE_URL, - poolclass=StaticPool, +from fastapi_app.globals import global_storage + +POSTGRES_HOST = "localhost" +POSTGRES_USERNAME = "admin" +POSTGRES_DATABASE = "postgres" +POSTGRES_PASSWORD = "postgres" +POSTGRES_SSL = "prefer" +POSTGRESQL_DATABASE_URL = ( + f"postgresql+asyncpg://{POSTGRES_USERNAME}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}/{POSTGRES_DATABASE}" ) -# Create a sessionmaker to manage sessions -TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +@pytest.fixture(scope="session") +def setup_env(): + os.environ["POSTGRES_HOST"] = POSTGRES_HOST + os.environ["POSTGRES_USERNAME"] = POSTGRES_USERNAME + os.environ["POSTGRES_DATABASE"] = POSTGRES_DATABASE + os.environ["POSTGRES_PASSWORD"] = POSTGRES_PASSWORD + os.environ["POSTGRES_SSL"] = POSTGRES_SSL + os.environ["POSTGRESQL_DATABASE_URL"] = POSTGRESQL_DATABASE_URL + os.environ["RUNNING_IN_PRODUCTION"] = "False" + os.environ["OPENAI_API_KEY"] = "fakekey" @pytest.fixture(scope="session") -def setup_database(): - """Create tables in the database for all tests.""" - try: - Base.metadata.create_all(bind=engine) +def mock_azure_credential(): + """Mock the Azure credential for testing.""" + with patch("azure.identity.DefaultAzureCredential", return_value=None): yield - Base.metadata.drop_all(bind=engine) - except Exception as e: - pytest.skip(f"Unable to connect to the database: {e}") @pytest.fixture(scope="session") -def app(): +def app(setup_env, mock_azure_credential): """Create a FastAPI app.""" - return create_app() - - -@pytest.fixture(scope="function") -def db_session(setup_database): - """Create a new database session with a rollback at the end of the test.""" - connection = engine.connect() - transaction = connection.begin() - session = TestingSessionLocal(bind=connection) - yield session - session.close() - transaction.rollback() - connection.close() + if not Path("src/static/").exists(): + pytest.skip("Please generate frontend files first!") + return create_app(is_testing=True) @pytest.fixture(scope="function") -def test_db_client(app, db_session): - """Create a test client that uses the override_get_db fixture to return a session.""" - - def override_db_session(): - try: - yield db_session - finally: - db_session.close() +def test_client(app): + """Create a test client.""" - app.router.lifespan = override_db_session with TestClient(app) as test_client: yield test_client -@pytest.fixture(scope="session") -def test_client(app): - """Create a test client.""" - with TestClient(app) as test_client: - yield test_client +@pytest.fixture(scope="function") +def db_session(): + """Create a new database session with a rollback at the end of the test.""" + async_sesion = async_sessionmaker(autocommit=False, autoflush=False, bind=global_storage.engine) + session = async_sesion() + session.begin() + yield session + session.rollback() + session.close()