Skip to content

[GA-153-1] Implement EdgeAttrDict update method (new) #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
f85d4bf
remove not needed imports, fix typos
hkernbach Aug 15, 2024
20b989b
moved over code from pr
hkernbach Aug 15, 2024
1fcbb10
fmt and lint
hkernbach Aug 15, 2024
9d8d1f6
Merge remote-tracking branch 'origin/main' into feature/node-dict
hkernbach Aug 15, 2024
22d0cd9
fix code, added test for graphs, added todo
hkernbach Aug 15, 2024
1173880
adapt MultiGraph to old code
hkernbach Aug 15, 2024
492a6c9
flake8
hkernbach Aug 15, 2024
0fa5db4
removed auto added import
hkernbach Aug 15, 2024
dce4e85
add update method to CustomNodeView
hkernbach Aug 16, 2024
73c194e
update_local_nodes as private method
hkernbach Aug 16, 2024
1c6b158
user logger instead of warnings
hkernbach Aug 16, 2024
a0de6e8
remove assertion, raise in case wrong key is given
hkernbach Aug 16, 2024
685802d
move test only func into tst, removed unused func
hkernbach Aug 16, 2024
bcc6ddf
remove import
hkernbach Aug 16, 2024
6cd34e5
TODO WIP
hkernbach Aug 16, 2024
17f4323
Merge branch 'feature/node-dict' into feature/edge-attr-dict
hkernbach Aug 16, 2024
abc65fb
fix typo
hkernbach Aug 16, 2024
e38dfac
Merge branch 'feature/node-dict' into feature/edge-attr-dict
hkernbach Aug 16, 2024
1ff708a
disabled this for now
hkernbach Aug 16, 2024
e9b14e1
Merge remote-tracking branch 'origin/main' into feature/node-dict
hkernbach Aug 16, 2024
fcd4b14
fix mypy
hkernbach Aug 16, 2024
54048aa
Merge remote-tracking branch 'origin/main' into feature/node-dict
hkernbach Aug 16, 2024
770b7a6
Merge remote-tracking branch 'origin/main' into feature/node-dict
hkernbach Aug 16, 2024
00e5fe8
py to 3.12
hkernbach Aug 16, 2024
73fa537
py to 3.12.3
hkernbach Aug 16, 2024
4c65ac4
py to 3.12.5
hkernbach Aug 16, 2024
938b30a
py to 3.12.5 ..............
hkernbach Aug 16, 2024
8cccca4
back to 3.12.2
hkernbach Aug 16, 2024
049d637
back to 3.10
hkernbach Aug 16, 2024
c98aa93
Merge branch 'feature/node-dict' into feature/edge-attr-dict
hkernbach Aug 16, 2024
e56ce4f
attempt to resolve merge conflicts
hkernbach Aug 16, 2024
f43a6d5
fixes after merge
hkernbach Aug 16, 2024
ffe843b
fix use of method
hkernbach Aug 16, 2024
e1d078c
added core view
hkernbach Aug 16, 2024
b8ccc75
Update nx_arangodb/classes/function.py
hkernbach Aug 20, 2024
3a87bfa
Update nx_arangodb/classes/function.py
hkernbach Aug 20, 2024
a4c69f1
Update tests/test.py
hkernbach Aug 20, 2024
52e7dfd
Update tests/test.py
hkernbach Aug 20, 2024
da332e0
optimize separate_edges_by_collections
hkernbach Aug 20, 2024
cc18d04
fmt
hkernbach Aug 20, 2024
a9dc300
Merge branch 'main' into feature/edge-attr-dict
hkernbach Aug 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions nx_arangodb/classes/coreviews.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import networkx as nx


class CustomAdjacencyView(nx.classes.coreviews.AdjacencyView):

def update(self, data):
return self._atlas.update(data)
159 changes: 101 additions & 58 deletions nx_arangodb/classes/dict/adj.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from collections import UserDict
from collections.abc import Iterator
from itertools import islice
Expand All @@ -8,12 +9,19 @@
from arango.database import StandardDatabase
from arango.exceptions import DocumentDeleteError
from arango.graph import Graph
from phenolrs.networkx.typings import (
DiGraphAdjDict,
GraphAdjDict,
MultiDiGraphAdjDict,
MultiGraphAdjDict,
)

from nx_arangodb.exceptions import EdgeTypeAmbiguity, MultipleEdgesFound
from nx_arangodb.logger import logger

