Skip to content
10 changes: 7 additions & 3 deletions api/utils/pagination.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, Query
from api.db.database import Base

from api.utils.success_response import success_response
Expand All @@ -12,6 +12,7 @@ def paginated_response(
skip: int,
limit: int,
join: Optional[Any] = None,
query: Optional[Query] = None,
filters: Optional[Dict[str, Any]]=None
):

Expand All @@ -24,6 +25,7 @@ def paginated_response(
* skip- this is the number of items to skip before fetching the next page of data. This would also
be a query parameter
* join- this is an optional argument to join a table to the query
* query- this is an optional custom query to use instead of querying all items from the model.
* filters- this is an optional dictionary of filters to apply to the query

Example use:
Expand Down Expand Up @@ -61,7 +63,8 @@ def paginated_response(
```
'''

query = db.query(model)
if query is None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you set this condition what happens is query is not None?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohkay, your test fails

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If query is none, then it gets all products;
if query is not none, it gets the product related to that category.

query = db.query(model)

if join is not None:
query = query.join(join)
Expand All @@ -82,7 +85,8 @@ def paginated_response(

total = int(query.count())
results = jsonable_encoder(query.offset(skip).limit(limit).all())
total_pages = int(total / limit) + (total % limit > 0)
# total_pages = int(total / limit) + (total % limit > 0)
total_pages = (total + limit - 1) // limit

return success_response(
status_code=200,
Expand Down
38 changes: 22 additions & 16 deletions api/v1/routes/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
from api.utils.pagination import paginated_response
from api.utils.success_response import success_response
from api.db.database import get_db
from api.v1.models.product import (
Product,
ProductFilterStatusEnum,
ProductStatusEnum,
)
from api.v1.models.product import Product, ProductCategory, ProductFilterStatusEnum, ProductStatusEnum
from api.v1.services.product import product_service, ProductCategoryService
from api.v1.schemas.product import (
ProductCategoryCreate,
Expand All @@ -38,20 +34,30 @@
"", response_model=success_response, status_code=200
)
async def get_all_products(
current_user: Annotated[
User, Depends(user_service.get_current_super_admin)
],
limit: Annotated[
int, Query(ge=1, description="Number of products per page")
] = 10,
skip: Annotated[
int, Query(ge=1, description="Page number (starts from 1)")
] = 0,
current_user: Annotated[User, Depends(user_service.get_current_super_admin)],
limit: Annotated[int, Query(
ge=1, description="Number of products per page")] = 10,
skip: Annotated[int, Query(
ge=1, description="Page number (starts from 1)")] = 1,
category: Annotated[Optional[str], Query(
description="Filter products by category name")] = None,
db: Session = Depends(get_db),
):
"""Endpoint to get all products. Only accessible to superadmin"""
"""
Endpoint to get all products. Only accessible to superadmin.
Optionally filter products by category.
"""
# Base query
query = db.query(Product)

# Apply category filter if provided
if category:
query = query.join(Product.category).filter(ProductCategory.name.ilike(f"%{category}%"))

# Calculate the number of items to skip based on the page number
items_to_skip = (skip - 1) * limit

return paginated_response(db=db, model=Product, limit=limit, skip=skip)
return paginated_response(db=db, model=Product, limit=limit, skip=items_to_skip, query=query)


# categories
Expand Down
7 changes: 3 additions & 4 deletions api/v1/schemas/product_comment.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional
from datetime import datetime


class ProductCommentBase(BaseModel):
content: str = Field(..., example="This is a comment")
content: str = Field(..., json_schema_extra={"example": "This is a comment"})
author: str


Expand All @@ -21,8 +21,7 @@ class ProductCommentInDB(ProductCommentBase):
created_at: datetime
updated_at: datetime

class Config:
orm_mode = True
model_config = ConfigDict(from_attributes=True)


class ProductCommentResponse(BaseModel):
Expand Down
6 changes: 3 additions & 3 deletions api/v1/schemas/stripe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator


class PaymentInfo(BaseModel):
Expand All @@ -9,13 +9,13 @@ class PaymentInfo(BaseModel):
exp_year: int
cvc: str = Field(..., min_length=3, max_length=4)

@validator('card_number')
@field_validator('card_number')
def card_number_validator(cls, v):
if not v.isdigit() or len(v) != 16:
raise ValueError('Card number must be 16 digits')
return v

@validator('cvc')
@field_validator('cvc')
def cvc_validator(cls, v):
if not v.isdigit() or not (3 <= len(v) <= 4):
raise ValueError('CVC must be 3 or 4 digits')
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/billing_plan/test_billing_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_create_exisiting_plans(mock_user_service, mock_db_session):
"currency": "Naira",
"features": ["Multiple team", "Premium support"],
}

print(data)
response = client.post(
"/api/v1/organisations/billing-plans",
headers={"Authorization": f"Bearer {access_token}"},
Expand Down
209 changes: 209 additions & 0 deletions tests/v1/product/test_get_product_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from unittest.mock import MagicMock
from uuid_extensions import uuid7
from datetime import datetime, timezone, timedelta

from api.v1.models.organisation import Organisation
from api.v1.models.product import Product, ProductCategory
from api.v1.models.user import User
from main import app
from api.v1.routes.blog import get_db
from api.v1.services.user import user_service


# Mock database dependency
@pytest.fixture
def db_session_mock():
db_session = MagicMock(spec=Session)
return db_session


@pytest.fixture
def client(db_session_mock):
app.dependency_overrides[get_db] = lambda: db_session_mock
client = TestClient(app)
yield client
app.dependency_overrides = {}


# Mock user service dependency

user_id = uuid7()
org_id = uuid7()
product_id = uuid7()
category_id = uuid7()
timezone_offset = -8.0
tzinfo = timezone(timedelta(hours=timezone_offset))
timeinfo = datetime.now(tzinfo)
created_at = timeinfo
updated_at = timeinfo
access_token = user_service.create_access_token(str(user_id))
access_token2 = user_service.create_access_token(str(uuid7()))

# Create test user

user = User(
id=str(user_id),
email="[email protected]",
password="password123",
created_at=created_at,
updated_at=updated_at,
)

# Create test organisation

org = Organisation(
id=str(org_id),
name="hng",
email=None,
industry=None,
type=None,
country=None,
state=None,
address=None,
description=None,
created_at=created_at,
updated_at=updated_at,
)

# Create test category

category = ProductCategory(id=category_id, name="Electronics")

# Create test product

product = Product(
id=str(product_id),
name="prod one",
description="Test product",
price=125.55,
org_id=str(org_id),
quantity=50,
image_url="http://img",
category_id=str(category_id),
status="in_stock",
archived=False,
)


# Mock data for multiple products
products = [
Product(
id=str(uuid7()),
name="Smartphone",
description="A smartphone",
price=500.00,
org_id=str(org_id),
quantity=10,
image_url="http://img1",
category_id=str(category_id),
status="in_stock",
archived=False,
),
Product(
id=str(uuid7()),
name="Laptop",
description="A laptop",
price=1200.00,
org_id=str(org_id),
quantity=5,
image_url="http://img2",
category_id=str(category_id),
status="in_stock",
archived=False,
),
Product(
id=str(uuid7()),
name="T-Shirt",
description="A T-Shirt",
price=20.00,
org_id=str(org_id),
quantity=100,
image_url="http://img3",
category_id=str(uuid7()), # Different category
status="in_stock",
archived=False,
),
]


def test_get_products_filtered_by_category(client, db_session_mock):
# Mock the database query to return filtered products
db_session_mock.query().join().filter().offset().limit().all.return_value = [
products[0], products[1]]
db_session_mock.query().join().filter().count.return_value = 2 # Return an integer

headers = {"authorization": f"Bearer {access_token}"}
response = client.get(
"/api/v1/products?category=Electronics",
headers=headers
)

assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert len(data["data"]["items"]) == 2


def test_get_all_products_without_filter(client, db_session_mock):
# Mock the database query to return all products
db_session_mock.query().offset().limit().all.return_value = products
db_session_mock.query().count.return_value = 3

headers = {"authorization": f"Bearer {access_token}"}
response = client.get(
"/api/v1/products",
headers=headers
)

assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert len(data["data"]["items"]) == 3


def test_unauthorized_access(client, db_session_mock):
# Test unauthorized access (missing or invalid token)
response = client.get("/api/v1/products")
assert response.status_code == 401
assert response.json() == {
"status": False,
"status_code": 401,
"message": "Not authenticated"
}


def test_invalid_category_name(client, db_session_mock):
# Mock the database query to return no products for an invalid category
db_session_mock.query().join().filter().offset().limit().all.return_value = []
db_session_mock.query().join().filter().count.return_value = 0 # Return an integer

headers = {"authorization": f"Bearer {access_token}"}
response = client.get(
"/api/v1/products?category=InvalidCategory",
headers=headers
)

assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert len(data["data"]["items"]) == 0


def test_empty_results_for_valid_category(client, db_session_mock):
# Mock the database query to return no products for a valid but unused category
db_session_mock.query().join().filter().offset().limit().all.return_value = []
db_session_mock.query().join().filter().count.return_value = 0 # Return an integer

headers = {"authorization": f"Bearer {access_token}"}
response = client.get(
"/api/v1/products?category=Furniture",
headers=headers
)

assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert len(data["data"]["items"]) == 0