Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Next

### Added

- Added automatic rate limiting with retry logic and exponential backoff for all Embedding providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely.

## 1.10.0

### Added
Expand Down
31 changes: 31 additions & 0 deletions docs/source/user_guide_rag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,37 @@ The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `Sente

If another embedder is desired, a custom embedder can be created, using the `Embedder` interface.

Embedder Rate Limiting
----------------------

All embedder implementations include automatic rate limiting that uses retry logic with exponential backoff by default, similar to LLM implementations. This feature helps handle API rate limits from embedding providers gracefully.

.. code:: python

from neo4j_graphrag.embeddings import OpenAIEmbeddings
from neo4j_graphrag.llm.rate_limit import RetryRateLimitHandler, NoOpRateLimitHandler

# Default rate limiting (automatically enabled)
embedder = OpenAIEmbeddings(model="text-embedding-3-large")

# Custom rate limiting configuration
embedder = OpenAIEmbeddings(
model="text-embedding-3-large",
rate_limit_handler=RetryRateLimitHandler(
max_attempts=5,
min_wait=2.0,
max_wait=120.0
)
)

# Disable rate limiting
embedder = OpenAIEmbeddings(
model="text-embedding-3-large",
rate_limit_handler=NoOpRateLimitHandler()
)

The rate limiting configuration works the same way as for LLMs. See the :ref:`Rate Limit Handling <Rate Limit Handling>` section above for more details on customization options.


Other Vector Retriever Configuration
----------------------------------------
Expand Down
15 changes: 15 additions & 0 deletions src/neo4j_graphrag/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,29 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional

from neo4j_graphrag.llm.rate_limit import (
DEFAULT_RATE_LIMIT_HANDLER,
RateLimitHandler,
)


class Embedder(ABC):
"""
Interface for embedding models.
An embedder passed into a retriever must implement this interface.
Args:
rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff.
"""

def __init__(self, rate_limit_handler: Optional[RateLimitHandler] = None):
if rate_limit_handler is not None:
self._rate_limit_handler = rate_limit_handler
else:
self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER

@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query text.
Expand Down
30 changes: 22 additions & 8 deletions src/neo4j_graphrag/embeddings/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# limitations under the License.
from __future__ import annotations

from typing import Any
from typing import Any, Optional

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler

try:
import cohere
Expand All @@ -25,19 +27,31 @@


class CohereEmbeddings(Embedder):
def __init__(self, model: str = "", **kwargs: Any) -> None:
def __init__(
self,
model: str = "",
rate_limit_handler: Optional[RateLimitHandler] = None,
**kwargs: Any,
) -> None:
if cohere is None:
raise ImportError(
"""Could not import cohere python client.
Please install it with `pip install "neo4j-graphrag[cohere]"`."""
)
super().__init__(rate_limit_handler)
self.model = model
self.client = cohere.Client(**kwargs)

@rate_limit_handler
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
response = self.client.embed(
texts=[text],
model=self.model,
**kwargs,
)
return response.embeddings[0] # type: ignore
try:
response = self.client.embed(
texts=[text],
model=self.model,
**kwargs,
)
return response.embeddings[0] # type: ignore
except Exception as e:
raise EmbeddingsGenerationError(
f"Failed to generate embedding with Cohere: {e}"
) from e
12 changes: 10 additions & 2 deletions src/neo4j_graphrag/embeddings/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from __future__ import annotations

import os
from typing import Any
from typing import Any, Optional

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler

try:
from mistralai import Mistral
Expand All @@ -36,18 +37,25 @@ class MistralAIEmbeddings(Embedder):
model (str): The name of the Mistral AI text embedding model to use. Defaults to "mistral-embed".
"""

def __init__(self, model: str = "mistral-embed", **kwargs: Any) -> None:
def __init__(
self,
model: str = "mistral-embed",
rate_limit_handler: Optional[RateLimitHandler] = None,
**kwargs: Any,
) -> None:
if Mistral is None:
raise ImportError(
"""Could not import mistralai.
Please install it with `pip install "neo4j-graphrag[mistralai]"`."""
)
super().__init__(rate_limit_handler)
api_key = kwargs.pop("api_key", None)
if api_key is None:
api_key = os.getenv("MISTRAL_API_KEY", "")
self.model = model
self.mistral_client = Mistral(api_key=api_key, **kwargs)

@rate_limit_handler
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate embeddings for a given query using a Mistral AI text embedding model.
Expand Down
12 changes: 10 additions & 2 deletions src/neo4j_graphrag/embeddings/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

from __future__ import annotations

from typing import Any
from typing import Any, Optional

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler


class OllamaEmbeddings(Embedder):
Expand All @@ -30,17 +31,24 @@ class OllamaEmbeddings(Embedder):
model (str): The name of the Mistral AI text embedding model to use. Defaults to "mistral-embed".
"""

def __init__(self, model: str, **kwargs: Any) -> None:
def __init__(
self,
model: str,
rate_limit_handler: Optional[RateLimitHandler] = None,
**kwargs: Any,
) -> None:
try:
import ollama
except ImportError:
raise ImportError(
"""Could not import ollama python client.
Please install it with `pip install "neo4j_graphrag[ollama]"`."""
)
super().__init__(rate_limit_handler)
self.model = model
self.client = ollama.Client(**kwargs)

@rate_limit_handler
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate embeddings for a given query using an Ollama text embedding model.
Expand Down
26 changes: 21 additions & 5 deletions src/neo4j_graphrag/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler

if TYPE_CHECKING:
import openai
Expand All @@ -31,14 +33,20 @@ class BaseOpenAIEmbeddings(Embedder, abc.ABC):

client: openai.OpenAI

def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
def __init__(
self,
model: str = "text-embedding-ada-002",
rate_limit_handler: Optional[RateLimitHandler] = None,
**kwargs: Any,
) -> None:
try:
import openai
except ImportError:
raise ImportError(
"""Could not import openai python client.
Please install it with `pip install "neo4j-graphrag[openai]"`."""
)
super().__init__(rate_limit_handler)
self.openai = openai
self.model = model
self.client = self._initialize_client(**kwargs)
Expand All @@ -51,6 +59,7 @@ def _initialize_client(self, **kwargs: Any) -> Any:
"""
pass

@rate_limit_handler
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate embeddings for a given query using an OpenAI text embedding model.
Expand All @@ -59,9 +68,16 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]:
text (str): The text to generate an embedding for.
**kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function.
"""
response = self.client.embeddings.create(input=text, model=self.model, **kwargs)
embedding: list[float] = response.data[0].embedding
return embedding
try:
response = self.client.embeddings.create(
input=text, model=self.model, **kwargs
)
embedding: list[float] = response.data[0].embedding
return embedding
except Exception as e:
raise EmbeddingsGenerationError(
f"Failed to generate embedding with OpenAI: {e}"
) from e


class OpenAIEmbeddings(BaseOpenAIEmbeddings):
Expand Down
38 changes: 27 additions & 11 deletions src/neo4j_graphrag/embeddings/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from typing import Any, Optional

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler


class SentenceTransformerEmbeddings(Embedder):
def __init__(
self, model: str = "all-MiniLM-L6-v2", *args: Any, **kwargs: Any
self,
model: str = "all-MiniLM-L6-v2",
rate_limit_handler: Optional[RateLimitHandler] = None,
*args: Any,
**kwargs: Any,
) -> None:
try:
import numpy as np
Expand All @@ -31,17 +37,27 @@ def __init__(
"""Could not import sentence_transformers python package.
Please install it with `pip install "neo4j-graphrag[sentence-transformers]"`."""
)
super().__init__(rate_limit_handler)
self.torch = torch
self.np = np
self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs)

@rate_limit_handler
def embed_query(self, text: str) -> Any:
result = self.model.encode([text])
if isinstance(result, self.torch.Tensor) or isinstance(result, self.np.ndarray):
return result.flatten().tolist()
elif isinstance(result, list) and all(
isinstance(x, self.torch.Tensor) for x in result
):
return [item for tensor in result for item in tensor.flatten().tolist()]
else:
raise ValueError("Unexpected return type from model encoding")
try:
result = self.model.encode([text])

if isinstance(result, self.torch.Tensor) or isinstance(
result, self.np.ndarray
):
return result.flatten().tolist()
elif isinstance(result, list) and all(
isinstance(x, self.torch.Tensor) for x in result
):
return [item for tensor in result for item in tensor.flatten().tolist()]
else:
raise ValueError("Unexpected return type from model encoding")
except Exception as e:
raise EmbeddingsGenerationError(
"Failed to generate embedding with SentenceTransformer"
) from e
27 changes: 21 additions & 6 deletions src/neo4j_graphrag/embeddings/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler

try:
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
Expand All @@ -37,14 +39,20 @@ class VertexAIEmbeddings(Embedder):
model (str): The name of the Vertex AI text embedding model to use. Defaults to "text-embedding-004".
"""

def __init__(self, model: str = "text-embedding-004") -> None:
def __init__(
self,
model: str = "text-embedding-004",
rate_limit_handler: Optional[RateLimitHandler] = None,
) -> None:
if TextEmbeddingModel is None:
raise ImportError(
"""Could not import Vertex AI Python client.
Please install it with `pip install "neo4j-graphrag[google]"`."""
)
super().__init__(rate_limit_handler)
self.model = TextEmbeddingModel.from_pretrained(model)

@rate_limit_handler
def embed_query(
self, text: str, task_type: str = "RETRIEVAL_QUERY", **kwargs: Any
) -> list[float]:
Expand All @@ -56,7 +64,14 @@ def embed_query(
task_type (str): The type of the text embedding task. Defaults to "RETRIEVAL_QUERY". See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#tasktype for a full list.
**kwargs (Any): Additional keyword arguments to pass to the Vertex AI client's get_embeddings method.
"""
# type annotation needed for mypy
inputs: list[str | TextEmbeddingInput] = [TextEmbeddingInput(text, task_type)]
embeddings = self.model.get_embeddings(inputs, **kwargs)
return embeddings[0].values
try:
# type annotation needed for mypy
inputs: list[str | TextEmbeddingInput] = [
TextEmbeddingInput(text, task_type)
]
embeddings = self.model.get_embeddings(inputs, **kwargs)
return list(embeddings[0].values)
except Exception as e:
raise EmbeddingsGenerationError(
f"Failed to generate embedding with VertexAI: {e}"
) from e
Loading