from ..enum import DIRECTED_GRAPH_TYPES, MULTIGRAPH_TYPES, GraphType, TraversalDirection
from ..function import (
ArangoDBBatchError,
aql,
aql_doc_get_key,
aql_doc_has_key,
Expand All @@ -23,6 +31,7 @@
aql_edge_get,
aql_edge_id,
aql_fetch_data_edge,
check_list_for_errors,
doc_insert,
doc_update,
get_arangodb_graph,
Expand All @@ -36,6 +45,8 @@
keys_are_not_reserved,
keys_are_strings,
logger_debug,
separate_edges_by_collections,
upsert_collection_edges,
)

#############
Expand Down Expand Up @@ -169,7 +180,7 @@ def __init__(
self.graph = graph
self.edge_id: str | None = None

# NodeAttrDict may be a child of another NodeAttrDict
# EdgeAttrDict may be a child of another EdgeAttrDict
# e.g G._adj['node/1']['node/2']['object']['foo'] = 'bar'
# In this case, **parent_keys** would be ['object']
# and **root** would be G._adj['node/1']['node/2']
Expand Down Expand Up @@ -1482,8 +1493,31 @@ def clear(self) -> None:
@keys_are_strings
@logger_debug
def update(self, edges: Any) -> None:
"""g._adj.update({'node/1': {'node/2': {'foo': 'bar'}})"""
raise NotImplementedError("AdjListOuterDict.update()")
"""g._adj.update({'node/1': {'node/2': {'_id': 'foo/bar', 'foo': "bar"}})"""
separated_by_edge_collection = separate_edges_by_collections(
edges, graph_type=self.graph_type, default_node_type=self.default_node_type
)
result = upsert_collection_edges(self.db, separated_by_edge_collection)

all_good = check_list_for_errors(result)
if all_good:
# Means no single operation failed, in this case we update the local cache
self.__set_adj_elements(edges)
else:
# In this case some or all documents failed. Right now we will not
# update the local cache, but raise an error instead.
# Reason: We cannot set silent to True, because we need as it does
# not report errors then. We need to update the driver to also pass
# the errors back to the user, then we can adjust the behavior here.
# This will also save network traffic and local computation time.
errors = []
for collections_results in result:
for collection_result in collections_results:
errors.append(collection_result)
warnings.warn(
"Failed to insert at least one node. Will not update local cache."
)
raise ArangoDBBatchError(errors)

@logger_debug
def values(self) -> Any:
Expand All @@ -1507,22 +1541,44 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:
yield from result

@logger_debug
def _fetch_all(self) -> None:
self.clear()
def __set_adj_elements(
self,
edges_dict: (
GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict
),
) -> None:
def set_edge_graph(
src_node_id: str, dst_node_id: str, edge: dict[str, Any]
) -> EdgeAttrDict:
adjlist_inner_dict = self.data[src_node_id]

edge_attr_dict: EdgeAttrDict
edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge)

adjlist_inner_dict.data[dst_node_id] = edge_attr_dict

return edge_attr_dict

def set_edge_multigraph(
src_node_id: str, dst_node_id: str, edges: dict[int, dict[str, Any]]
) -> EdgeKeyDict:
adjlist_inner_dict = self.data[src_node_id]

edge_key_dict = adjlist_inner_dict.edge_key_dict_factory()
edge_key_dict.src_node_id = src_node_id
edge_key_dict.dst_node_id = dst_node_id
edge_key_dict.FETCHED_ALL_DATA = True
edge_key_dict.FETCHED_ALL_IDS = True

for edge in edges.values():
edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge)
edge_key_dict.data[edge["_id"]] = edge_attr_dict

def set_adj_inner_dict(
adj_outer_dict: AdjListOuterDict, node_id: str
) -> AdjListInnerDict:
if node_id in adj_outer_dict.data:
return adj_outer_dict.data[node_id]
adjlist_inner_dict.data[dst_node_id] = edge_key_dict

adj_inner_dict = self.adjlist_inner_dict_factory()
adj_inner_dict.src_node_id = node_id
adj_inner_dict.FETCHED_ALL_DATA = True
adj_inner_dict.FETCHED_ALL_IDS = True
adj_outer_dict.data[node_id] = adj_inner_dict
return edge_key_dict

return adj_inner_dict
set_edge_func = set_edge_multigraph if self.is_multigraph else set_edge_graph

