Skip to content

[GA-153-1] Implement EdgeAttrDict update method #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 87 additions & 19 deletions nx_arangodb/classes/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,26 @@

from __future__ import annotations

import warnings
from collections import UserDict, defaultdict
from collections.abc import Iterator
from typing import Any, Callable, Generator

from arango.database import StandardDatabase
from arango.exceptions import DocumentInsertError
from arango.exceptions import ArangoError, DocumentInsertError
from arango.graph import Graph

from nx_arangodb.logger import logger

from ..typing import AdjDict
from ..utils.arangodb import (
ArangoDBBatchError,
check_list_for_errors,
separate_edges_by_collections,
separate_nodes_by_collections,
upsert_collection_documents,
upsert_collection_edges,
)
from .function import (
aql,
aql_as_list,
Expand Down Expand Up @@ -496,11 +506,45 @@ def clear(self) -> None:
# for collection in self.graph.vertex_collections():
# self.graph.vertex_collection(collection).truncate()

@keys_are_strings
@logger_debug
def update_local_nodes(self, nodes: Any) -> None:
for node_id, node_data in nodes.items():
node_attr_dict = self.node_attr_dict_factory()
node_attr_dict.node_id = node_id
node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, node_data)

self.data[node_id] = node_attr_dict

@keys_are_strings
@logger_debug
def update(self, nodes: Any) -> None:
"""g._node.update({'node/1': {'foo': 'bar'}, 'node/2': {'baz': 'qux'}})"""
raise NotImplementedError("NodeDict.update()")
separated_by_collection = separate_nodes_by_collections(
nodes, self.default_node_type
)

result = upsert_collection_documents(self.db, separated_by_collection)

all_good = check_list_for_errors(result)
if all_good:
# Means no single operation failed, in this case we update the local cache
self.update_local_nodes(nodes)
else:
# In this case some or all documents failed. Right now we will not
# update the local cache, but raise an error instead.
# Reason: We cannot set silent to True, because we need as it does
# not report errors then. We need to update the driver to also pass
# the errors back to the user, then we can adjust the behavior here.
# This will also save network traffic and local computation time.
errors = []
for collections_results in result:
for collection_result in collections_results:
errors.append(collection_result)
warnings.warn(
"Failed to insert at least one node. Will not update local cache."
)
raise ArangoDBBatchError(errors)

# TODO: Revisit typing of return value
@logger_debug
Expand Down Expand Up @@ -614,7 +658,7 @@ def __init__(
self.graph = graph
self.edge_id: str | None = None

# NodeAttrDict may be a child of another NodeAttrDict
# EdgeAttrDict may be a child of another EdgeAttrDict
# e.g G._adj['node/1']['node/2']['object']['foo'] = 'bar'
# In this case, **parent_keys** would be ['object']
# and **root** would be G._adj['node/1']['node/2']
Expand Down Expand Up @@ -1138,8 +1182,29 @@ def clear(self) -> None:
@keys_are_strings
@logger_debug
def update(self, edges: Any) -> None:
"""g._adj.update({'node/1': {'node/2': {'foo': 'bar'}})"""
raise NotImplementedError("AdjListOuterDict.update()")
"""g._adj.update({'node/1': {'node/2': {'_id': 'foo/bar', 'foo': "bar"}})"""
separated_by_edge_collection = separate_edges_by_collections(edges)
result = upsert_collection_edges(self.db, separated_by_edge_collection)

all_good = check_list_for_errors(result)
if all_good:
# Means no single operation failed, in this case we update the local cache
self.__set_adj_elements(edges)
else:
# In this case some or all documents failed. Right now we will not
# update the local cache, but raise an error instead.
# Reason: We cannot set silent to True, because we need as it does
# not report errors then. We need to update the driver to also pass
# the errors back to the user, then we can adjust the behavior here.
# This will also save network traffic and local computation time.
errors = []
for collections_results in result:
for collection_result in collections_results:
errors.append(collection_result)
warnings.warn(
"Failed to insert at least one node. Will not update local cache."
)
raise ArangoDBBatchError(errors)

# TODO: Revisit typing of return value
@logger_debug
Expand Down Expand Up @@ -1171,25 +1236,15 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:
yield from result

@logger_debug
def __fetch_all(self) -> None:
self.clear()

_, adj_dict, _, _, _ = get_arangodb_graph(
self.graph,
load_node_dict=False,
load_adj_dict=True,
is_directed=False, # TODO: Abstract based on Graph type
is_multigraph=False, # TODO: Abstract based on Graph type
load_coo=False,
)

for src_node_id, inner_dict in adj_dict.items():
def __set_adj_elements(self, edges_dict: AdjDict) -> None:
for src_node_id, inner_dict in edges_dict.items():
for dst_node_id, edge in inner_dict.items():

if src_node_id in self.data:
if dst_node_id in self.data[src_node_id].data:
continue

# TODO: Clean up those two if/else statements later
if src_node_id in self.data:
src_inner_dict = self.data[src_node_id]
else:
Expand All @@ -1209,8 +1264,21 @@ def __fetch_all(self) -> None:
edge_attr_dict = src_inner_dict.edge_attr_dict_factory()
edge_attr_dict.edge_id = edge["_id"]
edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge)

