Skip to content

Commit 84b165c

Browse files
authored
GA-154 | update phenolrs usage & use nx.config (#9)
* GA-147 | initial commit * new: recursive `EdgeAttrDict` * fix: `nested_keys` param * update tests * new: `AttrDict.root` * fix: `FETCHED_ALL_DATA` * checkpoint * checkpoint 2 (use NetworkX Config) * fix: lint * cleanup: `__fetch_all()` * fix: `self.clear()` * fix: `FETCHED_ALL_DATA` usage * fix: `logger_debug` * remove: walrus operator `:=` is acting weird... not sure what's going on * revert bccc1e6 * new: `load_adj_dict_as_multigraph` * cleanup * fix: `logger_debug`
1 parent f14ab7e commit 84b165c

File tree

6 files changed

+129
-233
lines changed

6 files changed

+129
-233
lines changed

nx_arangodb/classes/dict.py

Lines changed: 51 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
doc_get_or_insert,
3434
doc_insert,
3535
doc_update,
36+
get_arangodb_graph,
3637
get_node_id,
3738
get_node_type_and_id,
3839
key_is_not_reserved,
@@ -321,56 +322,6 @@ def __delitem__(self, key: str) -> None:
321322
root_data = self.root.data if self.root else self.data
322323
root_data["_rev"] = doc_update(self.db, self.node_id, update_dict)
323324

324-
# @logger_debug
325-
# def __iter__(self) -> Iterator[str]:
326-
# """for key in G._node['node/1']"""
327-
# yield from aql_doc_get_keys(self.db, self.node_id, self.parent_keys)
328-
329-
# @logger_debug
330-
# def __len__(self) -> int:
331-
# """len(G._node['node/1'])"""
332-
# return aql_doc_get_length(self.db, self.node_id, self.parent_keys)
333-
334-
# @logger_debug
335-
# def keys(self) -> Any:
336-
# """G._node['node/1'].keys()"""
337-
# yield from self.__iter__()
338-
339-
# @logger_debug
340-
# # TODO: Revisit typing of return value
341-
# def values(self) -> Any:
342-
# """G._node['node/1'].values()"""
343-
# self.data = self.db.document(self.node_id)
344-
# yield from self.data.values()
345-
346-
# @logger_debug
347-
# # TODO: Revisit typing of return value
348-
# def items(self) -> Any:
349-
# """G._node['node/1'].items()"""
350-
351-
# # TODO: Revisit this lazy hack
352-
# if self.parent_keys:
353-
# yield from self.data.items()
354-
# else:
355-
# self.data = self.db.document(self.node_id)
356-
# yield from self.data.items()
357-
358-
# ?
359-
# def pull():
360-
# pass
361-
362-
# ?
363-
# def push():
364-
# pass
365-
366-
# @logger_debug
367-
# def clear(self) -> None:
368-
# """G._node['node/1'].clear()"""
369-
# self.data.clear()
370-
371-
# # if clear_remote:
372-
# # doc_insert(self.db, self.node_id, silent=True, overwrite=True)
373-
374325
@keys_are_strings
375326
@keys_are_not_reserved
376327
# @values_are_json_serializable # TODO?
@@ -435,6 +386,9 @@ def __contains__(self, key: str) -> bool:
435386
if node_id in self.data:
436387
return True
437388

389+
if self.FETCHED_ALL_DATA:
390+
return False
391+
438392
return bool(self.graph.has_vertex(node_id))
439393

440394
@key_is_string
@@ -446,6 +400,9 @@ def __getitem__(self, key: str) -> NodeAttrDict:
446400
if vertex := self.data.get(node_id):
447401
return vertex
448402

403+
if self.FETCHED_ALL_DATA:
404+
raise KeyError(key)
405+
449406
if vertex := self.graph.vertex(node_id):
450407
node_attr_dict: NodeAttrDict = self.node_attr_dict_factory()
451408
node_attr_dict.node_id = node_id
@@ -472,7 +429,7 @@ def __setitem__(self, key: str, value: NodeAttrDict) -> None:
472429

473430
node_attr_dict = self.node_attr_dict_factory()
474431
node_attr_dict.node_id = node_id
475-
node_attr_dict.data = result
432+
node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, result)
476433