def propagate_edge_undirected(
src_node_id: str,
Expand All @@ -1536,7 +1592,7 @@ def propagate_edge_directed(
dst_node_id: str,
edge_key_or_attr_dict: EdgeKeyDict | EdgeAttrDict,
) -> None:
set_adj_inner_dict(self.mirror, dst_node_id)
self.__set_adj_inner_dict(self.mirror, dst_node_id)
self.mirror.data[dst_node_id].data[src_node_id] = edge_key_or_attr_dict

def propagate_edge_directed_symmetric(
Expand All @@ -1546,7 +1602,7 @@ def propagate_edge_directed_symmetric(
) -> None:
propagate_edge_directed(src_node_id, dst_node_id, edge_key_or_attr_dict)
propagate_edge_undirected(src_node_id, dst_node_id, edge_key_or_attr_dict)
set_adj_inner_dict(self.mirror, src_node_id)
self.__set_adj_inner_dict(self.mirror, src_node_id)
self.mirror.data[src_node_id].data[dst_node_id] = edge_key_or_attr_dict

propagate_edge_func = (
Expand All @@ -1559,38 +1615,39 @@ def propagate_edge_directed_symmetric(
)
)

def set_edge_graph(
src_node_id: str, dst_node_id: str, edge: dict[str, Any]
) -> EdgeAttrDict:
adjlist_inner_dict = self.data[src_node_id]

edge_attr_dict: EdgeAttrDict
edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge)
for src_node_id, inner_dict in edges_dict.items():
for dst_node_id, edge_or_edges in inner_dict.items():

adjlist_inner_dict.data[dst_node_id] = edge_attr_dict
if not self.is_directed:
if src_node_id in self.data:
if dst_node_id in self.data[src_node_id].data:
continue # can skip due not directed

return edge_attr_dict
self.__set_adj_inner_dict(self, src_node_id)
self.__set_adj_inner_dict(self, dst_node_id)
edge_attr_or_key_dict = set_edge_func( # type: ignore[operator]
src_node_id, dst_node_id, edge_or_edges
)

def set_edge_multigraph(
src_node_id: str, dst_node_id: str, edges: dict[int, dict[str, Any]]
) -> EdgeKeyDict:
adjlist_inner_dict = self.data[src_node_id]
propagate_edge_func(src_node_id, dst_node_id, edge_attr_or_key_dict)

edge_key_dict = adjlist_inner_dict.edge_key_dict_factory()
edge_key_dict.src_node_id = src_node_id
edge_key_dict.dst_node_id = dst_node_id
edge_key_dict.FETCHED_ALL_DATA = True
edge_key_dict.FETCHED_ALL_IDS = True
def __set_adj_inner_dict(
self, adj_outer_dict: AdjListOuterDict, node_id: str
) -> AdjListInnerDict:
if node_id in adj_outer_dict.data:
return adj_outer_dict.data[node_id]

for edge in edges.values():
edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge)
edge_key_dict.data[edge["_id"]] = edge_attr_dict

adjlist_inner_dict.data[dst_node_id] = edge_key_dict
adj_inner_dict = self.adjlist_inner_dict_factory()
adj_inner_dict.src_node_id = node_id
adj_inner_dict.FETCHED_ALL_DATA = True
adj_inner_dict.FETCHED_ALL_IDS = True
adj_outer_dict.data[node_id] = adj_inner_dict

return edge_key_dict
return adj_inner_dict

set_edge_func = set_edge_multigraph if self.is_multigraph else set_edge_graph
@logger_debug
def _fetch_all(self) -> None:
self.clear()

