From 5a2f4d796ca0f670771bc56f7b750a0478231093 Mon Sep 17 00:00:00 2001 From: FallenDeity <61227305+FallenDeity@users.noreply.github.com> Date: Sun, 24 Aug 2025 14:46:37 +0530 Subject: [PATCH] feat: Migrate to new unified google-genai SDK and switch to lash-2.5 models --- requirements.txt | 2 +- tests/test_llm_api.py | 67 ++++++++++++++++++++++++++++++------------- tools/llm_api.py | 51 ++++++++++++++++++-------------- 3 files changed, 77 insertions(+), 43 deletions(-) diff --git a/requirements.txt b/requirements.txt index 92ab6d7..ada9dbe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ pytest>=8.0.0 pytest-asyncio>=0.23.5 # Google Generative AI -google-generativeai +google-genai # gRPC, for Google Generative AI preventing WARNING: All log messages before absl::InitializeLog() is called are written to STDERR grpcio==1.70.0 diff --git a/tests/test_llm_api.py b/tests/test_llm_api.py index c0104d1..fb86f96 100644 --- a/tests/test_llm_api.py +++ b/tests/test_llm_api.py @@ -1,10 +1,8 @@ import unittest -from unittest.mock import patch, MagicMock, mock_open +from unittest.mock import patch, MagicMock from tools.llm_api import create_llm_client, query_llm, load_environment import os -import google.generativeai as genai import io -import sys def is_llm_configured(): """Check if LLM is configured by trying to connect to the server""" @@ -105,17 +103,20 @@ def setUp(self): self.mock_anthropic_response.content = [self.mock_anthropic_content] self.mock_anthropic_client.messages.create.return_value = self.mock_anthropic_response - # Set up Gemini-style response - Updated for Chat Session - self.mock_gemini_chat_session = MagicMock() # Mock for the chat session + # Set up Gemini-style response - Updated for new genai.Client API + self.mock_gemini_chat_session = MagicMock() self.mock_gemini_response = MagicMock() self.mock_gemini_response.text = "Test Gemini response" - self.mock_gemini_chat_session.send_message.return_value = self.mock_gemini_response # Mock send_message + self.mock_gemini_chat_session.send_message.return_value = self.mock_gemini_response - self.mock_gemini_model = MagicMock() # Mock for the GenerativeModel - self.mock_gemini_model.start_chat.return_value = self.mock_gemini_chat_session # Mock start_chat + self.mock_gemini_client = MagicMock() + self.mock_gemini_client.chats.create.return_value = self.mock_gemini_chat_session - self.mock_gemini_client = MagicMock() # Mock for the genai module itself - self.mock_gemini_client.GenerativeModel.return_value = self.mock_gemini_model + # Set up file upload mock for image testing + self.mock_gemini_file = MagicMock() + self.mock_gemini_file.uri = "test-file-uri" + self.mock_gemini_file.mime_type = "image/png" + self.mock_gemini_client.files.upload.return_value = self.mock_gemini_file # Set up SiliconFlow-style response self.mock_siliconflow_response = MagicMock() @@ -205,11 +206,12 @@ def test_create_anthropic_client(self, mock_anthropic): self.assertEqual(client, self.mock_anthropic_client) @unittest.skipIf(skip_llm_tests, skip_message) - @patch('tools.llm_api.genai') - def test_create_gemini_client(self, mock_genai): + @patch('tools.llm_api.genai.Client') + def test_create_gemini_client(self, mock_genai_client): + mock_genai_client.return_value = self.mock_gemini_client client = create_llm_client("gemini") - mock_genai.configure.assert_called_once_with(api_key='test-google-key') - self.assertEqual(client, mock_genai) + mock_genai_client.assert_called_once_with(api_key='test-google-key') + self.assertEqual(client, self.mock_gemini_client) @unittest.skipIf(skip_llm_tests, skip_message) @patch('tools.llm_api.OpenAI') @@ -292,15 +294,40 @@ def test_query_anthropic(self, mock_create_client): @unittest.skipIf(skip_llm_tests, skip_message) @patch('tools.llm_api.create_llm_client') def test_query_gemini(self, mock_create_client): - mock_create_client.return_value = self.mock_gemini_client # Use the updated mock from setUp + mock_create_client.return_value = self.mock_gemini_client response = query_llm("Test prompt", provider="gemini") self.assertEqual(response, "Test Gemini response") - # Update assertions to check chat flow - self.mock_gemini_client.GenerativeModel.assert_called_once_with("gemini-2.0-flash-exp") - self.mock_gemini_model.start_chat.assert_called_once_with( - history=[{'role': 'user', 'parts': ["Test prompt"]}] + # Verify the new genai.Client API calls + self.mock_gemini_client.chats.create.assert_called_once_with(model="gemini-2.5-flash") + self.mock_gemini_chat_session.send_message.assert_called_once_with(message="Test prompt") + + @unittest.skipIf(skip_llm_tests, skip_message) + @patch('tools.llm_api.create_llm_client') + @patch('tools.llm_api.encode_image_file') + @patch('tools.llm_api.genai') + def test_query_gemini_with_image(self, mock_genai, mock_encode_image, mock_create_client): + # Setup mocks + mock_create_client.return_value = self.mock_gemini_client + mock_encode_image.return_value = ("base64_data", "image/png") + + # Test query with image + response = query_llm("Describe this image", provider="gemini", image_path="test_image.png") + self.assertEqual(response, "Test Gemini response") + + # Verify file upload was called + self.mock_gemini_client.files.upload.assert_called_once_with( + file="test_image.png", + config=mock_genai.types.UploadFileConfig(mime_type="image/png") ) - self.mock_gemini_chat_session.send_message.assert_called_once_with("Test prompt") + + # Verify chat session was created with history + self.mock_gemini_client.chats.create.assert_called_once() + create_call_args = self.mock_gemini_client.chats.create.call_args + self.assertEqual(create_call_args[1]["model"], "gemini-2.5-flash") + self.assertIn("history", create_call_args[1]) + + # Verify send_message was called + self.mock_gemini_chat_session.send_message.assert_called_once_with(message="Describe this image") @unittest.skipIf(skip_llm_tests, skip_message) @patch('tools.llm_api.create_llm_client') diff --git a/tools/llm_api.py b/tools/llm_api.py index 4f70eb1..d59ecb4 100644 --- a/tools/llm_api.py +++ b/tools/llm_api.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -import google.generativeai as genai +from google import genai from openai import OpenAI, AzureOpenAI from anthropic import Anthropic import argparse @@ -9,9 +9,13 @@ from pathlib import Path import sys import base64 -from typing import Optional, Union, List +from typing import Any, Literal, Optional, cast import mimetypes + +Provider = Literal["openai", "azure", "deepseek", "siliconflow", "anthropic", "gemini", "local"] + + def load_environment(): """Load environment variables from .env files in order of precedence""" # Order of precedence: @@ -65,7 +69,7 @@ def encode_image_file(image_path: str) -> tuple[str, str]: return encoded_string, mime_type -def create_llm_client(provider="openai"): +def create_llm_client(provider: Provider = "openai"): if provider == "openai": api_key = os.getenv('OPENAI_API_KEY') base_url = os.getenv('OPENAI_BASE_URL', "https://api.openai.com/v1") @@ -111,8 +115,7 @@ def create_llm_client(provider="openai"): api_key = os.getenv('GOOGLE_API_KEY') if not api_key: raise ValueError("GOOGLE_API_KEY not found in environment variables") - genai.configure(api_key=api_key) - return genai + return genai.Client(api_key=api_key) elif provider == "local": return OpenAI( base_url="http://192.168.180.137:8006/v1", @@ -121,7 +124,7 @@ def create_llm_client(provider="openai"): else: raise ValueError(f"Unsupported provider: {provider}") -def query_llm(prompt: str, client=None, model=None, provider="openai", image_path: Optional[str] = None) -> Optional[str]: +def query_llm(prompt: str, client: Optional[Any] = None, model: Optional[str] = None, provider: Provider ="openai", image_path: Optional[str] = None) -> Optional[str]: """ Query an LLM with a prompt and optional image attachment. @@ -152,7 +155,7 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat elif provider == "anthropic": model = "claude-3-7-sonnet-20250219" elif provider == "gemini": - model = "gemini-2.0-flash-exp" + model = "gemini-2.5-flash" elif provider == "local": model = "Qwen/Qwen2.5-32B-Instruct-AWQ" @@ -218,23 +221,27 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat return response.content[0].text elif provider == "gemini": - model = client.GenerativeModel(model) + gemini_client = cast(genai.Client, client) + if image_path: - file = genai.upload_file(image_path, mime_type="image/png") - chat_session = model.start_chat( - history=[{ - "role": "user", - "parts": [file, prompt] - }] + file = gemini_client.files.upload( + file=image_path, + config=genai.types.UploadFileConfig(mime_type="image/png") ) - else: - chat_session = model.start_chat( - history=[{ - "role": "user", - "parts": [prompt] - }] + chat_session = gemini_client.chats.create( + model=model, + history=[ + genai.types.Content( + role="user", + parts=[ + genai.types.Part.from_uri(file_uri=str(file.uri), mime_type=file.mime_type), + ] + ) + ] ) - response = chat_session.send_message(prompt) + else: + chat_session = gemini_client.chats.create(model=model) + response = chat_session.send_message(message=prompt) return response.text except Exception as e: @@ -259,7 +266,7 @@ def main(): elif args.provider == 'anthropic': args.model = "claude-3-7-sonnet-20250219" elif args.provider == 'gemini': - args.model = "gemini-2.0-flash-exp" + args.model = "gemini-2.5-flash" elif args.provider == 'azure': args.model = os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms') # Get from env with fallback