From 04e334a0eea719ce2722f1db79165e93d4369c43 Mon Sep 17 00:00:00 2001 From: matdev83 <211248003+matdev83@users.noreply.github.com> Date: Mon, 13 Oct 2025 00:40:53 +0200 Subject: [PATCH] Clamp recent usage limit requests --- src/constants.py | 6 ++ src/core/app/controllers/usage_controller.py | 61 ++++++++++++++++++- src/core/common/usage_limits.py | 31 ++++++++++ src/core/services/usage_tracking_service.py | 30 ++++++++- .../test_usage_controller_comprehensive.py | 32 +++++----- ...st_usage_tracking_service_comprehensive.py | 49 +++++++++++++++ 6 files changed, 190 insertions(+), 19 deletions(-) create mode 100644 src/core/common/usage_limits.py diff --git a/src/constants.py b/src/constants.py index 4422c0d8f..1642eb80a 100644 --- a/src/constants.py +++ b/src/constants.py @@ -1 +1,7 @@ +"""Project-wide constants used across multiple modules.""" + DEFAULT_COMMAND_PREFIX: str = "!/" +"""Default command prefix for interactive commands.""" + +MAX_RECENT_USAGE_RECORDS: int = 1000 +"""Maximum number of recent usage records that can be requested at once.""" diff --git a/src/core/app/controllers/usage_controller.py b/src/core/app/controllers/usage_controller.py index 0283155d4..ab95eae41 100644 --- a/src/core/app/controllers/usage_controller.py +++ b/src/core/app/controllers/usage_controller.py @@ -12,6 +12,8 @@ from src.core.di.services import get_or_build_service_provider from src.core.domain.usage_data import UsageData from src.core.interfaces.usage_tracking_interface import IUsageTrackingService +from src.constants import MAX_RECENT_USAGE_RECORDS +from src.core.common.usage_limits import normalize_recent_usage_limit logger = logging.getLogger(__name__) @@ -62,8 +64,31 @@ async def get_recent_usage( if not self.usage_service: return [] + try: + requested_limit = int(limit) + except (TypeError, ValueError): + requested_limit = 0 + + normalized_limit = normalize_recent_usage_limit(requested_limit) + + if normalized_limit == 0: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Recent usage requested with limit=%s; returning empty result", limit + ) + return [] + + if normalized_limit < requested_limit: + if logger.isEnabledFor(logging.INFO): + logger.info( + "Recent usage limit clamped from %s to %s (max=%s)", + limit, + normalized_limit, + MAX_RECENT_USAGE_RECORDS, + ) + result = await self.usage_service.get_recent_usage( - session_id=session_id, limit=limit + session_id=session_id, limit=normalized_limit ) return result # type: ignore[no-any-return] @@ -92,7 +117,12 @@ async def get_usage_stats( @router.get("/recent", response_model=list[UsageData]) async def get_recent_usage( session_id: str | None = Query(None, description="Filter by session ID"), - limit: int = Query(100, description="Maximum number of records to return"), + limit: int = Query( + 100, + description="Maximum number of records to return", + ge=0, + le=MAX_RECENT_USAGE_RECORDS, + ), service_provider: Any = Depends(get_or_build_service_provider), ) -> list[UsageData]: """Get recent usage data. @@ -106,5 +136,30 @@ async def get_recent_usage( List of usage data entities """ usage_service = service_provider.get_required_service(IUsageTrackingService) - result = await usage_service.get_recent_usage(session_id=session_id, limit=limit) + try: + requested_limit = int(limit) + except (TypeError, ValueError): + requested_limit = 0 + + normalized_limit = normalize_recent_usage_limit(requested_limit) + + if normalized_limit == 0: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "API recent usage requested with limit=%s; returning empty result", limit + ) + return [] + + if normalized_limit < requested_limit: + if logger.isEnabledFor(logging.INFO): + logger.info( + "API recent usage limit clamped from %s to %s (max=%s)", + limit, + normalized_limit, + MAX_RECENT_USAGE_RECORDS, + ) + + result = await usage_service.get_recent_usage( + session_id=session_id, limit=normalized_limit + ) return result # type: ignore[no-any-return] diff --git a/src/core/common/usage_limits.py b/src/core/common/usage_limits.py new file mode 100644 index 000000000..95461d9d0 --- /dev/null +++ b/src/core/common/usage_limits.py @@ -0,0 +1,31 @@ +"""Utilities for normalizing usage-related request parameters.""" + +from __future__ import annotations + +from typing import Any + +from src.constants import MAX_RECENT_USAGE_RECORDS + + +def normalize_recent_usage_limit(limit: Any) -> int: + """Normalize the recent usage limit value to a safe, bounded integer. + + Args: + limit: The requested limit value that may come from untrusted sources. + + Returns: + A non-negative integer that does not exceed :data:`MAX_RECENT_USAGE_RECORDS`. + Invalid or non-positive values yield ``0`` so callers can short-circuit expensive + repository lookups. + """ + + try: + numeric_limit = int(limit) + except (TypeError, ValueError): + return 0 + + if numeric_limit <= 0: + return 0 + + return min(numeric_limit, MAX_RECENT_USAGE_RECORDS) + diff --git a/src/core/services/usage_tracking_service.py b/src/core/services/usage_tracking_service.py index ec3ca7976..ce7a1a53b 100644 --- a/src/core/services/usage_tracking_service.py +++ b/src/core/services/usage_tracking_service.py @@ -30,6 +30,8 @@ def headers(self) -> dict[str, str]: ... def media_type(self) -> str: ... +from src.constants import MAX_RECENT_USAGE_RECORDS +from src.core.common.usage_limits import normalize_recent_usage_limit from src.core.domain.usage_data import UsageData from src.core.domain.usage_stats import ModelUsageStats, UsageStatsResponse from src.core.interfaces.repositories_interface import IUsageRepository @@ -341,11 +343,37 @@ async def get_recent_usage( Returns: List of usage data entities """ + try: + requested_limit = int(limit) + except (TypeError, ValueError): + requested_limit = 0 + + normalized_limit = normalize_recent_usage_limit(requested_limit) + + if normalized_limit == 0: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Recent usage requested with limit=%s; returning empty result", limit + ) + return [] + + if normalized_limit < requested_limit: + if logger.isEnabledFor(logging.INFO): + logger.info( + "Recent usage limit clamped from %s to %s (max=%s)", + limit, + normalized_limit, + MAX_RECENT_USAGE_RECORDS, + ) + if session_id: data = await self._repository.get_by_session_id(session_id) else: data = await self._repository.get_all() # Sort by timestamp (newest first) and limit + if not data: + return [] + sorted_data = sorted(data, key=lambda x: x.timestamp, reverse=True) - return sorted_data[:limit] + return sorted_data[:normalized_limit] diff --git a/tests/unit/core/app/controllers/test_usage_controller_comprehensive.py b/tests/unit/core/app/controllers/test_usage_controller_comprehensive.py index d102ce508..74887cb07 100644 --- a/tests/unit/core/app/controllers/test_usage_controller_comprehensive.py +++ b/tests/unit/core/app/controllers/test_usage_controller_comprehensive.py @@ -8,6 +8,7 @@ from unittest.mock import AsyncMock import pytest +from src.constants import MAX_RECENT_USAGE_RECORDS from src.core.app.controllers.usage_controller import UsageController from src.core.domain.usage_data import UsageData from src.core.interfaces.usage_tracking_interface import IUsageTrackingService @@ -228,7 +229,7 @@ async def test_get_recent_usage_large_limit( assert result == mock_usage_data mock_usage_service.get_recent_usage.assert_called_once_with( - session_id=None, limit=10000 + session_id=None, limit=MAX_RECENT_USAGE_RECORDS ) @pytest.mark.asyncio @@ -236,15 +237,21 @@ async def test_get_recent_usage_zero_limit( self, controller: UsageController, mock_usage_service: IUsageTrackingService ) -> None: """Test get_recent_usage with zero limit.""" - mock_usage_data = [] - mock_usage_service.get_recent_usage.return_value = mock_usage_data - result = await controller.get_recent_usage(limit=0) - assert result == mock_usage_data - mock_usage_service.get_recent_usage.assert_called_once_with( - session_id=None, limit=0 - ) + assert result == [] + mock_usage_service.get_recent_usage.assert_not_called() + + @pytest.mark.asyncio + async def test_get_recent_usage_negative_limit( + self, controller: UsageController, mock_usage_service: IUsageTrackingService + ) -> None: + """Test get_recent_usage with negative limit.""" + + result = await controller.get_recent_usage(limit=-5) + + assert result == [] + mock_usage_service.get_recent_usage.assert_not_called() @pytest.mark.asyncio async def test_service_error_handling_stats( @@ -384,15 +391,10 @@ async def test_get_recent_usage_negative_limit( self, controller: UsageController, mock_usage_service: IUsageTrackingService ) -> None: """Test get_recent_usage with negative limit value.""" - mock_usage_data = [] - mock_usage_service.get_recent_usage.return_value = mock_usage_data - result = await controller.get_recent_usage(limit=-10) - assert result == mock_usage_data - mock_usage_service.get_recent_usage.assert_called_once_with( - session_id=None, limit=-10 - ) + assert result == [] + mock_usage_service.get_recent_usage.assert_not_called() @pytest.mark.asyncio async def test_get_usage_stats_zero_days( diff --git a/tests/unit/core/services/test_usage_tracking_service_comprehensive.py b/tests/unit/core/services/test_usage_tracking_service_comprehensive.py index 4e8b14b58..b415756fb 100644 --- a/tests/unit/core/services/test_usage_tracking_service_comprehensive.py +++ b/tests/unit/core/services/test_usage_tracking_service_comprehensive.py @@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, patch import pytest +from src.constants import MAX_RECENT_USAGE_RECORDS from src.core.domain.usage_data import UsageData from src.core.interfaces.repositories_interface import IUsageRepository from src.core.services.usage_tracking_service import UsageTrackingService @@ -346,6 +347,28 @@ async def test_get_recent_usage( assert result == mock_usage_data mock_repository.get_by_session_id.assert_called_once_with("session1") + @pytest.mark.asyncio + async def test_get_recent_usage_zero_limit( + self, service: UsageTrackingService, mock_repository: IUsageRepository + ) -> None: + """Recent usage should return empty results for non-positive limits.""" + + result = await service.get_recent_usage(limit=0) + + assert result == [] + mock_repository.get_all.assert_not_called() + + @pytest.mark.asyncio + async def test_get_recent_usage_negative_limit( + self, service: UsageTrackingService, mock_repository: IUsageRepository + ) -> None: + """Negative limits should be treated as zero to avoid large responses.""" + + result = await service.get_recent_usage(limit=-10) + + assert result == [] + mock_repository.get_all.assert_not_called() + @pytest.mark.asyncio async def test_get_recent_usage_defaults( self, service: UsageTrackingService, mock_repository: IUsageRepository @@ -359,6 +382,32 @@ async def test_get_recent_usage_defaults( assert result == mock_usage_data mock_repository.get_all.assert_called_once() + @pytest.mark.asyncio + async def test_get_recent_usage_large_limit_is_clamped( + self, service: UsageTrackingService, mock_repository: IUsageRepository + ) -> None: + """Large limits should be clamped to protect against excessive workloads.""" + + mock_usage_data = [ + UsageData( + id=str(index), + session_id="session", + model="model", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost=0.0, + timestamp=datetime.now(timezone.utc) + timedelta(seconds=index), + ) + for index in range(MAX_RECENT_USAGE_RECORDS + 50) + ] + mock_repository.get_all.return_value = mock_usage_data + + result = await service.get_recent_usage(limit=10_000) + + assert len(result) == MAX_RECENT_USAGE_RECORDS + mock_repository.get_all.assert_called_once() + @pytest.mark.asyncio async def test_get_recent_usage_with_session_id( self, service: UsageTrackingService, mock_repository: IUsageRepository