From d3ffdd200c5ed05b82ef1a818c6c31918eef1fef Mon Sep 17 00:00:00 2001 From: Peter Oyelegbin Date: Sat, 1 Mar 2025 20:01:24 +0100 Subject: [PATCH 1/6] enhanced the get product endpoint with category filter --- api/utils/pagination.py | 10 +- api/v1/routes/product.py | 17 +- tests/v1/product/test_get_product_filter.py | 209 ++++++++++++++++++++ 3 files changed, 230 insertions(+), 6 deletions(-) create mode 100644 tests/v1/product/test_get_product_filter.py diff --git a/api/utils/pagination.py b/api/utils/pagination.py index f10b4bf3f..d46074e0a 100644 --- a/api/utils/pagination.py +++ b/api/utils/pagination.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional from fastapi.encoders import jsonable_encoder -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, Query from api.db.database import Base from api.utils.success_response import success_response @@ -12,6 +12,7 @@ def paginated_response( skip: int, limit: int, join: Optional[Any] = None, + query: Optional[Query] = None, filters: Optional[Dict[str, Any]]=None ): @@ -24,6 +25,7 @@ def paginated_response( * skip- this is the number of items to skip before fetching the next page of data. This would also be a query parameter * join- this is an optional argument to join a table to the query + * query- this is an optional custom query to use instead of querying all items from the model. * filters- this is an optional dictionary of filters to apply to the query Example use: @@ -61,7 +63,8 @@ def paginated_response( ``` ''' - query = db.query(model) + if query is None: + query = db.query(model) if join is not None: query = query.join(join) @@ -82,7 +85,8 @@ def paginated_response( total = query.count() results = jsonable_encoder(query.offset(skip).limit(limit).all()) - total_pages = int(total / limit) + (total % limit > 0) + # total_pages = int(total / limit) + (total % limit > 0) + total_pages = (total + limit - 1) // limit return success_response( status_code=200, diff --git a/api/v1/routes/product.py b/api/v1/routes/product.py index 0efa71a9a..abf4f3c65 100644 --- a/api/v1/routes/product.py +++ b/api/v1/routes/product.py @@ -8,7 +8,7 @@ from api.utils.pagination import paginated_response from api.utils.success_response import success_response from api.db.database import get_db -from api.v1.models.product import Product, ProductFilterStatusEnum, ProductStatusEnum +from api.v1.models.product import Product, ProductCategory, ProductFilterStatusEnum, ProductStatusEnum from api.v1.services.product import product_service, ProductCategoryService from api.v1.schemas.product import ( ProductCategoryCreate, @@ -37,11 +37,22 @@ async def get_all_products( ge=1, description="Number of products per page")] = 10, skip: Annotated[int, Query( ge=1, description="Page number (starts from 1)")] = 0, + category: Annotated[Optional[str], Query( + description="Filter products by category name")] = None, db: Session = Depends(get_db), ): - """Endpoint to get all products. Only accessible to superadmin""" + """ + Endpoint to get all products. Only accessible to superadmin. + Optionally filter products by category. + """ + # Base query + query = db.query(Product) + + # Apply category filter if provided + if category: + query = query.join(Product.category).filter(ProductCategory.name.ilike(f"%{category}%")) - return paginated_response(db=db, model=Product, limit=limit, skip=skip) + return paginated_response(db=db, model=Product, limit=limit, skip=skip, query=query) # categories diff --git a/tests/v1/product/test_get_product_filter.py b/tests/v1/product/test_get_product_filter.py new file mode 100644 index 000000000..a92ac62eb --- /dev/null +++ b/tests/v1/product/test_get_product_filter.py @@ -0,0 +1,209 @@ +import pytest +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session +from unittest.mock import MagicMock +from uuid_extensions import uuid7 +from datetime import datetime, timezone, timedelta + +from api.v1.models.organisation import Organisation +from api.v1.models.product import Product, ProductCategory +from api.v1.models.user import User +from main import app +from api.v1.routes.blog import get_db +from api.v1.services.user import user_service + + +# Mock database dependency +@pytest.fixture +def db_session_mock(): + db_session = MagicMock(spec=Session) + return db_session + + +@pytest.fixture +def client(db_session_mock): + app.dependency_overrides[get_db] = lambda: db_session_mock + client = TestClient(app) + yield client + app.dependency_overrides = {} + + +# Mock user service dependency + +user_id = uuid7() +org_id = uuid7() +product_id = uuid7() +category_id = uuid7() +timezone_offset = -8.0 +tzinfo = timezone(timedelta(hours=timezone_offset)) +timeinfo = datetime.now(tzinfo) +created_at = timeinfo +updated_at = timeinfo +access_token = user_service.create_access_token(str(user_id)) +access_token2 = user_service.create_access_token(str(uuid7())) + +# Create test user + +user = User( + id=str(user_id), + email="testuser@test.com", + password="password123", + created_at=created_at, + updated_at=updated_at, +) + +# Create test organisation + +org = Organisation( + id=str(org_id), + name="hng", + email=None, + industry=None, + type=None, + country=None, + state=None, + address=None, + description=None, + created_at=created_at, + updated_at=updated_at, +) + +# Create test category + +category = ProductCategory(id=category_id, name="Electronics") + +# Create test product + +product = Product( + id=str(product_id), + name="prod one", + description="Test product", + price=125.55, + org_id=str(org_id), + quantity=50, + image_url="http://img", + category_id=str(category_id), + status="in_stock", + archived=False, +) + + +# Mock data for multiple products +products = [ + Product( + id=str(uuid7()), + name="Smartphone", + description="A smartphone", + price=500.00, + org_id=str(org_id), + quantity=10, + image_url="http://img1", + category_id=str(category_id), + status="in_stock", + archived=False, + ), + Product( + id=str(uuid7()), + name="Laptop", + description="A laptop", + price=1200.00, + org_id=str(org_id), + quantity=5, + image_url="http://img2", + category_id=str(category_id), + status="in_stock", + archived=False, + ), + Product( + id=str(uuid7()), + name="T-Shirt", + description="A T-Shirt", + price=20.00, + org_id=str(org_id), + quantity=100, + image_url="http://img3", + category_id=str(uuid7()), # Different category + status="in_stock", + archived=False, + ), +] + + +def test_get_products_filtered_by_category(client, db_session_mock): + # Mock the database query to return filtered products + db_session_mock.query().join().filter().offset().limit().all.return_value = [ + products[0], products[1]] + db_session_mock.query().join().filter().count.return_value = 2 # Return an integer + + headers = {"authorization": f"Bearer {access_token}"} + response = client.get( + "/api/v1/products?category=Electronics", + headers=headers + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert len(data["data"]["items"]) == 2 + + +def test_get_all_products_without_filter(client, db_session_mock): + # Mock the database query to return all products + db_session_mock.query().offset().limit().all.return_value = products + db_session_mock.query().count.return_value = 3 + + headers = {"authorization": f"Bearer {access_token}"} + response = client.get( + "/api/v1/products", + headers=headers + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert len(data["data"]["items"]) == 3 + + +def test_unauthorized_access(client, db_session_mock): + # Test unauthorized access (missing or invalid token) + response = client.get("/api/v1/products") + assert response.status_code == 401 + assert response.json() == { + "status": False, + "status_code": 401, + "message": "Not authenticated" + } + + +def test_invalid_category_name(client, db_session_mock): + # Mock the database query to return no products for an invalid category + db_session_mock.query().join().filter().offset().limit().all.return_value = [] + db_session_mock.query().join().filter().count.return_value = 0 # Return an integer + + headers = {"authorization": f"Bearer {access_token}"} + response = client.get( + "/api/v1/products?category=InvalidCategory", + headers=headers + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert len(data["data"]["items"]) == 0 + + +def test_empty_results_for_valid_category(client, db_session_mock): + # Mock the database query to return no products for a valid but unused category + db_session_mock.query().join().filter().offset().limit().all.return_value = [] + db_session_mock.query().join().filter().count.return_value = 0 # Return an integer + + headers = {"authorization": f"Bearer {access_token}"} + response = client.get( + "/api/v1/products?category=Furniture", + headers=headers + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert len(data["data"]["items"]) == 0 From 67392e2fee23efe3e369a3fba934aca9052fe26e Mon Sep 17 00:00:00 2001 From: Peter Oyelegbin Date: Sun, 2 Mar 2025 11:14:39 +0100 Subject: [PATCH 2/6] updated test_get_product_filer --- api/v1/schemas/plans.py | 6 +++--- api/v1/schemas/product_comment.py | 7 +++---- api/v1/schemas/stripe.py | 6 +++--- tests/v1/product/test_get_product_filter.py | 8 ++++---- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/api/v1/schemas/plans.py b/api/v1/schemas/plans.py index 18bddb4a3..e6244a61c 100644 --- a/api/v1/schemas/plans.py +++ b/api/v1/schemas/plans.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, validator +from pydantic import BaseModel, field_validator from typing import List, Optional from datetime import datetime @@ -14,14 +14,14 @@ class CreateBillingPlanSchema(BaseModel): organisation_id: str features: List[str] - @validator("price") + @field_validator("price") def adjust_price(cls, value, values): duration = values.get("duration") if duration == "yearly": value = value * 12 * 0.8 # Multiply by 12 and apply a 20% discount return value - @validator("duration") + @field_validator("duration") def validate_duration(cls, value): v = value.lower() if v not in ["monthly", "yearly"]: diff --git a/api/v1/schemas/product_comment.py b/api/v1/schemas/product_comment.py index 6448f3c06..a1b9486d7 100644 --- a/api/v1/schemas/product_comment.py +++ b/api/v1/schemas/product_comment.py @@ -1,10 +1,10 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Optional from datetime import datetime class ProductCommentBase(BaseModel): - content: str = Field(..., example="This is a comment") + content: str = Field(..., json_schema_extra={"example": "This is a comment"}) author: str @@ -21,8 +21,7 @@ class ProductCommentInDB(ProductCommentBase): created_at: datetime updated_at: datetime - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) class ProductCommentResponse(BaseModel): diff --git a/api/v1/schemas/stripe.py b/api/v1/schemas/stripe.py index 7bded814e..7f1e1e744 100644 --- a/api/v1/schemas/stripe.py +++ b/api/v1/schemas/stripe.py @@ -1,6 +1,6 @@ from typing import List, Optional -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator class PaymentInfo(BaseModel): @@ -9,13 +9,13 @@ class PaymentInfo(BaseModel): exp_year: int cvc: str = Field(..., min_length=3, max_length=4) - @validator('card_number') + @field_validator('card_number') def card_number_validator(cls, v): if not v.isdigit() or len(v) != 16: raise ValueError('Card number must be 16 digits') return v - @validator('cvc') + @field_validator('cvc') def cvc_validator(cls, v): if not v.isdigit() or not (3 <= len(v) <= 4): raise ValueError('CVC must be 3 or 4 digits') diff --git a/tests/v1/product/test_get_product_filter.py b/tests/v1/product/test_get_product_filter.py index a92ac62eb..3400889a8 100644 --- a/tests/v1/product/test_get_product_filter.py +++ b/tests/v1/product/test_get_product_filter.py @@ -142,8 +142,8 @@ def test_get_products_filtered_by_category(client, db_session_mock): ) assert response.status_code == 200 + assert response.is_success is True data = response.json() - assert data["success"] is True assert len(data["data"]["items"]) == 2 @@ -159,8 +159,8 @@ def test_get_all_products_without_filter(client, db_session_mock): ) assert response.status_code == 200 + assert response.is_success is True data = response.json() - assert data["success"] is True assert len(data["data"]["items"]) == 3 @@ -187,8 +187,8 @@ def test_invalid_category_name(client, db_session_mock): ) assert response.status_code == 200 + assert response.is_success is True data = response.json() - assert data["success"] is True assert len(data["data"]["items"]) == 0 @@ -204,6 +204,6 @@ def test_empty_results_for_valid_category(client, db_session_mock): ) assert response.status_code == 200 + assert response.is_success is True data = response.json() - assert data["success"] is True assert len(data["data"]["items"]) == 0 From 58c4bbe9744b90474399c7d392ae9055ae09dc40 Mon Sep 17 00:00:00 2001 From: Peter Oyelegbin Date: Sun, 2 Mar 2025 13:24:29 +0100 Subject: [PATCH 3/6] updated test_get_product_filter --- api/v1/routes/product.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/v1/routes/product.py b/api/v1/routes/product.py index abf4f3c65..ece1cbe76 100644 --- a/api/v1/routes/product.py +++ b/api/v1/routes/product.py @@ -36,7 +36,7 @@ async def get_all_products( limit: Annotated[int, Query( ge=1, description="Number of products per page")] = 10, skip: Annotated[int, Query( - ge=1, description="Page number (starts from 1)")] = 0, + ge=0, description="Page number (starts from 0)")] = 0, category: Annotated[Optional[str], Query( description="Filter products by category name")] = None, db: Session = Depends(get_db), From 0040c9e069fdbf9a7ca750d61b1180034f83e5f4 Mon Sep 17 00:00:00 2001 From: Peter Oyelegbin Date: Sun, 2 Mar 2025 17:15:18 +0100 Subject: [PATCH 4/6] updated product.py --- api/v1/routes/product.py | 7 +++++-- tests/v1/product/test_get_product_filter.py | 8 ++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/api/v1/routes/product.py b/api/v1/routes/product.py index ece1cbe76..0e394c733 100644 --- a/api/v1/routes/product.py +++ b/api/v1/routes/product.py @@ -36,7 +36,7 @@ async def get_all_products( limit: Annotated[int, Query( ge=1, description="Number of products per page")] = 10, skip: Annotated[int, Query( - ge=0, description="Page number (starts from 0)")] = 0, + ge=1, description="Page number (starts from 1)")] = 1, category: Annotated[Optional[str], Query( description="Filter products by category name")] = None, db: Session = Depends(get_db), @@ -52,7 +52,10 @@ async def get_all_products( if category: query = query.join(Product.category).filter(ProductCategory.name.ilike(f"%{category}%")) - return paginated_response(db=db, model=Product, limit=limit, skip=skip, query=query) + # Calculate the number of items to skip based on the page number + items_to_skip = (skip - 1) * limit + + return paginated_response(db=db, model=Product, limit=limit, skip=items_to_skip, query=query) # categories diff --git a/tests/v1/product/test_get_product_filter.py b/tests/v1/product/test_get_product_filter.py index 3400889a8..a92ac62eb 100644 --- a/tests/v1/product/test_get_product_filter.py +++ b/tests/v1/product/test_get_product_filter.py @@ -142,8 +142,8 @@ def test_get_products_filtered_by_category(client, db_session_mock): ) assert response.status_code == 200 - assert response.is_success is True data = response.json() + assert data["success"] is True assert len(data["data"]["items"]) == 2 @@ -159,8 +159,8 @@ def test_get_all_products_without_filter(client, db_session_mock): ) assert response.status_code == 200 - assert response.is_success is True data = response.json() + assert data["success"] is True assert len(data["data"]["items"]) == 3 @@ -187,8 +187,8 @@ def test_invalid_category_name(client, db_session_mock): ) assert response.status_code == 200 - assert response.is_success is True data = response.json() + assert data["success"] is True assert len(data["data"]["items"]) == 0 @@ -204,6 +204,6 @@ def test_empty_results_for_valid_category(client, db_session_mock): ) assert response.status_code == 200 - assert response.is_success is True data = response.json() + assert data["success"] is True assert len(data["data"]["items"]) == 0 From dbd86aec6252e3f18873cd27107c66d4da20d4a8 Mon Sep 17 00:00:00 2001 From: Peter Oyelegbin Date: Sun, 2 Mar 2025 18:04:39 +0100 Subject: [PATCH 5/6] fix repo conflict --- api/v1/routes/product.py | 20 -------------------- tests/v1/product/test_get_product_filter.py | 8 ++++---- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/api/v1/routes/product.py b/api/v1/routes/product.py index ec8c0bff8..e98f8716b 100644 --- a/api/v1/routes/product.py +++ b/api/v1/routes/product.py @@ -8,15 +8,7 @@ from api.utils.pagination import paginated_response from api.utils.success_response import success_response from api.db.database import get_db -<<<<<<< HEAD from api.v1.models.product import Product, ProductCategory, ProductFilterStatusEnum, ProductStatusEnum -======= -from api.v1.models.product import ( - Product, - ProductFilterStatusEnum, - ProductStatusEnum, -) ->>>>>>> d1c5cb10636bc5483eabcbb20b2cbc5180405baf from api.v1.services.product import product_service, ProductCategoryService from api.v1.schemas.product import ( ProductCategoryCreate, @@ -42,7 +34,6 @@ "", response_model=success_response, status_code=200 ) async def get_all_products( -<<<<<<< HEAD current_user: Annotated[User, Depends(user_service.get_current_super_admin)], limit: Annotated[int, Query( ge=1, description="Number of products per page")] = 10, @@ -50,17 +41,6 @@ async def get_all_products( ge=1, description="Page number (starts from 1)")] = 1, category: Annotated[Optional[str], Query( description="Filter products by category name")] = None, -======= - current_user: Annotated[ - User, Depends(user_service.get_current_super_admin) - ], - limit: Annotated[ - int, Query(ge=1, description="Number of products per page") - ] = 10, - skip: Annotated[ - int, Query(ge=1, description="Page number (starts from 1)") - ] = 0, ->>>>>>> d1c5cb10636bc5483eabcbb20b2cbc5180405baf db: Session = Depends(get_db), ): """ diff --git a/tests/v1/product/test_get_product_filter.py b/tests/v1/product/test_get_product_filter.py index a92ac62eb..eb014594c 100644 --- a/tests/v1/product/test_get_product_filter.py +++ b/tests/v1/product/test_get_product_filter.py @@ -143,7 +143,7 @@ def test_get_products_filtered_by_category(client, db_session_mock): assert response.status_code == 200 data = response.json() - assert data["success"] is True + assert data["status"] == "success" assert len(data["data"]["items"]) == 2 @@ -160,7 +160,7 @@ def test_get_all_products_without_filter(client, db_session_mock): assert response.status_code == 200 data = response.json() - assert data["success"] is True + assert data["status"] == "success" assert len(data["data"]["items"]) == 3 @@ -188,7 +188,7 @@ def test_invalid_category_name(client, db_session_mock): assert response.status_code == 200 data = response.json() - assert data["success"] is True + assert data["status"] == "success" assert len(data["data"]["items"]) == 0 @@ -205,5 +205,5 @@ def test_empty_results_for_valid_category(client, db_session_mock): assert response.status_code == 200 data = response.json() - assert data["success"] is True + assert data["status"] == "success" assert len(data["data"]["items"]) == 0 From 6ecc2e2891a0315080b058740646e5dec1fd6871 Mon Sep 17 00:00:00 2001 From: Peter Oyelegbin Date: Sun, 2 Mar 2025 18:38:54 +0100 Subject: [PATCH 6/6] fix repo conflict --- api/v1/schemas/plans.py | 6 +++--- tests/v1/billing_plan/test_billing_plan.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/v1/schemas/plans.py b/api/v1/schemas/plans.py index e6244a61c..18bddb4a3 100644 --- a/api/v1/schemas/plans.py +++ b/api/v1/schemas/plans.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, validator from typing import List, Optional from datetime import datetime @@ -14,14 +14,14 @@ class CreateBillingPlanSchema(BaseModel): organisation_id: str features: List[str] - @field_validator("price") + @validator("price") def adjust_price(cls, value, values): duration = values.get("duration") if duration == "yearly": value = value * 12 * 0.8 # Multiply by 12 and apply a 20% discount return value - @field_validator("duration") + @validator("duration") def validate_duration(cls, value): v = value.lower() if v not in ["monthly", "yearly"]: diff --git a/tests/v1/billing_plan/test_billing_plan.py b/tests/v1/billing_plan/test_billing_plan.py index 1a4d82ce1..131344f05 100644 --- a/tests/v1/billing_plan/test_billing_plan.py +++ b/tests/v1/billing_plan/test_billing_plan.py @@ -78,7 +78,7 @@ def test_create_exisiting_plans(mock_user_service, mock_db_session): "currency": "Naira", "features": ["Multiple team", "Premium support"], } - + print(data) response = client.post( "/api/v1/organisations/billing-plans", headers={"Authorization": f"Bearer {access_token}"},