diff --git a/src/connectors/gemini_cloud_project.py b/src/connectors/gemini_cloud_project.py index bebf5b8f..c9bc5a6c 100644 --- a/src/connectors/gemini_cloud_project.py +++ b/src/connectors/gemini_cloud_project.py @@ -43,6 +43,7 @@ # mypy: disable-error-code="no-untyped-call,no-untyped-def,no-any-return,has-type,var-annotated" import asyncio +import contextlib import json import logging import os @@ -836,6 +837,9 @@ async def _validate_project_access(self) -> None: if logger.isEnabledFor(logging.ERROR): logger.error(f"Failed to validate project access: {e}", exc_info=True) raise + finally: + with contextlib.suppress(Exception): + auth_session.close() async def _resolve_gemini_api_config( self, @@ -870,6 +874,7 @@ async def _resolve_gemini_api_config( async def _perform_health_check(self) -> bool: """Perform a health check by testing API connectivity with project.""" + session = None try: # With ADC, token handling is internal; proceed to simple request @@ -922,6 +927,10 @@ async def _perform_health_check(self) -> bool: f"Health check failed - unexpected error: {e}", exc_info=True ) return False + finally: + if session is not None: + with contextlib.suppress(Exception): + session.close() def _generate_user_prompt_id(self, request_data: Any) -> str: """Generate a unique user_prompt_id for Code Assist requests.""" @@ -1045,9 +1054,10 @@ async def _chat_completions_standard( **kwargs: Any, ) -> ResponseEnvelope: """Handle non-streaming chat completions.""" + auth_session = None try: - # Use ADC for API calls (matches gemini CLI behavior for project-id auth) auth_session = self._get_adc_authorized_session() + # Use ADC for API calls (matches gemini CLI behavior for project-id auth) # Ensure project is onboarded for standard-tier project_id = await self._ensure_project_onboarded(auth_session) @@ -1194,6 +1204,10 @@ async def _chat_completions_standard( if logger.isEnabledFor(logging.ERROR): logger.error(f"Unexpected error during API call: {e}", exc_info=True) raise BackendError(f"Unexpected error during API call: {e}") + finally: + if auth_session is not None: + with contextlib.suppress(Exception): + auth_session.close() async def _chat_completions_streaming( self, @@ -1203,9 +1217,11 @@ async def _chat_completions_streaming( **kwargs: Any, ) -> StreamingResponseEnvelope: """Handle streaming chat completions.""" + auth_session = None + stream_prepared = False try: - # Use ADC for streaming API calls auth_session = self._get_adc_authorized_session() + # Use ADC for streaming API calls # Ensure project is onboarded for standard-tier project_id = await self._ensure_project_onboarded(auth_session) @@ -1387,9 +1403,14 @@ async def stream_generator() -> AsyncGenerator[ProcessedResponse, None]: finally: if response: # Ensure response is defined before closing response.close() # Use synchronous close + if auth_session is not None: + with contextlib.suppress(Exception): + auth_session.close() + generator = stream_generator() + stream_prepared = True return StreamingResponseEnvelope( - content=stream_generator(), + content=generator, media_type="text/event-stream", headers={}, ) @@ -1402,6 +1423,10 @@ async def stream_generator() -> AsyncGenerator[ProcessedResponse, None]: f"Unexpected error during streaming API call: {e}", exc_info=True ) raise BackendError(f"Unexpected error during streaming API call: {e}") + finally: + if not stream_prepared and auth_session is not None: + with contextlib.suppress(Exception): + auth_session.close() def _build_generation_config(self, request_data: Any) -> dict[str, Any]: cfg: dict[str, Any] = { diff --git a/tests/unit/connectors/test_gemini_cloud_project_resource_management.py b/tests/unit/connectors/test_gemini_cloud_project_resource_management.py new file mode 100644 index 00000000..0b7324c0 --- /dev/null +++ b/tests/unit/connectors/test_gemini_cloud_project_resource_management.py @@ -0,0 +1,94 @@ +import asyncio +from typing import Any + +import httpx +import pytest + +from src.connectors.gemini_cloud_project import GeminiCloudProjectConnector +from src.core.config.app_config import AppConfig +from src.core.services.translation_service import TranslationService + + +class _DummyResponse: + def __init__(self, status_code: int = 200, json_data: dict[str, Any] | None = None) -> None: + self.status_code = status_code + self._json_data = json_data or {} + self.text = "" + + def json(self) -> dict[str, Any]: + return self._json_data + + +class _DummySession: + def __init__(self) -> None: + self.closed = False + + def close(self) -> None: + self.closed = True + + +@pytest.fixture() +def connector() -> GeminiCloudProjectConnector: + cfg = AppConfig() + client = httpx.AsyncClient() + backend = GeminiCloudProjectConnector( + client, + cfg, + translation_service=TranslationService(), + gcp_project_id="test-project", + ) + backend.gemini_api_base_url = "https://example.com" + return backend + + +@pytest.mark.asyncio +async def test_validate_project_access_closes_session( + connector: GeminiCloudProjectConnector, monkeypatch: pytest.MonkeyPatch +) -> None: + class _Session(_DummySession): + def request(self, *args: Any, **kwargs: Any) -> _DummyResponse: + return _DummyResponse( + json_data={"cloudaicompanionProject": {"id": connector.gcp_project_id}} + ) + + session = _Session() + + async def _immediate_to_thread(func: Any, *args: Any, **kwargs: Any) -> Any: + return func(*args, **kwargs) + + monkeypatch.setattr(connector, "_get_adc_authorized_session", lambda: session) + monkeypatch.setattr(asyncio, "to_thread", _immediate_to_thread) + + await connector._validate_project_access() + + assert session.closed is True + + +@pytest.mark.asyncio +async def test_perform_health_check_closes_session( + connector: GeminiCloudProjectConnector, monkeypatch: pytest.MonkeyPatch +) -> None: + class _Credentials: + def __init__(self) -> None: + self.token = "token" + + def refresh(self, request: Any) -> None: # pragma: no cover - simple stub + self.token = "new-token" + + class _Session(_DummySession): + def __init__(self) -> None: + super().__init__() + self.credentials = _Credentials() + + async def _fake_get(url: str, headers: dict[str, str], timeout: float) -> Any: + return _DummyResponse(status_code=200) + + session = _Session() + + monkeypatch.setattr(connector, "_get_adc_authorized_session", lambda: session) + monkeypatch.setattr(connector.client, "get", _fake_get) + + result = await connector._perform_health_check() + + assert result is True + assert session.closed is True