Skip to content
100 changes: 93 additions & 7 deletions libs/knowledge-store/ragstack_knowledge_store/embedding_model.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,108 @@
from abc import ABC, abstractmethod
from typing import List

from abc import ABC
from typing import List, Any, Optional
from collections import defaultdict

class EmbeddingModel(ABC):
"""Embedding model."""

@abstractmethod
def __init__(self, embeddings: Any, method_map: Optional[dict] = None, other_methods: Optional[List[str]] = None):
self.embeddings = embeddings
self.method_name = {}
method_map = method_map if method_map else {}
other_methods = other_methods if other_methods else []

base_methods = ['embed_texts', 'aembed_texts', 'embed_query', 'aembed_query']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should try to add all of these as methods, it's definitely pretty messy.

I think we should just have embed_mime(self, mime_type: str, content: Union[str, Bytes]) or something like that. Then there is only a single abstract method to use for any mime type and the names can be different, etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% but right now, LangChain doesn't have "embed_mime" :)

extended_methods = ['embed_images', 'aembed_images', 'embed_image', 'aembed_image']

# Combining all method names, including those mapped
all_methods = set(base_methods + extended_methods + other_methods + list(method_map.values()))

for method in all_methods:
mapped_method = method_map.get(method)
if hasattr(embeddings, method):
self.method_name[method] = method
elif hasattr(embeddings, mapped_method) if mapped_method else False:
self.method_name[method] = mapped_method
else:
self.method_name[method] = None

def does_implement(self, method_name: str) -> bool:
"""Check if the method is implemented."""
return self.method_name.get(method_name) is not None

def implements(self) -> List[str]:
"""List of methods that are implemented"""
return [method for method, impl in self.method_name.items() if impl is not None]

def invoke(self, method_name: str, *args, **kwargs):
"""Invoke a synchronous method if it's implemented."""
target_method = self.method_name.get(method_name)
if target_method and hasattr(self.embeddings, target_method):
return getattr(self.embeddings, target_method)(*args, **kwargs)
else:
raise NotImplementedError(f"{self.embeddings.__class__.__name__} does not implement {target_method}")

async def ainvoke(self, method_name: str, *args, **kwargs):
"""Invoke an asynchronous method if it's implemented."""
target_method = self.method_name.get(method_name)
if target_method and hasattr(self.embeddings, target_method):
return await getattr(self.embeddings, target_method)(*args, **kwargs)
else:
raise NotImplementedError(f"{self.embeddings.__class__.__name__} does not implement {target_method}")

def embed_mimes(self, texts: List[str], mime_types: List[str]) -> List[List[float]]:
"""Embed mime content."""

# Extract main MIME types
main_mime_types = [mime_type.split('/')[0] for mime_type in mime_types]

# Group texts by main MIME types
grouped_texts = defaultdict(list)
index_mapping = defaultdict(list)
for index, (text, main_mime_type) in enumerate(zip(texts, main_mime_types)):
grouped_texts[main_mime_type].append(text)
index_mapping[main_mime_type].append(index)

# Initialize result list with None to preserve order
embeddings = [None] * len(texts)

# Process each MIME type group
for mime_type, group_texts in grouped_texts.items():
method_name = f"embed_{mime_type}s"
if self.does_implement(method_name):
# Bulk embedding method exists
group_embeddings = self.invoke(method_name, group_texts)
for idx, emb in zip(index_mapping[mime_type], group_embeddings):
embeddings[idx] = emb
else:
# No bulk method, fall back to individual methods
singular_method_name = f"embed_{mime_type}"
for text, idx in zip(group_texts, index_mapping[mime_type]):
if self.does_implement(singular_method_name):
embedding = self.invoke(singular_method_name, text)
embeddings[idx] = embedding
else:
raise NotImplementedError(f"No embedding method available for MIME type: {mime_type}, implemented methods: {self.implements()}.")

# Ensure all embeddings are computed
if None in embeddings:
raise ValueError("Some embeddings were not computed correctly.")

return embeddings

def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed texts."""
return self.invoke('embed_texts', texts)

@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self.invoke('embed_query', text)

@abstractmethod
async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed texts."""
return await self.ainvoke('aembed_texts', texts)

@abstractmethod
async def aembed_query(self, text: str) -> List[float]:
"""Embed query text."""
return await self.ainvoke('aembed_query', text)

10 changes: 9 additions & 1 deletion libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,19 @@ class Node:

id: Optional[str] = None
"""Unique ID for the node. Will be generated by the GraphStore if not set."""

text: str = None
"""Text contained by the node."""
metadata: dict = field(default_factory=dict)
"""Metadata for the node."""
links: Set[Link] = field(default_factory=set)
"""Links for the node."""

mime_type: str = "text/plain"
"""Type of content, e.g. text/plain or image/png."""

mime_encoding: str = None
"""Encoding format"""

