From 09d3d9cdd4662c37bdf21e4db4c04fb26dfa819c Mon Sep 17 00:00:00 2001 From: matdev83 <211248003+matdev83@users.noreply.github.com> Date: Mon, 13 Oct 2025 00:30:51 +0200 Subject: [PATCH 1/2] Ensure Gemini Cloud Project sessions are closed --- src/connectors/gemini_cloud_project.py | 27 +++++- ...emini_cloud_project_resource_management.py | 94 +++++++++++++++++++ 2 files changed, 118 insertions(+), 3 deletions(-) create mode 100644 tests/unit/connectors/test_gemini_cloud_project_resource_management.py diff --git a/src/connectors/gemini_cloud_project.py b/src/connectors/gemini_cloud_project.py index bebf5b8f..34c114ea 100644 --- a/src/connectors/gemini_cloud_project.py +++ b/src/connectors/gemini_cloud_project.py @@ -43,6 +43,7 @@ # mypy: disable-error-code="no-untyped-call,no-untyped-def,no-any-return,has-type,var-annotated" import asyncio +import contextlib import json import logging import os @@ -836,6 +837,9 @@ async def _validate_project_access(self) -> None: if logger.isEnabledFor(logging.ERROR): logger.error(f"Failed to validate project access: {e}", exc_info=True) raise + finally: + with contextlib.suppress(Exception): + auth_session.close() async def _resolve_gemini_api_config( self, @@ -870,6 +874,7 @@ async def _resolve_gemini_api_config( async def _perform_health_check(self) -> bool: """Perform a health check by testing API connectivity with project.""" + session = None try: # With ADC, token handling is internal; proceed to simple request @@ -922,6 +927,10 @@ async def _perform_health_check(self) -> bool: f"Health check failed - unexpected error: {e}", exc_info=True ) return False + finally: + if session is not None: + with contextlib.suppress(Exception): + session.close() def _generate_user_prompt_id(self, request_data: Any) -> str: """Generate a unique user_prompt_id for Code Assist requests.""" @@ -1045,9 +1054,9 @@ async def _chat_completions_standard( **kwargs: Any, ) -> ResponseEnvelope: """Handle non-streaming chat completions.""" + auth_session = self._get_adc_authorized_session() try: # Use ADC for API calls (matches gemini CLI behavior for project-id auth) - auth_session = self._get_adc_authorized_session() # Ensure project is onboarded for standard-tier project_id = await self._ensure_project_onboarded(auth_session) @@ -1194,6 +1203,9 @@ async def _chat_completions_standard( if logger.isEnabledFor(logging.ERROR): logger.error(f"Unexpected error during API call: {e}", exc_info=True) raise BackendError(f"Unexpected error during API call: {e}") + finally: + with contextlib.suppress(Exception): + auth_session.close() async def _chat_completions_streaming( self, @@ -1203,9 +1215,10 @@ async def _chat_completions_streaming( **kwargs: Any, ) -> StreamingResponseEnvelope: """Handle streaming chat completions.""" + auth_session = self._get_adc_authorized_session() + stream_prepared = False try: # Use ADC for streaming API calls - auth_session = self._get_adc_authorized_session() # Ensure project is onboarded for standard-tier project_id = await self._ensure_project_onboarded(auth_session) @@ -1387,9 +1400,13 @@ async def stream_generator() -> AsyncGenerator[ProcessedResponse, None]: finally: if response: # Ensure response is defined before closing response.close() # Use synchronous close + with contextlib.suppress(Exception): + auth_session.close() + generator = stream_generator() + stream_prepared = True return StreamingResponseEnvelope( - content=stream_generator(), + content=generator, media_type="text/event-stream", headers={}, ) @@ -1402,6 +1419,10 @@ async def stream_generator() -> AsyncGenerator[ProcessedResponse, None]: f"Unexpected error during streaming API call: {e}", exc_info=True ) raise BackendError(f"Unexpected error during streaming API call: {e}") + finally: + if not stream_prepared: + with contextlib.suppress(Exception): + auth_session.close() def _build_generation_config(self, request_data: Any) -> dict[str, Any]: cfg: dict[str, Any] = { diff --git a/tests/unit/connectors/test_gemini_cloud_project_resource_management.py b/tests/unit/connectors/test_gemini_cloud_project_resource_management.py new file mode 100644 index 00000000..0b7324c0 --- /dev/null +++ b/tests/unit/connectors/test_gemini_cloud_project_resource_management.py @@ -0,0 +1,94 @@ +import asyncio +from typing import Any + +import httpx +import pytest + +from src.connectors.gemini_cloud_project import GeminiCloudProjectConnector +from src.core.config.app_config import AppConfig +from src.core.services.translation_service import TranslationService + + +class _DummyResponse: + def __init__(self, status_code: int = 200, json_data: dict[str, Any] | None = None) -> None: + self.status_code = status_code + self._json_data = json_data or {} + self.text = "" + + def json(self) -> dict[str, Any]: + return self._json_data + + +class _DummySession: + def __init__(self) -> None: + self.closed = False + + def close(self) -> None: + self.closed = True + + +@pytest.fixture() +def connector() -> GeminiCloudProjectConnector: + cfg = AppConfig() + client = httpx.AsyncClient() + backend = GeminiCloudProjectConnector( + client, + cfg, + translation_service=TranslationService(), + gcp_project_id="test-project", + ) + backend.gemini_api_base_url = "https://example.com" + return backend + + +@pytest.mark.asyncio +async def test_validate_project_access_closes_session( + connector: GeminiCloudProjectConnector, monkeypatch: pytest.MonkeyPatch +) -> None: + class _Session(_DummySession): + def request(self, *args: Any, **kwargs: Any) -> _DummyResponse: + return _DummyResponse( + json_data={"cloudaicompanionProject": {"id": connector.gcp_project_id}} + ) + + session = _Session() + + async def _immediate_to_thread(func: Any, *args: Any, **kwargs: Any) -> Any: + return func(*args, **kwargs) + + monkeypatch.setattr(connector, "_get_adc_authorized_session", lambda: session) + monkeypatch.setattr(asyncio, "to_thread", _immediate_to_thread) + + await connector._validate_project_access() + + assert session.closed is True + + +@pytest.mark.asyncio +async def test_perform_health_check_closes_session( + connector: GeminiCloudProjectConnector, monkeypatch: pytest.MonkeyPatch +) -> None: + class _Credentials: + def __init__(self) -> None: + self.token = "token" + + def refresh(self, request: Any) -> None: # pragma: no cover - simple stub + self.token = "new-token" + + class _Session(_DummySession): + def __init__(self) -> None: + super().__init__() + self.credentials = _Credentials() + + async def _fake_get(url: str, headers: dict[str, str], timeout: float) -> Any: + return _DummyResponse(status_code=200) + + session = _Session() + + monkeypatch.setattr(connector, "_get_adc_authorized_session", lambda: session) + monkeypatch.setattr(connector.client, "get", _fake_get) + + result = await connector._perform_health_check() + + assert result is True + assert session.closed is True From a5e4c53466dca3a1dc1f913ac3d04c48409ed847 Mon Sep 17 00:00:00 2001 From: matdev83 <211248003+matdev83@users.noreply.github.com> Date: Wed, 15 Oct 2025 23:08:34 +0200 Subject: [PATCH 2/2] Fix ADC session acquisition error handling --- src/connectors/gemini_cloud_project.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/connectors/gemini_cloud_project.py b/src/connectors/gemini_cloud_project.py index 34c114ea..c9bc5a6c 100644 --- a/src/connectors/gemini_cloud_project.py +++ b/src/connectors/gemini_cloud_project.py @@ -1054,8 +1054,9 @@ async def _chat_completions_standard( **kwargs: Any, ) -> ResponseEnvelope: """Handle non-streaming chat completions.""" - auth_session = self._get_adc_authorized_session() + auth_session = None try: + auth_session = self._get_adc_authorized_session() # Use ADC for API calls (matches gemini CLI behavior for project-id auth) # Ensure project is onboarded for standard-tier @@ -1204,8 +1205,9 @@ async def _chat_completions_standard( logger.error(f"Unexpected error during API call: {e}", exc_info=True) raise BackendError(f"Unexpected error during API call: {e}") finally: - with contextlib.suppress(Exception): - auth_session.close() + if auth_session is not None: + with contextlib.suppress(Exception): + auth_session.close() async def _chat_completions_streaming( self, @@ -1215,9 +1217,10 @@ async def _chat_completions_streaming( **kwargs: Any, ) -> StreamingResponseEnvelope: """Handle streaming chat completions.""" - auth_session = self._get_adc_authorized_session() + auth_session = None stream_prepared = False try: + auth_session = self._get_adc_authorized_session() # Use ADC for streaming API calls # Ensure project is onboarded for standard-tier @@ -1400,8 +1403,9 @@ async def stream_generator() -> AsyncGenerator[ProcessedResponse, None]: finally: if response: # Ensure response is defined before closing response.close() # Use synchronous close - with contextlib.suppress(Exception): - auth_session.close() + if auth_session is not None: + with contextlib.suppress(Exception): + auth_session.close() generator = stream_generator() stream_prepared = True @@ -1420,7 +1424,7 @@ async def stream_generator() -> AsyncGenerator[ProcessedResponse, None]: ) raise BackendError(f"Unexpected error during streaming API call: {e}") finally: - if not stream_prepared: + if not stream_prepared and auth_session is not None: with contextlib.suppress(Exception): auth_session.close()