Skip to content

GA-154 | update phenolrs usage & use nx.config #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 51 additions & 114 deletions nx_arangodb/classes/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
doc_get_or_insert,
doc_insert,
doc_update,
get_arangodb_graph,
get_node_id,
get_node_type_and_id,
key_is_not_reserved,
Expand Down Expand Up @@ -321,56 +322,6 @@ def __delitem__(self, key: str) -> None:
root_data = self.root.data if self.root else self.data
root_data["_rev"] = doc_update(self.db, self.node_id, update_dict)

# @logger_debug
# def __iter__(self) -> Iterator[str]:
# """for key in G._node['node/1']"""
# yield from aql_doc_get_keys(self.db, self.node_id, self.parent_keys)

# @logger_debug
# def __len__(self) -> int:
# """len(G._node['node/1'])"""
# return aql_doc_get_length(self.db, self.node_id, self.parent_keys)

# @logger_debug
# def keys(self) -> Any:
# """G._node['node/1'].keys()"""
# yield from self.__iter__()

# @logger_debug
# # TODO: Revisit typing of return value
# def values(self) -> Any:
# """G._node['node/1'].values()"""
# self.data = self.db.document(self.node_id)
# yield from self.data.values()

# @logger_debug
# # TODO: Revisit typing of return value
# def items(self) -> Any:
# """G._node['node/1'].items()"""

# # TODO: Revisit this lazy hack
# if self.parent_keys:
# yield from self.data.items()
# else:
# self.data = self.db.document(self.node_id)
# yield from self.data.items()

# ?
# def pull():
# pass

# ?
# def push():
# pass

# @logger_debug
# def clear(self) -> None:
# """G._node['node/1'].clear()"""
# self.data.clear()

# # if clear_remote:
# # doc_insert(self.db, self.node_id, silent=True, overwrite=True)

@keys_are_strings
@keys_are_not_reserved
# @values_are_json_serializable # TODO?
Expand Down Expand Up @@ -435,6 +386,9 @@ def __contains__(self, key: str) -> bool:
if node_id in self.data:
return True

if self.FETCHED_ALL_DATA:
return False

return bool(self.graph.has_vertex(node_id))

@key_is_string
Expand All @@ -446,6 +400,9 @@ def __getitem__(self, key: str) -> NodeAttrDict:
if vertex := self.data.get(node_id):
return vertex

if self.FETCHED_ALL_DATA:
raise KeyError(key)

if vertex := self.graph.vertex(node_id):
node_attr_dict: NodeAttrDict = self.node_attr_dict_factory()
node_attr_dict.node_id = node_id
Expand All @@ -472,7 +429,7 @@ def __setitem__(self, key: str, value: NodeAttrDict) -> None:

node_attr_dict = self.node_attr_dict_factory()
node_attr_dict.node_id = node_id
node_attr_dict.data = result
node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, result)

self.data[node_id] = node_attr_dict

Expand Down Expand Up @@ -570,16 +527,23 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:

@logger_debug
def __fetch_all(self):
self.data.clear()
for collection in self.graph.vertex_collections():
for doc in self.graph.vertex_collection(collection).all():
node_id = doc["_id"]
self.clear()

node_attr_dict = self.node_attr_dict_factory()
node_attr_dict.node_id = node_id
node_attr_dict.data = doc
node_dict, _, _, _, _ = get_arangodb_graph(
self.graph,
load_node_dict=True,
load_adj_dict=False,
load_adj_dict_as_directed=False, # not used
load_adj_dict_as_multigraph=False, # not used
load_coo=False,
)

self.data[node_id] = node_attr_dict
for node_id, node_data in node_dict.items():
node_attr_dict = self.node_attr_dict_factory()
node_attr_dict.node_id = node_id
node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, node_data)

self.data[node_id] = node_attr_dict

self.FETCHED_ALL_DATA = True

Expand Down Expand Up @@ -710,43 +674,6 @@ def __delitem__(self, key: str) -> None:
root_data = self.root.data if self.root else self.data
root_data["_rev"] = doc_update(self.db, self.edge_id, update_dict)

# @logger_debug
# def __iter__(self) -> Iterator[str]:
# """for key in G._adj['node/1']['node/2']"""
# assert self.edge_id
# yield from aql_doc_get_keys(self.db, self.edge_id)

