|
36 | 36 |
|
37 | 37 | __all__ = ["Graph"]
|
38 | 38 |
|
| 39 | +try: |
| 40 | + from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain |
| 41 | + from langchain_community.graphs import ArangoGraph |
| 42 | + from langchain_core.language_models import BaseLanguageModel |
| 43 | + from langchain_openai import ChatOpenAI |
| 44 | + |
| 45 | + LLM_AVAILABLE = True |
| 46 | +except Exception: |
| 47 | + LLM_AVAILABLE = False |
| 48 | + |
| 49 | + class BaseLanguageModel: # type: ignore[no-redef] |
| 50 | + pass |
| 51 | + |
39 | 52 |
|
40 | 53 | class Graph(nx.Graph):
|
41 | 54 | __networkx_backend__: ClassVar[str] = "arangodb" # nx >=3.2
|
@@ -85,8 +98,6 @@ def __init__(
|
85 | 98 | self.use_nxcg_cache = True
|
86 | 99 | self.nxcg_graph = None
|
87 | 100 |
|
88 |
| - # self.__qa_chain = None |
89 |
| - |
90 | 101 | # Does not apply to undirected graphs
|
91 | 102 | self.symmetrize_edges = symmetrize_edges
|
92 | 103 |
|
@@ -370,33 +381,34 @@ def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Curs
|
370 | 381 | return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs)
|
371 | 382 |
|
372 | 383 | # def pull(self) -> None:
|
373 |
| - # self._node._fetch_all() |
374 |
| - # self._adj._fetch_all() |
| 384 | + # TODO: what would this look like? |
375 | 385 |
|
376 |
| - # NOTE: OUT OF SERVICE |
377 |
| - # def chat(self, prompt: str) -> str: |
378 |
| - # if self.__qa_chain is None: |
379 |
| - # if not self.graph_exists_in_db: |
380 |
| - # return "Could not initialize QA chain: Graph does not exist" |
| 386 | + # def push(self) -> None: |
| 387 | + # TODO: what would this look like? |
381 | 388 |
|
382 |
| - # # try: |
383 |
| - # from langchain.chains import ArangoGraphQAChain |
384 |
| - # from langchain_community.graphs import ArangoGraph |
385 |
| - # from langchain_openai import ChatOpenAI |
| 389 | + def chat( |
| 390 | + self, prompt: str, verbose: bool = False, llm: BaseLanguageModel | None = None |
| 391 | + ) -> str: |
| 392 | + if not LLM_AVAILABLE: |
| 393 | + m = "LLM dependencies not installed. Install with **pip install nx-arangodb[llm]**" # noqa: E501 |
| 394 | + raise ModuleNotFoundError(m) |
386 | 395 |
|
387 |
| - # model = ChatOpenAI(temperature=0, model_name="gpt-4") |
| 396 | + if not self._graph_exists_in_db: |
| 397 | + m = "Cannot chat without a graph in the database" |
| 398 | + raise GraphNameNotSet(m) |
388 | 399 |
|
389 |
| - # self.__qa_chain = ArangoGraphQAChain.from_llm( |
390 |
| - # llm=model, graph=ArangoGraph(self.db), verbose=True |
391 |
| - # ) |
| 400 | + if llm is None: |
| 401 | + llm = ChatOpenAI(temperature=0, model_name="gpt-4") |
392 | 402 |
|
393 |
| - # # except Exception as e: |
394 |
| - # # return f"Could not initialize QA chain: {e}" |
| 403 | + chain = ArangoGraphQAChain.from_llm( |
| 404 | + llm=llm, |
| 405 | + graph=ArangoGraph(self.db), |
| 406 | + verbose=verbose, |
| 407 | + ) |
395 | 408 |
|
396 |
| - # self.__qa_chain.graph.set_schema() |
397 |
| - # result = self.__qa_chain.invoke(prompt) |
| 409 | + response = chain.invoke(prompt) |
398 | 410 |
|
399 |
| - # print(result["result"]) |
| 411 | + return str(response["result"]) |
400 | 412 |
|
401 | 413 | #####################
|
402 | 414 | # nx.Graph Overides #
|
|
0 commit comments