Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pytest>=8.0.0
pytest-asyncio>=0.23.5

# Google Generative AI
google-generativeai
google-genai

# gRPC, for Google Generative AI preventing WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
grpcio==1.70.0
67 changes: 47 additions & 20 deletions tests/test_llm_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import unittest
from unittest.mock import patch, MagicMock, mock_open
from unittest.mock import patch, MagicMock
from tools.llm_api import create_llm_client, query_llm, load_environment
import os
import google.generativeai as genai
import io
import sys

def is_llm_configured():
"""Check if LLM is configured by trying to connect to the server"""
Expand Down Expand Up @@ -105,17 +103,20 @@ def setUp(self):
self.mock_anthropic_response.content = [self.mock_anthropic_content]
self.mock_anthropic_client.messages.create.return_value = self.mock_anthropic_response

# Set up Gemini-style response - Updated for Chat Session
self.mock_gemini_chat_session = MagicMock() # Mock for the chat session
# Set up Gemini-style response - Updated for new genai.Client API
self.mock_gemini_chat_session = MagicMock()
self.mock_gemini_response = MagicMock()
self.mock_gemini_response.text = "Test Gemini response"
self.mock_gemini_chat_session.send_message.return_value = self.mock_gemini_response # Mock send_message
self.mock_gemini_chat_session.send_message.return_value = self.mock_gemini_response

self.mock_gemini_model = MagicMock() # Mock for the GenerativeModel
self.mock_gemini_model.start_chat.return_value = self.mock_gemini_chat_session # Mock start_chat
self.mock_gemini_client = MagicMock()
self.mock_gemini_client.chats.create.return_value = self.mock_gemini_chat_session

self.mock_gemini_client = MagicMock() # Mock for the genai module itself
self.mock_gemini_client.GenerativeModel.return_value = self.mock_gemini_model
# Set up file upload mock for image testing
self.mock_gemini_file = MagicMock()
self.mock_gemini_file.uri = "test-file-uri"
self.mock_gemini_file.mime_type = "image/png"
self.mock_gemini_client.files.upload.return_value = self.mock_gemini_file

# Set up SiliconFlow-style response
self.mock_siliconflow_response = MagicMock()
Expand Down Expand Up @@ -205,11 +206,12 @@ def test_create_anthropic_client(self, mock_anthropic):
self.assertEqual(client, self.mock_anthropic_client)

@unittest.skipIf(skip_llm_tests, skip_message)
@patch('tools.llm_api.genai')
def test_create_gemini_client(self, mock_genai):
@patch('tools.llm_api.genai.Client')
def test_create_gemini_client(self, mock_genai_client):
mock_genai_client.return_value = self.mock_gemini_client
client = create_llm_client("gemini")
mock_genai.configure.assert_called_once_with(api_key='test-google-key')
self.assertEqual(client, mock_genai)
mock_genai_client.assert_called_once_with(api_key='test-google-key')
self.assertEqual(client, self.mock_gemini_client)

