Skip to content

Commit 5b72482

Browse files
hkernbachaMahanna
andauthored
[GA-153-1] Implement EdgeAttrDict update method (new) (#30)
* remove not needed imports, fix typos * moved over code from pr * fmt and lint * fix code, added test for graphs, added todo * adapt MultiGraph to old code * flake8 * removed auto added import * add update method to CustomNodeView * update_local_nodes as private method * user logger instead of warnings * remove assertion, raise in case wrong key is given * move test only func into tst, removed unused func * remove import * TODO WIP * fix typo * disabled this for now * fix mypy * py to 3.12 * py to 3.12.3 * py to 3.12.5 * py to 3.12.5 .............. * back to 3.12.2 * back to 3.10 * fixes after merge * fix use of method * added core view * Update nx_arangodb/classes/function.py Co-authored-by: Anthony Mahanna <[email protected]> * Update nx_arangodb/classes/function.py Co-authored-by: Anthony Mahanna <[email protected]> * Update tests/test.py Co-authored-by: Anthony Mahanna <[email protected]> * Update tests/test.py Co-authored-by: Anthony Mahanna <[email protected]> * optimize separate_edges_by_collections * fmt --------- Co-authored-by: Anthony Mahanna <[email protected]>
1 parent cdc6a20 commit 5b72482

File tree

5 files changed

+394
-63
lines changed

5 files changed

+394
-63
lines changed

nx_arangodb/classes/coreviews.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import networkx as nx
2+
3+
4+
class CustomAdjacencyView(nx.classes.coreviews.AdjacencyView):
5+
6+
def update(self, data):
7+
return self._atlas.update(data)

nx_arangodb/classes/dict/adj.py

Lines changed: 101 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from collections import UserDict
45
from collections.abc import Iterator
56
from itertools import islice
@@ -8,12 +9,19 @@
89
from arango.database import StandardDatabase
910
from arango.exceptions import DocumentDeleteError
1011
from arango.graph import Graph
12+
from phenolrs.networkx.typings import (
13+
DiGraphAdjDict,
14+
GraphAdjDict,
15+
MultiDiGraphAdjDict,
16+
MultiGraphAdjDict,
17+
)
1118

1219
from nx_arangodb.exceptions import EdgeTypeAmbiguity, MultipleEdgesFound
1320
from nx_arangodb.logger import logger
1421

1522
from ..enum import DIRECTED_GRAPH_TYPES, MULTIGRAPH_TYPES, GraphType, TraversalDirection
1623
from ..function import (
24+
ArangoDBBatchError,
1725
aql,
1826
aql_doc_get_key,
1927
aql_doc_has_key,
@@ -23,6 +31,7 @@
2331
aql_edge_get,
2432
aql_edge_id,
2533
aql_fetch_data_edge,
34+
check_list_for_errors,
2635
doc_insert,
2736
doc_update,
2837
get_arangodb_graph,
@@ -36,6 +45,8 @@
3645
keys_are_not_reserved,
3746
keys_are_strings,
3847
logger_debug,
48+
separate_edges_by_collections,
49+
upsert_collection_edges,
3950
)
4051

4152
#############
@@ -169,7 +180,7 @@ def __init__(
169180
self.graph = graph
170181
self.edge_id: str | None = None
171182

172-
# NodeAttrDict may be a child of another NodeAttrDict
183+
# EdgeAttrDict may be a child of another EdgeAttrDict
173184
# e.g G._adj['node/1']['node/2']['object']['foo'] = 'bar'
174185
# In this case, **parent_keys** would be ['object']
175186
# and **root** would be G._adj['node/1']['node/2']
@@ -1482,8 +1493,31 @@ def clear(self) -> None:
14821493
@keys_are_strings
14831494
@logger_debug
14841495
def update(self, edges: Any) -> None:
1485-
"""g._adj.update({'node/1': {'node/2': {'foo': 'bar'}})"""
1486-
raise NotImplementedError("AdjListOuterDict.update()")
1496+
"""g._adj.update({'node/1': {'node/2': {'_id': 'foo/bar', 'foo': "bar"}})"""
1497+
separated_by_edge_collection = separate_edges_by_collections(
1498+
edges, graph_type=self.graph_type, default_node_type=self.default_node_type
1499+
)
1500+
result = upsert_collection_edges(self.db, separated_by_edge_collection)
1501+
1502+
all_good = check_list_for_errors(result)
1503+
if all_good:
1504+
# Means no single operation failed, in this case we update the local cache
1505+
self.__set_adj_elements(edges)
1506+
else:
1507+
# In this case some or all documents failed. Right now we will not
1508+
# update the local cache, but raise an error instead.
1509+
# Reason: We cannot set silent to True, because we need as it does
1510+
# not report errors then. We need to update the driver to also pass
1511+
# the errors back to the user, then we can adjust the behavior here.
1512+
# This will also save network traffic and local computation time.
1513+
errors = []
1514+
for collections_results in result:
1515+
for collection_result in collections_results:
1516+
errors.append(collection_result)
1517+
warnings.warn(
1518+
"Failed to insert at least one node. Will not update local cache."
1519+
)
1520+
raise ArangoDBBatchError(errors)
14871521

14881522
@logger_debug
14891523
def values(self) -> Any:
@@ -1507,22 +1541,44 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:
15071541
yield from result
15081542

15091543
@logger_debug
1510-
def _fetch_all(self) -> None:
1511-
self.clear()
1544+
def __set_adj_elements(
1545+
self,
1546+
edges_dict: (
1547+
GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict
1548+
),
1549+
) -> None:
1550+
def set_edge_graph(
1551+
src_node_id: str, dst_node_id: str, edge: dict[str, Any]
1552+
) -> EdgeAttrDict:
1553+
adjlist_inner_dict = self.data[src_node_id]
1554+
1555+
edge_attr_dict: EdgeAttrDict
1556+
edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge)
1557+
1558+
adjlist_inner_dict.data[dst_node_id] = edge_attr_dict
1559+
1560+
return edge_attr_dict
1561+
1562+
def set_edge_multigraph(
1563+
src_node_id: str, dst_node_id: str, edges: dict[int, dict[str, Any]]
1564+
) -> EdgeKeyDict:
1565+
adjlist_inner_dict = self.data[src_node_id]
1566+
1567+
edge_key_dict = adjlist_inner_dict.edge_key_dict_factory()
1568+
edge_key_dict.src_node_id = src_node_id
1569+
edge_key_dict.dst_node_id = dst_node_id
1570+
edge_key_dict.FETCHED_ALL_DATA = True
1571+
edge_key_dict.FETCHED_ALL_IDS = True
1572+
1573+
for edge in edges.values():
1574+
edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge)
1575+
edge_key_dict.data[edge["_id"]] = edge_attr_dict
15121576