class SetupMode(Enum):
SYNC = 1
Expand Down Expand Up @@ -338,6 +344,7 @@ def add_nodes(
node_ids = []
texts = []
metadatas = []
mime_types = []
nodes_links: List[Set[Link]] = []
for node in nodes:
if not node.id:
Expand All @@ -346,9 +353,10 @@ def add_nodes(
node_ids.append(node.id)
texts.append(node.text)
metadatas.append(node.metadata)
mime_types.append(node.mime_type)
nodes_links.append(node.links)

text_embeddings = self._embedding.embed_texts(texts)
text_embeddings = self._embedding.embed_mimes(texts,mime_types)

with self._concurrent_queries() as cq:
tuples = zip(node_ids, texts, text_embeddings, metadatas, nodes_links)
Expand Down
16 changes: 14 additions & 2 deletions libs/langchain/ragstack_langchain/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from langchain_core.pydantic_v1 import Field

from ragstack_langchain.graph_store.links import METADATA_LINKS_KEY, Link


def _has_next(iterator: Iterator) -> bool:
"""Checks if the iterator has more elements.
Warning: consumes an element from the iterator"""
Expand All @@ -41,13 +39,21 @@ class Node(Serializable):

id: Optional[str] = None
"""Unique ID for the node. Will be generated by the GraphStore if not set."""

text: str
"""Text contained by the node."""

metadata: dict = Field(default_factory=dict)
"""Metadata for the node."""

links: Set[Link] = Field(default_factory=set)
"""Links associated with the node."""

mime_type: str = "text/plain"
"""Type of content, e.g. text/plain or image/png."""

mime_encoding: str = None
"""Encoding format"""

def _texts_to_nodes(
texts: Iterable[str],
Expand All @@ -67,13 +73,15 @@ def _texts_to_nodes(
except StopIteration:
raise ValueError("texts iterable longer than ids")

mime_type = _metadata.get("mime_type", "text/plain")
links = _metadata.pop(METADATA_LINKS_KEY, set())
if not isinstance(links, Set):
links = set(links)
yield Node(
id=_id,
metadata=_metadata,
text=text,
mime_type=mime_type,
links=links,
)
if ids_it and _has_next(ids_it):
Expand All @@ -94,12 +102,16 @@ def _documents_to_nodes(
raise ValueError("documents iterable longer than ids")
metadata = doc.metadata.copy()
links = metadata.pop(METADATA_LINKS_KEY, set())
mime_type = metadata.get("mime_type","text/plain")
mime_encoding = metadata.get("mime_encoding")
if not isinstance(links, Set):
links = set(links)
yield Node(
id=_id,
metadata=metadata,
text=doc.page_content,
mime_type=mime_type,
mime_encoding=mime_encoding,
links=links,
)
if ids_it and _has_next(ids_it):
Expand Down
29 changes: 5 additions & 24 deletions libs/langchain/ragstack_langchain/graph_store/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,8 @@
from langchain_core.embeddings import Embeddings

from .base import GraphStore, Node, nodes_to_documents
from ragstack_knowledge_store import EmbeddingModel, graph_store


class _EmbeddingModelAdapter(EmbeddingModel):
def __init__(self, embeddings: Embeddings):
self.embeddings = embeddings

def embed_texts(self, texts: List[str]) -> List[List[float]]:
return self.embeddings.embed_documents(texts)

def embed_query(self, text: str) -> List[float]:
return self.embeddings.embed_query(text)

async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
return await self.embeddings.aembed_documents(texts)

async def aembed_query(self, text: str) -> List[float]:
return await self.embeddings.aembed_query(text)

from .embedding_adapter import EmbeddingAdapter
from ragstack_knowledge_store import graph_store

class CassandraGraphStore(GraphStore):
def __init__(
Expand Down Expand Up @@ -60,7 +43,7 @@ def __init__(
_setup_mode = getattr(graph_store.SetupMode, setup_mode.name)

self.store = graph_store.GraphStore(
embedding=_EmbeddingModelAdapter(embedding),
embedding=EmbeddingAdapter(embedding),
node_table=node_table,
targets_table=targets_table,
session=session,
Expand All @@ -80,10 +63,8 @@ def add_nodes(
_nodes = []
for node in nodes:
_nodes.append(
graph_store.Node(
id=node.id, text=node.text, metadata=node.metadata, links=node.links
)
)
graph_store.Node(id=node.id, text=node.text, mime_type=node.mime_type, mime_encoding=node.mime_encoding, metadata=node.metadata)
)
return self.store.add_nodes(_nodes)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import List
from ragstack_knowledge_store import EmbeddingModel

class EmbeddingAdapter(EmbeddingModel):
def __init__(self, embeddings):
super().__init__(embeddings,
method_map={'embed_texts': 'embed_documents',
'aembed_texts': 'aembed_documents'})