Skip to content

Commit 9761c0d

Browse files
authored
new: langchain plugin (#44)
1 parent 75a3660 commit 9761c0d

File tree

2 files changed

+39
-22
lines changed

2 files changed

+39
-22
lines changed

nx_arangodb/classes/graph.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,19 @@
3636

3737
__all__ = ["Graph"]
3838

39+
try:
40+
from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain
41+
from langchain_community.graphs import ArangoGraph
42+
from langchain_core.language_models import BaseLanguageModel
43+
from langchain_openai import ChatOpenAI
44+
45+
LLM_AVAILABLE = True
46+
except Exception:
47+
LLM_AVAILABLE = False
48+
49+
class BaseLanguageModel: # type: ignore[no-redef]
50+
pass
51+
3952

4053
class Graph(nx.Graph):
4154
__networkx_backend__: ClassVar[str] = "arangodb" # nx >=3.2
@@ -85,8 +98,6 @@ def __init__(
8598
self.use_nxcg_cache = True
8699
self.nxcg_graph = None
87100

88-
# self.__qa_chain = None
89-
90101
# Does not apply to undirected graphs
91102
self.symmetrize_edges = symmetrize_edges
92103

@@ -370,33 +381,34 @@ def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Curs
370381
return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs)
371382

372383
# def pull(self) -> None:
373-
# self._node._fetch_all()
374-
# self._adj._fetch_all()
384+
# TODO: what would this look like?
375385

376-
# NOTE: OUT OF SERVICE
377-
# def chat(self, prompt: str) -> str:
378-
# if self.__qa_chain is None:
379-
# if not self.graph_exists_in_db:
380-
# return "Could not initialize QA chain: Graph does not exist"
386+
# def push(self) -> None:
387+
# TODO: what would this look like?
381388

382-
# # try:
383-
# from langchain.chains import ArangoGraphQAChain
384-
# from langchain_community.graphs import ArangoGraph
385-
# from langchain_openai import ChatOpenAI
389+
def chat(
390+
self, prompt: str, verbose: bool = False, llm: BaseLanguageModel | None = None
391+
) -> str:
392+
if not LLM_AVAILABLE:
393+
m = "LLM dependencies not installed. Install with **pip install nx-arangodb[llm]**" # noqa: E501
394+
raise ModuleNotFoundError(m)
386395

387-
# model = ChatOpenAI(temperature=0, model_name="gpt-4")
396+
if not self._graph_exists_in_db:
397+
m = "Cannot chat without a graph in the database"
398+
raise GraphNameNotSet(m)
388399

389-
# self.__qa_chain = ArangoGraphQAChain.from_llm(
390-
# llm=model, graph=ArangoGraph(self.db), verbose=True
391-
# )
400+
if llm is None:
401+
llm = ChatOpenAI(temperature=0, model_name="gpt-4")
392402

393-
# # except Exception as e:
394-
# # return f"Could not initialize QA chain: {e}"
403+
chain = ArangoGraphQAChain.from_llm(
404+
llm=llm,
405+
graph=ArangoGraph(self.db),
406+
verbose=verbose,
407+
)
395408

396-
# self.__qa_chain.graph.set_schema()
397-
# result = self.__qa_chain.invoke(prompt)
409+
response = chain.invoke(prompt)
398410

399-
# print(result["result"])
411+
return str(response["result"])
400412

401413
#####################
402414
# nx.Graph Overides #

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ dev = [
5858
gpu = [
5959
"nx-cugraph-cu12 @ https://pypi.nvidia.com"
6060
]
61+
llm = [
62+
"langchain~=0.2.14",
63+
"langchain-openai~=0.1.22",
64+
"langchain-community~=0.2.12"
65+
]
6166

6267
[project.urls]
6368
Homepage = "https://github.com/aMahanna/nx-arangodb"

0 commit comments

Comments
 (0)