@unittest.skipIf(skip_llm_tests, skip_message)
@patch('tools.llm_api.OpenAI')
Expand Down Expand Up @@ -292,15 +294,40 @@ def test_query_anthropic(self, mock_create_client):
@unittest.skipIf(skip_llm_tests, skip_message)
@patch('tools.llm_api.create_llm_client')
def test_query_gemini(self, mock_create_client):
mock_create_client.return_value = self.mock_gemini_client # Use the updated mock from setUp
mock_create_client.return_value = self.mock_gemini_client
response = query_llm("Test prompt", provider="gemini")
self.assertEqual(response, "Test Gemini response")
# Update assertions to check chat flow
self.mock_gemini_client.GenerativeModel.assert_called_once_with("gemini-2.0-flash-exp")
self.mock_gemini_model.start_chat.assert_called_once_with(
history=[{'role': 'user', 'parts': ["Test prompt"]}]
# Verify the new genai.Client API calls
self.mock_gemini_client.chats.create.assert_called_once_with(model="gemini-2.5-flash")
self.mock_gemini_chat_session.send_message.assert_called_once_with(message="Test prompt")

@unittest.skipIf(skip_llm_tests, skip_message)
@patch('tools.llm_api.create_llm_client')
@patch('tools.llm_api.encode_image_file')
@patch('tools.llm_api.genai')
def test_query_gemini_with_image(self, mock_genai, mock_encode_image, mock_create_client):
# Setup mocks
mock_create_client.return_value = self.mock_gemini_client
mock_encode_image.return_value = ("base64_data", "image/png")

# Test query with image
response = query_llm("Describe this image", provider="gemini", image_path="test_image.png")
self.assertEqual(response, "Test Gemini response")

# Verify file upload was called
self.mock_gemini_client.files.upload.assert_called_once_with(
file="test_image.png",
config=mock_genai.types.UploadFileConfig(mime_type="image/png")
)
self.mock_gemini_chat_session.send_message.assert_called_once_with("Test prompt")

# Verify chat session was created with history
self.mock_gemini_client.chats.create.assert_called_once()
create_call_args = self.mock_gemini_client.chats.create.call_args
self.assertEqual(create_call_args[1]["model"], "gemini-2.5-flash")
self.assertIn("history", create_call_args[1])

# Verify send_message was called
self.mock_gemini_chat_session.send_message.assert_called_once_with(message="Describe this image")

@unittest.skipIf(skip_llm_tests, skip_message)
@patch('tools.llm_api.create_llm_client')
Expand Down
51 changes: 29 additions & 22 deletions tools/llm_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

import google.generativeai as genai
from google import genai
from openai import OpenAI, AzureOpenAI
from anthropic import Anthropic
import argparse
Expand All @@ -9,9 +9,13 @@
from pathlib import Path
import sys
import base64
from typing import Optional, Union, List
from typing import Any, Literal, Optional, cast
import mimetypes


Provider = Literal["openai", "azure", "deepseek", "siliconflow", "anthropic", "gemini", "local"]


def load_environment():
"""Load environment variables from .env files in order of precedence"""
# Order of precedence:
Expand Down Expand Up @@ -65,7 +69,7 @@ def encode_image_file(image_path: str) -> tuple[str, str]:

return encoded_string, mime_type

def create_llm_client(provider="openai"):
def create_llm_client(provider: Provider = "openai"):
if provider == "openai":
api_key = os.getenv('OPENAI_API_KEY')
base_url = os.getenv('OPENAI_BASE_URL', "https://api.openai.com/v1")
Expand Down Expand Up @@ -111,8 +115,7 @@ def create_llm_client(provider="openai"):
api_key = os.getenv('GOOGLE_API_KEY')
if not api_key:
raise ValueError("GOOGLE_API_KEY not found in environment variables")
genai.configure(api_key=api_key)
return genai
return genai.Client(api_key=api_key)
elif provider == "local":
return OpenAI(
base_url="http://192.168.180.137:8006/v1",
Expand All @@ -121,7 +124,7 @@ def create_llm_client(provider="openai"):
else:
raise ValueError(f"Unsupported provider: {provider}")

def query_llm(prompt: str, client=None, model=None, provider="openai", image_path: Optional[str] = None) -> Optional[str]:
def query_llm(prompt: str, client: Optional[Any] = None, model: Optional[str] = None, provider: Provider ="openai", image_path: Optional[str] = None) -> Optional[str]:
"""
Query an LLM with a prompt and optional image attachment.

Expand Down Expand Up @@ -152,7 +155,7 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat
elif provider == "anthropic":
model = "claude-3-7-sonnet-20250219"
elif provider == "gemini":
model = "gemini-2.0-flash-exp"
model = "gemini-2.5-flash"
elif provider == "local":
model = "Qwen/Qwen2.5-32B-Instruct-AWQ"

Expand Down Expand Up @@ -218,23 +221,27 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat
return response.content[0].text

elif provider == "gemini":
model = client.GenerativeModel(model)
gemini_client = cast(genai.Client, client)

if image_path:
file = genai.upload_file(image_path, mime_type="image/png")
chat_session = model.start_chat(
history=[{
"role": "user",
"parts": [file, prompt]
}]
file = gemini_client.files.upload(
file=image_path,
config=genai.types.UploadFileConfig(mime_type="image/png")
)
else:
chat_session = model.start_chat(
history=[{
"role": "user",
"parts": [prompt]
}]
chat_session = gemini_client.chats.create(
model=model,
history=[
genai.types.Content(
role="user",
parts=[
genai.types.Part.from_uri(file_uri=str(file.uri), mime_type=file.mime_type),
]
)
]
)
response = chat_session.send_message(prompt)
else:
chat_session = gemini_client.chats.create(model=model)
response = chat_session.send_message(message=prompt)
return response.text

except Exception as e:
Expand All @@ -259,7 +266,7 @@ def main():
elif args.provider == 'anthropic':
args.model = "claude-3-7-sonnet-20250219"
elif args.provider == 'gemini':
args.model = "gemini-2.0-flash-exp"
args.model = "gemini-2.5-flash"
elif args.provider == 'azure':
args.model = os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms') # Get from env with fallback

Expand Down