# @logger_debug
# def __len__(self) -> int:
# """len(G._adj['node/1']['node/'2])"""
# assert self.edge_id
# return aql_doc_get_length(self.db, self.edge_id)

# # TODO: Revisit typing of return value
# @logger_debug
# def keys(self) -> Any:
# """G._adj['node/1']['node/'2].keys()"""
# return self.__iter__()

# # TODO: Revisit typing of return value
# @logger_debug
# def values(self) -> Any:
# """G._adj['node/1']['node/'2].values()"""
# self.data = self.db.document(self.edge_id)
# yield from self.data.values()

# # TODO: Revisit typing of return value
# @logger_debug
# def items(self) -> Any:
# """G._adj['node/1']['node/'2].items()"""
# self.data = self.db.document(self.edge_id)
# yield from self.data.items()

# @logger_debug
# def clear(self) -> None:
# """G._adj['node/1']['node/'2].clear()"""
# self.data.clear()

@keys_are_strings
@keys_are_not_reserved
@logger_debug
Expand Down Expand Up @@ -836,6 +763,9 @@ def __contains__(self, key: str) -> bool:
if dst_node_id in self.data:
return True

if self.FETCHED_ALL_DATA:
return False

result = aql_edge_exists(
self.db,
self.src_node_id,
Expand All @@ -859,6 +789,9 @@ def __getitem__(self, key: str) -> EdgeAttrDict:
self.data[dst_node_id] = edge
return edge # type: ignore # false positive

if self.FETCHED_ALL_DATA:
raise KeyError(key)

assert self.src_node_id
edge = aql_edge_get(
self.db,
Expand Down Expand Up @@ -1022,8 +955,7 @@ def items(self) -> Any:

@logger_debug
def __fetch_all(self) -> None:
if self.FETCHED_ALL_DATA:
return
assert self.src_node_id

self.clear()

Expand All @@ -1037,8 +969,7 @@ def __fetch_all(self) -> None:
for edge in aql(self.db, query, bind_vars):
edge_attr_dict = self.edge_attr_dict_factory()
edge_attr_dict.edge_id = edge["_id"]
edge_attr_dict.data = edge

edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge)
self.data[edge["_to"]] = edge_attr_dict

self.FETCHED_ALL_DATA = True
Expand Down Expand Up @@ -1100,6 +1031,9 @@ def __contains__(self, key: str) -> bool:
if node_id in self.data:
return True

if self.FETCHED_ALL_DATA:
return False

return bool(self.graph.has_vertex(node_id))

@key_is_string
Expand All @@ -1114,7 +1048,6 @@ def __getitem__(self, key: str) -> AdjListInnerDict:
if self.graph.has_vertex(node_id):
adjlist_inner_dict: AdjListInnerDict = self.adjlist_inner_dict_factory()
adjlist_inner_dict.src_node_id = node_id

self.data[node_id] = adjlist_inner_dict

return adjlist_inner_dict
Expand Down Expand Up @@ -1237,41 +1170,45 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:
result = aql_fetch_data_edge(self.db, e_cols, data, default)
yield from result

# TODO: Revisit this logic
@logger_debug
def __fetch_all(self) -> None:
if self.FETCHED_ALL_DATA:
return

self.clear()
# items = defaultdict(dict)
for ed in self.graph.edge_definitions():
collection = ed["edge_collection"]

for edge in self.graph.edge_collection(collection):
src_node_id = edge["_from"]
dst_node_id = edge["_to"]
_, adj_dict, _, _, _ = get_arangodb_graph(
self.graph,
load_node_dict=False,
load_adj_dict=True,
load_adj_dict_as_directed=False, # TODO: Abstract based on Graph type
load_adj_dict_as_multigraph=False, # TODO: Abstract based on Graph type
load_coo=False,
)

for src_node_id, inner_dict in adj_dict.items():
for dst_node_id, edge in inner_dict.items():

# items[src_node_id][dst_node_id] = edge
# items[dst_node_id][src_node_id] = edge
if src_node_id in self.data:
if dst_node_id in self.data[src_node_id].data:
continue

if src_node_id in self.data:
src_inner_dict = self.data[src_node_id]
else:
src_inner_dict = self.adjlist_inner_dict_factory()
src_inner_dict.src_node_id = src_node_id
src_inner_dict.FETCHED_ALL_DATA = True
self.data[src_node_id] = src_inner_dict

if dst_node_id in self.data:
dst_inner_dict = self.data[dst_node_id]
else:
dst_inner_dict = self.adjlist_inner_dict_factory()
dst_inner_dict.src_node_id = dst_node_id
src_inner_dict.FETCHED_ALL_DATA = True
self.data[dst_node_id] = dst_inner_dict

edge_attr_dict = src_inner_dict.edge_attr_dict_factory()
edge_attr_dict.edge_id = edge["_id"]
edge_attr_dict.data = edge
edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge)