(
_,
Expand All @@ -1613,21 +1670,7 @@ def set_edge_multigraph(
if self.is_directed:
adj_dict = adj_dict["succ"]

for src_node_id, inner_dict in adj_dict.items():
for dst_node_id, edge_or_edges in inner_dict.items():

if not self.is_directed:
if src_node_id in self.data:
if dst_node_id in self.data[src_node_id].data:
continue # can skip due not directed

set_adj_inner_dict(self, src_node_id)
set_adj_inner_dict(self, dst_node_id)
edge_attr_or_key_dict = set_edge_func( # type: ignore[operator]
src_node_id, dst_node_id, edge_or_edges
)

propagate_edge_func(src_node_id, dst_node_id, edge_attr_or_key_dict)
self.__set_adj_elements(adj_dict)

self.FETCHED_ALL_DATA = True
self.FETCHED_ALL_IDS = True
Expand Down
107 changes: 106 additions & 1 deletion nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from nx_arangodb.logger import logger

from ..exceptions import AQLMultipleResultsFound, InvalidTraversalDirection
from .enum import GraphType


def do_load_all_edge_attributes(attributes: set[str]) -> bool:
Expand Down Expand Up @@ -242,7 +243,7 @@ def wrapper(self: Any, data: Any, *args: Any, **kwargs: Any) -> Any:
return wrapper


RESERVED_KEYS = {"_id", "_key", "_rev"}
RESERVED_KEYS = {"_id", "_key", "_rev", "_from", "_to"}


def key_is_not_reserved(func: Callable[..., Any]) -> Any:
Expand Down Expand Up @@ -744,3 +745,107 @@ def upsert_collection_documents(db: StandardDatabase, separated: Any) -> Any:
)

return results


def separate_edges_by_collections_graph(edges: Any, default_node_type: str) -> Any:
"""
Separate the dictionary into collections for Graph and DiGraph types.
:param edges: The input dictionary with keys that must contain the real doc id.
:param default_node_type: The name of the default collection for keys without '/'.
:return: A dictionary where the keys are collection names and the
values are dictionaries of key-value pairs belonging to those collections.
"""
separated: Any = {}

for from_doc_id, target_dict in edges.items():
for to_doc_id, edge_doc in target_dict.items():
assert edge_doc is not None and "_id" in edge_doc
edge_collection_name = get_node_type_and_id(
edge_doc["_id"], default_node_type
)[0]

if edge_collection_name not in separated:
separated[edge_collection_name] = []

edge_doc["_from"] = from_doc_id
edge_doc["_to"] = to_doc_id

separated[edge_collection_name].append(edge_doc)

return separated


def separate_edges_by_collections_multigraph(edges: Any, default_node_type: str) -> Any:
"""
Separate the dictionary into collections for MultiGraph and MultiDiGraph types.
:param edges: The input dictionary with keys that must contain the real doc id.
:param default_node_type: The name of the default collection for keys without '/'.
:return: A dictionary where the keys are collection names and the
values are dictionaries of key-value pairs belonging to those collections.
"""
separated: Any = {}

for from_doc_id, target_dict in edges.items():
for to_doc_id, edge_doc in target_dict.items():
# edge_doc is expected to be a list of edges in Multi(Di)Graph
for m_edge_id, m_edge_doc in edge_doc.items():
assert m_edge_doc is not None and "_id" in m_edge_doc
edge_collection_name = get_node_type_and_id(
m_edge_doc["_id"], default_node_type
)[0]

if edge_collection_name not in separated:
separated[edge_collection_name] = []

m_edge_doc["_from"] = from_doc_id
m_edge_doc["_to"] = to_doc_id

separated[edge_collection_name].append(m_edge_doc)

return separated


def separate_edges_by_collections(
edges: Any, graph_type: str, default_node_type: str
) -> Any:
"""
Wrapper function to separate the dictionary into collections based on graph type.
:param edges: The input dictionary with keys that must contain the real doc id.
:param graph_type: The type of graph to create.
:param default_node_type: The name of the default collection for keys without '/'.
:return: A dictionary where the keys are collection names and the
values are dictionaries of key-value pairs belonging to those collections.
"""
if graph_type in [GraphType.Graph.name, GraphType.DiGraph.name]:
return separate_edges_by_collections_graph(edges, default_node_type)
elif graph_type in [GraphType.MultiGraph.name, GraphType.MultiDiGraph.name]:
return separate_edges_by_collections_multigraph(edges, default_node_type)
else:
raise ValueError(f"Unsupported graph type: {graph_type}")


def upsert_collection_edges(db: StandardDatabase, separated: Any) -> Any:
"""
Process each collection in the separated dictionary.
:param db: The ArangoDB database object.
:param separated: A dictionary where the keys are collection names and the
values are dictionaries
of key-value pairs belonging to those collections.
:return: A list of results from the insert_many operation.
If inserting a document fails, the exception is not raised but
returned as an object in the result list.
"""

results = []

for collection_name, documents_list in separated.items():
collection = db.collection(collection_name)
results.append(
collection.insert_many(
documents_list,
silent=False,
overwrite_mode="update",
)
)

return results
Loading