Skip to content

Commit 9dc4cc2

Browse files
authored
GA-150 | MultiDiGraph Support (#26)
* GA-149 | initial commit failing for now * checkpoint no new tests yet, just experimenting with `AdjListInnerDict` * checkpoint 2 * checkpoint 3 still no new tests, just brainstorming * checkpoint 4 starting to get messy... * cleanup & comments * comments * cleanup: `__contains__` * cleanup: `__getitem__` * restructuring * docstring updates * checkpoint 5 * cleanup * new helper functions * checkpoint 6 * checkpoint 7 * cleanup * add warning * fix: conditional override * fix: func name * new: `FETCHED_ALL_IDS` Attribute used to establish if all ArangoDB IDs have been retrieved for the particular dict class. Not to be confused with `FETCHED_ALL_DATA`, which fetches both IDs & Documents * fix: parameterize `EDGE_TYPE_KEY` * cleanup: redundant code * fix: `nodes` & `edges` properties * new: `__process_int_edge_key` * new: `test_multigraph_*_crud` minimal suite for now. need to revisit * update: `test_algorithm` for `nxadb.MultiGraph` * fix: `__get_mirrored_adjlist_inner_dict` * extra docstring * new: graph overrides * fix: EdgeKeyDict docstring * update `phenolrs` wheel * fix: phenolrs * remove unused import * fix: except clause * fix: logger info * remove multigraph lock * fix: typo * cleanup: kwargs * remove print * fix: add `write_batch_size` to config this will be useful for bulk updates * temp: `NodeDict.update` hack Just a temporary solution. Will be removed shortly * revert ec1cbc8 * add custom exception * update node & edge type logic for new vs existing graphs * fix: `symmetrize_edges` logic * GA-150 | initial commit
1 parent 82f5536 commit 9dc4cc2

File tree

5 files changed

+230
-9
lines changed

5 files changed

+230
-9
lines changed

nx_arangodb/classes/dict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2258,7 +2258,8 @@ def set_edge_multigraph(
22582258
load_all_edge_attributes=True,
22592259
is_directed=self.is_directed,
22602260
is_multigraph=self.is_multigraph,
2261-
symmetrize_edges_if_directed=self.symmetrize_edges_if_directed,
2261+
symmetrize_edges_if_directed=self.is_directed
2262+
and self.symmetrize_edges_if_directed,
22622263
)
22632264

22642265
if self.is_directed:

nx_arangodb/classes/digraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ def __init__(
4545
read_parallelism,
4646
read_batch_size,
4747
write_batch_size,
48+
symmetrize_edges,
4849
*args,
4950
**kwargs,
5051
)
5152

52-
self.symmetrize_edges = symmetrize_edges
5353
if self.graph_exists_in_db:
5454
assert isinstance(self._succ, AdjListOuterDict)
5555
assert isinstance(self._pred, AdjListOuterDict)

nx_arangodb/classes/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
read_parallelism: int = 10,
5353
read_batch_size: int = 100000,
5454
write_batch_size: int = 50000,
55+
symmetrize_edges: bool = False,
5556
*args: Any,
5657
**kwargs: Any,
5758
):
@@ -80,7 +81,8 @@ def __init__(
8081
self.edge_indices: npt.NDArray[np.int64] | None = None
8182
self.vertex_ids_to_index: dict[str, int] | None = None
8283

83-
self.symmetrize_edges = False # Does not apply to undirected graphs
84+
# Does not apply to undirected graphs
85+
self.symmetrize_edges = symmetrize_edges
8486

8587
self.edge_type_key = edge_type_key
8688

nx_arangodb/classes/multidigraph.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import ClassVar
1+
from typing import Any, Callable, ClassVar
22

33
import networkx as nx
4+
from arango.database import StandardDatabase
45

56
import nx_arangodb as nxadb
67
from nx_arangodb.classes.digraph import DiGraph
@@ -20,10 +21,33 @@ class MultiDiGraph(MultiGraph, DiGraph, nx.MultiDiGraph):
2021
def to_networkx_class(cls) -> type[nx.MultiDiGraph]:
2122
return nx.MultiDiGraph # type: ignore[no-any-return]
2223

23-
def __init__(self, *args, **kwargs):
24-
super().__init__(*args, **kwargs)
25-
m = "nxadb.MultiDiGraph has not been implemented yet. This is a pass-through subclass of nx.MultiDiGraph for now." # noqa
26-
logger.warning(m)
24+
def __init__(
25+
self,
26+
graph_name: str | None = None,
27+
default_node_type: str | None = None,
28+
edge_type_key: str = "_edge_type",
29+
edge_type_func: Callable[[str, str], str] | None = None,
30+
db: StandardDatabase | None = None,
31+
read_parallelism: int = 10,
32+
read_batch_size: int = 100000,
33+
write_batch_size: int = 50000,
34+
symmetrize_edges: bool = False,
35+
*args: Any,
36+
**kwargs: Any,
37+
):
38+
super().__init__(
39+
graph_name,
40+
default_node_type,
41+
edge_type_key,
42+
edge_type_func,
43+
db,
44+
read_parallelism,
45+
read_batch_size,
46+
write_batch_size,
47+
symmetrize_edges,
48+
*args,
49+
**kwargs,
50+
)
2751

2852
#######################
2953
# Init helper methods #

tests/test.py

Lines changed: 195 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,30 @@ def assert_same_dict_values(
2828

2929

3030
def assert_bc(d1: dict[str | int, float], d2: dict[str | int, float]) -> None:
31+
assert d1
32+
assert d2
3133
assert_same_dict_values(d1, d2, 14)
3234

3335

3436
def assert_pagerank(d1: dict[str | int, float], d2: dict[str | int, float]) -> None:
37+
assert d1
38+
assert d2
3539
assert_same_dict_values(d1, d2, 15)
3640

3741

3842
def assert_louvain(l1: list[set[Any]], l2: list[set[Any]]) -> None:
3943
# TODO: Implement some kind of comparison
4044
# Reason: Louvain returns different results on different runs
45+
assert l1
46+
assert l2
4147
pass
4248

4349

4450
def assert_k_components(
4551
d1: dict[int, list[set[Any]]], d2: dict[int, list[set[Any]]]
4652
) -> None:
53+
assert d1
54+
assert d2
4755
assert d1.keys() == d2.keys(), "Dictionaries have different keys"
4856
assert d1 == d2
4957

@@ -91,6 +99,8 @@ def test_algorithm(
9199
G_4 = nxadb.DiGraph(graph_name="KarateGraph", symmetrize_edges=True)
92100
G_5 = nxadb.DiGraph(graph_name="KarateGraph", symmetrize_edges=False)
93101
G_6 = nxadb.MultiGraph(graph_name="KarateGraph")
102+
G_7 = nxadb.MultiDiGraph(graph_name="KarateGraph", symmetrize_edges=True)
103+
G_8 = nxadb.MultiDiGraph(graph_name="KarateGraph", symmetrize_edges=False)
94104

95105
r_1 = algorithm_func(G_1)
96106
r_2 = algorithm_func(G_2)
@@ -121,7 +131,12 @@ def test_algorithm(
121131
r_11 = algorithm_func(G_6)
122132
r_11_orig = algorithm_func.orig_func(G_6) # type: ignore
123133

124-
assert all([r_7, r_7_orig, r_8, r_8_orig, r_9, r_9_orig, r_10, r_11, r_11_orig])
134+
r_12 = algorithm_func(G_7)
135+
r_12_orig = algorithm_func.orig_func(G_7) # type: ignore
136+
137+
r_13 = algorithm_func(G_8)
138+
r_13_orig = algorithm_func.orig_func(G_8) # type: ignore
139+
125140
assert_func(r_7, r_7_orig)
126141
assert_func(r_8, r_8_orig)
127142
assert_func(r_9, r_9_orig)
@@ -134,6 +149,14 @@ def test_algorithm(
134149
assert_func(r_7, r_11)
135150
assert_func(r_8, r_11)
136151
assert_func(r_11, r_11_orig)
152+
assert_func(r_12, r_12_orig)
153+
assert_func(r_13, r_13_orig)
154+
assert r_12 != r_13
155+
assert r_12_orig != r_13_orig
156+
assert_func(r_8, r_12)
157+
assert_func(r_8_orig, r_12_orig)
158+
assert_func(r_9, r_13)
159+
assert_func(r_9_orig, r_13_orig)
137160

138161

139162
def test_shortest_path_remote_algorithm(load_graph: Any) -> None:
@@ -157,6 +180,7 @@ def test_shortest_path_remote_algorithm(load_graph: Any) -> None:
157180
(nxadb.Graph),
158181
(nxadb.DiGraph),
159182
(nxadb.MultiGraph),
183+
(nxadb.MultiDiGraph),
160184
],
161185
)
162186
def test_nodes_crud(load_graph: Any, graph_cls: type[nxadb.Graph]) -> None:
@@ -741,6 +765,176 @@ def test_multigraph_edges_crud(load_graph: Any) -> None:
741765
assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz"
742766

743767

768+
def test_multidigraph_edges_crud(load_graph: Any) -> None:
769+
G_1 = nxadb.MultiDiGraph(graph_name="KarateGraph")
770+
G_2 = G_NX
771+
772+
assert len(G_1.adj) == len(G_2.adj)
773+
assert len(G_1.edges) == len(G_2.edges)
774+
assert G_1.number_of_edges() == G_2.number_of_edges()
775+
776+
for src, dst, w in G_1.edges.data("weight"):
777+
assert G_1.adj[src][dst][0]["weight"] == w
778+
779+
for src, dst, w in G_1.edges.data("bad_key", default="boom!"):
780+
assert "bad_key" not in G_1.adj[src][dst][0]
781+
assert w == "boom!"
782+
783+
for k, edge_key_dict in G_1.adj["person/1"].items():
784+
assert db.has_document(k)
785+
assert db.has_document(edge_key_dict[0]["_id"])
786+
787+
G_1.add_edge("person/1", "person/1", foo="bar", _edge_type="knows")
788+
edge_id = G_1.adj["person/1"]["person/1"][0]["_id"]
789+
doc = db.document(edge_id)
790+
assert doc["foo"] == "bar"
791+
assert G_1.adj["person/1"]["person/1"][0]["foo"] == "bar"
792+
793+
del G_1.adj["person/1"]["person/1"][0]["foo"]
794+
doc = db.document(edge_id)
795+
assert "foo" not in doc
796+
797+
G_1.adj["person/1"]["person/1"][0].update({"bar": "foo"})
798+
doc = db.document(edge_id)
799+
assert doc["bar"] == "foo"
800+
801+
assert len(G_1.adj["person/1"]["person/1"][0]) == len(doc)
802+
adj_count = len(G_1.adj["person/1"])
803+
G_1.remove_edge("person/1", "person/1")
804+
assert len(G_1.adj["person/1"]) == adj_count - 1
805+
assert not db.has_document(edge_id)
806+
assert "person/1" in G_1
807+
808+
assert not db.has_document("person/new_node_1")
809+
col_count = db.collection("knows").count()
810+
811+
G_1.add_edge("new_node_1", "new_node_2", foo="bar")
812+
assert db.document(G_1["new_node_1"]["new_node_2"][0]["_id"])["foo"] == "bar"
813+
G_1.add_edge("new_node_1", "new_node_2", foo="bar", bar="foo")
814+
doc = db.document(G_1["new_node_1"]["new_node_2"][1]["_id"])
815+
assert doc["foo"] == "bar"
816+
assert doc["bar"] == "foo"
817+
818+
bind_vars = {
819+
"src": f"{G_1.default_node_type}/new_node_1",
820+
"dst": f"{G_1.default_node_type}/new_node_2",
821+
}
822+
823+
result = list(
824+
db.aql.execute(
825+
f"FOR e IN knows FILTER e._from == @src AND e._to == @dst RETURN e", # noqa
826+
bind_vars=bind_vars,
827+
)
828+
)
829+
830+
assert len(result) == 2
831+
832+
result = list(
833+
db.aql.execute(
834+
f"FOR e IN knows FILTER e._from == @dst AND e._to == @src RETURN e", # noqa
835+
bind_vars=bind_vars,
836+
)
837+
)
838+
839+
assert len(result) == 0
840+
841+
assert db.collection("knows").count() == col_count + 2
842+
assert G_1.adj["new_node_1"]["new_node_2"][0]
843+
assert G_1.adj["new_node_1"]["new_node_2"][0]["foo"] == "bar"
844+
assert G_1.pred["new_node_2"]["new_node_1"][0]
845+
assert "new_node_1" not in G_1.adj["new_node_2"]
846+
assert (
847+
G_1.adj["new_node_1"]["new_node_2"][0]["_id"]
848+
== G_1.pred["new_node_2"]["new_node_1"][0]["_id"]
849+
)
850+
edge_id = G_1.adj["new_node_1"]["new_node_2"][0]["_id"]
851+
doc = db.document(edge_id)
852+
assert db.has_document(doc["_from"])
853+
assert db.has_document(doc["_to"])
854+
assert G_1.nodes["new_node_1"]
855+
assert G_1.nodes["new_node_2"]
856+
857+
assert len(G_1.adj["new_node_1"]["new_node_2"]) == 2
858+
G_1.remove_edge("new_node_1", "new_node_2")
859+
G_1.clear()
860+
assert "new_node_1" in G_1
861+
assert "new_node_2" in G_1
862+
assert "new_node_2" in G_1.adj["new_node_1"]
863+
assert len(G_1.adj["new_node_1"]["new_node_2"]) == 1
864+
865+
G_1.add_edges_from(
866+
[("new_node_1", "new_node_2"), ("new_node_1", "new_node_3")], foo="bar"
867+
)
868+
G_1.clear()
869+
assert "new_node_1" in G_1
870+
assert "new_node_2" in G_1
871+
assert "new_node_3" in G_1
872+
873+
for k in G_1.adj["new_node_1"]["new_node_2"]:
874+
assert G_1.adj["new_node_1"]["new_node_2"][k]["foo"] == "bar"
875+
assert G_1.pred["new_node_2"]["new_node_1"][k]["foo"] == "bar"
876+
877+
for k in G_1.adj["new_node_1"]["new_node_3"]:
878+
assert G_1.adj["new_node_1"]["new_node_3"][k]["foo"] == "bar"
879+
assert G_1.pred["new_node_3"]["new_node_1"][k]["foo"] == "bar"
880+
881+
assert len(G_1.adj["new_node_1"]["new_node_2"]) == 2
882+
assert len(G_1.adj["new_node_1"]["new_node_3"]) == 1
883+
G_1.remove_edges_from([("new_node_1", "new_node_2"), ("new_node_1", "new_node_3")])
884+
assert len(G_1.adj["new_node_1"]["new_node_2"]) == 1
885+
886+
assert "new_node_1" in G_1
887+
assert "new_node_2" in G_1
888+
assert "new_node_3" in G_1
889+
assert "new_node_2" in G_1.adj["new_node_1"]
890+
assert "new_node_3" not in G_1.adj["new_node_1"]
891+
892+
edge_id = "knows/1"
893+
assert "person/1" not in G_1["person/2"]
894+
assert (
895+
G_1.succ["person/1"]["person/2"][edge_id]
896+
== G_1.pred["person/2"]["person/1"][edge_id]
897+
)
898+
new_weight = 1000
899+
G_1["person/1"]["person/2"][edge_id]["weight"] = new_weight
900+
assert G_1.succ["person/1"]["person/2"][edge_id]["weight"] == new_weight
901+
assert G_1.pred["person/2"]["person/1"][edge_id]["weight"] == new_weight
902+
G_1.clear()
903+
assert G_1.succ["person/1"]["person/2"][edge_id]["weight"] == new_weight
904+
G_1.clear()
905+
assert G_1.pred["person/2"]["person/1"][edge_id]["weight"] == new_weight
906+
907+
edge_id = G_1["person/1"]["person/2"][edge_id]["_id"]
908+
G_1["person/1"]["person/2"][edge_id]["object"] = {"foo": "bar", "bar": "foo"}
909+
assert "_rev" not in G_1["person/1"]["person/2"][edge_id]["object"]
910+
assert isinstance(G_1["person/1"]["person/2"][edge_id]["object"], EdgeAttrDict)
911+
assert db.document(edge_id)["object"] == {"foo": "bar", "bar": "foo"}
912+
913+
G_1["person/1"]["person/2"][edge_id]["object"]["foo"] = "baz"
914+
assert db.document(edge_id)["object"]["foo"] == "baz"
915+
916+
del G_1["person/1"]["person/2"][edge_id]["object"]["foo"]
917+
assert "_rev" not in G_1["person/1"]["person/2"][edge_id]["object"]
918+
assert isinstance(G_1["person/1"]["person/2"][edge_id]["object"], EdgeAttrDict)
919+
assert "foo" not in db.document(edge_id)["object"]
920+
921+
G_1["person/1"]["person/2"][edge_id]["object"].update(
922+
{"sub_object": {"foo": "bar"}}
923+
)
924+
assert "_rev" not in G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"]
925+
assert isinstance(
926+
G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"], EdgeAttrDict
927+
)
928+
assert db.document(edge_id)["object"]["sub_object"]["foo"] == "bar"
929+
930+
G_1.clear()
931+
932+
assert G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"]["foo"] == "bar"
933+
G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"]["foo"] = "baz"
934+
assert "_rev" not in G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"]
935+
assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz"
936+
937+
744938
def test_graph_dict_init(load_graph: Any) -> None:
745939
G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person")
746940
assert db.collection("_graphs").has("KarateGraph")

0 commit comments

Comments
 (0)