1513-
def set_adj_inner_dict(
1514-
adj_outer_dict: AdjListOuterDict, node_id: str
1515-
) -> AdjListInnerDict:
1516-
if node_id in adj_outer_dict.data:
1517-
return adj_outer_dict.data[node_id]
1577+
adjlist_inner_dict.data[dst_node_id] = edge_key_dict
15181578

1519-
adj_inner_dict = self.adjlist_inner_dict_factory()
1520-
adj_inner_dict.src_node_id = node_id
1521-
adj_inner_dict.FETCHED_ALL_DATA = True
1522-
adj_inner_dict.FETCHED_ALL_IDS = True
1523-
adj_outer_dict.data[node_id] = adj_inner_dict
1579+
return edge_key_dict
15241580

1525-
return adj_inner_dict
1581+
set_edge_func = set_edge_multigraph if self.is_multigraph else set_edge_graph
15261582

15271583
def propagate_edge_undirected(
15281584
src_node_id: str,
@@ -1536,7 +1592,7 @@ def propagate_edge_directed(
15361592
dst_node_id: str,
15371593
edge_key_or_attr_dict: EdgeKeyDict | EdgeAttrDict,
15381594
) -> None:
1539-
set_adj_inner_dict(self.mirror, dst_node_id)
1595+
self.__set_adj_inner_dict(self.mirror, dst_node_id)
15401596
self.mirror.data[dst_node_id].data[src_node_id] = edge_key_or_attr_dict
15411597

15421598
def propagate_edge_directed_symmetric(
@@ -1546,7 +1602,7 @@ def propagate_edge_directed_symmetric(
15461602
) -> None:
15471603
propagate_edge_directed(src_node_id, dst_node_id, edge_key_or_attr_dict)
15481604
propagate_edge_undirected(src_node_id, dst_node_id, edge_key_or_attr_dict)
1549-
set_adj_inner_dict(self.mirror, src_node_id)
1605+
self.__set_adj_inner_dict(self.mirror, src_node_id)
15501606
self.mirror.data[src_node_id].data[dst_node_id] = edge_key_or_attr_dict
15511607

15521608
propagate_edge_func = (
@@ -1559,38 +1615,39 @@ def propagate_edge_directed_symmetric(
15591615
)
15601616
)
15611617

1562-
def set_edge_graph(
1563-
src_node_id: str, dst_node_id: str, edge: dict[str, Any]
1564-
) -> EdgeAttrDict:
1565-
adjlist_inner_dict = self.data[src_node_id]
1566-
1567-
edge_attr_dict: EdgeAttrDict
1568-
edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge)
1618+
for src_node_id, inner_dict in edges_dict.items():
1619+
for dst_node_id, edge_or_edges in inner_dict.items():
15691620

1570-
adjlist_inner_dict.data[dst_node_id] = edge_attr_dict
1621+
if not self.is_directed:
1622+
if src_node_id in self.data:
1623+
if dst_node_id in self.data[src_node_id].data:
1624+
continue # can skip due not directed
15711625

1572-
return edge_attr_dict
1626+
self.__set_adj_inner_dict(self, src_node_id)
1627+
self.__set_adj_inner_dict(self, dst_node_id)
1628+
edge_attr_or_key_dict = set_edge_func( # type: ignore[operator]
1629+
src_node_id, dst_node_id, edge_or_edges
1630+
)
15731631

1574-
def set_edge_multigraph(
1575-
src_node_id: str, dst_node_id: str, edges: dict[int, dict[str, Any]]
1576-
) -> EdgeKeyDict:
1577-
adjlist_inner_dict = self.data[src_node_id]
1632+
propagate_edge_func(src_node_id, dst_node_id, edge_attr_or_key_dict)
15781633

1579-
edge_key_dict = adjlist_inner_dict.edge_key_dict_factory()
1580-
edge_key_dict.src_node_id = src_node_id
1581-
edge_key_dict.dst_node_id = dst_node_id
1582-
edge_key_dict.FETCHED_ALL_DATA = True
1583-
edge_key_dict.FETCHED_ALL_IDS = True
1634+
def __set_adj_inner_dict(
1635+
self, adj_outer_dict: AdjListOuterDict, node_id: str
1636+
) -> AdjListInnerDict:
1637+
if node_id in adj_outer_dict.data:
1638+
return adj_outer_dict.data[node_id]
15841639

1585-
for edge in edges.values():
1586-
edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge)
1587-
edge_key_dict.data[edge["_id"]] = edge_attr_dict
1588-
1589-
adjlist_inner_dict.data[dst_node_id] = edge_key_dict
1640+
adj_inner_dict = self.adjlist_inner_dict_factory()
1641+
adj_inner_dict.src_node_id = node_id
1642+
adj_inner_dict.FETCHED_ALL_DATA = True
1643+
adj_inner_dict.FETCHED_ALL_IDS = True
1644+
adj_outer_dict.data[node_id] = adj_inner_dict
15901645