477434
self.data[node_id] = node_attr_dict
478435

@@ -570,16 +527,23 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:
570527

571528
@logger_debug
572529
def __fetch_all(self):
573-
self.data.clear()
574-
for collection in self.graph.vertex_collections():
575-
for doc in self.graph.vertex_collection(collection).all():
576-
node_id = doc["_id"]
530+
self.clear()
577531

578-
node_attr_dict = self.node_attr_dict_factory()
579-
node_attr_dict.node_id = node_id
580-
node_attr_dict.data = doc
532+
node_dict, _, _, _, _ = get_arangodb_graph(
533+
self.graph,
534+
load_node_dict=True,
535+
load_adj_dict=False,
536+
load_adj_dict_as_directed=False, # not used
537+
load_adj_dict_as_multigraph=False, # not used
538+
load_coo=False,
539+
)
581540

582-
self.data[node_id] = node_attr_dict
541+
for node_id, node_data in node_dict.items():
542+
node_attr_dict = self.node_attr_dict_factory()
543+
node_attr_dict.node_id = node_id
544+
node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, node_data)
545+
546+
self.data[node_id] = node_attr_dict
583547

584548
self.FETCHED_ALL_DATA = True
585549

@@ -710,43 +674,6 @@ def __delitem__(self, key: str) -> None:
710674
root_data = self.root.data if self.root else self.data
711675
root_data["_rev"] = doc_update(self.db, self.edge_id, update_dict)
712676

713-
# @logger_debug
714-
# def __iter__(self) -> Iterator[str]:
715-
# """for key in G._adj['node/1']['node/2']"""
716-
# assert self.edge_id
717-
# yield from aql_doc_get_keys(self.db, self.edge_id)
718-
719-
# @logger_debug
720-
# def __len__(self) -> int:
721-
# """len(G._adj['node/1']['node/'2])"""
722-
# assert self.edge_id
723-
# return aql_doc_get_length(self.db, self.edge_id)
724-
725-
# # TODO: Revisit typing of return value
726-
# @logger_debug
727-
# def keys(self) -> Any:
728-
# """G._adj['node/1']['node/'2].keys()"""
729-
# return self.__iter__()
730-
731-
# # TODO: Revisit typing of return value
732-
# @logger_debug
733-
# def values(self) -> Any:
734-
# """G._adj['node/1']['node/'2].values()"""
735-
# self.data = self.db.document(self.edge_id)
736-
# yield from self.data.values()
737-
738-
# # TODO: Revisit typing of return value
739-
# @logger_debug
740-
# def items(self) -> Any:
741-
# """G._adj['node/1']['node/'2].items()"""
742-
# self.data = self.db.document(self.edge_id)
743-
# yield from self.data.items()
744-
745-
# @logger_debug
746-
# def clear(self) -> None:
747-
# """G._adj['node/1']['node/'2].clear()"""
748-
# self.data.clear()
749-
750677
@keys_are_strings
751678
@keys_are_not_reserved
752679
@logger_debug
@@ -836,6 +763,9 @@ def __contains__(self, key: str) -> bool:
836763
if dst_node_id in self.data:
837764
return True
838765

