diff --git a/.gitignore b/.gitignore index 7c3c0c7a2..453245ef6 100644 --- a/.gitignore +++ b/.gitignore @@ -171,8 +171,9 @@ jeff.py # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ alembic/versions -**/.DS_Store \ No newline at end of file +**/.DS_Store.idea/ +create_superadmin.py diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 5f61a68fb..000000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,10 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml -# Editor-based HTTP Client requests -/httpRequests/ - -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml - diff --git a/.idea/git_toolbox_blame.xml b/.idea/git_toolbox_blame.xml deleted file mode 100644 index 7dc124965..000000000 --- a/.idea/git_toolbox_blame.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - \ No newline at end of file diff --git a/.idea/hng_boilerplate_python_fastapi.iml b/.idea/hng_boilerplate_python_fastapi.iml deleted file mode 100644 index aad402c4e..000000000 --- a/.idea/hng_boilerplate_python_fastapi.iml +++ /dev/null @@ -1,10 +0,0 @@ - - - - - - - \ No newline at end of file diff --git a/.idea/hng_boilerplate_python_fastapi_web.iml b/.idea/hng_boilerplate_python_fastapi_web.iml deleted file mode 100644 index 48d2ec6ad..000000000 --- a/.idea/hng_boilerplate_python_fastapi_web.iml +++ /dev/null @@ -1,17 +0,0 @@ - - - - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index aa251da3f..000000000 --- a/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,54 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2da2..000000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/material_theme_project_new.xml b/.idea/material_theme_project_new.xml deleted file mode 100644 index 7514ff5f4..000000000 --- a/.idea/material_theme_project_new.xml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index 21609fef7..000000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 6c4a82298..000000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 35eb1ddfb..000000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/api/utils/settings.py b/api/utils/settings.py index 9b064f608..b9d9b10f7 100644 --- a/api/utils/settings.py +++ b/api/utils/settings.py @@ -34,8 +34,8 @@ class Settings(BaseSettings): TWILIO_AUTH_TOKEN: str = config("TWILIO_AUTH_TOKEN") TWILIO_PHONE_NUMBER: str = config("TWILIO_PHONE_NUMBER") - APP_NAME: str = config("APP_NAME") - + APP_NAME: str = config("APP_NAME", default="HNG Boilerplate") + # Base URLs ANCHOR_PYTHON_BASE_URL: str = config( "ANCHOR_PYTHON_BASE_URL", default="https://anchor-python.teams.hng.tech" diff --git a/api/v1/routes/user.py b/api/v1/routes/user.py index cec064283..a8113d8a9 100644 --- a/api/v1/routes/user.py +++ b/api/v1/routes/user.py @@ -117,33 +117,24 @@ def delete_user( async def get_users( current_user: Annotated[User, Depends(user_service.get_current_super_admin)], db: Annotated[Session, Depends(get_db)], - page: int = 1, per_page: int = 10, - is_active: Optional[bool] = Query(None), - is_deleted: Optional[bool] = Query(None), - is_verified: Optional[bool] = Query(None), - is_superadmin: Optional[bool] = Query(None) + page: int = Query(1, ge=1, description="Page number, starting from 1"), + limit: int = Query(20, ge=1, le=50, description="Users per page, max 50"), + search: Optional[str] = Query(None, description="Search term for first_name, last_name, or email"), + is_active: Optional[bool] = Query(None, description="Filter by active status"), ): """ - Retrieves all users. + Retrieves all users with search, filtering, and pagination. Args: - current_user: The current user(admin) making the request + current_user: The current superadmin making the request db: database Session object - page: the page number - per_page: the maximum size of users for each page - is_active: boolean to filter active users - is_deleted: boolean to filter deleted users - is_verified: boolean to filter verified users - is_superadmin: boolean to filter users that are super admin + page: page number (default: 1) + limit: maximum users per page (default: 20, max: 50) + search: term to search in first_name, last_name, or email + is_active: filter by active status Returns: - UserData + AllUsersResponse """ - query_params = { - 'is_active': is_active, - 'is_deleted': is_deleted, - 'is_verified': is_verified, - 'is_superadmin': is_superadmin, - } - return user_service.fetch_all(db, page, per_page, **query_params) + return user_service.fetch_all(db, page, limit, search, is_active) @user_router.post("", status_code=status.HTTP_201_CREATED, response_model=AdminCreateUserResponse) def admin_registers_user( diff --git a/api/v1/schemas/user.py b/api/v1/schemas/user.py index 095135e11..d8c456527 100644 --- a/api/v1/schemas/user.py +++ b/api/v1/schemas/user.py @@ -1,23 +1,18 @@ from email_validator import validate_email, EmailNotValidError import dns.resolver from datetime import datetime -from typing import (Optional, Union, - List, Annotated, Dict, - Literal) +from typing import Optional, Union, List, Annotated, Dict, Literal -from pydantic import (BaseModel, EmailStr, - field_validator, ConfigDict, - StringConstraints, - model_validator) - -from pydantic import Field # Added this import +from pydantic import ( + BaseModel, EmailStr, field_validator, ConfigDict, + StringConstraints, model_validator, Field +) def validate_mx_record(domain: str): """ Validate mx records for email """ try: - # Try to resolve the MX record for the domain mx_records = dns.resolver.resolve(domain, 'MX') return True if mx_records else False except dns.resolver.NoAnswer: @@ -29,7 +24,6 @@ def validate_mx_record(domain: str): class UserBase(BaseModel): """Base user schema""" - id: str first_name: str last_name: str @@ -39,17 +33,15 @@ class UserBase(BaseModel): class UserEmailSender(BaseModel): email: EmailStr - class UserCreate(BaseModel): """Schema to create a user""" - email: EmailStr password: Annotated[ - str, StringConstraints( - min_length=8, - max_length=64, - strip_whitespace=True - ) + str, StringConstraints(min_length=8, max_length=64, strip_whitespace=True) + ] + confirm_password: Annotated[ + str, StringConstraints(min_length=8, max_length=64, strip_whitespace=True), + Field(exclude=True) ] """Added the confirm_password field to UserCreate Model""" confirm_password: Annotated[ @@ -62,18 +54,10 @@ class UserCreate(BaseModel): Field(exclude=True) # exclude confirm_password field ] first_name: Annotated[ - str, StringConstraints( - min_length=3, - max_length=30, - strip_whitespace=True - ) + str, StringConstraints(min_length=3, max_length=30, strip_whitespace=True) ] last_name: Annotated[ - str, StringConstraints( - min_length=3, - max_length=30, - strip_whitespace=True - ) + str, StringConstraints(min_length=3, max_length=30, strip_whitespace=True) ] @model_validator(mode='before') @@ -83,10 +67,10 @@ def validate_password(cls, values: dict): Validates passwords """ password = values.get('password') + confirm_password = values.get('confirm_password') # gets the confirm password email = values.get("email") - # constraints for password if not any(c.islower() for c in password): raise ValueError("password must include at least one lowercase character") if not any(c.isupper() for c in password): @@ -96,13 +80,14 @@ def validate_password(cls, values: dict): if not any(c in ['!','@','#','$','%','&','*','?','_','-'] for c in password): raise ValueError("password must include at least one special character") + """Confirm Password Validation""" if not confirm_password: raise ValueError("Confirm password field is required") elif password != confirm_password: raise ValueError("Passwords do not match") - + try: email = validate_email(email, check_deliverability=True) if email.domain.count(".com") > 1: @@ -117,9 +102,8 @@ def validate_password(cls, values: dict): return values class UserUpdate(BaseModel): - - first_name : Optional[str] = None - last_name : Optional[str] = None + first_name: Optional[str] = None + last_name: Optional[str] = None class UserData(BaseModel): """ @@ -167,7 +151,7 @@ class ProfileData(BaseModel): bio: Optional[str] = None phone_number: Optional[str] = None avatar_url: Optional[str] = None - recovery_email: Optional[EmailStr] + recovery_email: Optional[EmailStr] = None model_config = ConfigDict(from_attributes=True) @@ -197,22 +181,9 @@ class AuthMeResponse(BaseModel): data: Dict[Literal["user", "organisations", "profile"], Union[UserData2, List[OrganisationData], ProfileData]] - -class AllUsersResponse(BaseModel): - """ - Schema for all users - """ - message: str - status_code: int - status: str - page: int - per_page: int - total: int - data: Union[List[UserData], List[None]] - class AdminCreateUser(BaseModel): """ - Schema for admin to create a users + Schema for admin to create users """ email: EmailStr first_name: str @@ -225,7 +196,6 @@ class AdminCreateUser(BaseModel): model_config = ConfigDict(from_attributes=True) - class AdminCreateUserResponse(BaseModel): """ Schema response for user created by admin @@ -238,7 +208,8 @@ class AdminCreateUserResponse(BaseModel): class LoginRequest(BaseModel): email: EmailStr password: str - totp_code: str | None = None + + totp_code: Optional[str] = None @model_validator(mode='before') @classmethod @@ -252,7 +223,6 @@ def validate_password(cls, values: dict): email = values.get("email") totp_code = values.get("totp_code") - # constraints for password if not any(c.islower() for c in password): raise ValueError("password must include at least one lowercase character") if not any(c.isupper() for c in password): @@ -275,13 +245,12 @@ def validate_password(cls, values: dict): if totp_code: from api.v1.schemas.totp_device import TOTPTokenSchema - + if not TOTPTokenSchema.validate_totp_code(totp_code): raise ValueError("totp code must be a 6-digit number") return values - class EmailRequest(BaseModel): email: EmailStr @@ -304,46 +273,31 @@ def validate_email(cls, values: dict): raise ValueError(exc) from exc return values - class Token(BaseModel): token: str - class TokenData(BaseModel): """Schema to structure token data""" - id: Optional[str] - class DeactivateUserSchema(BaseModel): """Schema for deactivating a user""" - reason: Optional[str] = None confirmation: bool - class ChangePasswordSchema(BaseModel): """Schema for changing password of a user""" - old_password: Annotated[ Optional[str], - StringConstraints(min_length=8, - max_length=64, - strip_whitespace=True) + StringConstraints(min_length=8, max_length=64, strip_whitespace=True) ] = None - new_password: Annotated[ str, - StringConstraints(min_length=8, - max_length=64, - strip_whitespace=True) + StringConstraints(min_length=8, max_length=64, strip_whitespace=True) ] - confirm_new_password: Annotated[ str, - StringConstraints(min_length=8, - max_length=64, - strip_whitespace=True) + StringConstraints(min_length=8, max_length=64, strip_whitespace=True) ] @model_validator(mode='before') @@ -358,7 +312,6 @@ def validate_password(cls, values: dict): if (old_password and old_password.strip() == '') or old_password == '': values['old_password'] = None - # constraints for old_password if old_password and old_password.strip(): if not any(c.islower() for c in old_password): raise ValueError("Old password must include at least one lowercase character") @@ -369,7 +322,6 @@ def validate_password(cls, values: dict): if not any(c in ['!','@','#','$','%','&','*','?','_','-'] for c in old_password): raise ValueError("Old password must include at least one special character") - # constraints for new_password if not any(c.islower() for c in new_password): raise ValueError("New password must include at least one lowercase character") if not any(c.isupper() for c in new_password): @@ -384,17 +336,13 @@ def validate_password(cls, values: dict): return values - class ChangePwdRet(BaseModel): - """schema for returning change password response""" - + """Schema for returning change password response""" status_code: int message: str - class MagicLinkRequest(BaseModel): """Schema for magic link creation""" - email: EmailStr @model_validator(mode='before') @@ -416,15 +364,12 @@ def validate_email(cls, values: dict): raise ValueError(exc) from exc return values - class MagicLinkResponse(BaseModel): - """Schema for magic link respone""" - + """Schema for magic link response""" message: str class UserRoleSchema(BaseModel): """Schema for user role""" - role: str user_id: str org_id: str @@ -437,4 +382,24 @@ def role_validator(cls, value): """ if value not in ["admin", "user", "guest", "owner"]: raise ValueError("Role has to be one of admin, guest, user, or owner") + return value + +class Pagination(BaseModel): + """Schema for pagination details""" + page: int + limit: int + total_pages: int + total_users: int + +class AllUsersResponse(BaseModel): + """ + Schema for all users + """ + message: str + status_code: int + status: str + data: Dict[str, Union[List[UserData], Pagination]] + + model_config = ConfigDict(from_attributes=True) + diff --git a/api/v1/services/user.py b/api/v1/services/user.py index 48c665c80..575625a2f 100644 --- a/api/v1/services/user.py +++ b/api/v1/services/user.py @@ -8,7 +8,7 @@ from jose import JWTError, jwt from fastapi import Depends, HTTPException, Request from sqlalchemy.orm import Session -from sqlalchemy import desc +from sqlalchemy import desc, or_, func from passlib.context import CryptContext from datetime import datetime, timedelta @@ -36,68 +36,101 @@ class UserService(Service): def fetch_all( self, db: Session, - page: int, - per_page: int, + page: int = 1, + limit: int = 20, + search: Optional[str] = None, + is_active: Optional[bool] = None, **query_params: Optional[Any], ): """ - Fetch all users + Retrieves all users with optional search, filtering, and pagination. + Args: - db: database Session object - page: page number - per_page: max number of users in a page - query_params: params to filter by + db: SQLAlchemy database session. + page: Page number to retrieve (default: 1, minimum: 1). + limit: Number of users per page (default: 20, range: 1-50). + search: Term to filter users by first_name, last_name, or email (case-insensitive). + is_active: Boolean to filter users by active status. + **query_params: Additional query parameters (currently unused). + + Returns: + AllUsersResponse: Object containing filtered users and pagination metadata. + + Raises: + HTTPException: If is_active is not a boolean value (422 Unprocessable Entity). """ - per_page = min(per_page, 10) - - # Enable filter by query parameter - filters = [] - if all(query_params): - # Validate boolean query parameters - for param, value in query_params.items(): - if value is not None and not isinstance(value, bool): - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Invalid value for '{param}'. Must be a boolean.", - ) - if value == None: - continue - if hasattr(User, param): - filters.append(getattr(User, param) == value) + # Restrict pagination parameters to valid ranges + limit = min(max(limit, 1), 50) # Caps limit between 1 and 50 + page = max(page, 1) # Ensures page is at least 1 + + # Initialize base query for User table query = db.query(User) + + # Filter by search term across first_name, last_name, and email + if search: + search_term = f"%{search.strip().lower()}%" + query = query.filter( + or_( + func.lower(User.first_name).like(search_term), + func.lower(User.last_name).like(search_term), + func.lower(User.email).like(search_term), + ) + ) + + # Filter by is_active status if provided + if is_active is not None: + if not isinstance(is_active, bool): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="is_active must be a boolean", + ) + query = query.filter(User.is_active == is_active) + + # Get total user count before applying pagination total_users = query.count() - if filters: - query = query.filter(*filters) - total_users = query.count() - all_users: list = ( + # Apply pagination and order by creation date (descending) + users = ( query.order_by(desc(User.created_at)) - .limit(per_page) - .offset((page - 1) * per_page) + .limit(limit) + .offset((page - 1) * limit) .all() ) - return self.all_users_response(all_users, total_users, page, per_page) + # Compute total pages based on user count and limit + total_pages = (total_users + limit - 1) // limit + # Return structured response with users and pagination details + return self.all_users_response(users, total_users, page, limit, total_pages) + def all_users_response( - self, users: list, total_users: int, page: int, per_page: int + self, users: list, total_users: int, page: int, limit: int, total_pages: int ): """ Generates a response for all users Args: - users: a list containing user objects + users: list of user objects total_users: total number of users + page: current page + limit: users per page + total_pages: total number of pages """ if not users or len(users) == 0: return user.AllUsersResponse( - message="No User(s) for this query", + message="No User(s) found for this query", status="success", status_code=200, - page=page, - per_page=per_page, - total=0, - data=[], + data={ + "users": [], + "pagination": { + "page": page, + "limit": limit, + "total_pages": total_pages, + "total_users": total_users, + }, + }, ) + all_users = [ user.UserData.model_validate(usr, from_attributes=True) for usr in users @@ -106,10 +139,15 @@ def all_users_response( message="Users successfully retrieved", status="success", status_code=200, - page=page, - per_page=per_page, - total=total_users, - data=all_users, + data={ + "users": all_users, + "pagination": { + "page": page, + "limit": limit, + "total_pages": total_pages, + "total_users": total_users, + }, + }, ) def fetch(self, db: Session, id): diff --git a/tests/v1/user/test_get_all_users.py b/tests/v1/user/test_get_all_users.py index 84166cfb6..00becca04 100644 --- a/tests/v1/user/test_get_all_users.py +++ b/tests/v1/user/test_get_all_users.py @@ -3,36 +3,30 @@ from datetime import datetime from sqlalchemy.orm import Session from unittest.mock import MagicMock, patch -from main import app # Adjust this import according to your project structure +from main import app from api.db.database import get_db - -from api.v1.schemas.user import AllUsersResponse, UserData +from api.v1.schemas.user import AllUsersResponse, UserData, Pagination from api.v1.models.user import User from api.v1.services.user import UserService - client = TestClient(app) - @pytest.fixture def mock_db_session(): session = MagicMock(spec=Session) yield session - @pytest.fixture def user_service_mock(): return MagicMock() - -# Overriding the dependency +# Override dependencies @pytest.fixture(autouse=True) def override_get_db(mock_db_session): app.dependency_overrides[get_db] = lambda: mock_db_session - @pytest.fixture(autouse=True) -def override_User_services(user_service_mock): +def override_user_service(user_service_mock): app.dependency_overrides[UserService] = lambda: user_service_mock @pytest.fixture @@ -44,7 +38,7 @@ def mock_superadmin(): @pytest.fixture def mock_token_verification(): with patch("api.v1.services.user.UserService.verify_access_token") as mock: - mock.return_value = MagicMock(id="superadmin_id", is_superadmin=True) + mock.return_value = MagicMock(id="superadmin_id") yield mock def test_get_all_users(mock_db_session, user_service_mock, mock_superadmin, mock_token_verification): @@ -54,86 +48,127 @@ def test_get_all_users(mock_db_session, user_service_mock, mock_superadmin, mock created_at = datetime.now() updated_at = datetime.now() page = 1 - per_page = 10 + limit = 10 # Updated from per_page to limit mock_users = [ - User(id='admin_id', email='admin@email.com', first_name='admin', - last_name='admin', password='super_admin', created_at=created_at, - updated_at=updated_at, is_active=True, is_deleted=False, - is_verified=True, is_superadmin=False), - User(id='user_id', email='user@email.com', first_name='admin', - last_name='admin', password='my_password', created_at=created_at, updated_at=updated_at, is_active=True, is_deleted=False, - is_verified=True, is_superadmin=False) + User( + id='admin_id', + email='admin@email.com', + first_name='admin', + last_name='admin', + password='super_admin', + created_at=created_at, + updated_at=updated_at, + is_active=True, + is_deleted=False, + is_verified=True, + is_superadmin=False + ), + User( + id='user_id', + email='user@email.com', + first_name='admin', + last_name='admin', + password='my_password', + created_at=created_at, + updated_at=updated_at, + is_active=True, + is_deleted=False, + is_verified=True, + is_superadmin=False + ) ] - + (mock_db_session .query.return_value .order_by.return_value .limit.return_value - .offset.return_value. - all.return_value) = mock_users + .offset.return_value + .all.return_value) = mock_users mock_db_session.query.return_value.count.return_value = len(mock_users) - + + # Updated to match new AllUsersResponse structure user_service_mock.fetch_all.return_value = AllUsersResponse( message='Users successfully retrieved', status='success', - page=page, - per_page=per_page, status_code=200, - total=len(mock_users), - data=[UserData( - id=user.id, - email=user.email, - first_name=user.first_name, - last_name=user.last_name, - is_active=True, - is_deleted=False, - is_verified=True, - is_superadmin=False, - created_at=user.created_at, - updated_at=updated_at - ) for user in mock_users] + data={ + "users": [ + UserData( + id=user.id, + email=user.email, + first_name=user.first_name, + last_name=user.last_name, + is_active=user.is_active, + is_deleted=user.is_deleted, + is_verified=user.is_verified, + is_superadmin=user.is_superadmin, + created_at=user.created_at, + updated_at=user.updated_at + ) for user in mock_users + ], + "pagination": Pagination( + page=page, + limit=limit, + total_pages=1, + total_users=len(mock_users) + ) + } ) - headers = { - 'Authorization': f'Bearer fake_token' - } - response = client.get(f"/api/v1/users?page={page}&per_page={per_page}", headers=headers) - print(response.json()) - - assert response.json().get('status_code') == 200 + + headers = {'Authorization': 'Bearer fake_token'} + response = client.get(f"/api/v1/users?page={page}&limit={limit}", headers=headers) - assert response.json() == { - 'message': 'Users successfully retrieved', - 'status': 'success', - 'status_code': 200, - 'page': page, - 'per_page': per_page, - 'total': len(mock_users), - 'data': [ - { - 'id': mock_users[0].id, - 'email': mock_users[0].email, - 'first_name': mock_users[0].first_name, - 'last_name': mock_users[0].last_name, - 'is_active': True, - 'is_deleted': False, - 'is_verified': True, - 'is_superadmin': False, - 'created_at': mock_users[0].created_at.isoformat(), - 'updated_at': updated_at.isoformat() - }, - { - 'id': mock_users[1].id, - 'email': mock_users[1].email, - 'first_name': mock_users[1].first_name, - 'last_name': mock_users[1].last_name, - 'is_active': True, - 'is_deleted': False, - 'is_verified': True, - 'is_superadmin': False, - 'created_at': mock_users[1].created_at.isoformat(), - 'updated_at': updated_at.isoformat() + assert response.status_code == 200 + response_json = response.json() + assert response_json["status_code"] == 200 + assert response_json["status"] == "success" + assert response_json["message"] == "Users successfully retrieved" + assert len(response_json["data"]["users"]) == len(mock_users) + assert response_json["data"]["pagination"]["page"] == page + assert response_json["data"]["pagination"]["limit"] == limit + assert response_json["data"]["pagination"]["total_users"] == len(mock_users) + assert response_json["data"]["pagination"]["total_pages"] == 1 + + # Optional: Detailed JSON comparison + expected_json = { + "message": "Users successfully retrieved", + "status": "success", + "status_code": 200, + "data": { + "users": [ + { + "id": mock_users[0].id, + "email": mock_users[0].email, + "first_name": mock_users[0].first_name, + "last_name": mock_users[0].last_name, + "is_active": mock_users[0].is_active, + "is_deleted": mock_users[0].is_deleted, + "is_verified": mock_users[0].is_verified, + "is_superadmin": mock_users[0].is_superadmin, + "created_at": mock_users[0].created_at.isoformat(), + "updated_at": mock_users[0].updated_at.isoformat() + }, + { + "id": mock_users[1].id, + "email": mock_users[1].email, + "first_name": mock_users[1].first_name, + "last_name": mock_users[1].last_name, + "is_active": mock_users[1].is_active, + "is_deleted": mock_users[1].is_deleted, + "is_verified": mock_users[1].is_verified, + "is_superadmin": mock_users[1].is_superadmin, + "created_at": mock_users[1].created_at.isoformat(), + "updated_at": mock_users[1].updated_at.isoformat() + } + ], + "pagination": { + "page": page, + "limit": limit, + "total_pages": 1, + "total_users": len(mock_users) } - ] + } } + assert response_json == expected_json \ No newline at end of file diff --git a/tests/v1/user/test_user_service.py b/tests/v1/user/test_user_service.py new file mode 100644 index 000000000..9a160f334 --- /dev/null +++ b/tests/v1/user/test_user_service.py @@ -0,0 +1,90 @@ +import pytest +from sqlalchemy.orm import Session +from api.v1.models.user import User +from api.v1.services.user import UserService +from unittest.mock import MagicMock +from datetime import datetime + +@pytest.fixture +def db_session(): + return MagicMock(spec=Session) + +@pytest.fixture +def user_service(): + return UserService() + +def test_fetch_all_search(db_session, user_service): + # Mock users with all required fields + created_at = datetime.now() + updated_at = datetime.now() + user1 = User( + id="1", + first_name="John", + last_name="Doe", + email="john.doe@example.com", + is_active=True, + is_deleted=False, + is_verified=True, + is_superadmin=False, + created_at=created_at, + updated_at=updated_at + ) + db_session.query().filter().order_by().limit().offset().all.return_value = [user1] + db_session.query().filter().count.return_value = 1 + + response = user_service.fetch_all(db_session, page=1, limit=20, search="john") + assert response.status_code == 200 + assert len(response.data["users"]) == 1 + assert response.data["users"][0].first_name == "John" + assert response.data["pagination"].total_users == 1 # Access as attribute + assert response.data["pagination"].total_pages == 1 # Access as attribute + +def test_fetch_all_is_active_filter(db_session, user_service): + created_at = datetime.now() + updated_at = datetime.now() + user1 = User( + id="1", + first_name="Jane", + last_name="Doe", + email="jane.doe@example.com", + is_active=True, + is_deleted=False, + is_verified=True, + is_superadmin=False, + created_at=created_at, + updated_at=updated_at + ) + db_session.query().filter().order_by().limit().offset().all.return_value = [user1] + db_session.query().filter().count.return_value = 1 + + response = user_service.fetch_all(db_session, page=1, limit=20, is_active=True) + assert response.status_code == 200 + assert len(response.data["users"]) == 1 + assert response.data["users"][0].is_active is True + +def test_fetch_all_pagination(db_session, user_service): + created_at = datetime.now() + updated_at = datetime.now() + users = [ + User( + id=str(i), + first_name=f"User{i}", + last_name="Test", + email=f"user{i}@example.com", + is_active=True, + is_deleted=False, + is_verified=True, + is_superadmin=False, + created_at=created_at, + updated_at=updated_at + ) for i in range(25) + ] + db_session.query().order_by().limit().offset().all.return_value = users[:20] + db_session.query().count.return_value = 25 + + response = user_service.fetch_all(db_session, page=1, limit=20) + assert response.data["pagination"].page == 1 + assert response.data["pagination"].limit == 20 + assert response.data["pagination"].total_users == 25 + assert response.data["pagination"].total_pages == 2 + assert len(response.data["users"]) == 20 \ No newline at end of file