1591-
return edge_key_dict
1646+
return adj_inner_dict
15921647

1593-
set_edge_func = set_edge_multigraph if self.is_multigraph else set_edge_graph
1648+
@logger_debug
1649+
def _fetch_all(self) -> None:
1650+
self.clear()
15941651

15951652
(
15961653
_,
@@ -1613,21 +1670,7 @@ def set_edge_multigraph(
16131670
if self.is_directed:
16141671
adj_dict = adj_dict["succ"]
16151672

1616-
for src_node_id, inner_dict in adj_dict.items():
1617-
for dst_node_id, edge_or_edges in inner_dict.items():
1618-
1619-
if not self.is_directed:
1620-
if src_node_id in self.data:
1621-
if dst_node_id in self.data[src_node_id].data:
1622-
continue # can skip due not directed
1623-
1624-
set_adj_inner_dict(self, src_node_id)
1625-
set_adj_inner_dict(self, dst_node_id)
1626-
edge_attr_or_key_dict = set_edge_func( # type: ignore[operator]
1627-
src_node_id, dst_node_id, edge_or_edges
1628-
)
1629-
1630-
propagate_edge_func(src_node_id, dst_node_id, edge_attr_or_key_dict)
1673+
self.__set_adj_elements(adj_dict)
16311674

16321675
self.FETCHED_ALL_DATA = True
16331676
self.FETCHED_ALL_IDS = True

nx_arangodb/classes/function.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from nx_arangodb.logger import logger
3131

3232
from ..exceptions import AQLMultipleResultsFound, InvalidTraversalDirection
33+
from .enum import GraphType
3334

3435

3536
def do_load_all_edge_attributes(attributes: set[str]) -> bool:
@@ -242,7 +243,7 @@ def wrapper(self: Any, data: Any, *args: Any, **kwargs: Any) -> Any:
242243
return wrapper
243244

244245

245-
RESERVED_KEYS = {"_id", "_key", "_rev"}
246+
RESERVED_KEYS = {"_id", "_key", "_rev", "_from", "_to"}
246247

247248

248249
def key_is_not_reserved(func: Callable[..., Any]) -> Any:
@@ -744,3 +745,107 @@ def upsert_collection_documents(db: StandardDatabase, separated: Any) -> Any:
744745
)
745746

746747
return results
748+
749+
750+
def separate_edges_by_collections_graph(edges: Any, default_node_type: str) -> Any:
751+
"""
752+
Separate the dictionary into collections for Graph and DiGraph types.
753+
:param edges: The input dictionary with keys that must contain the real doc id.
754+
:param default_node_type: The name of the default collection for keys without '/'.
755+
:return: A dictionary where the keys are collection names and the
756+
values are dictionaries of key-value pairs belonging to those collections.
757+
"""
758+
separated: Any = {}
759+
760+
for from_doc_id, target_dict in edges.items():
761+
for to_doc_id, edge_doc in target_dict.items():
762+
assert edge_doc is not None and "_id" in edge_doc
763+
edge_collection_name = get_node_type_and_id(
764+
edge_doc["_id"], default_node_type
765+
)[0]
766+
767+
if edge_collection_name not in separated:
768+
separated[edge_collection_name] = []
769+
770+
edge_doc["_from"] = from_doc_id
771+
edge_doc["_to"] = to_doc_id
772+
773+
separated[edge_collection_name].append(edge_doc)
774+
775+
return separated
776+
777+
778+
def separate_edges_by_collections_multigraph(edges: Any, default_node_type: str) -> Any:
779+
"""
780+
Separate the dictionary into collections for MultiGraph and MultiDiGraph types.
781+
:param edges: The input dictionary with keys that must contain the real doc id.
782+
:param default_node_type: The name of the default collection for keys without '/'.
783+
:return: A dictionary where the keys are collection names and the
784+
values are dictionaries of key-value pairs belonging to those collections.
785+
"""
786+
separated: Any = {}
787+
788+
for from_doc_id, target_dict in edges.items():
789+
for to_doc_id, edge_doc in target_dict.items():
790+
# edge_doc is expected to be a list of edges in Multi(Di)Graph
791+
for m_edge_id, m_edge_doc in edge_doc.items():
792+
assert m_edge_doc is not None and "_id" in m_edge_doc
793+
edge_collection_name = get_node_type_and_id(
794+
m_edge_doc["_id"], default_node_type
795+
)[0]
796+
797+
if edge_collection_name not in separated:
798+
separated[edge_collection_name] = []
799+
800+
m_edge_doc["_from"] = from_doc_id
801+
m_edge_doc["_to"] = to_doc_id
802+
803+
separated[edge_collection_name].append(m_edge_doc)
804+
805+
return separated
806+
807+
808+
def separate_edges_by_collections(
809+
edges: Any, graph_type: str, default_node_type: str
810+
) -> Any:
811+
"""
812+
Wrapper function to separate the dictionary into collections based on graph type.
813+
:param edges: The input dictionary with keys that must contain the real doc id.
814+
:param graph_type: The type of graph to create.
815+
:param default_node_type: The name of the default collection for keys without '/'.
816+
:return: A dictionary where the keys are collection names and the
817+
values are dictionaries of key-value pairs belonging to those collections.
818+
"""
819+
if graph_type in [GraphType.Graph.name, GraphType.DiGraph.name]:
820+
return separate_edges_by_collections_graph(edges, default_node_type)
821+
elif graph_type in [GraphType.MultiGraph.name, GraphType.MultiDiGraph.name]:
822+
return separate_edges_by_collections_multigraph(edges, default_node_type)
823+
else:
824+
raise ValueError(f"Unsupported graph type: {graph_type}")
825+
826+
827+
def upsert_collection_edges(db: StandardDatabase, separated: Any) -> Any:
828+
"""
829+
Process each collection in the separated dictionary.
830+
:param db: The ArangoDB database object.
831+
:param separated: A dictionary where the keys are collection names and the
832+
values are dictionaries
833+
of key-value pairs belonging to those collections.
834+
:return: A list of results from the insert_many operation.
835+
If inserting a document fails, the exception is not raised but
836+
returned as an object in the result list.
837+
"""
838+
839+
results = []
840+
841+
for collection_name, documents_list in separated.items():
842+
collection = db.collection(collection_name)
843+
results.append(
844+
collection.insert_many(
845+
documents_list,
846+
silent=False,
847+
overwrite_mode="update",
848+
)
849+
)
850+
851+
return results

0 commit comments

Comments
 (0)