self.data[src_node_id].data[dst_node_id] = edge_attr_dict
self.data[dst_node_id].data[src_node_id] = edge_attr_dict
Expand Down
3 changes: 0 additions & 3 deletions nx_arangodb/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,3 @@ def __set_graph_name(self, graph_name: str | None = None) -> None:

def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Cursor:
return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs)

def pull(self, load_node_dict=True, load_adj_dict=True, load_coo=True):
raise NotImplementedError("nxadb.DiGraph.pull() is not implemented yet.")
39 changes: 22 additions & 17 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from collections import UserDict
from typing import Any, Callable, Tuple

import networkx as nx
import numpy as np
import numpy.typing as npt
from arango.collection import StandardCollection
from arango.cursor import Cursor
from arango.database import StandardDatabase
from arango.graph import Graph

import nx_arangodb as nxadb
from nx_arangodb.logger import logger
Expand All @@ -25,10 +27,11 @@


def get_arangodb_graph(
G: nxadb.Graph | nxadb.DiGraph,
adb_graph: Graph,
load_node_dict: bool,
load_adj_dict: bool,
load_adj_dict_as_directed: bool,
load_adj_dict_as_multigraph: bool,
load_coo: bool,
) -> Tuple[
dict[str, dict[str, Any]],
Expand All @@ -46,12 +49,6 @@ def get_arangodb_graph(
- Destination Indices (COO)
- Node-ID-to-index mapping (COO)
"""
if not G.graph_exists_in_db:
raise GraphDoesNotExist(
"Graph does not exist in the database. Can't load graph."
)

adb_graph = G.db.graph(G.graph_name)
v_cols = adb_graph.vertex_collections()
edge_definitions = adb_graph.edge_definitions()
e_cols = {c["edge_collection"] for c in edge_definitions}
Expand All @@ -63,22 +60,30 @@ def get_arangodb_graph(

from phenolrs.networkx_loader import NetworkXLoader

config = nx.config.backends.arangodb

kwargs = {}
if G.graph_loader_parallelism is not None:
kwargs["parallelism"] = G.graph_loader_parallelism
if G.graph_loader_batch_size is not None:
kwargs["batch_size"] = G.graph_loader_batch_size
if parallelism := config.get("load_parallelism"):
kwargs["parallelism"] = parallelism
if batch_size := config.get("load_batch_size"):
kwargs["batch_size"] = batch_size

assert config.db_name
assert config.host
assert config.username
assert config.password

# TODO: Remove ignore when phenolrs is published
return NetworkXLoader.load_into_networkx( # type: ignore
G.db.name,
metagraph,
[G._host],
username=G._username,
password=G._password,
config.db_name,
metagraph=metagraph,
hosts=[config.host],
username=config.username,
password=config.password,
load_node_dict=load_node_dict,
load_adj_dict=load_adj_dict,
load_adj_dict_as_directed=load_adj_dict_as_directed,
load_adj_dict_as_multigraph=load_adj_dict_as_multigraph,
load_coo=load_coo,
**kwargs,
)
Expand All @@ -103,7 +108,7 @@ def logger_debug(func: Callable[..., Any]) -> Any:
"""Decorator to log debug messages."""

def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
logger.debug(f"{func.__name__} - {args} - {kwargs}")
logger.debug(f"{type(self)}.{func.__name__} - {args} - {kwargs}")
return func(self, *args, **kwargs)

return wrapper
Expand Down
Loading
Loading