766+
if self.FETCHED_ALL_DATA:
767+
return False
768+
839769
result = aql_edge_exists(
840770
self.db,
841771
self.src_node_id,
@@ -859,6 +789,9 @@ def __getitem__(self, key: str) -> EdgeAttrDict:
859789
self.data[dst_node_id] = edge
860790
return edge # type: ignore # false positive
861791

792+
if self.FETCHED_ALL_DATA:
793+
raise KeyError(key)
794+
862795
assert self.src_node_id
863796
edge = aql_edge_get(
864797
self.db,
@@ -1022,8 +955,7 @@ def items(self) -> Any:
1022955

1023956
@logger_debug
1024957
def __fetch_all(self) -> None:
1025-
if self.FETCHED_ALL_DATA:
1026-
return
958+
assert self.src_node_id
1027959

1028960
self.clear()
1029961

@@ -1037,8 +969,7 @@ def __fetch_all(self) -> None:
1037969
for edge in aql(self.db, query, bind_vars):
1038970
edge_attr_dict = self.edge_attr_dict_factory()
1039971
edge_attr_dict.edge_id = edge["_id"]
1040-
edge_attr_dict.data = edge
1041-
972+
edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge)
1042973
self.data[edge["_to"]] = edge_attr_dict
1043974

1044975
self.FETCHED_ALL_DATA = True
@@ -1100,6 +1031,9 @@ def __contains__(self, key: str) -> bool:
11001031
if node_id in self.data:
11011032
return True
11021033

1034+
if self.FETCHED_ALL_DATA:
1035+
return False
1036+
11031037
return bool(self.graph.has_vertex(node_id))
11041038

11051039
@key_is_string
@@ -1114,7 +1048,6 @@ def __getitem__(self, key: str) -> AdjListInnerDict:
11141048
if self.graph.has_vertex(node_id):
11151049
adjlist_inner_dict: AdjListInnerDict = self.adjlist_inner_dict_factory()
11161050
adjlist_inner_dict.src_node_id = node_id
1117-
11181051
self.data[node_id] = adjlist_inner_dict
11191052

11201053
return adjlist_inner_dict
@@ -1237,41 +1170,45 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:
12371170
result = aql_fetch_data_edge(self.db, e_cols, data, default)
12381171
yield from result
12391172

1240-
# TODO: Revisit this logic
12411173
@logger_debug
12421174
def __fetch_all(self) -> None:
1243-
if self.FETCHED_ALL_DATA:
1244-
return
1245-
12461175
self.clear()
1247-
# items = defaultdict(dict)
1248-
for ed in self.graph.edge_definitions():
1249-
collection = ed["edge_collection"]
12501176

1251-
for edge in self.graph.edge_collection(collection):
1252-
src_node_id = edge["_from"]
1253-
dst_node_id = edge["_to"]
1177+
_, adj_dict, _, _, _ = get_arangodb_graph(
1178+
self.graph,
1179+
load_node_dict=False,
1180+
load_adj_dict=True,
1181+
load_adj_dict_as_directed=False, # TODO: Abstract based on Graph type
1182+
load_adj_dict_as_multigraph=False, # TODO: Abstract based on Graph type
1183+
load_coo=False,
1184+
)
1185+
1186+
for src_node_id, inner_dict in adj_dict.items():
1187+
for dst_node_id, edge in inner_dict.items():
12541188

1255-
# items[src_node_id][dst_node_id] = edge
1256-
# items[dst_node_id][src_node_id] = edge
1189+
if src_node_id in self.data:
1190+
if dst_node_id in self.data[src_node_id].data:
1191+
continue
12571192

12581193
if src_node_id in self.data:
12591194
src_inner_dict = self.data[src_node_id]
12601195
else:
12611196
src_inner_dict = self.adjlist_inner_dict_factory()
12621197
src_inner_dict.src_node_id = src_node_id
1198+
src_inner_dict.FETCHED_ALL_DATA = True
12631199
self.data[src_node_id] = src_inner_dict
12641200

12651201
if dst_node_id in self.data:
12661202
dst_inner_dict = self.data[dst_node_id]
12671203
else:
12681204
dst_inner_dict = self.adjlist_inner_dict_factory()
12691205
dst_inner_dict.src_node_id = dst_node_id
1206+
src_inner_dict.FETCHED_ALL_DATA = True
12701207
self.data[dst_node_id] = dst_inner_dict
12711208

12721209
edge_attr_dict = src_inner_dict.edge_attr_dict_factory()
12731210
edge_attr_dict.edge_id = edge["_id"]
1274-
edge_attr_dict.data = edge
1211+
edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge)
12751212

