diff --git a/README.md b/README.md index 3bfe4c7..e3e49d4 100644 --- a/README.md +++ b/README.md @@ -104,13 +104,18 @@ The MCP MariaDB Server provides **optional** embedding and vector store capabili - **OpenAI** - **Gemini** - **Open models from Huggingface** +- **Open models from Ollama** ### Configuration -- `EMBEDDING_PROVIDER`: Set to `openai`, `gemini`, `huggingface`, or leave unset to disable +- `EMBEDDING_PROVIDER`: Set to `openai`, `gemini`, `huggingface`, `ollama` or leave unset to disable - `OPENAI_API_KEY`: Required if using OpenAI embeddings - `GEMINI_API_KEY`: Required if using Gemini embeddings - `HF_MODEL`: Required if using HuggingFace embeddings (e.g., "intfloat/multilingual-e5-large-instruct" or "BAAI/bge-m3") +- `OLLAMA_HOST`: Required if using Ollama embeddings +- `OLLAMA_PORT`: Required if using Ollama embeddings +- `OLLAMA_MODEL`: Required if using Ollama embeddings + ### Model Selection - Default and allowed models are configurable in code (`DEFAULT_OPENAI_MODEL`, `ALLOWED_OPENAI_MODELS`) @@ -130,22 +135,25 @@ A vector store table has the following columns: All configuration is via environment variables (typically set in a `.env` file): -| Variable | Description | Required | Default | -|------------------------|--------------------------------------------------------|----------|--------------| -| `DB_HOST` | MariaDB host address | Yes | `localhost` | -| `DB_PORT` | MariaDB port | No | `3306` | -| `DB_USER` | MariaDB username | Yes | | -| `DB_PASSWORD` | MariaDB password | Yes | | -| `DB_NAME` | Default database (optional; can be set per query) | No | | -| `DB_CHARSET` | Character set for database connection (e.g., `cp1251`) | No | MariaDB default | -| `MCP_READ_ONLY` | Enforce read-only SQL mode (`true`/`false`) | No | `true` | -| `MCP_MAX_POOL_SIZE` | Max DB connection pool size | No | `10` | -| `EMBEDDING_PROVIDER` | Embedding provider (`openai`/`gemini`/`huggingface`) | No |`None`(Disabled)| -| `OPENAI_API_KEY` | API key for OpenAI embeddings | Yes (if EMBEDDING_PROVIDER=openai) | | -| `GEMINI_API_KEY` | API key for Gemini embeddings | Yes (if EMBEDDING_PROVIDER=gemini) | | -| `HF_MODEL` | Open models from Huggingface | Yes (if EMBEDDING_PROVIDER=huggingface) | | -| `ALLOWED_ORIGINS` | Comma-separated list of allowed origins | No | Long list of allowed origins corresponding to local use of the server | -| `ALLOWED_HOSTS` | Comma-separated list of allowed hosts | No | `localhost,127.0.0.1` | +| Variable | Description | Required | Default | +|------------------------|-----------------------------------------------------------------|----------|--------------| +| `DB_HOST` | MariaDB host address | Yes | `localhost` | +| `DB_PORT` | MariaDB port | No | `3306` | +| `DB_USER` | MariaDB username | Yes | | +| `DB_PASSWORD` | MariaDB password | Yes | | +| `DB_NAME` | Default database (optional; can be set per query) | No | | +| `DB_CHARSET` | Character set for database connection (e.g., `cp1251`) | No | MariaDB default | +| `MCP_READ_ONLY` | Enforce read-only SQL mode (`true`/`false`) | No | `true` | +| `MCP_MAX_POOL_SIZE` | Max DB connection pool size | No | `10` | +| `EMBEDDING_PROVIDER` | Embedding provider (`openai`/`gemini`/`huggingface`/`ollama`) | No |`None`(Disabled)| +| `OPENAI_API_KEY` | API key for OpenAI embeddings | Yes (if EMBEDDING_PROVIDER=openai) | | +| `GEMINI_API_KEY` | API key for Gemini embeddings | Yes (if EMBEDDING_PROVIDER=gemini) | | +| `HF_MODEL` | Open models from Huggingface | Yes (if EMBEDDING_PROVIDER=huggingface) | | +| `OLLAMA_HOST` | Ollama host address | Yes (if EMBEDDING_PROVIDER=ollama) | `localhost` | +| `OLLAMA_PORT` | Ollama port | Yes (if EMBEDDING_PROVIDER=ollama) | `11434` | +| `OLLAMA_MODEL` | Open models from Ollama | Yes (if EMBEDDING_PROVIDER=ollama) | | +| `ALLOWED_ORIGINS` | Comma-separated list of allowed origins | No | Long list of allowed origins corresponding to local use of the server | +| `ALLOWED_HOSTS` | Comma-separated list of allowed hosts | No | `localhost,127.0.0.1` | Note that if using 'http' or 'sse' as the transport, configuring authentication is important for security if you allow connections outside of localhost. Because different organizations use different authentication methods, the server does not provide a default authentication method. You will need to configure your own authentication method. Thankfully FastMCP provides a simple way to do this starting with version 2.12.1. See the [FastMCP documentation](https://gofastmcp.com/servers/auth/authentication#environment-configuration) for more information. We have provided an example configuration below. @@ -166,6 +174,9 @@ EMBEDDING_PROVIDER=openai OPENAI_API_KEY=sk-... GEMINI_API_KEY=AI... HF_MODEL="BAAI/bge-m3" +OLLAMA_HOST=localhost +OLLAMA_PORT=11434 +OLLAMA_MODEL="nomic-embed-text" ``` **Without Embedding Support:** diff --git a/pyproject.toml b/pyproject.toml index 4c22b8b..8af75da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "asyncmy>=0.2.10", "fastmcp[websockets]==2.12.1", "google-genai>=1.15.0", + "ollama>=0.6.0", "openai>=1.78.1", "python-dotenv>=1.1.0", "sentence-transformers>=4.1.0", diff --git a/src/config.py b/src/config.py index 85270c2..80460f2 100644 --- a/src/config.py +++ b/src/config.py @@ -79,7 +79,10 @@ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") # Open models from Huggingface HF_MODEL = os.getenv("HF_MODEL") - +# Ollama Configuration +OLLAMA_HOST = os.getenv("OLLAMA_HOST", "localhost") +OLLAMA_PORT = os.getenv("OLLAMA_PORT", 11434) +OLLAMA_MODEL = os.getenv("OLLAMA_MODEL") # --- Validation --- if not all([DB_USER, DB_PASSWORD]): @@ -99,9 +102,13 @@ if not HF_MODEL: logger.error("EMBEDDING_PROVIDER is 'huggingface' but HF_MODEL is missing.") raise ValueError("HuggingFace model is required when EMBEDDING_PROVIDER is 'huggingface'.") +elif EMBEDDING_PROVIDER == "ollama": + if not OLLAMA_MODEL: + logger.error("EMBEDDING_PROVIDER is 'ollama' but OLLAMA_MODEL is missing.") + raise ValueError("Ollama model is required when EMBEDDING_PROVIDER is 'ollama'.") else: EMBEDDING_PROVIDER = None - logger.info(f"No EMBEDDING_PROVIDER selected or it is set to None. Disabling embedding features.") + logger.info("No EMBEDDING_PROVIDER selected or it is set to None. Disabling embedding features.") logger.info(f"Read-only mode: {MCP_READ_ONLY}") logger.info(f"Logging to console and to file: {LOG_FILE_PATH} (Level: {LOG_LEVEL}, MaxSize: {LOG_MAX_BYTES}B, Backups: {LOG_BACKUP_COUNT})") \ No newline at end of file diff --git a/src/embeddings.py b/src/embeddings.py index e16ea1c..d99d843 100644 --- a/src/embeddings.py +++ b/src/embeddings.py @@ -1,9 +1,12 @@ +import json import logging import sys import os import asyncio from typing import List, Optional, Dict, Any, Union, Awaitable import numpy as np +import requests +from tqdm import tqdm # Import configuration variables and the logger instance from config import ( @@ -11,6 +14,9 @@ OPENAI_API_KEY, GEMINI_API_KEY, HF_MODEL, + OLLAMA_HOST, + OLLAMA_PORT, + OLLAMA_MODEL, logger ) @@ -43,6 +49,13 @@ genai = None # type: ignore GoogleAPICoreExceptions = None # type: ignore +# Import Ollama client library +try: + import ollama + logger.info("Successfully imported ollama") +except ImportError as e: + logger.warning(f"Ollama library not installed. Ollama provider will not be available. Error: {e}") + # --- Model Definitions --- # Define allowed models and defaults for each provider # OpenAI Embedding Models @@ -66,6 +79,24 @@ "intfloat/multilingual-e5-large-instruct": 1024, "BAAI/bge-m3": 1024 } +# Open Embedding Models - Ollama +ALLOWED_OLLAMA_MODELS: List[str] = ["nomic-embed-text" , "embeddinggemma", "mxbai-embed-large", + "bge-m3", "all-minilm", "snowflake-arctic-embed", "snowflake-arctic-embed2", + "bge-large", "paraphrase-multilingual", "granite-embedding", "qwen3-embedding"] +DEFAULT_OLLAMA_MODEL: str = "nomic-embed-text" +OLLAMA_MODEL_DIMENSIONS = { + "nomic-embed-text": 768, + "embeddinggemma": 768, + "mxbai-embed-large": 1024, + "bge-m3": 1024, + "all-minilm": 384, + "snowflake-arctic-embed": 1024, + "snowflake-arctic-embed2": 1024, + "bge-large": 1024, + "paraphrase-multilingual": 384, + "granite-embedding": 1024, + "qwen3-embedding:": 1536 +} class EmbeddingService: """ @@ -141,6 +172,46 @@ def __init__(self): logger.error(f"Failed to initialize HuggingFace SentenceTransformer with model '{HF_MODEL}': {e}", exc_info=True) self.huggingface_client = None # Ensure it's None if init fails raise RuntimeError(f"HuggingFace SentenceTransformer (model: {HF_MODEL}) initialization failed: {e}") + elif self.provider == "ollama": + if not OLLAMA_MODEL: # From config.py + logger.error("EMBEDDING_PROVIDER is 'ollama' but OLLAMA_MODEL is missing in config.") + raise ValueError("Ollama model (OLLAMA_MODEL) is required in config for the Ollama provider.") + try: + response = requests.get(f"http://{OLLAMA_HOST}:{OLLAMA_PORT}/api/version", timeout=2) + if response.status_code == 200: # Check Ollama server is running + self.default_model = OLLAMA_MODEL + self.allowed_models = ALLOWED_OLLAMA_MODELS + logger.info(f"Initializing ollama with configured OLLAMA_MODEL: {self.default_model}") + response = requests.post(f"http://{OLLAMA_HOST}:{OLLAMA_PORT}/api/pull", json={"model": OLLAMA_MODEL}, stream=True) # Download the embedding model + # Download ProgressBar + pbar = None + current_total = None + for line in response.iter_lines(): + if not line: + continue + data = json.loads(line.decode("utf-8")) + if "total" in data and data["total"] != current_total: + if pbar: + pbar.close() + current_total = data["total"] + pbar = tqdm(total=data["total"], unit="B", unit_scale=True, dynamic_ncols=True, leave=True) + if pbar and "completed" in data: + pbar.n = data["completed"] + pbar.refresh() + if data.get("status") == "success": + if pbar: + pbar.close() + break + # Initialize ollama Client + self.ollama_client = ollama.Client(host=f"http://{OLLAMA_HOST}:{OLLAMA_PORT}") + logger.info(f"Ollama provider initialized. Default model (from config.OLLAMA_MODEL): '{self.default_model}'. Client loaded.") + else: + logger.error("Ollama server is NOT running. Start it with: 'ollama serve'") + self.ollama_client = None # Ensure it's None if init fails + except Exception as e: + logger.error(f"Failed to initialize ollama with model '{OLLAMA_MODEL}': {e}", exc_info=True) + self.ollama_client = None # Ensure it's None if init fails + raise RuntimeError(f"Ollama (model: {OLLAMA_MODEL}) initialization failed: {e}") else: logger.error(f"Unsupported embedding provider configured: {self.provider}") raise ValueError(f"Unsupported embedding provider: {self.provider}") @@ -342,9 +413,18 @@ async def embed(self, text: Union[str, List[str]], model_name: Optional[str] = N return embeddings_list[0] if embeddings_list and isinstance(embeddings_list, list) and embeddings_list[0] else embeddings_list else: return embeddings_list - else: - logger.error(f"Embed called with unsupported provider: {self.provider}") - raise RuntimeError(f"Unsupported embedding provider: {self.provider}") + elif self.provider == "ollama": + if not self.ollama_client: + logger.critical("Ollama client not properly initialized.") + raise RuntimeError("Ollama client not initialized.") + + embeddings = [] + for text in texts: + response = self.ollama_client.embeddings(model=OLLAMA_MODEL, prompt=text) + embeddings.append(response["embedding"]) + + logger.debug(f"Ollama embedding(s) received. Count: {len(embeddings)}, Dimension: {len(embeddings[0]) if embeddings else 'N/A'}") + return embeddings[0] if single_input else embeddings except OpenAIError as e: logger.error(f"OpenAI API error during embedding: {e}", exc_info=True) diff --git a/src/tests/test_embedding_service.py b/src/tests/test_embedding_service.py index b86dcd7..3edffd2 100644 --- a/src/tests/test_embedding_service.py +++ b/src/tests/test_embedding_service.py @@ -42,5 +42,17 @@ def test_gemini_init_and_embed(self): self.assertIsInstance(result, list) self.assertEqual(len(result), 768) +class TestEmbeddingServiceOllama(unittest.TestCase): + @patch("embeddings.EMBEDDING_PROVIDER", "ollama") + def test_ollama_init_and_embed(self): + service = EmbeddingService() + self.assertEqual(service.provider, "ollama") + self.assertIn("nomic-embed-text", service.allowed_models) + self.assertEqual(service.default_model, "nomic-embed-text") + # Test embed + result = asyncio.run(service.embed("hello world")) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 768) + if __name__ == "__main__": unittest.main()