self.data[src_node_id].data[dst_node_id] = edge_attr_dict
self.data[dst_node_id].data[src_node_id] = edge_attr_dict

@logger_debug
def __fetch_all(self) -> None:
self.clear()

_, adj_dict, _, _, _ = get_arangodb_graph(
self.graph,
load_node_dict=False,
load_adj_dict=True,
is_directed=False, # TODO: Abstract based on Graph type
is_multigraph=False, # TODO: Abstract based on Graph type
load_coo=False,
)

self.__set_adj_elements(adj_dict)
self.FETCHED_ALL_DATA = True
5 changes: 3 additions & 2 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
GraphDoesNotExist,
InvalidTraversalDirection,
)
from ..typing import AdjDict


def get_arangodb_graph(
Expand All @@ -35,7 +36,7 @@ def get_arangodb_graph(
load_coo: bool,
) -> Tuple[
dict[str, dict[str, Any]],
dict[str, dict[str, dict[str, Any]]],
AdjDict,
npt.NDArray[np.int64],
npt.NDArray[np.int64],
dict[str, int],
Expand Down Expand Up @@ -152,7 +153,7 @@ def wrapper(
return wrapper


RESERVED_KEYS = {"_id", "_key", "_rev"}
RESERVED_KEYS = {"_id", "_key", "_rev", "_from", "_to"}


def key_is_not_reserved(func: Callable[..., Any]) -> Any:
Expand Down
37 changes: 34 additions & 3 deletions nx_arangodb/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
from __future__ import annotations

from collections.abc import Hashable
from typing import TypeVar
from typing import Any, Dict, TypeVar

import cupy as cp
import numpy as np
import numpy.typing as npt

from nx_arangodb.logger import logger

try:
import cupy as cp
except ModuleNotFoundError as e:
GPU_ENABLED = False
logger.info(f"NXCG is disabled. {e}.")


AttrKey = TypeVar("AttrKey", bound=Hashable)
EdgeKey = TypeVar("EdgeKey", bound=Hashable)
NodeKey = TypeVar("NodeKey", bound=Hashable)
Expand All @@ -18,6 +25,30 @@
IndexValue = TypeVar("IndexValue")
Dtype = TypeVar("Dtype")

# AdjDict is a dictionary of dictionaries of dictionaries
# The outer dict is holding _from_id(s) as keys
# - It may or may not hold valid ArangoDB document _id(s)
# The inner dict is holding _to_id(s) as keys
# - It may or may not hold valid ArangoDB document _id(s)
# The next inner dict contains then the actual edges data (key, val)
# Example
# {
# 'person/1': {
# 'person/32': {
# '_id': 'knows/16',
# 'extraValue': '16'
# },
# 'person/33': {
# '_id': 'knows/17',
# 'extraValue': '17'
# }
# ...
# }
# ...
# }
# The above example is a graph with 2 edges from person/1 to person/32 and person/33
AdjDict = Dict[str, Dict[str, Dict[str, Any]]]


class any_ndarray:
def __class_getitem__(cls, item):
Expand Down
Loading