From 676a2ebc6cb0233b084afa8fbdfd14e31fba6458 Mon Sep 17 00:00:00 2001 From: john0isaac Date: Fri, 28 Jun 2024 19:52:25 +0300 Subject: [PATCH 01/31] add pytest dependency psycopg2-binary is required by create_engine for postgresql database in setting up the tests --- .github/workflows/app-tests.yaml | 10 ++++++++-- pyproject.toml | 22 ++++++++++++++++------ requirements-dev.txt | 6 +++++- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/.github/workflows/app-tests.yaml b/.github/workflows/app-tests.yaml index af4a2d37..488e1f76 100755 --- a/.github/workflows/app-tests.yaml +++ b/.github/workflows/app-tests.yaml @@ -5,9 +5,13 @@ on: branches: [ main ] pull_request: branches: [ main ] + workflow_dispatch: + +permissions: + contents: read jobs: - test_package: + test-package: name: Test ${{ matrix.os }} Python ${{ matrix.python_version }} runs-on: ${{ matrix.os }} strategy: @@ -65,4 +69,6 @@ jobs: run: | cd ./src/frontend npm install - npm run build \ No newline at end of file + npm run build + - name: Run Pytest + run: python3 -m pytest diff --git a/pyproject.toml b/pyproject.toml index d6d7928b..a0f8ab4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,20 @@ [tool.ruff] line-length = 120 -target-version = "py311" +target-version = "py312" +lint.select = ["E", "F", "I", "UP"] +lint.ignore = ["D203"] +lint.isort.known-first-party = ["fastapi_app"] -[tool.ruff.lint] -select = ["E", "F", "I", "UP"] -ignore = ["D203"] +[tool.mypy] +check_untyped_defs = true +python_version = 3.12 +exclude = [".venv/*"] -[tool.ruff.lint.isort] -known-first-party = ["fastapi_app"] +[tool.pytest.ini_options] +addopts = "-ra --cov" +testpaths = ["tests"] +pythonpath = ['src'] +filterwarnings = ["ignore::DeprecationWarning"] + +[tool.coverage.report] +show_missing = true diff --git a/requirements-dev.txt b/requirements-dev.txt index d1c56a4b..1acba1ea 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,8 @@ ruff pre-commit pip-tools -pip-compile-cross-platform \ No newline at end of file +pip-compile-cross-platform +pytest +pytest-cov +pytest-asyncio +psycopg2-binary From 296db5c66cbc25840dfc0e2eae8ec8e52f98e963 Mon Sep 17 00:00:00 2001 From: john0isaac Date: Fri, 28 Jun 2024 19:53:33 +0300 Subject: [PATCH 02/31] restrict workflow run and cache dependencies for faster reruns There is no need to do a full installation to use ruff this will make the workflow faster --- .github/workflows/python-code-quality.yaml | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python-code-quality.yaml b/.github/workflows/python-code-quality.yaml index 56191b1a..fce62c6f 100644 --- a/.github/workflows/python-code-quality.yaml +++ b/.github/workflows/python-code-quality.yaml @@ -3,11 +3,21 @@ name: Python code quality on: push: branches: [ main ] + paths: + - '**.py' + pull_request: branches: [ main ] + paths: + - '**.py' + + workflow_dispatch: + +permissions: + contents: read jobs: - build: + checks-format-and-lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -15,11 +25,12 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.12" + cache: 'pip' - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install -r requirements-dev.txt + python3 -m pip install --upgrade pip + python3 -m pip install ruff - name: Lint with ruff run: ruff check . - name: Check formatting with ruff - run: ruff format --check . + run: ruff format . --check From d41a6dfcc02552bb47af80d4c5e6ebee68f6cb1a Mon Sep 17 00:00:00 2001 From: john0isaac Date: Fri, 28 Jun 2024 19:54:09 +0300 Subject: [PATCH 03/31] mypy fixes --- src/fastapi_app/embeddings.py | 2 +- src/fastapi_app/postgres_searcher.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fastapi_app/embeddings.py b/src/fastapi_app/embeddings.py index eb6942c4..769db29c 100644 --- a/src/fastapi_app/embeddings.py +++ b/src/fastapi_app/embeddings.py @@ -4,7 +4,7 @@ async def compute_text_embedding( - q: str, openai_client, embed_model: str, embed_deployment: str = None, embedding_dimensions: int = 1536 + q: str, openai_client, embed_model: str, embed_deployment: str | None = None, embedding_dimensions: int = 1536 ): SUPPORTED_DIMENSIONS_MODEL = { "text-embedding-ada-002": False, diff --git a/src/fastapi_app/postgres_searcher.py b/src/fastapi_app/postgres_searcher.py index 5d13650e..ab3197d2 100644 --- a/src/fastapi_app/postgres_searcher.py +++ b/src/fastapi_app/postgres_searcher.py @@ -103,7 +103,7 @@ async def search( async def search_and_embed( self, - query_text: str, + query_text: str | None = None, top: int = 5, enable_vector_search: bool = False, enable_text_search: bool = False, From 44446bbf41233e46a5f9610290c2a682c7626da1 Mon Sep 17 00:00:00 2001 From: john0isaac Date: Fri, 28 Jun 2024 19:55:28 +0300 Subject: [PATCH 04/31] add dataclasses for response for better typing and easier usage this also improves the swagger api examples as it's typed now --- src/fastapi_app/api_models.py | 15 ++++++++++++++ src/fastapi_app/api_routes.py | 6 ++++-- src/fastapi_app/rag_advanced.py | 36 ++++++++++++++++----------------- src/fastapi_app/rag_simple.py | 29 +++++++++++++------------- 4 files changed, 51 insertions(+), 35 deletions(-) diff --git a/src/fastapi_app/api_models.py b/src/fastapi_app/api_models.py index 0945cb10..1c7a477f 100644 --- a/src/fastapi_app/api_models.py +++ b/src/fastapi_app/api_models.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Any from pydantic import BaseModel @@ -17,3 +18,17 @@ class ThoughtStep(BaseModel): title: str description: Any props: dict = {} + + +@dataclass +class RAGContext: + data_points: dict[int, dict[str, Any]] + thoughts: list[ThoughtStep] + followup_questions: list[str] | None = None + + +@dataclass +class RetrievalResponse: + message: Message + context: RAGContext + session_state: Any | None = None diff --git a/src/fastapi_app/api_routes.py b/src/fastapi_app/api_routes.py index 8c03dda7..c41ea010 100644 --- a/src/fastapi_app/api_routes.py +++ b/src/fastapi_app/api_routes.py @@ -9,6 +9,8 @@ from fastapi_app.rag_advanced import AdvancedRAGChat from fastapi_app.rag_simple import SimpleRAGChat +from .api_models import RetrievalResponse + router = fastapi.APIRouter() @@ -52,7 +54,7 @@ async def search_handler(query: str, top: int = 5, enable_vector_search: bool = return [item.to_dict() for item in results] -@router.post("/chat") +@router.post("/chat", response_model=RetrievalResponse) async def chat_handler(chat_request: ChatRequest): messages = [message.model_dump() for message in chat_request.messages] overrides = chat_request.context.get("overrides", {}) @@ -79,5 +81,5 @@ async def chat_handler(chat_request: ChatRequest): chat_deployment=global_storage.openai_chat_deployment, ) - response = await ragchat.run(messages, overrides=overrides) + response: RetrievalResponse = await ragchat.run(messages, overrides=overrides) return response diff --git a/src/fastapi_app/rag_advanced.py b/src/fastapi_app/rag_advanced.py index d603d997..1952688e 100644 --- a/src/fastapi_app/rag_advanced.py +++ b/src/fastapi_app/rag_advanced.py @@ -5,12 +5,10 @@ ) from openai import AsyncOpenAI -from openai.types.chat import ( - ChatCompletion, -) +from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from .api_models import ThoughtStep +from .api_models import RAGContext, RetrievalResponse, ThoughtStep from .postgres_searcher import PostgresSearcher from .query_rewriter import build_search_function, extract_search_arguments @@ -35,7 +33,7 @@ def __init__( async def run( self, messages: list[dict], overrides: dict[str, Any] = {} - ) -> dict[str, Any] | AsyncGenerator[dict[str, Any], None]: + ) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]: text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] top = overrides.get("top", 3) @@ -45,7 +43,7 @@ async def run( # Generate an optimized keyword search query based on the chat history and the last question query_response_token_limit = 500 - query_messages = build_messages( + query_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, system_prompt=self.query_prompt_template, new_user_content=original_user_query, @@ -55,7 +53,7 @@ async def run( ) chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create( - messages=query_messages, # type: ignore + messages=query_messages, # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, temperature=0.0, # Minimize creativity for search query generation @@ -81,7 +79,7 @@ async def run( # Generate a contextual and content specific answer using the search results and chat history response_token_limit = 1024 - messages = build_messages( + contextual_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, new_user_content=original_user_query + "\n\nSources:\n" + content, @@ -90,21 +88,21 @@ async def run( fallback_to_default=True, ) - chat_completion_response = await self.openai_chat_client.chat.completions.create( + chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, - messages=messages, + messages=contextual_messages, temperature=overrides.get("temperature", 0.3), max_tokens=response_token_limit, n=1, stream=False, ) - first_choice = chat_completion_response.model_dump()["choices"][0] - return { - "message": first_choice["message"], - "context": { - "data_points": {item.id: item.to_dict() for item in results}, - "thoughts": [ + first_choice = chat_completion_response.choices[0] + return RetrievalResponse( + message=first_choice.message, + context=RAGContext( + data_points={item.id: item.to_dict() for item in results}, + thoughts=[ ThoughtStep( title="Prompt to generate search arguments", description=[str(message) for message in query_messages], @@ -130,7 +128,7 @@ async def run( ), ThoughtStep( title="Prompt to generate answer", - description=[str(message) for message in messages], + description=[str(message) for message in contextual_messages], props=( {"model": self.chat_model, "deployment": self.chat_deployment} if self.chat_deployment @@ -138,5 +136,5 @@ async def run( ), ), ], - }, - } + ), + ) diff --git a/src/fastapi_app/rag_simple.py b/src/fastapi_app/rag_simple.py index bf1613e2..e599c234 100644 --- a/src/fastapi_app/rag_simple.py +++ b/src/fastapi_app/rag_simple.py @@ -5,9 +5,10 @@ ) from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from .api_models import ThoughtStep +from .api_models import RAGContext, RetrievalResponse, ThoughtStep from .postgres_searcher import PostgresSearcher @@ -30,7 +31,7 @@ def __init__( async def run( self, messages: list[dict], overrides: dict[str, Any] = {} - ) -> dict[str, Any] | AsyncGenerator[dict[str, Any], None]: + ) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]: text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] top = overrides.get("top", 3) @@ -48,7 +49,7 @@ async def run( # Generate a contextual and content specific answer using the search results and chat history response_token_limit = 1024 - messages = build_messages( + contextual_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, new_user_content=original_user_query + "\n\nSources:\n" + content, @@ -57,21 +58,21 @@ async def run( fallback_to_default=True, ) - chat_completion_response = await self.openai_chat_client.chat.completions.create( + chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, - messages=messages, + messages=contextual_messages, temperature=overrides.get("temperature", 0.3), max_tokens=response_token_limit, n=1, stream=False, ) - first_choice = chat_completion_response.model_dump()["choices"][0] - return { - "message": first_choice["message"], - "context": { - "data_points": {item.id: item.to_dict() for item in results}, - "thoughts": [ + first_choice = chat_completion_response.choices[0] + return RetrievalResponse( + message=first_choice.message, + context=RAGContext( + data_points={item.id: item.to_dict() for item in results}, + thoughts=[ ThoughtStep( title="Search query for database", description=original_user_query if text_search else None, @@ -87,7 +88,7 @@ async def run( ), ThoughtStep( title="Prompt to generate answer", - description=[str(message) for message in messages], + description=[str(message) for message in contextual_messages], props=( {"model": self.chat_model, "deployment": self.chat_deployment} if self.chat_deployment @@ -95,5 +96,5 @@ async def run( ), ), ], - }, - } + ), + ) From 8c58aac00ddb6210577ae95cc78f6567172d804b Mon Sep 17 00:00:00 2001 From: john0isaac Date: Fri, 28 Jun 2024 19:56:08 +0300 Subject: [PATCH 05/31] setup test client for database and without database with test coverage 100% for frontend routes --- tests/conftest.py | 67 ++++++++++++++++++++++++++++++++++++++ tests/test_endpoints.py | 71 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/test_endpoints.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..21df4428 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,67 @@ +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +from fastapi_app import create_app +from fastapi_app.postgres_models import Base + +POSTGRESQL_DATABASE_URL = "postgresql://admin:postgres@localhost:5432/postgres" + + +# Create a SQLAlchemy engine +engine = create_engine( + POSTGRESQL_DATABASE_URL, + poolclass=StaticPool, +) + +# Create a sessionmaker to manage sessions +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +@pytest.fixture(scope="session") +def setup_database(): + """Create tables in the database for all tests.""" + try: + Base.metadata.create_all(bind=engine) + yield + Base.metadata.drop_all(bind=engine) + except Exception as e: + pytest.skip(f"Unable to connect to the database: {e}") + + +@pytest.fixture(scope="function") +def db_session(setup_database): + """Create a new database session with a rollback at the end of the test.""" + connection = engine.connect() + transaction = connection.begin() + session = TestingSessionLocal(bind=connection) + yield session + session.close() + transaction.rollback() + connection.close() + + +@pytest.fixture(scope="function") +def test_db_client(db_session): + """Create a test client that uses the override_get_db fixture to return a session.""" + + def override_db_session(): + try: + yield db_session + finally: + db_session.close() + + app = create_app() + app.router.lifespan = override_db_session + with TestClient(app) as test_client: + yield test_client + + +@pytest.fixture(scope="function") +def test_client(): + """Create a test client.""" + app = create_app() + with TestClient(app) as test_client: + yield test_client diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py new file mode 100644 index 00000000..ac9e4cc9 --- /dev/null +++ b/tests/test_endpoints.py @@ -0,0 +1,71 @@ +import os + +import pytest + + +@pytest.mark.asyncio +async def test_index(test_client): + """test the index route""" + response = test_client.get("/") + + html_index_file_path = "src/static/index.html" + with open(html_index_file_path, "rb") as f: + html_index_file = f.read() + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/html; charset=utf-8" + assert response.headers["Content-Length"] == str(len(html_index_file)) + assert html_index_file == response.content + + +@pytest.mark.asyncio +async def test_favicon(test_client): + """test the favicon route""" + response = test_client.get("/favicon.ico") + + favicon_file_path = "src/static/favicon.ico" + with open(favicon_file_path, "rb") as f: + favicon_file = f.read() + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "image/vnd.microsoft.icon" + assert response.headers["Content-Length"] == str(len(favicon_file)) + assert favicon_file == response.content + + +@pytest.mark.asyncio +async def test_assets_non_existent_404(test_client): + """test the assets route with a non-existent file""" + response = test_client.get("/assets/manifest.json") + + assert response.status_code == 404 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["Content-Length"] == "22" + assert b'{"detail":"Not Found"}' in response.content + + +@pytest.mark.asyncio +async def test_assets(test_client): + """test the assets route with an existing file""" + assets_dir_path = "src/static/assets" + assets_file_path = os.listdir(assets_dir_path)[0] + + with open(os.path.join(assets_dir_path, assets_file_path), "rb") as f: + assets_file = f.read() + + response = test_client.get(f"/assets/{assets_file_path}") + + assert response.status_code == 200 + assert response.headers["Content-Length"] == str(len(assets_file)) + assert assets_file == response.content + + +@pytest.mark.asyncio +async def test_chat_non_json_415(test_client): + """test the chat route with a non-json request""" + response = test_client.post("/chat") + + assert response.status_code == 422 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["Content-Length"] == "82" + assert b'{"detail":[{"type":"missing"' in response.content From a0522e6ccd25aea2ce72c02cc6bcdd01b880085b Mon Sep 17 00:00:00 2001 From: john0isaac Date: Fri, 28 Jun 2024 20:06:00 +0300 Subject: [PATCH 06/31] fix tests for windows and macos --- tests/test_endpoints.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index ac9e4cc9..388776bb 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -28,7 +28,6 @@ async def test_favicon(test_client): favicon_file = f.read() assert response.status_code == 200 - assert response.headers["Content-Type"] == "image/vnd.microsoft.icon" assert response.headers["Content-Length"] == str(len(favicon_file)) assert favicon_file == response.content From 06c424b6409adbdacf302366816b29c4e511b99c Mon Sep 17 00:00:00 2001 From: john0isaac Date: Sat, 29 Jun 2024 15:20:01 +0300 Subject: [PATCH 07/31] use basemodel instead of dataclass to match the other models and for validation --- src/fastapi_app/api_models.py | 7 ++----- src/fastapi_app/rag_advanced.py | 4 ++-- src/fastapi_app/rag_simple.py | 4 ++-- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/fastapi_app/api_models.py b/src/fastapi_app/api_models.py index 1c7a477f..f96bd84f 100644 --- a/src/fastapi_app/api_models.py +++ b/src/fastapi_app/api_models.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any from pydantic import BaseModel @@ -20,15 +19,13 @@ class ThoughtStep(BaseModel): props: dict = {} -@dataclass -class RAGContext: +class RAGContext(BaseModel): data_points: dict[int, dict[str, Any]] thoughts: list[ThoughtStep] followup_questions: list[str] | None = None -@dataclass -class RetrievalResponse: +class RetrievalResponse(BaseModel): message: Message context: RAGContext session_state: Any | None = None diff --git a/src/fastapi_app/rag_advanced.py b/src/fastapi_app/rag_advanced.py index 1952688e..85c720d5 100644 --- a/src/fastapi_app/rag_advanced.py +++ b/src/fastapi_app/rag_advanced.py @@ -8,7 +8,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from .api_models import RAGContext, RetrievalResponse, ThoughtStep +from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep from .postgres_searcher import PostgresSearcher from .query_rewriter import build_search_function, extract_search_arguments @@ -99,7 +99,7 @@ async def run( ) first_choice = chat_completion_response.choices[0] return RetrievalResponse( - message=first_choice.message, + message=Message(content=first_choice.message.content, role=first_choice.message.role), context=RAGContext( data_points={item.id: item.to_dict() for item in results}, thoughts=[ diff --git a/src/fastapi_app/rag_simple.py b/src/fastapi_app/rag_simple.py index e599c234..c3bea611 100644 --- a/src/fastapi_app/rag_simple.py +++ b/src/fastapi_app/rag_simple.py @@ -8,7 +8,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from .api_models import RAGContext, RetrievalResponse, ThoughtStep +from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep from .postgres_searcher import PostgresSearcher @@ -69,7 +69,7 @@ async def run( ) first_choice = chat_completion_response.choices[0] return RetrievalResponse( - message=first_choice.message, + message=Message(content=first_choice.message.content, role=first_choice.message.role), context=RAGContext( data_points={item.id: item.to_dict() for item in results}, thoughts=[ From cf3bc786d81671fef4119383bf4e35ce4fa09819 Mon Sep 17 00:00:00 2001 From: john0isaac Date: Sat, 29 Jun 2024 15:20:26 +0300 Subject: [PATCH 08/31] fix scopes and add app fixture --- tests/conftest.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 21df4428..f20af11f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,12 @@ def setup_database(): pytest.skip(f"Unable to connect to the database: {e}") +@pytest.fixture(scope="session") +def app(): + """Create a FastAPI app.""" + return create_app() + + @pytest.fixture(scope="function") def db_session(setup_database): """Create a new database session with a rollback at the end of the test.""" @@ -44,7 +50,7 @@ def db_session(setup_database): @pytest.fixture(scope="function") -def test_db_client(db_session): +def test_db_client(app, db_session): """Create a test client that uses the override_get_db fixture to return a session.""" def override_db_session(): @@ -53,15 +59,13 @@ def override_db_session(): finally: db_session.close() - app = create_app() app.router.lifespan = override_db_session with TestClient(app) as test_client: yield test_client -@pytest.fixture(scope="function") -def test_client(): +@pytest.fixture(scope="session") +def test_client(app): """Create a test client.""" - app = create_app() with TestClient(app) as test_client: yield test_client From 4d9a53613114d1286ed292e7153617d9a3c6de59 Mon Sep 17 00:00:00 2001 From: john0isaac Date: Mon, 1 Jul 2024 09:38:35 +0000 Subject: [PATCH 09/31] fix tests to use existing setup --- requirements-dev.txt | 1 - src/fastapi_app/__init__.py | 5 +- tests/conftest.py | 92 +++++++++++++++++-------------------- 3 files changed, 46 insertions(+), 52 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1acba1ea..4c00a59c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,4 +6,3 @@ pip-compile-cross-platform pytest pytest-cov pytest-asyncio -psycopg2-binary diff --git a/src/fastapi_app/__init__.py b/src/fastapi_app/__init__.py index de1d0fc8..ce7c780b 100644 --- a/src/fastapi_app/__init__.py +++ b/src/fastapi_app/__init__.py @@ -52,11 +52,12 @@ async def lifespan(app: FastAPI): await engine.dispose() -def create_app(): +def create_app(is_testing: bool = False): env = Env() if not os.getenv("RUNNING_IN_PRODUCTION"): - env.read_env(".env") + if not is_testing: + env.read_env(".env") logging.basicConfig(level=logging.INFO) else: logging.basicConfig(level=logging.WARNING) diff --git a/tests/conftest.py b/tests/conftest.py index f20af11f..47af097a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,71 +1,65 @@ +import os +from pathlib import Path +from unittest.mock import patch + import pytest from fastapi.testclient import TestClient -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import StaticPool +from sqlalchemy.ext.asyncio import async_sessionmaker from fastapi_app import create_app -from fastapi_app.postgres_models import Base - -POSTGRESQL_DATABASE_URL = "postgresql://admin:postgres@localhost:5432/postgres" - - -# Create a SQLAlchemy engine -engine = create_engine( - POSTGRESQL_DATABASE_URL, - poolclass=StaticPool, +from fastapi_app.globals import global_storage + +POSTGRES_HOST = "localhost" +POSTGRES_USERNAME = "admin" +POSTGRES_DATABASE = "postgres" +POSTGRES_PASSWORD = "postgres" +POSTGRES_SSL = "prefer" +POSTGRESQL_DATABASE_URL = ( + f"postgresql+asyncpg://{POSTGRES_USERNAME}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}/{POSTGRES_DATABASE}" ) -# Create a sessionmaker to manage sessions -TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +@pytest.fixture(scope="session") +def setup_env(): + os.environ["POSTGRES_HOST"] = POSTGRES_HOST + os.environ["POSTGRES_USERNAME"] = POSTGRES_USERNAME + os.environ["POSTGRES_DATABASE"] = POSTGRES_DATABASE + os.environ["POSTGRES_PASSWORD"] = POSTGRES_PASSWORD + os.environ["POSTGRES_SSL"] = POSTGRES_SSL + os.environ["POSTGRESQL_DATABASE_URL"] = POSTGRESQL_DATABASE_URL + os.environ["RUNNING_IN_PRODUCTION"] = "False" + os.environ["OPENAI_API_KEY"] = "fakekey" @pytest.fixture(scope="session") -def setup_database(): - """Create tables in the database for all tests.""" - try: - Base.metadata.create_all(bind=engine) +def mock_azure_credential(): + """Mock the Azure credential for testing.""" + with patch("azure.identity.DefaultAzureCredential", return_value=None): yield - Base.metadata.drop_all(bind=engine) - except Exception as e: - pytest.skip(f"Unable to connect to the database: {e}") @pytest.fixture(scope="session") -def app(): +def app(setup_env, mock_azure_credential): """Create a FastAPI app.""" - return create_app() - - -@pytest.fixture(scope="function") -def db_session(setup_database): - """Create a new database session with a rollback at the end of the test.""" - connection = engine.connect() - transaction = connection.begin() - session = TestingSessionLocal(bind=connection) - yield session - session.close() - transaction.rollback() - connection.close() + if not Path("src/static/").exists(): + pytest.skip("Please generate frontend files first!") + return create_app(is_testing=True) @pytest.fixture(scope="function") -def test_db_client(app, db_session): - """Create a test client that uses the override_get_db fixture to return a session.""" - - def override_db_session(): - try: - yield db_session - finally: - db_session.close() +def test_client(app): + """Create a test client.""" - app.router.lifespan = override_db_session with TestClient(app) as test_client: yield test_client -@pytest.fixture(scope="session") -def test_client(app): - """Create a test client.""" - with TestClient(app) as test_client: - yield test_client +@pytest.fixture(scope="function") +def db_session(): + """Create a new database session with a rollback at the end of the test.""" + async_sesion = async_sessionmaker(autocommit=False, autoflush=False, bind=global_storage.engine) + session = async_sesion() + session.begin() + yield session + session.rollback() + session.close() From 75732493594d964f52aaee0e1a36d51aaaa80795 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 5 Jul 2024 13:47:48 +0000 Subject: [PATCH 10/31] add mocks and use monkey patch for setting env vars --- src/fastapi_app/__init__.py | 4 +- tests/__init__.py | 0 tests/conftest.py | 53 ++++++++++++------ tests/mocks.py | 108 ++++++++++++++++++++++++++++++++++++ 4 files changed, 147 insertions(+), 18 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/mocks.py diff --git a/src/fastapi_app/__init__.py b/src/fastapi_app/__init__.py index ce7c780b..5105b694 100644 --- a/src/fastapi_app/__init__.py +++ b/src/fastapi_app/__init__.py @@ -52,11 +52,11 @@ async def lifespan(app: FastAPI): await engine.dispose() -def create_app(is_testing: bool = False): +def create_app(testing: bool = False): env = Env() if not os.getenv("RUNNING_IN_PRODUCTION"): - if not is_testing: + if not testing: env.read_env(".env") logging.basicConfig(level=logging.INFO) else: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index 47af097a..69ab7204 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from unittest.mock import patch +from unittest import mock import pytest from fastapi.testclient import TestClient @@ -8,6 +8,7 @@ from fastapi_app import create_app from fastapi_app.globals import global_storage +from tests.mocks import MockAzureCredential POSTGRES_HOST = "localhost" POSTGRES_USERNAME = "admin" @@ -20,34 +21,54 @@ @pytest.fixture(scope="session") -def setup_env(): - os.environ["POSTGRES_HOST"] = POSTGRES_HOST - os.environ["POSTGRES_USERNAME"] = POSTGRES_USERNAME - os.environ["POSTGRES_DATABASE"] = POSTGRES_DATABASE - os.environ["POSTGRES_PASSWORD"] = POSTGRES_PASSWORD - os.environ["POSTGRES_SSL"] = POSTGRES_SSL - os.environ["POSTGRESQL_DATABASE_URL"] = POSTGRESQL_DATABASE_URL - os.environ["RUNNING_IN_PRODUCTION"] = "False" - os.environ["OPENAI_API_KEY"] = "fakekey" +def monkeypatch_session(): + with pytest.MonkeyPatch.context() as monkeypatch_session: + yield monkeypatch_session @pytest.fixture(scope="session") -def mock_azure_credential(): - """Mock the Azure credential for testing.""" - with patch("azure.identity.DefaultAzureCredential", return_value=None): +def mock_session_env(monkeypatch_session): + """Mock the environment variables for testing.""" + with mock.patch.dict(os.environ, clear=True): + # Database + monkeypatch_session.setenv("POSTGRES_HOST", POSTGRES_HOST) + monkeypatch_session.setenv("POSTGRES_USERNAME", POSTGRES_USERNAME) + monkeypatch_session.setenv("POSTGRES_DATABASE", POSTGRES_DATABASE) + monkeypatch_session.setenv("POSTGRES_PASSWORD", POSTGRES_PASSWORD) + monkeypatch_session.setenv("POSTGRES_SSL", POSTGRES_SSL) + monkeypatch_session.setenv("POSTGRESQL_DATABASE_URL", POSTGRESQL_DATABASE_URL) + monkeypatch_session.setenv("RUNNING_IN_PRODUCTION", "False") + # Azure Subscription + monkeypatch_session.setenv("AZURE_SUBSCRIPTION_ID", "test-storage-subid") + # OpenAI + monkeypatch_session.setenv("AZURE_OPENAI_CHATGPT_MODEL", "gpt-35-turbo") + monkeypatch_session.setenv("OPENAI_API_KEY", "fakekey") + # Allowed Origin + monkeypatch_session.setenv("ALLOWED_ORIGIN", "https://frontend.com") + + if os.getenv("AZURE_USE_AUTHENTICATION") is not None: + monkeypatch_session.delenv("AZURE_USE_AUTHENTICATION") yield @pytest.fixture(scope="session") -def app(setup_env, mock_azure_credential): +def app(mock_session_env): """Create a FastAPI app.""" if not Path("src/static/").exists(): pytest.skip("Please generate frontend files first!") - return create_app(is_testing=True) + return create_app(testing=True) + + +@pytest.fixture(scope="function") +def mock_default_azure_credential(mock_session_env): + """Mock the Azure credential for testing.""" + with mock.patch("azure.identity.DefaultAzureCredential") as mock_default_azure_credential: + mock_default_azure_credential.return_value = MockAzureCredential() + yield mock_default_azure_credential @pytest.fixture(scope="function") -def test_client(app): +def test_client(monkeypatch, app, mock_default_azure_credential): """Create a test client.""" with TestClient(app) as test_client: diff --git a/tests/mocks.py b/tests/mocks.py new file mode 100644 index 00000000..f08a02d6 --- /dev/null +++ b/tests/mocks.py @@ -0,0 +1,108 @@ +import json +from collections import namedtuple + +import openai.types +from azure.core.credentials_async import AsyncTokenCredential + +MOCK_EMBEDDING_DIMENSIONS = 1536 +MOCK_EMBEDDING_MODEL_NAME = "text-embedding-ada-002" + +MockToken = namedtuple("MockToken", ["token", "expires_on", "value"]) + + +class MockAzureCredential(AsyncTokenCredential): + async def get_token(self, uri): + return MockToken("", 9999999999, "") + + +class MockAzureCredentialExpired(AsyncTokenCredential): + def __init__(self): + self.access_number = 0 + + async def get_token(self, uri): + self.access_number += 1 + if self.access_number == 1: + return MockToken("", 0, "") + else: + return MockToken("", 9999999999, "") + + +class MockAsyncPageIterator: + def __init__(self, data): + self.data = data + + def __aiter__(self): + return self + + async def __anext__(self): + if not self.data: + raise StopAsyncIteration + return self.data.pop(0) # This should be a list of dictionaries. + + +class MockCaption: + def __init__(self, text, highlights=None, additional_properties=None): + self.text = text + self.highlights = highlights or [] + self.additional_properties = additional_properties or {} + + +class MockResponse: + def __init__(self, text, status): + self.text = text + self.status = status + + async def text(self): + return self._text + + async def __aexit__(self, exc_type, exc, tb): + pass + + async def __aenter__(self): + return self + + async def json(self): + return json.loads(self.text) + + +class MockEmbeddingsClient: + def __init__(self, create_embedding_response: openai.types.CreateEmbeddingResponse): + self.create_embedding_response = create_embedding_response + + async def create(self, *args, **kwargs) -> openai.types.CreateEmbeddingResponse: + return self.create_embedding_response + + +class MockClient: + def __init__(self, embeddings_client): + self.embeddings = embeddings_client + + +def mock_computervision_response(): + return MockResponse( + status=200, + text=json.dumps( + { + "vector": [ + 0.011925711, + 0.023533698, + 0.010133852, + 0.0063544377, + -0.00038590943, + 0.0013952175, + 0.009054946, + -0.033573493, + -0.002028305, + ], + "modelVersion": "2022-04-11", + } + ), + ) + + +class MockSynthesisResult: + def __init__(self, result): + self.__result = result + + def get(self): + return self.__result From 9fb9e54b7cd91eec24a0c3d5e93b58050c0c05dd Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 5 Jul 2024 14:25:33 +0000 Subject: [PATCH 11/31] create database and seed data --- tests/conftest.py | 34 +++++++++++++------ tests/test_api_routes.py | 12 +++++++ ...t_endpoints.py => test_frontend_routes.py} | 11 ------ 3 files changed, 36 insertions(+), 21 deletions(-) create mode 100644 tests/test_api_routes.py rename tests/{test_endpoints.py => test_frontend_routes.py} (82%) diff --git a/tests/conftest.py b/tests/conftest.py index 69ab7204..e72048cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,11 +3,14 @@ from unittest import mock import pytest +import pytest_asyncio from fastapi.testclient import TestClient from sqlalchemy.ext.asyncio import async_sessionmaker from fastapi_app import create_app -from fastapi_app.globals import global_storage +from fastapi_app.postgres_engine import create_postgres_engine_from_env +from fastapi_app.setup_postgres_database import create_db_schema +from fastapi_app.setup_postgres_seeddata import seed_data from tests.mocks import MockAzureCredential POSTGRES_HOST = "localhost" @@ -51,12 +54,22 @@ def mock_session_env(monkeypatch_session): yield -@pytest.fixture(scope="session") -def app(mock_session_env): +async def create_and_seed_db(): + """Create and seed the database.""" + engine = await create_postgres_engine_from_env() + await create_db_schema(engine) + await seed_data(engine) + await engine.dispose() + + +@pytest_asyncio.fixture(scope="session") +async def app(mock_session_env): """Create a FastAPI app.""" if not Path("src/static/").exists(): pytest.skip("Please generate frontend files first!") - return create_app(testing=True) + app = create_app(testing=True) + await create_and_seed_db() + return app @pytest.fixture(scope="function") @@ -67,20 +80,21 @@ def mock_default_azure_credential(mock_session_env): yield mock_default_azure_credential -@pytest.fixture(scope="function") -def test_client(monkeypatch, app, mock_default_azure_credential): +@pytest_asyncio.fixture(scope="function") +async def test_client(monkeypatch, app, mock_default_azure_credential): """Create a test client.""" - with TestClient(app) as test_client: yield test_client -@pytest.fixture(scope="function") -def db_session(): +@pytest_asyncio.fixture(scope="function") +async def db_session(): """Create a new database session with a rollback at the end of the test.""" - async_sesion = async_sessionmaker(autocommit=False, autoflush=False, bind=global_storage.engine) + engine = await create_postgres_engine_from_env() + async_sesion = async_sessionmaker(autocommit=False, autoflush=False, bind=engine) session = async_sesion() session.begin() yield session session.rollback() session.close() + await engine.dispose() diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py new file mode 100644 index 00000000..99af3fdc --- /dev/null +++ b/tests/test_api_routes.py @@ -0,0 +1,12 @@ +import pytest + + +@pytest.mark.asyncio +async def test_chat_non_json_415(test_client): + """test the chat route with a non-json request""" + response = test_client.post("/chat") + + assert response.status_code == 422 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["Content-Length"] == "82" + assert b'{"detail":[{"type":"missing"' in response.content diff --git a/tests/test_endpoints.py b/tests/test_frontend_routes.py similarity index 82% rename from tests/test_endpoints.py rename to tests/test_frontend_routes.py index 388776bb..10bf4ec4 100644 --- a/tests/test_endpoints.py +++ b/tests/test_frontend_routes.py @@ -57,14 +57,3 @@ async def test_assets(test_client): assert response.status_code == 200 assert response.headers["Content-Length"] == str(len(assets_file)) assert assets_file == response.content - - -@pytest.mark.asyncio -async def test_chat_non_json_415(test_client): - """test the chat route with a non-json request""" - response = test_client.post("/chat") - - assert response.status_code == 422 - assert response.headers["Content-Type"] == "application/json" - assert response.headers["Content-Length"] == "82" - assert b'{"detail":[{"type":"missing"' in response.content From 61f9b8f758cc5844d54b6e4af4beb7ef96c67ce7 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 5 Jul 2024 15:43:32 +0000 Subject: [PATCH 12/31] add tests for items handler and similar --- src/fastapi_app/__init__.py | 2 +- src/fastapi_app/api_routes.py | 5 + src/fastapi_app/rag_advanced.py | 5 +- src/fastapi_app/rag_simple.py | 5 +- tests/conftest.py | 6 +- tests/data.py | 1563 +++++++++++++++++++++++++++++++ tests/mocks.py | 86 -- tests/test_api_routes.py | 78 ++ 8 files changed, 1656 insertions(+), 94 deletions(-) create mode 100644 tests/data.py diff --git a/src/fastapi_app/__init__.py b/src/fastapi_app/__init__.py index 5105b694..b915ac3e 100644 --- a/src/fastapi_app/__init__.py +++ b/src/fastapi_app/__init__.py @@ -18,7 +18,7 @@ async def lifespan(app: FastAPI): load_dotenv(override=True) - azure_credential = None + azure_credential: azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential | None = None try: if client_id := os.getenv("APP_IDENTITY_ID"): # Authenticate using a user-assigned managed identity on Azure diff --git a/src/fastapi_app/api_routes.py b/src/fastapi_app/api_routes.py index c41ea010..93fd7ee8 100644 --- a/src/fastapi_app/api_routes.py +++ b/src/fastapi_app/api_routes.py @@ -1,4 +1,5 @@ import fastapi +from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import async_sessionmaker @@ -20,6 +21,8 @@ async def item_handler(id: int): async_session_maker = async_sessionmaker(global_storage.engine, expire_on_commit=False) async with async_session_maker() as session: item = (await session.scalars(select(Item).where(Item.id == id))).first() + if not item: + raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404) return item.to_dict() @@ -29,6 +32,8 @@ async def similar_handler(id: int, n: int = 5): async_session_maker = async_sessionmaker(global_storage.engine, expire_on_commit=False) async with async_session_maker() as session: item = (await session.scalars(select(Item).where(Item.id == id))).first() + if not item: + raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404) closest = await session.execute( select(Item, Item.embedding.l2_distance(item.embedding)) .filter(Item.id != id) diff --git a/src/fastapi_app/rag_advanced.py b/src/fastapi_app/rag_advanced.py index 85c720d5..d5cef924 100644 --- a/src/fastapi_app/rag_advanced.py +++ b/src/fastapi_app/rag_advanced.py @@ -97,9 +97,10 @@ async def run( n=1, stream=False, ) - first_choice = chat_completion_response.choices[0] + first_choice_message = chat_completion_response.choices[0].message + return RetrievalResponse( - message=Message(content=first_choice.message.content, role=first_choice.message.role), + message=Message(content=str(first_choice_message.content), role=first_choice_message.role), context=RAGContext( data_points={item.id: item.to_dict() for item in results}, thoughts=[ diff --git a/src/fastapi_app/rag_simple.py b/src/fastapi_app/rag_simple.py index c3bea611..271dae86 100644 --- a/src/fastapi_app/rag_simple.py +++ b/src/fastapi_app/rag_simple.py @@ -67,9 +67,10 @@ async def run( n=1, stream=False, ) - first_choice = chat_completion_response.choices[0] + first_choice_message = chat_completion_response.choices[0].message + return RetrievalResponse( - message=Message(content=first_choice.message.content, role=first_choice.message.role), + message=Message(content=str(first_choice_message.content), role=first_choice_message.role), context=RAGContext( data_points={item.id: item.to_dict() for item in results}, thoughts=[ diff --git a/tests/conftest.py b/tests/conftest.py index e72048cb..f46852c6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -93,8 +93,8 @@ async def db_session(): engine = await create_postgres_engine_from_env() async_sesion = async_sessionmaker(autocommit=False, autoflush=False, bind=engine) session = async_sesion() - session.begin() + await session.begin() yield session - session.rollback() - session.close() + await session.rollback() + await session.close() await engine.dispose() diff --git a/tests/data.py b/tests/data.py new file mode 100644 index 00000000..58d9ad08 --- /dev/null +++ b/tests/data.py @@ -0,0 +1,1563 @@ +from dataclasses import dataclass + + +@dataclass +class TestData: + id: int + type: str + brand: str + name: str + description: str + price: float + embeddings: list[float] + + +test_data = TestData( + id=1, + type="Footwear", + brand="Daybird", + name="Wanderer Black Hiking Boots", + description="Daybird's Wanderer Hiking Boots in sleek black are perfect for all your " + "outdoor adventures. These boots are made with a waterproof leather upper and a durable " + "rubber sole for superior traction. With their cushioned insole and padded collar, " + "these boots will keep you comfortable all day long.", + price=109.99, + embeddings=[ + -0.010269113, + -0.01673832, + 0.0025070684, + -0.009927924, + 0.0075596725, + 0.0073790434, + -0.0090849865, + -0.05860419, + 0.013500371, + -0.050522696, + -0.022866337, + 0.011172259, + -0.011747598, + -0.011346199, + -0.009192026, + 0.022612117, + 0.01953473, + 0.022745917, + 0.018089695, + -0.013440161, + -0.006673251, + 0.025850065, + -0.0066765957, + -0.0056363046, + -0.020457946, + -0.020136828, + 0.01906643, + -0.01911995, + 0.010242353, + -0.022759298, + 0.0061715026, + 0.0006464189, + -0.013279602, + -0.009372656, + -0.010657132, + 0.0026927153, + 0.0042715496, + -0.0039403955, + 0.014303168, + 0.01376797, + 0.015841862, + 0.00022578667, + -0.0026609378, + 0.009386036, + -0.0010185488, + 0.010683891, + 0.00019714521, + -0.017835476, + -0.008623378, + 0.026933841, + -0.003545687, + -0.0003781927, + 0.004596013, + 0.007646642, + 0.011781047, + 0.0045525283, + 0.0055526798, + 0.013560581, + 0.029783772, + -0.029676732, + -0.033369597, + 0.002610763, + -0.01665804, + 0.008027971, + -0.0035122372, + -0.015493983, + -0.028659856, + -0.003579137, + 0.008509649, + 0.01917347, + 0.007673402, + 0.00095666654, + 0.015882002, + -0.0033784376, + 0.011319439, + -0.019722048, + -0.024391651, + 0.006041048, + -0.014182748, + -0.009278996, + 0.018022794, + -0.026304984, + -0.018798832, + 0.016470721, + 0.015601023, + 0.0061547775, + 0.011486689, + 0.0030723712, + -0.036875147, + -0.017260138, + 0.00838923, + 0.0048134374, + 0.0053018057, + 0.002868327, + 0.0031743934, + 0.008007901, + -0.0058938684, + 0.015734823, + -0.0009968064, + -0.02480643, + 0.015520743, + 0.010877901, + 0.002376614, + -0.013426781, + -0.027335241, + 0.0009993151, + -0.007459323, + 0.010683891, + 0.02448531, + -0.0013898425, + -0.0076198825, + 0.023508575, + 0.032004844, + -0.031897806, + 0.00077143783, + -0.010315943, + -0.011714147, + 0.016109461, + 0.016885499, + -0.031496406, + 0.024378272, + 0.015734823, + 0.014530627, + -0.0025137584, + 0.0008613344, + 0.023455055, + -0.011694077, + -0.012175756, + 0.011379649, + 0.0036794867, + 0.0070846844, + 0.018879112, + 0.01920023, + -0.032272443, + -0.008048041, + 0.014932025, + -0.018758692, + 0.005833659, + -0.020163586, + -0.0056797895, + -0.014811606, + 0.014329928, + -0.006536106, + 0.005803554, + 0.012650744, + 0.025743026, + -0.0053653605, + 0.009499765, + 0.008536409, + -0.0059841834, + 0.03002461, + 0.005602855, + 0.039551135, + -0.020324146, + -0.005469055, + 0.0148651255, + 0.023521954, + 0.018169973, + -0.012242655, + -0.026478924, + -0.0023548715, + -0.010630371, + 0.01644396, + -0.02498037, + 0.04329752, + 0.03047953, + 0.005750034, + 0.0016641314, + -0.0012075406, + 0.0040106406, + 0.005442295, + 0.0166179, + 0.004920477, + 0.042200368, + 0.023307875, + 0.0023314564, + 0.018705172, + 0.0020337526, + -0.0220234, + 0.010643751, + -0.03526955, + -0.009774054, + 0.0118947765, + 0.005820279, + -0.01886573, + -0.031309087, + -0.009379346, + -0.004331759, + 0.002878362, + -0.00549916, + -0.0037731463, + 0.0029268644, + 0.014704566, + 0.008790628, + -0.6602203, + 0.0036761416, + -0.0056898245, + -0.024632491, + 0.032192163, + 0.00034892405, + 0.031014727, + 0.017701676, + -0.012811303, + 0.009546596, + 0.017019298, + 0.014570767, + -0.014383447, + 0.009646945, + -0.024471931, + -0.010202213, + 0.003331608, + -0.020819204, + 0.012135616, + 0.0021140324, + -0.026465544, + 0.010342702, + -0.007914241, + -0.01898615, + -0.013346502, + -0.0016173016, + 0.004060815, + -0.0014433622, + -0.00824205, + 0.009044847, + -0.039845496, + 0.04412708, + -0.0056363046, + -0.007941001, + 0.05477752, + -0.021421302, + -0.014289788, + 0.0436454, + 0.028954215, + 0.056570433, + -0.012971863, + -0.0042247195, + 0.0041711996, + -0.0010051689, + 0.009586735, + 0.016082702, + 0.03537659, + -0.0025137584, + 0.007713542, + -0.01388839, + 0.0328344, + 0.026024004, + 0.0059674582, + -0.010195523, + -0.0051111416, + -0.008750488, + 0.036473747, + -0.010790931, + 0.024244472, + 0.015092585, + -0.017527737, + 0.020029787, + -0.0019802328, + -0.006084533, + -0.014891886, + 0.0070177843, + -0.031710483, + -0.019976268, + 0.00828219, + -0.026144424, + 0.021407923, + 0.021795942, + -0.0002933136, + 0.009633565, + 0.019601628, + 0.018330533, + 0.005234906, + -0.0072853835, + -0.014985546, + 0.00075387664, + 0.013420091, + 0.011533518, + -0.013058833, + -0.0169524, + 0.028820416, + -0.007412493, + 0.019012911, + -0.0023582163, + 0.031496406, + -0.0019902678, + 0.012858133, + 0.020872723, + -0.025060648, + -0.027669739, + -0.018571373, + 0.010509952, + 0.0036661066, + -0.012958483, + 0.0076131923, + -0.009178647, + -0.011459928, + 0.0012769491, + -0.0023080416, + -0.00038718234, + 0.012637364, + 0.028419016, + -0.019909367, + 0.0043685543, + 0.0061547775, + -0.021327643, + -0.0027713224, + -0.0070646144, + -0.0027178025, + -0.0167517, + -0.0008149227, + -0.023669133, + 0.034199156, + -0.010162073, + 0.034627315, + -0.026759902, + 0.015333424, + 0.0034854773, + -0.006843845, + -0.022946617, + -0.013507061, + 0.0034419924, + -0.015627783, + -0.010302562, + -0.01097825, + -0.008723728, + -0.008636759, + 0.0025605883, + 0.020845965, + -0.0011255884, + 0.011580348, + 0.025823306, + -0.00087722304, + -0.010175453, + 0.043244004, + 0.0023732688, + -0.01086452, + 0.0006175684, + -0.010309253, + -0.01091804, + -0.00543226, + -0.032540042, + -0.041986287, + 0.007111444, + -0.032540042, + -0.00021575172, + 0.028097898, + -0.017099578, + -0.008998017, + -0.001103846, + 0.015601023, + -0.01907981, + -0.0037296615, + -0.012731024, + -0.010643751, + 0.016992537, + 0.019641768, + 0.015052445, + -0.00011425224, + -0.023414915, + -0.019695288, + -0.0109448, + -0.000100349636, + 0.025395148, + 0.0005076019, + -0.048060786, + -0.018745312, + -0.011961676, + -0.010516642, + 0.012256036, + 0.017327037, + -0.011272609, + -0.022438178, + -0.008402609, + -0.008964567, + 0.002167552, + -0.0013379952, + -0.006549486, + -0.026920462, + -0.017888995, + 0.0022193994, + -0.0052516307, + -0.0056195795, + 0.02511417, + 0.0050877263, + 0.019093191, + -0.019775568, + 0.015855242, + -0.0016474065, + 0.009760674, + 0.0040240204, + -0.036339946, + 0.011794427, + 0.018169973, + 0.009539905, + 0.021541722, + 0.028633095, + -0.034011837, + 0.010215593, + -0.0221572, + 0.005298461, + -0.010329322, + -0.0167517, + -0.012062026, + 0.036420226, + 0.014209508, + 0.022625498, + -0.02498037, + -0.004970652, + -0.009332516, + 0.014557387, + 0.008656829, + 0.0018062934, + 0.024190951, + -0.011386339, + -0.0138348695, + 0.0022210719, + -0.0048502325, + 0.008917738, + -0.007660022, + -0.00011874707, + 0.0012861479, + -0.021300882, + 0.013226082, + -0.003331608, + -0.0078072017, + 0.019735428, + 0.020350905, + 0.028927455, + 0.011560278, + 0.015788343, + -0.017741816, + -0.007954381, + -0.036875147, + 0.05234237, + -0.0046127383, + 0.039684936, + 0.009131817, + -0.0118747065, + -0.012657434, + 0.0070980643, + 0.023816314, + 0.0021140324, + 0.020524845, + -0.00059624406, + -0.00339349, + -0.0005109469, + -0.0048401975, + -0.0034955123, + 0.005549335, + 0.020779064, + -0.023174075, + 0.025850065, + 0.007399113, + 0.0051278663, + 0.014303168, + 0.017193237, + 0.003010489, + 0.006335407, + 0.0069308146, + 0.026117666, + -0.005696514, + 0.00045700895, + -0.014758087, + -0.0034118877, + -0.0012016869, + 0.0037497315, + -0.005810244, + 0.0025589156, + -0.009111747, + -0.011446549, + 0.011486689, + 0.005308496, + -0.0035590671, + -0.011312749, + 0.022304378, + -0.015226385, + -0.014985546, + 0.018584752, + 0.025970485, + 0.002864982, + 0.0011063548, + -0.0119549865, + 0.018009415, + -0.010550092, + 0.04693687, + 0.0019518004, + 0.021474821, + 0.008462819, + -0.015614403, + 0.031335846, + 0.002881707, + 0.025649367, + -0.022638878, + 0.017768575, + -0.014089089, + -0.012623984, + -0.0050910716, + -0.021595242, + -0.031014727, + 0.050174817, + -0.026693003, + -0.010048344, + -0.007833961, + 0.0013789713, + 0.007238554, + -0.025569087, + 0.00035665932, + 0.01374121, + -0.0032078433, + -0.00016975813, + 0.015895382, + -0.024003632, + -0.015413704, + 0.02726834, + 0.019347409, + -0.019802328, + -0.008850838, + 0.0042247195, + -0.002035425, + 0.09697789, + 0.015721442, + 0.0010670511, + 0.01902629, + 0.01116557, + -0.017139718, + -0.026987363, + -0.0009867714, + -0.0016206466, + 0.014129229, + 0.013527131, + -0.028204937, + -0.005786829, + -0.017888995, + 0.02745566, + 0.0059105936, + -0.01360072, + -0.024070533, + -0.003341643, + -0.028231697, + -0.006358822, + -0.010188833, + -0.015761582, + 0.015908763, + 0.010048344, + 0.0018330533, + -0.0024067187, + -0.014758087, + 0.032245684, + -0.020993143, + 0.001787896, + 0.0047900225, + 0.0005410518, + 0.022946617, + -0.008944497, + 5.4983237e-05, + -0.011734217, + -0.014396828, + 0.026933841, + -0.011640558, + 0.02986405, + 0.020779064, + 0.010443052, + -0.031228807, + -0.0042682043, + -0.018772071, + -0.008496269, + 0.026706383, + -0.010503261, + -0.029489413, + 0.01914671, + 0.025609227, + 0.0019451104, + -0.0139552895, + -0.0024535486, + 0.0017427386, + 0.013045453, + -0.0029937641, + 0.007412493, + -0.026920462, + -0.012316246, + -0.037223026, + 0.037330065, + 0.008904357, + 0.0137947295, + 0.013098972, + -0.021006523, + 0.023856454, + -0.0274289, + 0.0036995565, + -0.015333424, + -0.00045115524, + -0.003856771, + 0.0070779943, + 0.0018079659, + 0.014985546, + 0.0027579425, + -0.010489882, + 0.014008809, + 0.0036460368, + 0.0139954295, + 0.0048401975, + 0.022170579, + 0.0009491403, + 0.005228216, + 0.002244487, + -0.0068405, + -0.014222888, + -0.0017527736, + 0.020257246, + 0.016176362, + -0.0019986301, + 0.0099747535, + -0.0091853365, + 0.015828483, + -0.00409092, + 0.029971091, + 0.017032677, + 0.006348787, + 0.019628389, + 0.008656829, + 0.0031392712, + -0.0059641134, + -0.0047465377, + 0.01616298, + -0.0020387701, + 0.0023180766, + 0.010329322, + 0.00822867, + 0.0148651255, + 0.005201456, + -0.017420696, + 0.0034520274, + 0.014263028, + -0.015039065, + 0.008897668, + 0.012496875, + 0.001629009, + 0.0074526328, + -0.047070667, + 0.0064391014, + -0.020069927, + 0.031737246, + 0.014370068, + -0.005495815, + 0.019641768, + 0.011372958, + -0.025809927, + -0.0274289, + -0.0009667015, + 0.013132422, + 0.03551039, + 0.004773298, + -0.020899484, + -0.004405349, + -0.032272443, + -0.020297386, + -0.003043939, + 0.0016097754, + 0.0014818296, + 0.0077670617, + 0.0219565, + 0.01616298, + -0.0022495042, + -0.017286897, + -0.015386944, + -0.028579576, + 0.009265617, + -0.022344518, + 0.059674583, + -0.008877598, + 0.012797924, + 0.0015980679, + -0.0025672782, + -0.013500371, + -0.019829087, + -0.009138507, + -0.007934311, + 0.018450953, + 0.026880322, + -0.017500976, + -0.03034573, + 0.028900694, + 0.0167517, + -0.004114335, + -0.0031074937, + -0.039042696, + 0.0020053203, + -0.019842468, + 0.010108553, + -0.003258018, + 0.013239462, + -0.005602855, + -0.009011397, + -0.0031392712, + -0.0029001045, + -0.013560581, + -0.008609999, + -0.018169973, + -0.03545687, + 0.0139552895, + -0.033717476, + -0.0109448, + 0.01907981, + -0.012630674, + -0.0040641604, + 0.03604559, + -0.0020738924, + 0.015092585, + 0.013292981, + 0.030987967, + -0.0021039974, + 0.025635988, + -0.0021324297, + -0.0019735429, + -0.009071607, + 0.00018177918, + -0.04091589, + -0.0021290847, + 0.03773146, + 0.0009365966, + -0.018036174, + 0.012296176, + -0.0032396207, + -0.01911995, + 0.00025672783, + 0.0036828315, + -0.025207829, + 0.0042581693, + -0.019548109, + -0.025395148, + -0.029088015, + -0.045572113, + -0.003142616, + 0.0077603715, + -0.005830314, + -0.020819204, + 0.008489579, + -0.024204332, + -0.015279904, + -0.009339206, + -0.012396525, + 0.024418412, + -0.0053787404, + 0.0328344, + 0.018089695, + -0.005489125, + -0.0043183793, + -0.003582482, + 0.0055560246, + 0.00037840175, + -0.0019885954, + -0.0101553835, + -0.028017618, + -0.027040882, + 0.017514355, + 0.0011807807, + -0.010275803, + -0.022946617, + 0.025702886, + 0.012851443, + 0.013861629, + 0.009800814, + 0.004381934, + 0.0028181523, + 0.0065294164, + -0.023588855, + 0.018959392, + -0.016270021, + -0.010202213, + 0.0014625959, + -0.0064758966, + -0.015922142, + 0.011346199, + 0.0031877735, + -0.011366269, + -0.018531233, + -0.002438496, + 0.020551605, + 0.011493378, + -0.018651651, + 0.026278224, + -0.008991327, + 0.0111053595, + 0.01663128, + 0.010677201, + 0.010262422, + 0.005194766, + 0.01105853, + 0.031175287, + -0.013607411, + -0.014396828, + 0.0035155823, + -0.011379649, + -0.014249648, + 0.0046194284, + -0.018263634, + -0.020886105, + 0.017219998, + -0.013413401, + 0.010302562, + -0.006683286, + 0.00018648307, + 0.0004666258, + -0.01640382, + -0.0033165554, + -0.020631885, + -0.023093795, + 0.0010168763, + -0.008061421, + -0.023040276, + -0.002732855, + -0.017674915, + 0.021046663, + -0.011774357, + -0.029221814, + -0.0016992538, + 0.02515431, + -0.013032072, + 0.004960617, + 0.015025686, + 0.01926713, + -0.010182143, + 0.0021223947, + -0.009479696, + -0.04682983, + -0.0028081173, + 0.004438799, + -0.02436489, + -0.029409133, + -0.014169369, + -0.009854334, + 0.021327643, + -0.014155989, + 0.008643448, + -0.0091251265, + 0.02464587, + -0.003152651, + -0.028258458, + 0.010376153, + 0.00279641, + -0.006004253, + 0.023495195, + -0.011506758, + 0.017848855, + -0.01636368, + 0.014544007, + -0.0034754423, + -0.0139954295, + -0.020658644, + -0.0120553365, + 0.020618506, + -0.011265919, + -0.048515704, + 0.0015930504, + 0.029516174, + -0.009693774, + 0.0046127383, + -0.0053419457, + -0.0080748005, + 0.026612723, + 0.028659856, + 0.010757481, + 0.022585358, + -0.011493378, + 0.00010202213, + -0.0015219694, + -0.009011397, + -0.014878506, + -0.007874101, + 0.0032897955, + 0.059299946, + -0.0020538226, + -0.0074660126, + 0.0012710954, + 0.014503867, + -0.056677476, + -0.010081793, + 0.02452545, + 0.004585978, + 0.014089089, + -0.009519835, + -0.007854031, + 0.035082232, + -0.0045659086, + 0.031442884, + -0.006124673, + -0.017072817, + -0.0110652195, + 0.023040276, + -0.0037698012, + -0.025488807, + -0.03585827, + -0.009419486, + 0.01651086, + -0.026037386, + 0.015427084, + 0.0070579243, + -0.037945542, + -0.03807934, + -0.006104603, + 0.0007037018, + 0.046321392, + 0.013440161, + 0.013038763, + -0.034948435, + -0.020551605, + -0.0045224237, + -0.011366269, + -0.0099346135, + -0.014329928, + 0.013045453, + 0.029034493, + -0.0018330533, + -0.020110067, + -0.01940093, + -0.005766759, + -0.012503564, + 0.010202213, + -0.018838972, + -0.006890675, + 0.011259229, + 0.0050375517, + -0.025140928, + -0.0011180622, + 0.01122578, + 0.0048401975, + -0.0027930648, + -0.009513145, + -0.02745566, + 0.004679638, + -0.014597527, + 0.020725545, + -0.0069977148, + -0.0015621093, + -0.0167517, + 0.0027445625, + -0.013466921, + 0.010269113, + -0.008536409, + -0.02515431, + -0.004207995, + 0.012858133, + 0.01364086, + -0.0015110982, + 0.0008688606, + -0.026866943, + -0.020203726, + 0.017541116, + 0.010289183, + -0.010302562, + 0.007372353, + 0.028606337, + -0.0016114479, + 0.0019384205, + 0.026800042, + 0.20048518, + -0.0075730523, + 0.011112049, + 0.02223748, + 0.0033299355, + 0.02211706, + 0.013045453, + 0.013420091, + 0.014998926, + -0.020043166, + 0.0013396676, + 0.018384052, + -0.03047953, + -0.0015880329, + 0.014396828, + -0.023495195, + -0.034948435, + -0.014222888, + -0.009198717, + -0.0030255415, + -0.006776945, + -0.0025371732, + 0.021474821, + -0.029917572, + 0.03856102, + 0.017393937, + -0.00026258154, + -0.027094401, + -0.00066481635, + 0.014570767, + -0.012021886, + -0.038186383, + -0.009352586, + 0.020364286, + 0.036821626, + 0.005422225, + 0.017032677, + 0.0018146558, + -0.001675839, + 0.011961676, + -0.012082096, + -0.011961676, + -0.015012305, + 0.0023933388, + 0.0063855816, + 0.0046528783, + -0.012831373, + 0.00081784953, + -0.00038864577, + 0.014544007, + 0.015882002, + -0.012215896, + 0.018798832, + -0.015922142, + -0.006241747, + -0.009285687, + 0.042521484, + -0.019695288, + 2.2147478e-05, + -0.0025204483, + -0.013339811, + 0.0059674582, + -0.017059438, + 0.01645734, + 0.0010436362, + 0.011988437, + 0.006081188, + -0.00545233, + 0.005218181, + 0.0050877263, + 0.021193843, + 0.012637364, + 0.0002535919, + 0.03029221, + -0.014544007, + -0.028258458, + 0.0047799875, + 0.016310161, + 0.012296176, + 0.03299496, + -0.009733914, + -0.015547504, + -0.014196129, + 0.014664426, + -0.017674915, + -0.014731326, + -0.0024769634, + -0.0003294813, + -0.019467829, + -0.009988134, + 0.012088786, + 0.011192329, + 0.0001456115, + 0.020712165, + 0.0045291134, + 0.013647551, + -0.028231697, + 0.0139954295, + -0.018343912, + -0.018624892, + -0.017126339, + -0.017674915, + -0.0023214216, + 0.016350301, + -0.00839592, + 0.0071917237, + 0.018129835, + -0.006619731, + -0.0029703493, + -0.0070378543, + -0.007733612, + -0.0276965, + 0.003268053, + 0.0036895217, + 0.009687085, + -0.0017025988, + 0.007138204, + -0.009559975, + -0.01890587, + -0.019882608, + 0.03636671, + -0.0061848825, + 0.020083306, + 0.014089089, + -0.017420696, + -0.0099948235, + -0.0070244744, + 0.0028465847, + -0.0030556463, + -0.031148527, + 0.015159485, + -0.0055961646, + 0.0037564214, + -0.015132725, + -0.015105965, + 0.015748203, + 0.0034654073, + -0.023856454, + -0.02471277, + 0.006051083, + -0.014530627, + 0.0015646181, + 0.011560278, + 0.0218896, + 0.04982694, + -0.0162834, + 0.027830299, + -0.010697271, + -0.022210719, + -0.001117226, + 0.0034486824, + -0.0016649677, + -0.025488807, + -0.004365209, + 0.0074526328, + 0.01945445, + -0.05030862, + -0.008890978, + -0.0099546835, + 0.0041611646, + -0.026693003, + -0.009031467, + 0.031576686, + -0.04185249, + -0.014008809, + 0.014196129, + -0.17094226, + 0.004662913, + 0.010576852, + -0.015105965, + 0.02223748, + -0.014517247, + 0.017554495, + 0.015694683, + -0.025341628, + -0.0032329308, + 0.012235966, + 0.0059841834, + -0.018665032, + -0.008342399, + -0.010329322, + 0.031255566, + 0.006506001, + 0.02498037, + 0.014958786, + 0.010570162, + 0.009386036, + 0.0021424647, + 0.0005506686, + -0.015948903, + -0.007733612, + -0.00056237605, + 0.01085114, + -0.009299066, + -0.018330533, + -0.007091374, + -0.0101152435, + -0.007706852, + 0.006810395, + 0.004907097, + 0.017327037, + -0.00421803, + 0.011172259, + 0.0069843344, + -0.00211905, + 0.02494023, + 0.007238554, + 0.009319136, + 0.03023869, + -0.025836686, + 0.02237128, + -0.002617453, + -0.004973997, + 0.01945445, + -0.024351511, + -0.034466755, + 0.015413704, + 1.7737582e-06, + -0.023000136, + 0.0029251918, + 0.021207223, + 0.0069843344, + 0.014691186, + 0.032941442, + 0.0018096385, + -0.02986405, + 0.005482435, + -0.010657132, + -0.012115546, + -0.0070110946, + -0.004669603, + 0.005275046, + 0.0025722957, + 0.00818184, + -0.0018263634, + 0.00014456619, + -0.026304984, + -0.010877901, + -0.0075262226, + -0.01118564, + 0.0080145905, + 0.02231776, + 3.0627543e-05, + -0.035403352, + -0.017005919, + -0.018705172, + -0.0027562699, + 0.042039808, + -0.025194447, + 0.021595242, + 0.0009357603, + 0.011553588, + -0.013667621, + -0.0009140179, + -0.0028114624, + -0.024217712, + 0.009379346, + -0.023936734, + -0.00271613, + -0.025528947, + -0.0055760946, + 0.017888995, + 0.006592971, + -0.015788343, + -0.020926245, + -0.007900861, + 0.009379346, + 0.016551, + -0.030640088, + 0.013400021, + 0.026037386, + 0.014463727, + 0.0091251265, + -0.0055861296, + 0.026318364, + -0.0027278375, + 0.011774357, + 0.0013940237, + -0.0010745773, + -0.01086452, + 0.0119349165, + 0.0028114624, + 0.014008809, + -0.02156848, + 0.0118747065, + -0.04602703, + 0.014503867, + 0.00093994156, + -0.03021193, + 0.0035155823, + -0.020672025, + -0.014396828, + -0.08643448, + -0.0064993114, + 0.01941431, + -0.013406712, + -0.029088015, + -0.015253144, + 0.003162686, + 0.009580045, + 0.021260742, + -0.0009800814, + 0.014731326, + 0.0006915762, + 0.008576549, + -0.0036594167, + 0.03874834, + 0.009539905, + 0.01929389, + -0.00039930793, + -0.031656966, + 0.0059942184, + -0.018718552, + 0.00082955696, + -0.014891886, + -0.011767667, + -0.008088181, + 0.00411099, + -0.01916009, + -0.0059340084, + 0.0385075, + 0.019681908, + -0.034948435, + 0.014878506, + -0.010509952, + -0.0039638104, + -0.007900861, + 0.0019685254, + -0.025850065, + -0.0050542764, + 0.02718806, + -0.029328853, + 0.00035937713, + 0.011112049, + -0.000111325375, + -0.034172397, + 0.0058470387, + 0.0032379483, + 0.0043049995, + 0.00836916, + 0.008402609, + -0.0067736004, + -0.006325372, + -0.011821187, + -0.017393937, + -0.010315943, + 0.034145635, + -0.023495195, + -0.024886709, + -0.0065227263, + 0.009352586, + -0.0028984318, + -0.0081417, + 0.0067434954, + -0.00010526259, + 0.017929135, + -0.005482435, + -0.021381162, + -0.016805219, + 0.036313187, + 0.010275803, + -0.0069977148, + 0.0024518762, + 0.016551, + -0.016229881, + 0.028927455, + -0.0028348772, + 0.009673704, + -0.025435288, + 0.00079401647, + -0.0009600115, + -0.01616298, + -0.006566211, + -0.014824986, + 0.026626103, + -0.026331745, + 0.01357396, + 0.018477714, + 0.019012911, + -0.007706852, + 0.0039403955, + -0.030640088, + -0.0007739466, + 0.01633692, + 0.010670511, + -0.030506289, + -0.002204347, + -0.01103177, + 0.00070746493, + -0.02472615, + 0.006318682, + -0.0053720507, + -0.020872723, + 0.0075396025, + -0.0716095, + 0.025796546, + -0.012249346, + -0.021314263, + -0.0064223767, + 0.010289183, + -0.0053720507, + 0.0050475867, + -0.008810698, + 0.011821187, + -0.002915157, + -0.023040276, + 0.0055326098, + -0.0005305987, + 0.007512843, + -0.011540208, + -0.0032964854, + -0.0016081029, + 0.012891583, + 0.0042882743, + -0.023615614, + 0.0019484555, + -0.0001933821, + 0.014891886, + -0.04313696, + -0.01355389, + -0.012644054, + -0.0073589734, + 0.0218896, + 0.006345442, + 0.012135616, + -0.015761582, + 0.0216889, + 0.017046059, + -0.019815708, + -0.016765079, + 0.020364286, + 0.019922748, + 0.0021140324, + -0.020404426, + 0.0024535486, + -0.008429369, + 0.021140324, + -0.02503389, + -0.020792445, + -0.009988134, + 0.032379482, + 0.015467224, + 0.020966385, + 0.010168763, + 0.008362469, + 0.021113563, + 0.026572583, + -0.020578366, + -0.029542932, + -0.01096487, + -0.003000454, + -0.01906643, + -0.020150207, + -0.03559067, + 0.04102293, + 0.01091804, + 0.008509649, + 0.007874101, + 0.019347409, + 0.020886105, + 0.00550585, + -0.0024351513, + 0.04123701, + 0.010690581, + -0.025047269, + 0.018705172, + 0.017594635, + 0.01107191, + 0.00408423, + -0.01099832, + -0.007847342, + -0.022291, + -0.012844753, + 0.02487333, + -0.014035569, + 0.0099546835, + -0.048060786, + 0.0100550335, + -0.01936079, + -0.0029536244, + -0.016310161, + -0.0129451025, + 0.006843845, + 0.0110786, + -0.036634307, + 0.0078005116, + 0.0064156866, + 0.0033466604, + -0.009118437, + 0.00411099, + 0.007124824, + -0.009406106, + -0.008469509, + 0.023588855, + -0.0014592509, + 0.010369462, + -0.0064959666, + -0.037704702, + -0.008870908, + 0.017594635, + 0.025743026, + -0.010576852, + -0.03872158, + 0.010897971, + -0.008730418, + 0.010175453, + -0.019936128, + -0.0055961646, + 0.008737109, + -0.0044421437, + 0.011807807, + -0.01926713, + -0.023147317, + 0.022933237, + 0.00104949, + 0.025702886, + 0.0054489854, + -0.024311371, + -0.002759615, + 0.003043939, + 0.0053753955, + 0.0033031756, + 0.015333424, + 0.0010545074, + 0.017996034, + -0.018598132, + 0.0007463504, + 0.026666243, + -0.013346502, + 0.012764473, + 0.03591179, + 0.051111415, + -0.02156848, + 0.051486053, + 0.035082232, + -0.003030559, + 0.0067936704, + -0.034092117, + 0.003552377, + 0.010530022, + -0.01910657, + -0.01355389, + -0.026425404, + -0.0013789713, + -0.016136222, + -0.034600556, + 0.00680036, + -0.044501718, + -5.8798614e-05, + -0.016096082, + -0.008496269, + 0.005696514, + 0.007920931, + 0.03604559, + 0.0014550698, + 0.028445777, + 0.004743193, + -0.022344518, + -0.040300414, + 0.01667142, + 0.010563471, + -0.018062934, + -0.034092117, + 0.015333424, + 0.006817085, + -0.0055459896, + -0.01657776, + -0.001906643, + -0.018183354, + -0.019427689, + 0.029007735, + -0.00043150343, + 0.019347409, + 0.004733158, + 0.020096688, + -0.0045659086, + -0.01378135, + 0.016430581, + 0.0034219224, + -0.028044378, + -0.020484706, + 0.011727528, + ], +) diff --git a/tests/mocks.py b/tests/mocks.py index f08a02d6..89053918 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -1,12 +1,7 @@ -import json from collections import namedtuple -import openai.types from azure.core.credentials_async import AsyncTokenCredential -MOCK_EMBEDDING_DIMENSIONS = 1536 -MOCK_EMBEDDING_MODEL_NAME = "text-embedding-ada-002" - MockToken = namedtuple("MockToken", ["token", "expires_on", "value"]) @@ -25,84 +20,3 @@ async def get_token(self, uri): return MockToken("", 0, "") else: return MockToken("", 9999999999, "") - - -class MockAsyncPageIterator: - def __init__(self, data): - self.data = data - - def __aiter__(self): - return self - - async def __anext__(self): - if not self.data: - raise StopAsyncIteration - return self.data.pop(0) # This should be a list of dictionaries. - - -class MockCaption: - def __init__(self, text, highlights=None, additional_properties=None): - self.text = text - self.highlights = highlights or [] - self.additional_properties = additional_properties or {} - - -class MockResponse: - def __init__(self, text, status): - self.text = text - self.status = status - - async def text(self): - return self._text - - async def __aexit__(self, exc_type, exc, tb): - pass - - async def __aenter__(self): - return self - - async def json(self): - return json.loads(self.text) - - -class MockEmbeddingsClient: - def __init__(self, create_embedding_response: openai.types.CreateEmbeddingResponse): - self.create_embedding_response = create_embedding_response - - async def create(self, *args, **kwargs) -> openai.types.CreateEmbeddingResponse: - return self.create_embedding_response - - -class MockClient: - def __init__(self, embeddings_client): - self.embeddings = embeddings_client - - -def mock_computervision_response(): - return MockResponse( - status=200, - text=json.dumps( - { - "vector": [ - 0.011925711, - 0.023533698, - 0.010133852, - 0.0063544377, - -0.00038590943, - 0.0013952175, - 0.009054946, - -0.033573493, - -0.002028305, - ], - "modelVersion": "2022-04-11", - } - ), - ) - - -class MockSynthesisResult: - def __init__(self, result): - self.__result = result - - def get(self): - return self.__result diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py index 99af3fdc..1ff43f92 100644 --- a/tests/test_api_routes.py +++ b/tests/test_api_routes.py @@ -1,5 +1,83 @@ import pytest +from tests.data import test_data + + +@pytest.mark.asyncio +async def test_item_handler(test_client): + """test the item_handler route""" + response = test_client.get(f"/items/{test_data.id}") + response_data = response.json() + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["Content-Length"] == "405" + assert response_data["id"] == test_data.id + assert response_data["name"] == test_data.name + assert response_data["description"] == test_data.description + assert response_data["price"] == test_data.price + assert response_data["type"] == test_data.type + assert response_data["brand"] == test_data.brand + + +@pytest.mark.asyncio +async def test_item_handler_404(test_client): + """test the item_handler route with a non-existent item""" + item_id = 10000000 + response = test_client.get(f"/items/{item_id}") + + assert response.status_code == 404 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["Content-Length"] == "45" + assert bytes(f'{{"detail":"Item with ID {item_id} not found."}}', "utf-8") in response.content + + +@pytest.mark.asyncio +async def test_similar_handler(test_client): + """test the similar_handler route""" + response = test_client.get("/similar?id=1&n=1") + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["Content-Length"] == "428" + assert response.json() == [ + { + "id": 71, + "name": "Explorer Frost Boots", + "price": 149.99, + "distance": 0.47, + "type": "Footwear", + "brand": "Daybird", + "description": "The Explorer Frost Boots by Daybird are the perfect companion for " + "cold-weather adventures. These premium boots are designed with a waterproof and insulated " + "shell, keeping your feet warm and protected in icy conditions. The sleek black design " + "with blue accents adds a touch of style to your outdoor gear.", + } + ] + + +@pytest.mark.asyncio +async def test_similar_handler_422(test_client): + """test the similar_handler route with missing query parameters""" + response = test_client.get("/similar") + + assert response.status_code == 422 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["Content-Length"] == "88" + assert b'{"detail":[{"type":"missing","loc":["query","id"]' in response.content + + +@pytest.mark.asyncio +async def test_similar_handler_404(test_client): + """test the similar_handler route with a non-existent item""" + item_id = 10000000 + response = test_client.get(f"/similar?id={item_id}&n=1") + + assert response.status_code == 404 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["Content-Length"] == "45" + assert bytes(f'{{"detail":"Item with ID {item_id} not found."}}', "utf-8") in response.content + @pytest.mark.asyncio async def test_chat_non_json_415(test_client): From 86b79860877d1ccc355df960ea6234f6659e56e0 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 5 Jul 2024 19:37:33 +0000 Subject: [PATCH 13/31] add search_handler tests --- tests/conftest.py | 30 +++++++++++++++++++++++++++++- tests/test_api_routes.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index f46852c6..ed785cc2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,15 +2,20 @@ from pathlib import Path from unittest import mock +import openai +import openai.resources import pytest import pytest_asyncio from fastapi.testclient import TestClient +from openai.types import CreateEmbeddingResponse, Embedding +from openai.types.create_embedding_response import Usage from sqlalchemy.ext.asyncio import async_sessionmaker from fastapi_app import create_app from fastapi_app.postgres_engine import create_postgres_engine_from_env from fastapi_app.setup_postgres_database import create_db_schema from fastapi_app.setup_postgres_seeddata import seed_data +from tests.data import test_data from tests.mocks import MockAzureCredential POSTGRES_HOST = "localhost" @@ -80,9 +85,32 @@ def mock_default_azure_credential(mock_session_env): yield mock_default_azure_credential +@pytest.fixture(autouse=True) +def mock_openai_embedding(monkeypatch): + async def mock_acreate(*args, **kwargs): + return CreateEmbeddingResponse( + object="list", + data=[ + Embedding( + embedding=test_data.embeddings, + index=0, + object="embedding", + ) + ], + model="text-embedding-ada-002", + usage=Usage(prompt_tokens=8, total_tokens=8), + ) + + def patch(): + monkeypatch.setattr(openai.resources.AsyncEmbeddings, "create", mock_acreate) + + return patch + + @pytest_asyncio.fixture(scope="function") -async def test_client(monkeypatch, app, mock_default_azure_credential): +async def test_client(monkeypatch, app, mock_default_azure_credential, mock_openai_embedding): """Create a test client.""" + mock_openai_embedding() with TestClient(app) as test_client: yield test_client diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py index 1ff43f92..3f4cb681 100644 --- a/tests/test_api_routes.py +++ b/tests/test_api_routes.py @@ -79,6 +79,34 @@ async def test_similar_handler_404(test_client): assert bytes(f'{{"detail":"Item with ID {item_id} not found."}}', "utf-8") in response.content +@pytest.mark.asyncio +async def test_search_handler(test_client): + """test the search_handler route""" + response = test_client.get(f"/search?query={test_data.name}&top=1") + response_data = response.json()[0] + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["Content-Length"] == "407" + assert response_data["id"] == test_data.id + assert response_data["name"] == test_data.name + assert response_data["description"] == test_data.description + assert response_data["price"] == test_data.price + assert response_data["type"] == test_data.type + assert response_data["brand"] == test_data.brand + + +@pytest.mark.asyncio +async def test_search_handler_422(test_client): + """test the search_handler route with missing query parameters""" + response = test_client.get("/search") + + assert response.status_code == 422 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["Content-Length"] == "91" + assert b'{"detail":[{"type":"missing","loc":["query","query"]' in response.content + + @pytest.mark.asyncio async def test_chat_non_json_415(test_client): """test the chat route with a non-json request""" From 610a4183b910135d9041867abadf7306256dd889 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 5 Jul 2024 20:14:48 +0000 Subject: [PATCH 14/31] add chat tests --- tests/conftest.py | 119 +++++++++++++++- tests/test_api_routes.py | 286 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 403 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ed785cc2..f6839e9b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,11 @@ import pytest_asyncio from fastapi.testclient import TestClient from openai.types import CreateEmbeddingResponse, Embedding +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat.chat_completion import ( + ChatCompletionMessage, + Choice, +) from openai.types.create_embedding_response import Usage from sqlalchemy.ext.asyncio import async_sessionmaker @@ -107,10 +112,122 @@ def patch(): return patch +@pytest.fixture +def mock_openai_chatcompletion(monkeypatch): + class AsyncChatCompletionIterator: + def __init__(self, answer: str): + chunk_id = "test-id" + model = "gpt-35-turbo" + self.responses = [ + {"object": "chat.completion.chunk", "choices": [], "id": chunk_id, "model": model, "created": 1}, + { + "object": "chat.completion.chunk", + "choices": [{"delta": {"role": "assistant"}, "index": 0, "finish_reason": None}], + "id": chunk_id, + "model": model, + "created": 1, + }, + ] + # Split at << to simulate chunked responses + if answer.find("<<") > -1: + parts = answer.split("<<") + self.responses.append( + { + "object": "chat.completion.chunk", + "choices": [ + { + "delta": {"role": "assistant", "content": parts[0] + "<<"}, + "index": 0, + "finish_reason": None, + } + ], + "id": chunk_id, + "model": model, + "created": 1, + } + ) + self.responses.append( + { + "object": "chat.completion.chunk", + "choices": [ + {"delta": {"role": "assistant", "content": parts[1]}, "index": 0, "finish_reason": None} + ], + "id": chunk_id, + "model": model, + "created": 1, + } + ) + self.responses.append( + { + "object": "chat.completion.chunk", + "choices": [{"delta": {"role": None, "content": None}, "index": 0, "finish_reason": "stop"}], + "id": chunk_id, + "model": model, + "created": 1, + } + ) + else: + self.responses.append( + { + "object": "chat.completion.chunk", + "choices": [{"delta": {"content": answer}, "index": 0, "finish_reason": None}], + "id": chunk_id, + "model": model, + "created": 1, + } + ) + + def __aiter__(self): + return self + + async def __anext__(self): + if self.responses: + return ChatCompletionChunk.model_validate(self.responses.pop(0)) + else: + raise StopAsyncIteration + + async def mock_acreate(*args, **kwargs): + messages = kwargs["messages"] + last_question = messages[-1]["content"] + if last_question == "Generate search query for: What is the capital of France?": + answer = "capital of France" + elif last_question == "Generate search query for: Are interest rates high?": + answer = "interest rates" + elif isinstance(last_question, list) and last_question[2].get("image_url"): + answer = "From the provided sources, the impact of interest rates and GDP growth on " + "financial markets can be observed through the line graph. [Financial Market Analysis Report 2023-7.png]" + else: + answer = "The capital of France is Paris. [Benefit_Options-2.pdf]." + if messages[0]["content"].find("Generate 3 very brief follow-up questions") > -1: + answer = "The capital of France is Paris. [Benefit_Options-2.pdf]. <>" + if "stream" in kwargs and kwargs["stream"] is True: + return AsyncChatCompletionIterator(answer) + else: + return ChatCompletion( + object="chat.completion", + choices=[ + Choice( + message=ChatCompletionMessage(role="assistant", content=answer), finish_reason="stop", index=0 + ) + ], + id="test-123", + created=0, + model="test-model", + ) + + def patch(): + monkeypatch.setattr(openai.resources.chat.completions.AsyncCompletions, "create", mock_acreate) + + return patch + + @pytest_asyncio.fixture(scope="function") -async def test_client(monkeypatch, app, mock_default_azure_credential, mock_openai_embedding): +async def test_client( + monkeypatch, app, mock_default_azure_credential, mock_openai_embedding, mock_openai_chatcompletion +): """Create a test client.""" mock_openai_embedding() + mock_openai_chatcompletion() with TestClient(app) as test_client: yield test_client diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py index 3f4cb681..ee913c04 100644 --- a/tests/test_api_routes.py +++ b/tests/test_api_routes.py @@ -108,7 +108,291 @@ async def test_search_handler_422(test_client): @pytest.mark.asyncio -async def test_chat_non_json_415(test_client): +async def test_simple_chat_flow(test_client): + """test the simple chat flow route with hybrid retrieval mode""" + response = test_client.post( + "/chat", + json={ + "context": { + "overrides": {"top": 1, "use_advanced_flow": False, "retrieval_mode": "hybrid", "temperature": 0.3} + }, + "messages": [{"content": "What is the capital of France?", "role": "user"}], + }, + ) + response_data = response.json() + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/json" + assert response_data["message"]["content"] == "The capital of France is Paris. [Benefit_Options-2.pdf]." + assert response_data["message"]["role"] == "assistant" + assert response_data["context"]["data_points"] == { + "1": { + "id": 1, + "name": "Wanderer Black Hiking Boots", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all " + "your outdoor adventures. These boots are made with a waterproof " + "leather upper and a durable rubber sole for superior traction. With " + "their cushioned insole and padded collar, these boots will keep you " + "comfortable all day long.", + "brand": "Daybird", + "price": 109.99, + "type": "Footwear", + } + } + assert response_data["context"]["thoughts"] == [ + { + "description": "What is the capital of France?", + "props": {"text_search": True, "top": 1, "vector_search": True}, + "title": "Search query for database", + }, + { + "description": [ + { + "brand": "Daybird", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your " + "outdoor adventures. These boots are made with a waterproof leather upper and a durable " + "rubber sole for superior traction. With their cushioned insole and padded collar, " + "these boots will keep you comfortable all day long.", + "id": 1, + "name": "Wanderer Black Hiking Boots", + "price": 109.99, + "type": "Footwear", + }, + ], + "props": {}, + "title": "Search results", + }, + { + "description": [ + "{'role': 'system', 'content': \"Assistant helps customers with questions about " + "products.\\nRespond as if you are a salesperson helping a customer in a store. " + "Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the " + "products.\\nIf there isn't enough information below, say you don't know.\\nDo not " + "generate answers that don't use the sources below.\\nEach product has an ID in brackets " + "followed by colon and the product details.\\nAlways include the product ID for each product " + "you use in the response.\\nUse square brackets to reference the source, " + "for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", + "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer " + "Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for " + "all your outdoor adventures. These boots are made with a waterproof leather upper and a durable " + "rubber sole for superior traction. With their cushioned insole and padded collar, " + "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " + 'Type:Footwear\\n\\n"}', + ], + "props": {"model": "gpt-35-turbo"}, + "title": "Prompt to generate answer", + }, + ] + assert response_data["context"]["thoughts"] == [ + { + "description": "What is the capital of France?", + "props": {"text_search": True, "top": 1, "vector_search": True}, + "title": "Search query for database", + }, + { + "description": [ + { + "brand": "Daybird", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all " + "your outdoor adventures. These boots are made with a waterproof leather upper and " + "a durable rubber sole for superior traction. With their cushioned insole and padded " + "collar, these boots will keep you comfortable all day long.", + "id": 1, + "name": "Wanderer Black Hiking Boots", + "price": 109.99, + "type": "Footwear", + } + ], + "props": {}, + "title": "Search results", + }, + { + "description": [ + "{'role': 'system', 'content': \"Assistant helps customers with questions about " + "products.\\nRespond as if you are a salesperson helping a customer in a store. " + "Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the " + "products.\\nIf there isn't enough information below, say you don't know.\\nDo not " + "generate answers that don't use the sources below.\\nEach product has an ID in brackets " + "followed by colon and the product details.\\nAlways include the product ID for each product " + "you use in the response.\\nUse square brackets to reference the source, " + "for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", + "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer " + "Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for " + "all your outdoor adventures. These boots are made with a waterproof leather upper and a durable " + "rubber sole for superior traction. With their cushioned insole and padded collar, " + "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " + 'Type:Footwear\\n\\n"}', + ], + "props": {"model": "gpt-35-turbo"}, + "title": "Prompt to generate answer", + }, + ] + assert response_data["session_state"] is None + + +@pytest.mark.asyncio +async def test_advanced_chat_flow(test_client): + """test the advanced chat flow route with hybrid retrieval mode""" + response = test_client.post( + "/chat", + json={ + "context": { + "overrides": {"top": 1, "use_advanced_flow": True, "retrieval_mode": "hybrid", "temperature": 0.3} + }, + "messages": [{"content": "What is the capital of France?", "role": "user"}], + }, + ) + response_data = response.json() + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/json" + assert response_data["message"]["content"] == "The capital of France is Paris. [Benefit_Options-2.pdf]." + assert response_data["message"]["role"] == "assistant" + assert response_data["context"]["data_points"] == { + "1": { + "id": 1, + "name": "Wanderer Black Hiking Boots", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all " + "your outdoor adventures. These boots are made with a waterproof " + "leather upper and a durable rubber sole for superior traction. With " + "their cushioned insole and padded collar, these boots will keep you " + "comfortable all day long.", + "brand": "Daybird", + "price": 109.99, + "type": "Footwear", + } + } + assert response_data["context"]["thoughts"] == [ + { + "description": [ + "{'role': 'system', 'content': 'Below is a history of the " + "conversation so far, and a new question asked by the user that " + "needs to be answered by searching database rows.\\nYou have " + "access to an Azure PostgreSQL database with an items table that " + "has columns for title, description, brand, price, and " + "type.\\nGenerate a search query based on the conversation and the " + "new question.\\nIf the question is not in English, translate the " + "question to English before generating the search query.\\nIf you " + "cannot generate a search query, return the original user " + "question.\\nDO NOT return anything besides the query.'}", + "{'role': 'user', 'content': 'What is the capital of France?'}", + ], + "props": { + "model": "gpt-35-turbo", + }, + "title": "Prompt to generate search arguments", + }, + { + "description": "The capital of France is Paris. [Benefit_Options-2.pdf].", + "props": {"filters": [], "text_search": True, "top": 1, "vector_search": True}, + "title": "Search using generated search arguments", + }, + { + "description": [ + { + "brand": "Daybird", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your " + "outdoor adventures. These boots are made with a waterproof leather upper and a durable " + "rubber sole for superior traction. With their cushioned insole and padded collar, " + "these boots will keep you comfortable all day long.", + "id": 1, + "name": "Wanderer Black Hiking Boots", + "price": 109.99, + "type": "Footwear", + }, + ], + "props": {}, + "title": "Search results", + }, + { + "description": [ + "{'role': 'system', 'content': \"Assistant helps customers with questions about " + "products.\\nRespond as if you are a salesperson helping a customer in a store. " + "Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the " + "products.\\nIf there isn't enough information below, say you don't know.\\nDo not " + "generate answers that don't use the sources below.\\nEach product has an ID in brackets " + "followed by colon and the product details.\\nAlways include the product ID for each product " + "you use in the response.\\nUse square brackets to reference the source, " + "for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", + "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer " + "Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for " + "all your outdoor adventures. These boots are made with a waterproof leather upper and a durable " + "rubber sole for superior traction. With their cushioned insole and padded collar, " + "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " + 'Type:Footwear\\n\\n"}', + ], + "props": {"model": "gpt-35-turbo"}, + "title": "Prompt to generate answer", + }, + ] + assert response_data["context"]["thoughts"] == [ + { + "description": [ + "{'role': 'system', 'content': 'Below is a history of the " + "conversation so far, and a new question asked by the user that " + "needs to be answered by searching database rows.\\nYou have " + "access to an Azure PostgreSQL database with an items table that " + "has columns for title, description, brand, price, and " + "type.\\nGenerate a search query based on the conversation and the " + "new question.\\nIf the question is not in English, translate the " + "question to English before generating the search query.\\nIf you " + "cannot generate a search query, return the original user " + "question.\\nDO NOT return anything besides the query.'}", + "{'role': 'user', 'content': 'What is the capital of France?'}", + ], + "props": { + "model": "gpt-35-turbo", + }, + "title": "Prompt to generate search arguments", + }, + { + "description": "The capital of France is Paris. [Benefit_Options-2.pdf].", + "props": {"filters": [], "text_search": True, "top": 1, "vector_search": True}, + "title": "Search using generated search arguments", + }, + { + "description": [ + { + "brand": "Daybird", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all " + "your outdoor adventures. These boots are made with a waterproof leather upper and " + "a durable rubber sole for superior traction. With their cushioned insole and padded " + "collar, these boots will keep you comfortable all day long.", + "id": 1, + "name": "Wanderer Black Hiking Boots", + "price": 109.99, + "type": "Footwear", + } + ], + "props": {}, + "title": "Search results", + }, + { + "description": [ + "{'role': 'system', 'content': \"Assistant helps customers with questions about " + "products.\\nRespond as if you are a salesperson helping a customer in a store. " + "Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the " + "products.\\nIf there isn't enough information below, say you don't know.\\nDo not " + "generate answers that don't use the sources below.\\nEach product has an ID in brackets " + "followed by colon and the product details.\\nAlways include the product ID for each product " + "you use in the response.\\nUse square brackets to reference the source, " + "for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", + "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer " + "Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for " + "all your outdoor adventures. These boots are made with a waterproof leather upper and a durable " + "rubber sole for superior traction. With their cushioned insole and padded collar, " + "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " + 'Type:Footwear\\n\\n"}', + ], + "props": {"model": "gpt-35-turbo"}, + "title": "Prompt to generate answer", + }, + ] + assert response_data["session_state"] is None + + +@pytest.mark.asyncio +async def test_chat_non_json_422(test_client): """test the chat route with a non-json request""" response = test_client.post("/chat") From eac85f0306930048242855c0799854656698224f Mon Sep 17 00:00:00 2001 From: John Aziz Date: Sat, 6 Jul 2024 12:52:33 +0000 Subject: [PATCH 15/31] remove content length assertion to allow for flexibility in response lenght --- tests/test_api_routes.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py index ee913c04..dab506cc 100644 --- a/tests/test_api_routes.py +++ b/tests/test_api_routes.py @@ -11,7 +11,6 @@ async def test_item_handler(test_client): assert response.status_code == 200 assert response.headers["Content-Type"] == "application/json" - assert response.headers["Content-Length"] == "405" assert response_data["id"] == test_data.id assert response_data["name"] == test_data.name assert response_data["description"] == test_data.description @@ -39,7 +38,6 @@ async def test_similar_handler(test_client): assert response.status_code == 200 assert response.headers["Content-Type"] == "application/json" - assert response.headers["Content-Length"] == "428" assert response.json() == [ { "id": 71, @@ -87,7 +85,6 @@ async def test_search_handler(test_client): assert response.status_code == 200 assert response.headers["Content-Type"] == "application/json" - assert response.headers["Content-Length"] == "407" assert response_data["id"] == test_data.id assert response_data["name"] == test_data.name assert response_data["description"] == test_data.description From b36a2609c022c8ca58ca6fb8ae05c287eede91fb Mon Sep 17 00:00:00 2001 From: John Aziz Date: Sat, 6 Jul 2024 12:53:10 +0000 Subject: [PATCH 16/31] add azure openai env vars and use session scoped fixutres --- tests/conftest.py | 53 ++++++++++++++++++++++++----------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f6839e9b..c4764d6e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,9 +53,16 @@ def mock_session_env(monkeypatch_session): monkeypatch_session.setenv("RUNNING_IN_PRODUCTION", "False") # Azure Subscription monkeypatch_session.setenv("AZURE_SUBSCRIPTION_ID", "test-storage-subid") - # OpenAI - monkeypatch_session.setenv("AZURE_OPENAI_CHATGPT_MODEL", "gpt-35-turbo") - monkeypatch_session.setenv("OPENAI_API_KEY", "fakekey") + # Azure OpenAI + monkeypatch_session.setenv("OPENAI_CHAT_HOST", "azure") + monkeypatch_session.setenv("OPENAI_EMBED_HOST", "azure") + monkeypatch_session.setenv("AZURE_OPENAI_VERSION", "2024-03-01-preview") + monkeypatch_session.setenv("AZURE_OPENAI_CHAT_DEPLOYMENT", "gpt-35-turbo") + monkeypatch_session.setenv("AZURE_OPENAI_CHAT_MODEL", "gpt-35-turbo") + monkeypatch_session.setenv("AZURE_OPENAI_EMBED_DEPLOYMENT", "text-embedding-ada-002") + monkeypatch_session.setenv("AZURE_OPENAI_EMBED_MODEL", "text-embedding-ada-002") + monkeypatch_session.setenv("AZURE_OPENAI_EMBED_MODEL_DIMENSIONS", "1536") + monkeypatch_session.setenv("AZURE_OPENAI_KEY", "fakekey") # Allowed Origin monkeypatch_session.setenv("ALLOWED_ORIGIN", "https://frontend.com") @@ -82,16 +89,8 @@ async def app(mock_session_env): return app -@pytest.fixture(scope="function") -def mock_default_azure_credential(mock_session_env): - """Mock the Azure credential for testing.""" - with mock.patch("azure.identity.DefaultAzureCredential") as mock_default_azure_credential: - mock_default_azure_credential.return_value = MockAzureCredential() - yield mock_default_azure_credential - - -@pytest.fixture(autouse=True) -def mock_openai_embedding(monkeypatch): +@pytest.fixture(scope="session") +def mock_openai_embedding(monkeypatch_session): async def mock_acreate(*args, **kwargs): return CreateEmbeddingResponse( object="list", @@ -106,14 +105,13 @@ async def mock_acreate(*args, **kwargs): usage=Usage(prompt_tokens=8, total_tokens=8), ) - def patch(): - monkeypatch.setattr(openai.resources.AsyncEmbeddings, "create", mock_acreate) + monkeypatch_session.setattr(openai.resources.AsyncEmbeddings, "create", mock_acreate) - return patch + yield -@pytest.fixture -def mock_openai_chatcompletion(monkeypatch): +@pytest.fixture(scope="session") +def mock_openai_chatcompletion(monkeypatch_session): class AsyncChatCompletionIterator: def __init__(self, answer: str): chunk_id = "test-id" @@ -215,19 +213,22 @@ async def mock_acreate(*args, **kwargs): model="test-model", ) - def patch(): - monkeypatch.setattr(openai.resources.chat.completions.AsyncCompletions, "create", mock_acreate) + monkeypatch_session.setattr(openai.resources.chat.completions.AsyncCompletions, "create", mock_acreate) - return patch + yield + + +@pytest.fixture(scope="function") +def mock_default_azure_credential(mock_session_env): + """Mock the Azure credential for testing.""" + with mock.patch("azure.identity.DefaultAzureCredential") as mock_default_azure_credential: + mock_default_azure_credential.return_value = MockAzureCredential() + yield @pytest_asyncio.fixture(scope="function") -async def test_client( - monkeypatch, app, mock_default_azure_credential, mock_openai_embedding, mock_openai_chatcompletion -): +async def test_client(app, mock_default_azure_credential, mock_openai_embedding, mock_openai_chatcompletion): """Create a test client.""" - mock_openai_embedding() - mock_openai_chatcompletion() with TestClient(app) as test_client: yield test_client From ea722d7a3364e25f886e4aed7e0edc62bff7e94a Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Thu, 11 Jul 2024 16:29:08 -0700 Subject: [PATCH 17/31] Typing improvements --- requirements-dev.txt | 1 + src/fastapi_app/openai_clients.py | 51 ++++++++++++++++++++----------- src/fastapi_app/rag_advanced.py | 4 ++- src/fastapi_app/rag_simple.py | 8 ++--- 4 files changed, 41 insertions(+), 23 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 4c00a59c..5b2e11a0 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,3 +6,4 @@ pip-compile-cross-platform pytest pytest-cov pytest-asyncio +mypy diff --git a/src/fastapi_app/openai_clients.py b/src/fastapi_app/openai_clients.py index 73d7267a..affee39c 100644 --- a/src/fastapi_app/openai_clients.py +++ b/src/fastapi_app/openai_clients.py @@ -8,24 +8,31 @@ async def create_openai_chat_client(azure_credential): + openai_chat_client: openai.AsyncAzureOpenAI | openai.AsyncOpenAI OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST") if OPENAI_CHAT_HOST == "azure": - client_args = {} + api_version = os.environ["AZURE_OPENAI_VERSION"] + azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"] + azure_deployment = os.environ["AZURE_OPENAI_EMBED_DEPLOYMENT"] if api_key := os.getenv("AZURE_OPENAI_KEY"): logger.info("Authenticating to Azure OpenAI using API key...") - client_args["api_key"] = api_key + openai_chat_client = openai.AsyncAzureOpenAI( + api_version=api_version, + azure_endpoint=azure_endpoint, + azure_deployment=azure_deployment, + api_key=api_key, + ) else: logger.info("Authenticating to Azure OpenAI using Azure Identity...") token_provider = azure.identity.get_bearer_token_provider( azure_credential, "https://cognitiveservices.azure.com/.default" ) - client_args["azure_ad_token_provider"] = token_provider - openai_chat_client = openai.AsyncAzureOpenAI( - api_version=os.getenv("AZURE_OPENAI_VERSION"), - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), - azure_deployment=os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT"), - **client_args, - ) + openai_chat_client = openai.AsyncAzureOpenAI( + api_version=api_version, + azure_endpoint=azure_endpoint, + azure_deployment=azure_deployment, + azure_ad_token_provider=token_provider, + ) openai_chat_model = os.getenv("AZURE_OPENAI_CHAT_MODEL") elif OPENAI_CHAT_HOST == "ollama": logger.info("Authenticating to OpenAI using Ollama...") @@ -43,24 +50,32 @@ async def create_openai_chat_client(azure_credential): async def create_openai_embed_client(azure_credential): + openai_embed_client: openai.AsyncAzureOpenAI | openai.AsyncOpenAI OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST") if OPENAI_EMBED_HOST == "azure": - client_args = {} + api_version = os.environ["AZURE_OPENAI_VERSION"] + azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"] + azure_deployment = os.environ["AZURE_OPENAI_EMBED_DEPLOYMENT"] if api_key := os.getenv("AZURE_OPENAI_KEY"): logger.info("Authenticating to Azure OpenAI using API key...") - client_args["api_key"] = api_key + openai_embed_client = openai.AsyncAzureOpenAI( + api_version=api_version, + azure_endpoint=azure_endpoint, + azure_deployment=azure_deployment, + api_key=api_key, + ) else: logger.info("Authenticating to Azure OpenAI using Azure Identity...") token_provider = azure.identity.get_bearer_token_provider( azure_credential, "https://cognitiveservices.azure.com/.default" ) - client_args["azure_ad_token_provider"] = token_provider - openai_embed_client = openai.AsyncAzureOpenAI( - api_version=os.getenv("AZURE_OPENAI_VERSION"), - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), - azure_deployment=os.getenv("AZURE_OPENAI_EMBED_DEPLOYMENT"), - **client_args, - ) + openai_embed_client = openai.AsyncAzureOpenAI( + api_version=api_version, + azure_endpoint=azure_endpoint, + azure_deployment=azure_deployment, + azure_ad_token_provider=token_provider, + ) + openai_embed_model = os.getenv("AZURE_OPENAI_EMBED_MODEL") openai_embed_dimensions = os.getenv("AZURE_OPENAI_EMBED_DIMENSIONS") else: diff --git a/src/fastapi_app/rag_advanced.py b/src/fastapi_app/rag_advanced.py index d5cef924..a5220ba5 100644 --- a/src/fastapi_app/rag_advanced.py +++ b/src/fastapi_app/rag_advanced.py @@ -32,13 +32,15 @@ def __init__( self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read() async def run( - self, messages: list[dict], overrides: dict[str, Any] = {} + self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {} ) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]: text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] top = overrides.get("top", 3) original_user_query = messages[-1]["content"] + if not isinstance(original_user_query, str): + raise ValueError("The most recent message content must be a string.") past_messages = messages[:-1] # Generate an optimized keyword search query based on the chat history and the last question diff --git a/src/fastapi_app/rag_simple.py b/src/fastapi_app/rag_simple.py index 271dae86..9a4beb96 100644 --- a/src/fastapi_app/rag_simple.py +++ b/src/fastapi_app/rag_simple.py @@ -1,8 +1,6 @@ import pathlib from collections.abc import AsyncGenerator -from typing import ( - Any, -) +from typing import Any from openai import AsyncOpenAI from openai.types.chat import ChatCompletion, ChatCompletionMessageParam @@ -30,13 +28,15 @@ def __init__( self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read() async def run( - self, messages: list[dict], overrides: dict[str, Any] = {} + self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {} ) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]: text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] top = overrides.get("top", 3) original_user_query = messages[-1]["content"] + if not isinstance(original_user_query, str): + raise ValueError("The most recent message content must be a string.") past_messages = messages[:-1] # Retrieve relevant items from the database From addd3dafe6fadc4c299884f5930bf14c8d70b475 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 12 Jul 2024 00:27:27 +0000 Subject: [PATCH 18/31] fix mypy module error --- src/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/__init__.py diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index e69de29b..00000000 From 5261ab213490bd04698b105885f5e2a4042602f7 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 12 Jul 2024 01:02:36 +0000 Subject: [PATCH 19/31] typing improvements --- src/fastapi_app/api_routes.py | 10 +++++----- src/fastapi_app/setup_postgres_seeddata.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/fastapi_app/api_routes.py b/src/fastapi_app/api_routes.py index 93fd7ee8..e80d051f 100644 --- a/src/fastapi_app/api_routes.py +++ b/src/fastapi_app/api_routes.py @@ -72,19 +72,19 @@ async def chat_handler(chat_request: ChatRequest): embed_dimensions=global_storage.openai_embed_dimensions, ) if overrides.get("use_advanced_flow"): - ragchat = AdvancedRAGChat( + run_ragchat = AdvancedRAGChat( searcher=searcher, openai_chat_client=global_storage.openai_chat_client, chat_model=global_storage.openai_chat_model, chat_deployment=global_storage.openai_chat_deployment, - ) + ).run else: - ragchat = SimpleRAGChat( + run_ragchat = SimpleRAGChat( searcher=searcher, openai_chat_client=global_storage.openai_chat_client, chat_model=global_storage.openai_chat_model, chat_deployment=global_storage.openai_chat_deployment, - ) + ).run - response: RetrievalResponse = await ragchat.run(messages, overrides=overrides) + response = await run_ragchat(messages, overrides=overrides) return response diff --git a/src/fastapi_app/setup_postgres_seeddata.py b/src/fastapi_app/setup_postgres_seeddata.py index e8719c90..e7a80699 100644 --- a/src/fastapi_app/setup_postgres_seeddata.py +++ b/src/fastapi_app/setup_postgres_seeddata.py @@ -36,8 +36,8 @@ async def seed_data(engine): with open(os.path.join(current_dir, "seed_data.json")) as f: catalog_items = json.load(f) for catalog_item in catalog_items: - item = await session.execute(select(Item).filter(Item.id == catalog_item["Id"])) - if item.scalars().first(): + db_item = await session.execute(select(Item).filter(Item.id == catalog_item["Id"])) + if db_item.scalars().first(): continue item = Item( id=catalog_item["Id"], From 39abb0fa8c7fede9987843eeacd9a1870bfe15c0 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 12 Jul 2024 01:02:55 +0000 Subject: [PATCH 20/31] ignore pgvector from mypy checks --- pyproject.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a0f8ab4f..5905aa19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,5 +16,11 @@ testpaths = ["tests"] pythonpath = ['src'] filterwarnings = ["ignore::DeprecationWarning"] +[[tool.mypy.overrides]] +module = [ + "pgvector.*", +] +ignore_missing_imports = true + [tool.coverage.report] show_missing = true From 121b3417c509ff41f6aa893bf3bf348e40d5d4f9 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 12 Jul 2024 01:03:32 +0000 Subject: [PATCH 21/31] follow sqlalchmey example for using columns() --- src/fastapi_app/postgres_searcher.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/fastapi_app/postgres_searcher.py b/src/fastapi_app/postgres_searcher.py index ab3197d2..8d16da27 100644 --- a/src/fastapi_app/postgres_searcher.py +++ b/src/fastapi_app/postgres_searcher.py @@ -1,6 +1,6 @@ from openai import AsyncOpenAI from pgvector.utils import to_db -from sqlalchemy import Float, Integer, select, text +from sqlalchemy import Float, Integer, column, select, text from sqlalchemy.ext.asyncio import async_sessionmaker from fastapi_app.embeddings import compute_text_embedding @@ -78,11 +78,11 @@ async def search( """ if query_text is not None and len(query_vector) > 0: - sql = text(hybrid_query).columns(id=Integer, score=Float) + sql = text(hybrid_query).columns(column("id", Integer), column("score", Float)) elif len(query_vector) > 0: - sql = text(vector_query).columns(id=Integer, rank=Integer) + sql = text(vector_query).columns(column("id", Integer), column("rank", Integer)) elif query_text is not None: - sql = text(fulltext_query).columns(id=Integer, rank=Integer) + sql = text(fulltext_query).columns(column("id", Integer), column("rank", Integer)) else: raise ValueError("Both query text and query vector are empty") @@ -113,7 +113,7 @@ async def search_and_embed( Search items by query text. Optionally converts the query text to a vector if enable_vector_search is True. """ vector: list[float] = [] - if enable_vector_search: + if enable_vector_search and query_text is not None: vector = await compute_text_embedding( query_text, self.openai_embed_client, From d0a9a02b1983d0d3b5b946c43372522f16d82e43 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 12 Jul 2024 01:03:52 +0000 Subject: [PATCH 22/31] reimplement abstract functions --- tests/mocks.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/mocks.py b/tests/mocks.py index 89053918..eb96e26c 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -1,4 +1,5 @@ from collections import namedtuple +from types import TracebackType from azure.core.credentials_async import AsyncTokenCredential @@ -9,6 +10,17 @@ class MockAzureCredential(AsyncTokenCredential): async def get_token(self, uri): return MockToken("", 9999999999, "") + async def close(self) -> None: + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + pass + class MockAzureCredentialExpired(AsyncTokenCredential): def __init__(self): From c5d99f513e3344e091eed7ce2fa65571b1c7d185 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 12 Jul 2024 10:21:54 +0000 Subject: [PATCH 23/31] fix typo in env var name --- src/fastapi_app/openai_clients.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastapi_app/openai_clients.py b/src/fastapi_app/openai_clients.py index affee39c..b69856c8 100644 --- a/src/fastapi_app/openai_clients.py +++ b/src/fastapi_app/openai_clients.py @@ -13,7 +13,7 @@ async def create_openai_chat_client(azure_credential): if OPENAI_CHAT_HOST == "azure": api_version = os.environ["AZURE_OPENAI_VERSION"] azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"] - azure_deployment = os.environ["AZURE_OPENAI_EMBED_DEPLOYMENT"] + azure_deployment = os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"] if api_key := os.getenv("AZURE_OPENAI_KEY"): logger.info("Authenticating to Azure OpenAI using API key...") openai_chat_client = openai.AsyncAzureOpenAI( From 5c17e9a404e772a40d24562c34aeb52f97441d13 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 12 Jul 2024 13:06:36 +0000 Subject: [PATCH 24/31] use fastapi dependency instead of global storage --- src/fastapi_app/__init__.py | 45 ++----- src/fastapi_app/api_routes.py | 90 ------------- src/fastapi_app/dependencies.py | 123 ++++++++++++++++++ src/fastapi_app/embeddings.py | 8 +- src/fastapi_app/globals.py | 13 -- src/fastapi_app/openai_clients.py | 19 ++- src/fastapi_app/postgres_searcher.py | 35 +++-- src/fastapi_app/rag_advanced.py | 4 +- src/fastapi_app/rag_simple.py | 4 +- src/fastapi_app/routes/api_routes.py | 99 ++++++++++++++ .../{ => routes}/frontend_routes.py | 2 +- src/fastapi_app/update_embeddings.py | 14 +- tests/test_api_routes.py | 16 +-- 13 files changed, 280 insertions(+), 192 deletions(-) delete mode 100644 src/fastapi_app/api_routes.py create mode 100644 src/fastapi_app/dependencies.py delete mode 100644 src/fastapi_app/globals.py create mode 100644 src/fastapi_app/routes/api_routes.py rename src/fastapi_app/{ => routes}/frontend_routes.py (90%) diff --git a/src/fastapi_app/__init__.py b/src/fastapi_app/__init__.py index b915ac3e..a11d2913 100644 --- a/src/fastapi_app/__init__.py +++ b/src/fastapi_app/__init__.py @@ -2,14 +2,12 @@ import logging import os -import azure.identity from dotenv import load_dotenv from environs import Env from fastapi import FastAPI -from .globals import global_storage -from .openai_clients import create_openai_chat_client, create_openai_embed_client -from .postgres_engine import create_postgres_engine_from_env +from fastapi_app.dependencies import get_azure_credentials +from fastapi_app.postgres_engine import create_postgres_engine_from_env logger = logging.getLogger("ragapp") @@ -18,34 +16,8 @@ async def lifespan(app: FastAPI): load_dotenv(override=True) - azure_credential: azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential | None = None - try: - if client_id := os.getenv("APP_IDENTITY_ID"): - # Authenticate using a user-assigned managed identity on Azure - # See web.bicep for value of APP_IDENTITY_ID - logger.info( - "Using managed identity for client ID %s", - client_id, - ) - azure_credential = azure.identity.ManagedIdentityCredential(client_id=client_id) - else: - azure_credential = azure.identity.DefaultAzureCredential() - except Exception as e: - logger.warning("Failed to authenticate to Azure: %s", e) - + azure_credential = await get_azure_credentials() engine = await create_postgres_engine_from_env(azure_credential) - global_storage.engine = engine - - openai_chat_client, openai_chat_model = await create_openai_chat_client(azure_credential) - global_storage.openai_chat_client = openai_chat_client - global_storage.openai_chat_model = openai_chat_model - - openai_embed_client, openai_embed_model, openai_embed_dimensions = await create_openai_embed_client( - azure_credential - ) - global_storage.openai_embed_client = openai_embed_client - global_storage.openai_embed_model = openai_embed_model - global_storage.openai_embed_dimensions = openai_embed_dimensions yield @@ -55,17 +27,16 @@ async def lifespan(app: FastAPI): def create_app(testing: bool = False): env = Env() - if not os.getenv("RUNNING_IN_PRODUCTION"): + if os.getenv("RUNNING_IN_PRODUCTION"): + logging.basicConfig(level=logging.WARNING) + else: if not testing: - env.read_env(".env") + env.read_env(".env", override=True) logging.basicConfig(level=logging.INFO) - else: - logging.basicConfig(level=logging.WARNING) app = FastAPI(docs_url="/docs", lifespan=lifespan) - from . import api_routes # noqa - from . import frontend_routes # noqa + from fastapi_app.routes import api_routes, frontend_routes app.include_router(api_routes.router) app.mount("/", frontend_routes.router) diff --git a/src/fastapi_app/api_routes.py b/src/fastapi_app/api_routes.py deleted file mode 100644 index e80d051f..00000000 --- a/src/fastapi_app/api_routes.py +++ /dev/null @@ -1,90 +0,0 @@ -import fastapi -from fastapi import HTTPException -from sqlalchemy import select -from sqlalchemy.ext.asyncio import async_sessionmaker - -from fastapi_app.api_models import ChatRequest -from fastapi_app.globals import global_storage -from fastapi_app.postgres_models import Item -from fastapi_app.postgres_searcher import PostgresSearcher -from fastapi_app.rag_advanced import AdvancedRAGChat -from fastapi_app.rag_simple import SimpleRAGChat - -from .api_models import RetrievalResponse - -router = fastapi.APIRouter() - - -@router.get("/items/{id}") -async def item_handler(id: int): - """A simple API to get an item by ID.""" - async_session_maker = async_sessionmaker(global_storage.engine, expire_on_commit=False) - async with async_session_maker() as session: - item = (await session.scalars(select(Item).where(Item.id == id))).first() - if not item: - raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404) - return item.to_dict() - - -@router.get("/similar") -async def similar_handler(id: int, n: int = 5): - """A similarity API to find items similar to items with given ID.""" - async_session_maker = async_sessionmaker(global_storage.engine, expire_on_commit=False) - async with async_session_maker() as session: - item = (await session.scalars(select(Item).where(Item.id == id))).first() - if not item: - raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404) - closest = await session.execute( - select(Item, Item.embedding.l2_distance(item.embedding)) - .filter(Item.id != id) - .order_by(Item.embedding.l2_distance(item.embedding)) - .limit(n) - ) - return [item.to_dict() | {"distance": round(distance, 2)} for item, distance in closest] - - -@router.get("/search") -async def search_handler(query: str, top: int = 5, enable_vector_search: bool = True, enable_text_search: bool = True): - """A search API to find items based on a query.""" - searcher = PostgresSearcher( - global_storage.engine, - openai_embed_client=global_storage.openai_embed_client, - embed_deployment=global_storage.openai_embed_deployment, - embed_model=global_storage.openai_embed_model, - embed_dimensions=global_storage.openai_embed_dimensions, - ) - results = await searcher.search_and_embed( - query, top=top, enable_vector_search=enable_vector_search, enable_text_search=enable_text_search - ) - return [item.to_dict() for item in results] - - -@router.post("/chat", response_model=RetrievalResponse) -async def chat_handler(chat_request: ChatRequest): - messages = [message.model_dump() for message in chat_request.messages] - overrides = chat_request.context.get("overrides", {}) - - searcher = PostgresSearcher( - global_storage.engine, - openai_embed_client=global_storage.openai_embed_client, - embed_deployment=global_storage.openai_embed_deployment, - embed_model=global_storage.openai_embed_model, - embed_dimensions=global_storage.openai_embed_dimensions, - ) - if overrides.get("use_advanced_flow"): - run_ragchat = AdvancedRAGChat( - searcher=searcher, - openai_chat_client=global_storage.openai_chat_client, - chat_model=global_storage.openai_chat_model, - chat_deployment=global_storage.openai_chat_deployment, - ).run - else: - run_ragchat = SimpleRAGChat( - searcher=searcher, - openai_chat_client=global_storage.openai_chat_client, - chat_model=global_storage.openai_chat_model, - chat_deployment=global_storage.openai_chat_deployment, - ).run - - response = await run_ragchat(messages, overrides=overrides) - return response diff --git a/src/fastapi_app/dependencies.py b/src/fastapi_app/dependencies.py new file mode 100644 index 00000000..d3450ec8 --- /dev/null +++ b/src/fastapi_app/dependencies.py @@ -0,0 +1,123 @@ +import logging +import os +from typing import Annotated + +import azure.identity +from dotenv import load_dotenv +from fastapi import Depends +from openai import AsyncAzureOpenAI, AsyncOpenAI +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker + +from fastapi_app.openai_clients import create_openai_chat_client, create_openai_embed_client +from fastapi_app.postgres_engine import create_postgres_engine_from_env + +logger = logging.getLogger("ragapp") + + +class OpenAIClient(BaseModel): + """ + OpenAI client + """ + + client: AsyncOpenAI | AsyncAzureOpenAI + model_config = {"arbitrary_types_allowed": True} + + +class FastAPIAppContext(BaseModel): + """ + Context for the FastAPI app + """ + + openai_chat_model: str + openai_embed_model: str + openai_embed_dimensions: int + openai_chat_deployment: str + openai_embed_deployment: str + + +async def common_parameters(): + """ + Get the common parameters for the FastAPI app + """ + load_dotenv(override=True) + OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST") + OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST") + if OPENAI_EMBED_HOST == "azure": + openai_embed_deployment = os.getenv("AZURE_OPENAI_EMBED_DEPLOYMENT", "text-embedding-ada-002") + openai_embed_model = os.getenv("AZURE_OPENAI_EMBED_MODEL", "text-embedding-ada-002") + openai_embed_dimensions = int(os.getenv("AZURE_OPENAI_EMBED_DIMENSIONS", 1536)) + else: + openai_embed_deployment = "text-embedding-ada-002" + openai_embed_model = os.getenv("OPENAICOM_EMBED_MODEL", "text-embedding-ada-002") + openai_embed_dimensions = int(os.getenv("OPENAICOM_EMBED_DIMENSIONS", 1536)) + if OPENAI_CHAT_HOST == "azure": + openai_chat_deployment = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT", "gpt-35-turbo") + openai_chat_model = os.getenv("AZURE_OPENAI_CHAT_MODEL", "gpt-35-turbo") + elif OPENAI_CHAT_HOST == "ollama": + openai_chat_deployment = "phi3:3.8b" + openai_chat_model = os.getenv("OLLAMA_CHAT_MODEL", "phi3:3.8b") + else: + openai_chat_deployment = "gpt-3.5-turbo" + openai_chat_model = os.getenv("OPENAICOM_CHAT_MODEL", "gpt-3.5-turbo") + return FastAPIAppContext( + openai_chat_model=openai_chat_model, + openai_embed_model=openai_embed_model, + openai_embed_dimensions=openai_embed_dimensions, + openai_chat_deployment=openai_chat_deployment, + openai_embed_deployment=openai_embed_deployment, + ) + + +async def get_azure_credentials() -> azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential: + azure_credential: azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential + try: + if client_id := os.getenv("APP_IDENTITY_ID"): + # Authenticate using a user-assigned managed identity on Azure + # See web.bicep for value of APP_IDENTITY_ID + logger.info( + "Using managed identity for client ID %s", + client_id, + ) + azure_credential = azure.identity.ManagedIdentityCredential(client_id=client_id) + else: + azure_credential = azure.identity.DefaultAzureCredential() + return azure_credential + except Exception as e: + logger.warning("Failed to authenticate to Azure: %s", e) + raise e + + +async def get_engine(): + """Get the agent database engine""" + load_dotenv(override=True) + azure_credentials = await get_azure_credentials() + engine = await create_postgres_engine_from_env(azure_credentials) + return engine + + +async def get_async_session(engine: Annotated[AsyncEngine, Depends(get_engine)]): + """Get the agent database""" + async_session_maker = async_sessionmaker(engine, expire_on_commit=False) + async with async_session_maker() as async_session: + yield async_session + + +async def get_openai_chat_client(): + """Get the OpenAI chat client""" + azure_credentials = await get_azure_credentials() + chat_client = await create_openai_chat_client(azure_credentials) + return OpenAIClient(client=chat_client) + + +async def get_openai_embed_client(): + """Get the OpenAI embed client""" + azure_credentials = await get_azure_credentials() + embed_client = await create_openai_embed_client(azure_credentials) + return OpenAIClient(client=embed_client) + + +CommonDeps = Annotated[FastAPIAppContext, Depends(common_parameters)] +DBSession = Annotated[AsyncSession, Depends(get_async_session)] +ChatClient = Annotated[OpenAIClient, Depends(get_openai_chat_client)] +EmbeddingsClient = Annotated[OpenAIClient, Depends(get_openai_embed_client)] diff --git a/src/fastapi_app/embeddings.py b/src/fastapi_app/embeddings.py index 769db29c..9d55812c 100644 --- a/src/fastapi_app/embeddings.py +++ b/src/fastapi_app/embeddings.py @@ -2,9 +2,15 @@ TypedDict, ) +from openai import AsyncAzureOpenAI, AsyncOpenAI + async def compute_text_embedding( - q: str, openai_client, embed_model: str, embed_deployment: str | None = None, embedding_dimensions: int = 1536 + q: str, + openai_client: AsyncOpenAI | AsyncAzureOpenAI, + embed_model: str, + embed_deployment: str | None = None, + embedding_dimensions: int = 1536, ): SUPPORTED_DIMENSIONS_MODEL = { "text-embedding-ada-002": False, diff --git a/src/fastapi_app/globals.py b/src/fastapi_app/globals.py deleted file mode 100644 index 09523c33..00000000 --- a/src/fastapi_app/globals.py +++ /dev/null @@ -1,13 +0,0 @@ -class Global: - def __init__(self): - self.engine = None - self.openai_chat_client = None - self.openai_embed_client = None - self.openai_chat_model = None - self.openai_embed_model = None - self.openai_embed_dimensions = None - self.openai_chat_deployment = None - self.openai_embed_deployment = None - - -global_storage = Global() diff --git a/src/fastapi_app/openai_clients.py b/src/fastapi_app/openai_clients.py index b69856c8..eaf4b690 100644 --- a/src/fastapi_app/openai_clients.py +++ b/src/fastapi_app/openai_clients.py @@ -7,7 +7,9 @@ logger = logging.getLogger("ragapp") -async def create_openai_chat_client(azure_credential): +async def create_openai_chat_client( + azure_credential: azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential, +) -> openai.AsyncAzureOpenAI | openai.AsyncOpenAI: openai_chat_client: openai.AsyncAzureOpenAI | openai.AsyncOpenAI OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST") if OPENAI_CHAT_HOST == "azure": @@ -33,23 +35,22 @@ async def create_openai_chat_client(azure_credential): azure_deployment=azure_deployment, azure_ad_token_provider=token_provider, ) - openai_chat_model = os.getenv("AZURE_OPENAI_CHAT_MODEL") elif OPENAI_CHAT_HOST == "ollama": logger.info("Authenticating to OpenAI using Ollama...") openai_chat_client = openai.AsyncOpenAI( base_url=os.getenv("OLLAMA_ENDPOINT"), api_key="nokeyneeded", ) - openai_chat_model = os.getenv("OLLAMA_CHAT_MODEL") else: logger.info("Authenticating to OpenAI using OpenAI.com API key...") openai_chat_client = openai.AsyncOpenAI(api_key=os.getenv("OPENAICOM_KEY")) - openai_chat_model = os.getenv("OPENAICOM_CHAT_MODEL") - return openai_chat_client, openai_chat_model + return openai_chat_client -async def create_openai_embed_client(azure_credential): +async def create_openai_embed_client( + azure_credential: azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential, +) -> openai.AsyncAzureOpenAI | openai.AsyncOpenAI: openai_embed_client: openai.AsyncAzureOpenAI | openai.AsyncOpenAI OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST") if OPENAI_EMBED_HOST == "azure": @@ -76,10 +77,6 @@ async def create_openai_embed_client(azure_credential): azure_ad_token_provider=token_provider, ) - openai_embed_model = os.getenv("AZURE_OPENAI_EMBED_MODEL") - openai_embed_dimensions = os.getenv("AZURE_OPENAI_EMBED_DIMENSIONS") else: openai_embed_client = openai.AsyncOpenAI(api_key=os.getenv("OPENAICOM_KEY")) - openai_embed_model = os.getenv("OPENAICOM_EMBED_MODEL") - openai_embed_dimensions = os.getenv("OPENAICOM_EMBED_DIMENSIONS") - return openai_embed_client, openai_embed_model, openai_embed_dimensions + return openai_embed_client diff --git a/src/fastapi_app/postgres_searcher.py b/src/fastapi_app/postgres_searcher.py index 8d16da27..81c01ce3 100644 --- a/src/fastapi_app/postgres_searcher.py +++ b/src/fastapi_app/postgres_searcher.py @@ -1,7 +1,7 @@ -from openai import AsyncOpenAI +from openai import AsyncAzureOpenAI, AsyncOpenAI from pgvector.utils import to_db from sqlalchemy import Float, Integer, column, select, text -from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncSession from fastapi_app.embeddings import compute_text_embedding from fastapi_app.postgres_models import Item @@ -10,13 +10,13 @@ class PostgresSearcher: def __init__( self, - engine, - openai_embed_client: AsyncOpenAI, + db_session: AsyncSession, + openai_embed_client: AsyncOpenAI | AsyncAzureOpenAI, embed_deployment: str | None, # Not needed for non-Azure OpenAI or for retrieval_mode="text" embed_model: str, embed_dimensions: int, ): - self.async_session_maker = async_sessionmaker(engine, expire_on_commit=False) + self.db_session = db_session self.openai_embed_client = openai_embed_client self.embed_model = embed_model self.embed_deployment = embed_deployment @@ -86,20 +86,19 @@ async def search( else: raise ValueError("Both query text and query vector are empty") - async with self.async_session_maker() as session: - results = ( - await session.execute( - sql, - {"embedding": to_db(query_vector), "query": query_text, "k": 60}, - ) - ).fetchall() + results = ( + await self.db_session.execute( + sql, + {"embedding": to_db(query_vector), "query": query_text, "k": 60}, + ) + ).fetchall() - # Convert results to Item models - items = [] - for id, _ in results[:top]: - item = await session.execute(select(Item).where(Item.id == id)) - items.append(item.scalar()) - return items + # Convert results to Item models + items = [] + for id, _ in results[:top]: + item = await self.db_session.execute(select(Item).where(Item.id == id)) + items.append(item.scalar()) + return items async def search_and_embed( self, diff --git a/src/fastapi_app/rag_advanced.py b/src/fastapi_app/rag_advanced.py index a5220ba5..024a5fbd 100644 --- a/src/fastapi_app/rag_advanced.py +++ b/src/fastapi_app/rag_advanced.py @@ -4,7 +4,7 @@ Any, ) -from openai import AsyncOpenAI +from openai import AsyncAzureOpenAI, AsyncOpenAI from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit @@ -18,7 +18,7 @@ def __init__( self, *, searcher: PostgresSearcher, - openai_chat_client: AsyncOpenAI, + openai_chat_client: AsyncOpenAI | AsyncAzureOpenAI, chat_model: str, chat_deployment: str | None, # Not needed for non-Azure OpenAI ): diff --git a/src/fastapi_app/rag_simple.py b/src/fastapi_app/rag_simple.py index 9a4beb96..f8db974e 100644 --- a/src/fastapi_app/rag_simple.py +++ b/src/fastapi_app/rag_simple.py @@ -2,7 +2,7 @@ from collections.abc import AsyncGenerator from typing import Any -from openai import AsyncOpenAI +from openai import AsyncAzureOpenAI, AsyncOpenAI from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit @@ -15,7 +15,7 @@ def __init__( self, *, searcher: PostgresSearcher, - openai_chat_client: AsyncOpenAI, + openai_chat_client: AsyncOpenAI | AsyncAzureOpenAI, chat_model: str, chat_deployment: str | None, # Not needed for non-Azure OpenAI ): diff --git a/src/fastapi_app/routes/api_routes.py b/src/fastapi_app/routes/api_routes.py new file mode 100644 index 00000000..44636554 --- /dev/null +++ b/src/fastapi_app/routes/api_routes.py @@ -0,0 +1,99 @@ +from typing import Any + +import fastapi +from fastapi import HTTPException +from sqlalchemy import select + +from fastapi_app.api_models import ChatRequest, RetrievalResponse +from fastapi_app.dependencies import ChatClient, CommonDeps, DBSession, EmbeddingsClient +from fastapi_app.postgres_models import Item +from fastapi_app.postgres_searcher import PostgresSearcher +from fastapi_app.rag_advanced import AdvancedRAGChat +from fastapi_app.rag_simple import SimpleRAGChat + +router = fastapi.APIRouter() + + +@router.get("/items/{id}", response_model=dict[str, Any]) +async def item_handler(id: int, database_session: DBSession) -> dict[str, Any]: + """A simple API to get an item by ID.""" + item = (await database_session.scalars(select(Item).where(Item.id == id))).first() + if not item: + raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404) + return item.to_dict() + + +@router.get("/similar", response_model=list[dict[str, Any]]) +async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> list[dict[str, Any]]: + """A similarity API to find items similar to items with given ID.""" + item = (await database_session.scalars(select(Item).where(Item.id == id))).first() + if not item: + raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404) + closest = await database_session.execute( + select(Item, Item.embedding.l2_distance(item.embedding)) + .filter(Item.id != id) + .order_by(Item.embedding.l2_distance(item.embedding)) + .limit(n) + ) + return [item.to_dict() | {"distance": round(distance, 2)} for item, distance in closest] + + +@router.get("/search", response_model=list[dict[str, Any]]) +async def search_handler( + context: CommonDeps, + database_session: DBSession, + openai_embed: EmbeddingsClient, + query: str, + top: int = 5, + enable_vector_search: bool = True, + enable_text_search: bool = True, +) -> list[dict[str, Any]]: + """A search API to find items based on a query.""" + searcher = PostgresSearcher( + db_session=database_session, + openai_embed_client=openai_embed.client, + embed_deployment=context.openai_embed_deployment, + embed_model=context.openai_embed_model, + embed_dimensions=context.openai_embed_dimensions, + ) + results = await searcher.search_and_embed( + query, top=top, enable_vector_search=enable_vector_search, enable_text_search=enable_text_search + ) + return [item.to_dict() for item in results] + + +@router.post("/chat", response_model=RetrievalResponse) +async def chat_handler( + context: CommonDeps, + database_session: DBSession, + openai_embed: EmbeddingsClient, + openai_chat: ChatClient, + chat_request: ChatRequest, +): + messages = [message.model_dump() for message in chat_request.messages] + overrides = chat_request.context.get("overrides", {}) + + searcher = PostgresSearcher( + db_session=database_session, + openai_embed_client=openai_embed.client, + embed_deployment=context.openai_embed_deployment, + embed_model=context.openai_embed_model, + embed_dimensions=context.openai_embed_dimensions, + ) + if overrides.get("use_advanced_flow"): + run_ragchat = AdvancedRAGChat( + searcher=searcher, + openai_chat_client=openai_chat.client, + chat_model=context.openai_chat_model, + chat_deployment=context.openai_chat_deployment, + ).run + else: + run_ragchat = SimpleRAGChat( + searcher=searcher, + openai_chat_client=openai_chat.client, + chat_model=context.openai_chat_model, + chat_deployment=context.openai_chat_deployment, + ).run + + response = await run_ragchat(messages, overrides=overrides) + return response diff --git a/src/fastapi_app/frontend_routes.py b/src/fastapi_app/routes/frontend_routes.py similarity index 90% rename from src/fastapi_app/frontend_routes.py rename to src/fastapi_app/routes/frontend_routes.py index c6eccfb2..0dedfc85 100644 --- a/src/fastapi_app/frontend_routes.py +++ b/src/fastapi_app/routes/frontend_routes.py @@ -4,7 +4,7 @@ from fastapi.staticfiles import StaticFiles from starlette.routing import Mount, Route, Router -parent_dir = Path(__file__).resolve().parent.parent +parent_dir = Path(__file__).resolve().parent.parent.parent async def index(request) -> FileResponse: diff --git a/src/fastapi_app/update_embeddings.py b/src/fastapi_app/update_embeddings.py index 4a8d80d7..28149403 100644 --- a/src/fastapi_app/update_embeddings.py +++ b/src/fastapi_app/update_embeddings.py @@ -4,15 +4,15 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import async_sessionmaker +from fastapi_app.dependencies import common_parameters, get_engine, get_openai_embed_client from fastapi_app.embeddings import compute_text_embedding -from fastapi_app.openai_clients import create_openai_embed_client -from fastapi_app.postgres_engine import create_postgres_engine from fastapi_app.postgres_models import Item async def update_embeddings(): - engine = await create_postgres_engine() - openai_embed_client, openai_embed_model, openai_embed_dimensions = await create_openai_embed_client() + engine = await get_engine() + openai_embed = await get_openai_embed_client() + common_params = await common_parameters() async with async_sessionmaker(engine, expire_on_commit=False)() as session: async with session.begin(): @@ -21,9 +21,9 @@ async def update_embeddings(): for item in items: item.embedding = await compute_text_embedding( item.to_str_for_embedding(), - openai_client=openai_embed_client, - embed_model=openai_embed_model, - embedding_dimensions=openai_embed_dimensions, + openai_client=openai_embed.client, + embed_model=common_params.openai_embed_model, + embedding_dimensions=common_params.openai_embed_dimensions, ) await session.commit() diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py index dab506cc..98e3ceeb 100644 --- a/tests/test_api_routes.py +++ b/tests/test_api_routes.py @@ -176,7 +176,7 @@ async def test_simple_chat_flow(test_client): "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " 'Type:Footwear\\n\\n"}', ], - "props": {"model": "gpt-35-turbo"}, + "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, "title": "Prompt to generate answer", }, ] @@ -220,7 +220,7 @@ async def test_simple_chat_flow(test_client): "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " 'Type:Footwear\\n\\n"}', ], - "props": {"model": "gpt-35-turbo"}, + "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, "title": "Prompt to generate answer", }, ] @@ -274,9 +274,7 @@ async def test_advanced_chat_flow(test_client): "question.\\nDO NOT return anything besides the query.'}", "{'role': 'user', 'content': 'What is the capital of France?'}", ], - "props": { - "model": "gpt-35-turbo", - }, + "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, "title": "Prompt to generate search arguments", }, { @@ -318,7 +316,7 @@ async def test_advanced_chat_flow(test_client): "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " 'Type:Footwear\\n\\n"}', ], - "props": {"model": "gpt-35-turbo"}, + "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, "title": "Prompt to generate answer", }, ] @@ -337,9 +335,7 @@ async def test_advanced_chat_flow(test_client): "question.\\nDO NOT return anything besides the query.'}", "{'role': 'user', 'content': 'What is the capital of France?'}", ], - "props": { - "model": "gpt-35-turbo", - }, + "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, "title": "Prompt to generate search arguments", }, { @@ -381,7 +377,7 @@ async def test_advanced_chat_flow(test_client): "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " 'Type:Footwear\\n\\n"}', ], - "props": {"model": "gpt-35-turbo"}, + "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, "title": "Prompt to generate answer", }, ] From 3a6076c0996d703fd180fe267da2eac14552536f Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 12 Jul 2024 13:22:42 +0000 Subject: [PATCH 25/31] fix type and add mypy to tests --- .github/workflows/app-tests.yaml | 2 ++ src/fastapi_app/api_models.py | 3 ++- src/fastapi_app/routes/api_routes.py | 3 +-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/app-tests.yaml b/.github/workflows/app-tests.yaml index 488e1f76..3ca63843 100755 --- a/.github/workflows/app-tests.yaml +++ b/.github/workflows/app-tests.yaml @@ -70,5 +70,7 @@ jobs: cd ./src/frontend npm install npm run build + - name: Run MyPy + run: python3 -m mypy . - name: Run Pytest run: python3 -m pytest diff --git a/src/fastapi_app/api_models.py b/src/fastapi_app/api_models.py index f96bd84f..34506eaf 100644 --- a/src/fastapi_app/api_models.py +++ b/src/fastapi_app/api_models.py @@ -1,5 +1,6 @@ from typing import Any +from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel @@ -9,7 +10,7 @@ class Message(BaseModel): class ChatRequest(BaseModel): - messages: list[Message] + messages: list[ChatCompletionMessageParam] context: dict = {} diff --git a/src/fastapi_app/routes/api_routes.py b/src/fastapi_app/routes/api_routes.py index 44636554..cece3650 100644 --- a/src/fastapi_app/routes/api_routes.py +++ b/src/fastapi_app/routes/api_routes.py @@ -70,7 +70,6 @@ async def chat_handler( openai_chat: ChatClient, chat_request: ChatRequest, ): - messages = [message.model_dump() for message in chat_request.messages] overrides = chat_request.context.get("overrides", {}) searcher = PostgresSearcher( @@ -95,5 +94,5 @@ async def chat_handler( chat_deployment=context.openai_chat_deployment, ).run - response = await run_ragchat(messages, overrides=overrides) + response = await run_ragchat(chat_request.messages, overrides=overrides) return response From da5220c9431c9c8ec637157f86fc0f18d4fe989e Mon Sep 17 00:00:00 2001 From: John Aziz Date: Fri, 12 Jul 2024 22:51:31 +0000 Subject: [PATCH 26/31] remove multiple env loading, use single azure_credentials --- src/fastapi_app/__init__.py | 23 ++--------------------- src/fastapi_app/dependencies.py | 11 ++++------- src/fastapi_app/openai_clients.py | 4 ++-- tests/conftest.py | 1 + 4 files changed, 9 insertions(+), 30 deletions(-) diff --git a/src/fastapi_app/__init__.py b/src/fastapi_app/__init__.py index a11d2913..eb777e50 100644 --- a/src/fastapi_app/__init__.py +++ b/src/fastapi_app/__init__.py @@ -1,40 +1,21 @@ -import contextlib import logging import os from dotenv import load_dotenv -from environs import Env from fastapi import FastAPI -from fastapi_app.dependencies import get_azure_credentials -from fastapi_app.postgres_engine import create_postgres_engine_from_env - logger = logging.getLogger("ragapp") -@contextlib.asynccontextmanager -async def lifespan(app: FastAPI): - load_dotenv(override=True) - - azure_credential = await get_azure_credentials() - engine = await create_postgres_engine_from_env(azure_credential) - - yield - - await engine.dispose() - - def create_app(testing: bool = False): - env = Env() - if os.getenv("RUNNING_IN_PRODUCTION"): logging.basicConfig(level=logging.WARNING) else: if not testing: - env.read_env(".env", override=True) + load_dotenv(override=True) logging.basicConfig(level=logging.INFO) - app = FastAPI(docs_url="/docs", lifespan=lifespan) + app = FastAPI(docs_url="/docs") from fastapi_app.routes import api_routes, frontend_routes diff --git a/src/fastapi_app/dependencies.py b/src/fastapi_app/dependencies.py index d3450ec8..b7e54503 100644 --- a/src/fastapi_app/dependencies.py +++ b/src/fastapi_app/dependencies.py @@ -3,7 +3,6 @@ from typing import Annotated import azure.identity -from dotenv import load_dotenv from fastapi import Depends from openai import AsyncAzureOpenAI, AsyncOpenAI from pydantic import BaseModel @@ -40,7 +39,6 @@ async def common_parameters(): """ Get the common parameters for the FastAPI app """ - load_dotenv(override=True) OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST") OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST") if OPENAI_EMBED_HOST == "azure": @@ -69,7 +67,7 @@ async def common_parameters(): ) -async def get_azure_credentials() -> azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential: +def get_azure_credentials() -> azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential: azure_credential: azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential try: if client_id := os.getenv("APP_IDENTITY_ID"): @@ -88,10 +86,11 @@ async def get_azure_credentials() -> azure.identity.DefaultAzureCredential | azu raise e +azure_credentials = get_azure_credentials() + + async def get_engine(): """Get the agent database engine""" - load_dotenv(override=True) - azure_credentials = await get_azure_credentials() engine = await create_postgres_engine_from_env(azure_credentials) return engine @@ -105,14 +104,12 @@ async def get_async_session(engine: Annotated[AsyncEngine, Depends(get_engine)]) async def get_openai_chat_client(): """Get the OpenAI chat client""" - azure_credentials = await get_azure_credentials() chat_client = await create_openai_chat_client(azure_credentials) return OpenAIClient(client=chat_client) async def get_openai_embed_client(): """Get the OpenAI embed client""" - azure_credentials = await get_azure_credentials() embed_client = await create_openai_embed_client(azure_credentials) return OpenAIClient(client=embed_client) diff --git a/src/fastapi_app/openai_clients.py b/src/fastapi_app/openai_clients.py index eaf4b690..fa9ffd62 100644 --- a/src/fastapi_app/openai_clients.py +++ b/src/fastapi_app/openai_clients.py @@ -25,7 +25,7 @@ async def create_openai_chat_client( api_key=api_key, ) else: - logger.info("Authenticating to Azure OpenAI using Azure Identity...") + logger.info("Authenticating to Azure OpenAI Chat using Azure Identity...") token_provider = azure.identity.get_bearer_token_provider( azure_credential, "https://cognitiveservices.azure.com/.default" ) @@ -66,7 +66,7 @@ async def create_openai_embed_client( api_key=api_key, ) else: - logger.info("Authenticating to Azure OpenAI using Azure Identity...") + logger.info("Authenticating to Azure OpenAI Embedding using Azure Identity...") token_provider = azure.identity.get_bearer_token_provider( azure_credential, "https://cognitiveservices.azure.com/.default" ) diff --git a/tests/conftest.py b/tests/conftest.py index c4764d6e..2df9caae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,6 +56,7 @@ def mock_session_env(monkeypatch_session): # Azure OpenAI monkeypatch_session.setenv("OPENAI_CHAT_HOST", "azure") monkeypatch_session.setenv("OPENAI_EMBED_HOST", "azure") + monkeypatch_session.setenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com") monkeypatch_session.setenv("AZURE_OPENAI_VERSION", "2024-03-01-preview") monkeypatch_session.setenv("AZURE_OPENAI_CHAT_DEPLOYMENT", "gpt-35-turbo") monkeypatch_session.setenv("AZURE_OPENAI_CHAT_MODEL", "gpt-35-turbo") From 79a8a2dbe620dfd62f2c3d07c6d11af16058ea5f Mon Sep 17 00:00:00 2001 From: John Aziz Date: Sat, 13 Jul 2024 13:09:25 +0000 Subject: [PATCH 27/31] use app state to store global vars there is only one engine, sessionmaker, azure_credentials, context, chat_client, and embed_client created during the lifespan of the fastapi app --- src/fastapi_app/__init__.py | 36 +++++++++++++++- src/fastapi_app/dependencies.py | 64 ++++++++++++++++++---------- src/fastapi_app/update_embeddings.py | 11 +++-- 3 files changed, 83 insertions(+), 28 deletions(-) diff --git a/src/fastapi_app/__init__.py b/src/fastapi_app/__init__.py index eb777e50..78c19264 100644 --- a/src/fastapi_app/__init__.py +++ b/src/fastapi_app/__init__.py @@ -1,12 +1,46 @@ import logging import os +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import TypedDict from dotenv import load_dotenv from fastapi import FastAPI +from openai import AsyncAzureOpenAI, AsyncOpenAI +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from fastapi_app.dependencies import ( + FastAPIAppContext, + common_parameters, + create_async_sessionmaker, + get_azure_credentials, +) +from fastapi_app.openai_clients import create_openai_chat_client, create_openai_embed_client +from fastapi_app.postgres_engine import create_postgres_engine_from_env logger = logging.getLogger("ragapp") +class State(TypedDict): + sessionmaker: async_sessionmaker[AsyncSession] + context: FastAPIAppContext + chat_client: AsyncOpenAI | AsyncAzureOpenAI + embed_client: AsyncOpenAI | AsyncAzureOpenAI + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncIterator[State]: + context = await common_parameters() + azure_credential = await get_azure_credentials() + engine = await create_postgres_engine_from_env(azure_credential) + sessionmaker = await create_async_sessionmaker(engine) + chat_client = await create_openai_chat_client(azure_credential) + embed_client = await create_openai_embed_client(azure_credential) + + yield {"sessionmaker": sessionmaker, "context": context, "chat_client": chat_client, "embed_client": embed_client} + await engine.dispose() + + def create_app(testing: bool = False): if os.getenv("RUNNING_IN_PRODUCTION"): logging.basicConfig(level=logging.WARNING) @@ -15,7 +49,7 @@ def create_app(testing: bool = False): load_dotenv(override=True) logging.basicConfig(level=logging.INFO) - app = FastAPI(docs_url="/docs") + app = FastAPI(docs_url="/docs", lifespan=lifespan) from fastapi_app.routes import api_routes, frontend_routes diff --git a/src/fastapi_app/dependencies.py b/src/fastapi_app/dependencies.py index b7e54503..7650ac25 100644 --- a/src/fastapi_app/dependencies.py +++ b/src/fastapi_app/dependencies.py @@ -1,16 +1,14 @@ import logging import os +from collections.abc import AsyncGenerator from typing import Annotated import azure.identity -from fastapi import Depends +from fastapi import Depends, Request from openai import AsyncAzureOpenAI, AsyncOpenAI from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker -from fastapi_app.openai_clients import create_openai_chat_client, create_openai_embed_client -from fastapi_app.postgres_engine import create_postgres_engine_from_env - logger = logging.getLogger("ragapp") @@ -67,7 +65,7 @@ async def common_parameters(): ) -def get_azure_credentials() -> azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential: +async def get_azure_credentials() -> azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential: azure_credential: azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential try: if client_id := os.getenv("APP_IDENTITY_ID"): @@ -86,35 +84,55 @@ def get_azure_credentials() -> azure.identity.DefaultAzureCredential | azure.ide raise e -azure_credentials = get_azure_credentials() +async def create_async_sessionmaker(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]: + """Get the agent database""" + return async_sessionmaker( + engine, + expire_on_commit=False, + autoflush=False, + ) -async def get_engine(): - """Get the agent database engine""" - engine = await create_postgres_engine_from_env(azure_credentials) - return engine +async def get_async_sessionmaker( + request: Request, +) -> AsyncGenerator[async_sessionmaker[AsyncSession], None]: + yield request.state.sessionmaker -async def get_async_session(engine: Annotated[AsyncEngine, Depends(get_engine)]): - """Get the agent database""" - async_session_maker = async_sessionmaker(engine, expire_on_commit=False) - async with async_session_maker() as async_session: - yield async_session +async def get_context( + request: Request, +) -> FastAPIAppContext: + return request.state.context + + +async def get_async_db_session( + sessionmaker: Annotated[async_sessionmaker[AsyncSession], Depends(get_async_sessionmaker)], +) -> AsyncGenerator[AsyncSession, None]: + async with sessionmaker() as session: + try: + yield session + except: + await session.rollback() + raise + else: + await session.commit() -async def get_openai_chat_client(): +async def get_openai_chat_client( + request: Request, +) -> OpenAIClient: """Get the OpenAI chat client""" - chat_client = await create_openai_chat_client(azure_credentials) - return OpenAIClient(client=chat_client) + return OpenAIClient(client=request.state.chat_client) -async def get_openai_embed_client(): +async def get_openai_embed_client( + request: Request, +) -> OpenAIClient: """Get the OpenAI embed client""" - embed_client = await create_openai_embed_client(azure_credentials) - return OpenAIClient(client=embed_client) + return OpenAIClient(client=request.state.embed_client) -CommonDeps = Annotated[FastAPIAppContext, Depends(common_parameters)] -DBSession = Annotated[AsyncSession, Depends(get_async_session)] +CommonDeps = Annotated[FastAPIAppContext, Depends(get_context)] +DBSession = Annotated[AsyncSession, Depends(get_async_db_session)] ChatClient = Annotated[OpenAIClient, Depends(get_openai_chat_client)] EmbeddingsClient = Annotated[OpenAIClient, Depends(get_openai_embed_client)] diff --git a/src/fastapi_app/update_embeddings.py b/src/fastapi_app/update_embeddings.py index 28149403..5a1ea447 100644 --- a/src/fastapi_app/update_embeddings.py +++ b/src/fastapi_app/update_embeddings.py @@ -4,14 +4,17 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import async_sessionmaker -from fastapi_app.dependencies import common_parameters, get_engine, get_openai_embed_client +from fastapi_app.dependencies import common_parameters, get_azure_credentials from fastapi_app.embeddings import compute_text_embedding +from fastapi_app.openai_clients import create_openai_embed_client +from fastapi_app.postgres_engine import create_postgres_engine_from_env from fastapi_app.postgres_models import Item async def update_embeddings(): - engine = await get_engine() - openai_embed = await get_openai_embed_client() + azure_credential = await get_azure_credentials() + engine = await create_postgres_engine_from_env(azure_credential) + openai_embed_client = await create_openai_embed_client(azure_credential) common_params = await common_parameters() async with async_sessionmaker(engine, expire_on_commit=False)() as session: @@ -21,7 +24,7 @@ async def update_embeddings(): for item in items: item.embedding = await compute_text_embedding( item.to_str_for_embedding(), - openai_client=openai_embed.client, + openai_client=openai_embed_client, embed_model=common_params.openai_embed_model, embedding_dimensions=common_params.openai_embed_dimensions, ) From 3f0286b3e56f82d142543e575d8ed1bc104ceaab Mon Sep 17 00:00:00 2001 From: John Aziz Date: Sat, 13 Jul 2024 13:22:09 +0000 Subject: [PATCH 28/31] add pydatnic types for Item table --- src/fastapi_app/api_models.py | 13 +++++++++++++ src/fastapi_app/routes/api_routes.py | 24 ++++++++++++------------ tests/data.py | 11 ++--------- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/fastapi_app/api_models.py b/src/fastapi_app/api_models.py index 34506eaf..2e214a5e 100644 --- a/src/fastapi_app/api_models.py +++ b/src/fastapi_app/api_models.py @@ -30,3 +30,16 @@ class RetrievalResponse(BaseModel): message: Message context: RAGContext session_state: Any | None = None + + +class ItemPublic(BaseModel): + id: int + type: str + brand: str + name: str + description: str + price: float + + +class ItemWithDistance(ItemPublic): + distance: float diff --git a/src/fastapi_app/routes/api_routes.py b/src/fastapi_app/routes/api_routes.py index cece3650..d0cf3de4 100644 --- a/src/fastapi_app/routes/api_routes.py +++ b/src/fastapi_app/routes/api_routes.py @@ -1,10 +1,8 @@ -from typing import Any - import fastapi from fastapi import HTTPException from sqlalchemy import select -from fastapi_app.api_models import ChatRequest, RetrievalResponse +from fastapi_app.api_models import ChatRequest, ItemPublic, ItemWithDistance, RetrievalResponse from fastapi_app.dependencies import ChatClient, CommonDeps, DBSession, EmbeddingsClient from fastapi_app.postgres_models import Item from fastapi_app.postgres_searcher import PostgresSearcher @@ -14,17 +12,17 @@ router = fastapi.APIRouter() -@router.get("/items/{id}", response_model=dict[str, Any]) -async def item_handler(id: int, database_session: DBSession) -> dict[str, Any]: +@router.get("/items/{id}", response_model=ItemPublic) +async def item_handler(id: int, database_session: DBSession) -> ItemPublic: """A simple API to get an item by ID.""" item = (await database_session.scalars(select(Item).where(Item.id == id))).first() if not item: raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404) - return item.to_dict() + return ItemPublic.model_validate(item.to_dict()) -@router.get("/similar", response_model=list[dict[str, Any]]) -async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> list[dict[str, Any]]: +@router.get("/similar", response_model=list[ItemWithDistance]) +async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> list[ItemWithDistance]: """A similarity API to find items similar to items with given ID.""" item = (await database_session.scalars(select(Item).where(Item.id == id))).first() if not item: @@ -35,10 +33,12 @@ async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> l .order_by(Item.embedding.l2_distance(item.embedding)) .limit(n) ) - return [item.to_dict() | {"distance": round(distance, 2)} for item, distance in closest] + return [ + ItemWithDistance.model_validate(item.to_dict() | {"distance": round(distance, 2)}) for item, distance in closest + ] -@router.get("/search", response_model=list[dict[str, Any]]) +@router.get("/search", response_model=list[ItemPublic]) async def search_handler( context: CommonDeps, database_session: DBSession, @@ -47,7 +47,7 @@ async def search_handler( top: int = 5, enable_vector_search: bool = True, enable_text_search: bool = True, -) -> list[dict[str, Any]]: +) -> list[ItemPublic]: """A search API to find items based on a query.""" searcher = PostgresSearcher( db_session=database_session, @@ -59,7 +59,7 @@ async def search_handler( results = await searcher.search_and_embed( query, top=top, enable_vector_search=enable_vector_search, enable_text_search=enable_text_search ) - return [item.to_dict() for item in results] + return [ItemPublic.model_validate(item.to_dict()) for item in results] @router.post("/chat", response_model=RetrievalResponse) diff --git a/tests/data.py b/tests/data.py index 58d9ad08..a9bb94d9 100644 --- a/tests/data.py +++ b/tests/data.py @@ -1,14 +1,7 @@ -from dataclasses import dataclass +from fastapi_app.api_models import ItemPublic -@dataclass -class TestData: - id: int - type: str - brand: str - name: str - description: str - price: float +class TestData(ItemPublic): embeddings: list[float] From 02d0b1b681170ca34f7f43b02dca1086c6f28383 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Mon, 15 Jul 2024 01:34:29 +0000 Subject: [PATCH 29/31] add more tests and fix azure credentials mocking AsyncTokenCredentials is not the correct class to inherit from as we are not using the async credentials --- src/fastapi_app/embeddings.py | 2 +- tests/conftest.py | 2 +- tests/mocks.py | 28 +++++----------- tests/test_dependencies.py | 21 ++++++++++++ tests/test_embeddings.py | 18 ++++++++++ tests/test_openai_clients.py | 24 +++++++++++++ tests/test_postgres_engine.py | 63 +++++++++++++++++++++++++++++++++++ 7 files changed, 136 insertions(+), 22 deletions(-) create mode 100644 tests/test_dependencies.py create mode 100644 tests/test_embeddings.py create mode 100644 tests/test_openai_clients.py create mode 100644 tests/test_postgres_engine.py diff --git a/src/fastapi_app/embeddings.py b/src/fastapi_app/embeddings.py index 9d55812c..39a97be8 100644 --- a/src/fastapi_app/embeddings.py +++ b/src/fastapi_app/embeddings.py @@ -11,7 +11,7 @@ async def compute_text_embedding( embed_model: str, embed_deployment: str | None = None, embedding_dimensions: int = 1536, -): +) -> list[float]: SUPPORTED_DIMENSIONS_MODEL = { "text-embedding-ada-002": False, "text-embedding-3-small": True, diff --git a/tests/conftest.py b/tests/conftest.py index 2df9caae..ac1e9e30 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -224,7 +224,7 @@ def mock_default_azure_credential(mock_session_env): """Mock the Azure credential for testing.""" with mock.patch("azure.identity.DefaultAzureCredential") as mock_default_azure_credential: mock_default_azure_credential.return_value = MockAzureCredential() - yield + yield mock_default_azure_credential @pytest_asyncio.fixture(scope="function") diff --git a/tests/mocks.py b/tests/mocks.py index eb96e26c..172a4ed1 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -1,34 +1,22 @@ from collections import namedtuple -from types import TracebackType -from azure.core.credentials_async import AsyncTokenCredential +from azure.core.credentials import TokenCredential -MockToken = namedtuple("MockToken", ["token", "expires_on", "value"]) +MockToken = namedtuple("MockToken", ["token", "expires_on"]) -class MockAzureCredential(AsyncTokenCredential): - async def get_token(self, uri): - return MockToken("", 9999999999, "") - - async def close(self) -> None: - pass - - async def __aexit__( - self, - exc_type: type[BaseException] | None = None, - exc_value: BaseException | None = None, - traceback: TracebackType | None = None, - ) -> None: - pass +class MockAzureCredential(TokenCredential): + def get_token(self, uri): + return MockToken("", 9999999999) -class MockAzureCredentialExpired(AsyncTokenCredential): +class MockAzureCredentialExpired(TokenCredential): def __init__(self): self.access_number = 0 async def get_token(self, uri): self.access_number += 1 if self.access_number == 1: - return MockToken("", 0, "") + return MockToken("", 0) else: - return MockToken("", 9999999999, "") + return MockToken("", 9999999999) diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py new file mode 100644 index 00000000..89a459db --- /dev/null +++ b/tests/test_dependencies.py @@ -0,0 +1,21 @@ +import pytest + +from fastapi_app.dependencies import common_parameters, get_azure_credentials + + +@pytest.mark.asyncio +async def test_get_common_parameters(mock_session_env): + result = await common_parameters() + assert result.openai_chat_model == "gpt-35-turbo" + assert result.openai_embed_model == "text-embedding-ada-002" + assert result.openai_embed_dimensions == 1536 + assert result.openai_chat_deployment == "gpt-35-turbo" + assert result.openai_embed_deployment == "text-embedding-ada-002" + + +@pytest.mark.asyncio +async def test_get_azure_credentials(mock_session_env, mock_default_azure_credential): + result = await get_azure_credentials() + token = result.get_token("https://vault.azure.net") + assert token.expires_on == 9999999999 + assert token.token == "" diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py new file mode 100644 index 00000000..f983c2e4 --- /dev/null +++ b/tests/test_embeddings.py @@ -0,0 +1,18 @@ +import pytest + +from fastapi_app.embeddings import compute_text_embedding +from fastapi_app.openai_clients import create_openai_embed_client +from tests.data import test_data + + +@pytest.mark.asyncio +async def test_compute_text_embedding(mock_default_azure_credential, mock_openai_embedding): + openai_embed_client = await create_openai_embed_client(mock_default_azure_credential) + result = await compute_text_embedding( + q="test", + openai_client=openai_embed_client, + embed_model="text-embedding-ada-002", + embed_deployment="text-embedding-ada-002", + embedding_dimensions=1536, + ) + assert result == test_data.embeddings diff --git a/tests/test_openai_clients.py b/tests/test_openai_clients.py new file mode 100644 index 00000000..722445a5 --- /dev/null +++ b/tests/test_openai_clients.py @@ -0,0 +1,24 @@ +import pytest + +from fastapi_app.openai_clients import create_openai_chat_client, create_openai_embed_client +from tests.data import test_data + + +@pytest.mark.asyncio +async def test_create_openai_embed_client(mock_default_azure_credential, mock_openai_embedding): + openai_embed_client = await create_openai_embed_client(mock_default_azure_credential) + assert openai_embed_client.embeddings.create is not None + embeddings = await openai_embed_client.embeddings.create( + model="text-embedding-ada-002", input="test", dimensions=1536 + ) + assert embeddings.data[0].embedding == test_data.embeddings + + +@pytest.mark.asyncio +async def test_create_openai_chat_client(mock_default_azure_credential, mock_openai_chatcompletion): + openai_chat_client = await create_openai_chat_client(mock_default_azure_credential) + assert openai_chat_client.chat.completions.create is not None + response = await openai_chat_client.chat.completions.create( + model="gpt-35-turbo", messages=[{"content": "test", "role": "user"}] + ) + assert response.choices[0].message.content == "The capital of France is Paris. [Benefit_Options-2.pdf]." diff --git a/tests/test_postgres_engine.py b/tests/test_postgres_engine.py new file mode 100644 index 00000000..098110b9 --- /dev/null +++ b/tests/test_postgres_engine.py @@ -0,0 +1,63 @@ +import os + +import pytest + +from fastapi_app.postgres_engine import ( + create_postgres_engine, + create_postgres_engine_from_args, + create_postgres_engine_from_env, +) +from tests.conftest import POSTGRES_DATABASE, POSTGRES_HOST, POSTGRES_PASSWORD, POSTGRES_SSL, POSTGRES_USERNAME + + +@pytest.mark.asyncio +async def test_create_postgres_engine(mock_session_env, mock_default_azure_credential): + engine = await create_postgres_engine( + host=os.environ["POSTGRES_HOST"], + username=os.environ["POSTGRES_USERNAME"], + database=os.environ["POSTGRES_DATABASE"], + password=os.environ.get("POSTGRES_PASSWORD"), + sslmode=os.environ.get("POSTGRES_SSL"), + azure_credential=mock_default_azure_credential, + ) + assert engine.url.host == "localhost" + assert engine.url.username == "admin" + assert engine.url.database == "postgres" + assert engine.url.password == "postgres" + assert engine.url.query["ssl"] == "prefer" + + +@pytest.mark.asyncio +async def test_create_postgres_engine_from_env(mock_session_env, mock_default_azure_credential): + engine = await create_postgres_engine_from_env( + azure_credential=mock_default_azure_credential, + ) + assert engine.url.host == "localhost" + assert engine.url.username == "admin" + assert engine.url.database == "postgres" + assert engine.url.password == "postgres" + assert engine.url.query["ssl"] == "prefer" + + +@pytest.mark.asyncio +async def test_create_postgres_engine_from_args(mock_default_azure_credential): + args = type( + "Args", + (), + { + "host": POSTGRES_HOST, + "username": POSTGRES_USERNAME, + "database": POSTGRES_DATABASE, + "password": POSTGRES_PASSWORD, + "sslmode": POSTGRES_SSL, + }, + ) + engine = await create_postgres_engine_from_args( + args=args, + azure_credential=mock_default_azure_credential, + ) + assert engine.url.host == "localhost" + assert engine.url.username == "admin" + assert engine.url.database == "postgres" + assert engine.url.password == "postgres" + assert engine.url.query["ssl"] == "prefer" From 17fc97f8493d49613a3a38d2747ef730234d21ad Mon Sep 17 00:00:00 2001 From: John Aziz Date: Wed, 17 Jul 2024 05:08:08 +0000 Subject: [PATCH 30/31] apply feedback from pr review --- src/fastapi_app/dependencies.py | 8 +------- src/fastapi_app/routes/api_routes.py | 2 +- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/fastapi_app/dependencies.py b/src/fastapi_app/dependencies.py index 7650ac25..00837111 100644 --- a/src/fastapi_app/dependencies.py +++ b/src/fastapi_app/dependencies.py @@ -109,13 +109,7 @@ async def get_async_db_session( sessionmaker: Annotated[async_sessionmaker[AsyncSession], Depends(get_async_sessionmaker)], ) -> AsyncGenerator[AsyncSession, None]: async with sessionmaker() as session: - try: - yield session - except: - await session.rollback() - raise - else: - await session.commit() + yield session async def get_openai_chat_client( diff --git a/src/fastapi_app/routes/api_routes.py b/src/fastapi_app/routes/api_routes.py index d0cf3de4..b0f02189 100644 --- a/src/fastapi_app/routes/api_routes.py +++ b/src/fastapi_app/routes/api_routes.py @@ -13,7 +13,7 @@ @router.get("/items/{id}", response_model=ItemPublic) -async def item_handler(id: int, database_session: DBSession) -> ItemPublic: +async def item_handler(database_session: DBSession, id: int) -> ItemPublic: """A simple API to get an item by ID.""" item = (await database_session.scalars(select(Item).where(Item.id == id))).first() if not item: From b2bb12175afc3909530242a4084e76a1279e9742 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Wed, 17 Jul 2024 18:14:31 +0000 Subject: [PATCH 31/31] add postgres searcher tests --- tests/conftest.py | 18 ++++++++++++++- tests/test_postgres_searcher.py | 41 +++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 tests/test_postgres_searcher.py diff --git a/tests/conftest.py b/tests/conftest.py index ac1e9e30..d503fd3d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import async_sessionmaker from fastapi_app import create_app +from fastapi_app.openai_clients import create_openai_embed_client from fastapi_app.postgres_engine import create_postgres_engine_from_env from fastapi_app.setup_postgres_database import create_db_schema from fastapi_app.setup_postgres_seeddata import seed_data @@ -235,7 +236,7 @@ async def test_client(app, mock_default_azure_credential, mock_openai_embedding, @pytest_asyncio.fixture(scope="function") -async def db_session(): +async def db_session(mock_session_env, mock_default_azure_credential): """Create a new database session with a rollback at the end of the test.""" engine = await create_postgres_engine_from_env() async_sesion = async_sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -245,3 +246,18 @@ async def db_session(): await session.rollback() await session.close() await engine.dispose() + + +@pytest_asyncio.fixture(scope="function") +async def postgres_searcher(mock_session_env, mock_default_azure_credential, db_session, mock_openai_embedding): + from fastapi_app.postgres_searcher import PostgresSearcher + + openai_embed_client = await create_openai_embed_client(mock_default_azure_credential) + + yield PostgresSearcher( + db_session=db_session, + openai_embed_client=openai_embed_client, + embed_deployment="text-embedding-ada-002", + embed_model="text-embedding-ada-002", + embed_dimensions=1536, + ) diff --git a/tests/test_postgres_searcher.py b/tests/test_postgres_searcher.py new file mode 100644 index 00000000..ee2992e0 --- /dev/null +++ b/tests/test_postgres_searcher.py @@ -0,0 +1,41 @@ +import pytest + +from fastapi_app.api_models import ItemPublic +from tests.data import test_data + + +def test_postgres_build_filter_clause_without_filters(postgres_searcher): + assert postgres_searcher.build_filter_clause(None) == ("", "") + assert postgres_searcher.build_filter_clause([]) == ("", "") + + +def test_postgres_build_filter_clause_with_filters(postgres_searcher): + assert postgres_searcher.build_filter_clause([{"column": "id", "comparison_operator": "=", "value": 1}]) == ( + "WHERE id = 1", + "AND id = 1", + ) + + +@pytest.mark.asyncio +async def test_postgres_searcher_search_empty_text_search(postgres_searcher): + assert await postgres_searcher.search("", [], 5, None) == [] + + +@pytest.mark.asyncio +async def test_postgres_searcher_search(postgres_searcher): + assert (await postgres_searcher.search(test_data.name, test_data.embeddings, 5, None))[0].to_dict() == ItemPublic( + **test_data.model_dump() + ).model_dump() + + +@pytest.mark.asyncio +async def test_postgres_searcher_search_and_embed_empty_text_search(postgres_searcher): + assert await postgres_searcher.search_and_embed("", 5, False, True) == [] + + +@pytest.mark.asyncio +async def test_postgres_searcher_search_and_embed(postgres_searcher): + assert await postgres_searcher.search_and_embed("", 5, False, True) == [] + assert (await postgres_searcher.search_and_embed(test_data.name, 5, True))[0].to_dict() == ItemPublic( + **test_data.model_dump() + ).model_dump()