12761213
self.data[src_node_id].data[dst_node_id] = edge_attr_dict
12771214
self.data[dst_node_id].data[src_node_id] = edge_attr_dict

nx_arangodb/classes/digraph.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,3 @@ def __set_graph_name(self, graph_name: str | None = None) -> None:
154154

155155
def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Cursor:
156156
return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs)
157-
158-
def pull(self, load_node_dict=True, load_adj_dict=True, load_coo=True):
159-
raise NotImplementedError("nxadb.DiGraph.pull() is not implemented yet.")

nx_arangodb/classes/function.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from collections import UserDict
99
from typing import Any, Callable, Tuple
1010

11+
import networkx as nx
1112
import numpy as np
1213
import numpy.typing as npt
1314
from arango.collection import StandardCollection
1415
from arango.cursor import Cursor
1516
from arango.database import StandardDatabase
17+
from arango.graph import Graph
1618

1719
import nx_arangodb as nxadb
1820
from nx_arangodb.logger import logger
@@ -25,10 +27,11 @@
2527

2628

2729
def get_arangodb_graph(
28-
G: nxadb.Graph | nxadb.DiGraph,
30+
adb_graph: Graph,
2931
load_node_dict: bool,
3032
load_adj_dict: bool,
3133
load_adj_dict_as_directed: bool,
34+
load_adj_dict_as_multigraph: bool,
3235
load_coo: bool,
3336
) -> Tuple[
3437
dict[str, dict[str, Any]],
@@ -46,12 +49,6 @@ def get_arangodb_graph(
4649
- Destination Indices (COO)
4750
- Node-ID-to-index mapping (COO)
4851
"""
49-
if not G.graph_exists_in_db:
50-
raise GraphDoesNotExist(
51-
"Graph does not exist in the database. Can't load graph."
52-
)
53-
54-
adb_graph = G.db.graph(G.graph_name)
5552
v_cols = adb_graph.vertex_collections()
5653
edge_definitions = adb_graph.edge_definitions()
5754
e_cols = {c["edge_collection"] for c in edge_definitions}
@@ -63,22 +60,30 @@ def get_arangodb_graph(
6360

6461
from phenolrs.networkx_loader import NetworkXLoader
6562

63+
config = nx.config.backends.arangodb
64+
6665
kwargs = {}
67-
if G.graph_loader_parallelism is not None:
68-
kwargs["parallelism"] = G.graph_loader_parallelism
69-
if G.graph_loader_batch_size is not None:
70-
kwargs["batch_size"] = G.graph_loader_batch_size
66+
if parallelism := config.get("load_parallelism"):
67+
kwargs["parallelism"] = parallelism
68+
if batch_size := config.get("load_batch_size"):
69+
kwargs["batch_size"] = batch_size
70+
71+
assert config.db_name
72+
assert config.host
73+
assert config.username
74+
assert config.password
7175

7276
# TODO: Remove ignore when phenolrs is published
7377
return NetworkXLoader.load_into_networkx( # type: ignore
74-
G.db.name,
75-
metagraph,
76-
[G._host],
77-
username=G._username,
78-
password=G._password,
78+
config.db_name,
79+
metagraph=metagraph,
80+
hosts=[config.host],
81+
username=config.username,
82+
password=config.password,
7983
load_node_dict=load_node_dict,
8084
load_adj_dict=load_adj_dict,
8185
load_adj_dict_as_directed=load_adj_dict_as_directed,
86+
load_adj_dict_as_multigraph=load_adj_dict_as_multigraph,
8287
load_coo=load_coo,
8388
**kwargs,
8489
)
@@ -103,7 +108,7 @@ def logger_debug(func: Callable[..., Any]) -> Any:
103108
"""Decorator to log debug messages."""
104109

105110
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
106-
logger.debug(f"{func.__name__} - {args} - {kwargs}")
111+
logger.debug(f"{type(self)}.{func.__name__} - {args} - {kwargs}")
107112
return func(self, *args, **kwargs)
108113

109114
return wrapper

0 commit comments

Comments
 (0)