Skip to content

Commit 2c1c040

Browse files
authored
added ability to load edge attrs. (#25)
* added ability to load edge attrs. * also add to AdjListOuterDict * black fmt * fix init * updated phenolrs * remove now obsolete test * more tests, fixed load_all edge attr * added comment for clarity * better test name * move logic for edge attrs into one helper method, so it is only present in one location * fmt * import order * reformat msg, fix lang * applied suggested code changes * in fetch all for adjlist always load all edge attributes * add edge_values to coo representation and cache * fmt * fmt * remove not needed code anymore * added data definition for edge values * cleanup of unused imports * rm edge attrs of def args for adj
1 parent 9dc4cc2 commit 2c1c040

File tree

8 files changed

+165
-25
lines changed

8 files changed

+165
-25
lines changed

nx_arangodb/classes/dict.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@
2121
from .enum import DIRECTED_GRAPH_TYPES, MULTIGRAPH_TYPES, GraphType, TraversalDirection
2222
from .function import (
2323
aql,
24-
aql_as_list,
2524
aql_doc_get_key,
26-
aql_doc_get_keys,
27-
aql_doc_get_length,
2825
aql_doc_has_key,
2926
aql_edge_count_src,
3027
aql_edge_count_src_dst,
@@ -33,7 +30,6 @@
3330
aql_edge_id,
3431
aql_fetch_data,
3532
aql_fetch_data_edge,
36-
aql_single,
3733
create_collection,
3834
doc_delete,
3935
doc_get_or_insert,
@@ -45,7 +41,6 @@
4541
get_update_dict,
4642
json_serializable,
4743
key_is_adb_id_or_int,
48-
key_is_int,
4944
key_is_not_reserved,
5045
key_is_string,
5146
keys_are_not_reserved,
@@ -752,6 +747,7 @@ def _fetch_all(self):
752747
load_node_dict=True,
753748
load_adj_dict=False,
754749
load_coo=False,
750+
edge_collections_attributes=set(),
755751
load_all_vertex_attributes=True,
756752
load_all_edge_attributes=False, # not used
757753
is_directed=False, # not used
@@ -2254,6 +2250,7 @@ def set_edge_multigraph(
22542250
load_node_dict=False,
22552251
load_adj_dict=True,
22562252
load_coo=False,
2253+
edge_collections_attributes=set(),
22572254
load_all_vertex_attributes=False, # not used
22582255
load_all_edge_attributes=True,
22592256
is_directed=self.is_directed,

nx_arangodb/classes/digraph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
default_node_type: str | None = None,
2929
edge_type_key: str = "_edge_type",
3030
edge_type_func: Callable[[str, str], str] | None = None,
31+
edge_collections_attributes: set[str] | None = None,
3132
db: StandardDatabase | None = None,
3233
read_parallelism: int = 10,
3334
read_batch_size: int = 100000,
@@ -41,6 +42,7 @@ def __init__(
4142
default_node_type,
4243
edge_type_key,
4344
edge_type_func,
45+
edge_collections_attributes,
4446
db,
4547
read_parallelism,
4648
read_batch_size,

nx_arangodb/classes/function.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
DiGraphAdjDict,
2222
DstIndices,
2323
EdgeIndices,
24+
EdgeValuesDict,
2425
GraphAdjDict,
2526
MultiDiGraphAdjDict,
2627
MultiGraphAdjDict,
@@ -38,11 +39,19 @@
3839
)
3940

4041

42+
def do_load_all_edge_attributes(attributes: set[str]) -> bool:
43+
if len(attributes) == 0:
44+
return True
45+
46+
return False
47+
48+
4149
def get_arangodb_graph(
4250
adb_graph: Graph,
4351
load_node_dict: bool,
4452
load_adj_dict: bool,
4553
load_coo: bool,
54+
edge_collections_attributes: set[str],
4655
load_all_vertex_attributes: bool,
4756
load_all_edge_attributes: bool,
4857
is_directed: bool,
@@ -55,6 +64,7 @@ def get_arangodb_graph(
5564
DstIndices,
5665
EdgeIndices,
5766
ArangoIDtoIndex,
67+
EdgeValuesDict,
5868
]:
5969
"""Pulls the graph from the database, assuming the graph exists.
6070
@@ -71,7 +81,7 @@ def get_arangodb_graph(
7181

7282
metagraph: dict[str, dict[str, Any]] = {
7383
"vertexCollections": {col: set() for col in v_cols},
74-
"edgeCollections": {col: set() for col in e_cols},
84+
"edgeCollections": {col: edge_collections_attributes for col in e_cols},
7585
}
7686

7787
if not any((load_node_dict, load_adj_dict, load_coo)):
@@ -89,6 +99,21 @@ def get_arangodb_graph(
8999
assert config.username
90100
assert config.password
91101

102+
res_do_load_all_edge_attributes = do_load_all_edge_attributes(
103+
edge_collections_attributes
104+
)
105+
106+
if res_do_load_all_edge_attributes is not load_all_edge_attributes:
107+
if len(edge_collections_attributes) > 0:
108+
raise ValueError(
109+
"You have specified to load at least one specific edge attribute"
110+
" and at the same time set the parameter `load_all_vertex_attributes`"
111+
" to true. This combination is not allowed."
112+
)
113+
else:
114+
# We need this case as the user wants by purpose to not load any edge data
115+
res_do_load_all_edge_attributes = load_all_edge_attributes
116+
92117
(
93118
node_dict,
94119
adj_dict,
@@ -106,7 +131,7 @@ def get_arangodb_graph(
106131
load_adj_dict=load_adj_dict,
107132
load_coo=load_coo,
108133
load_all_vertex_attributes=load_all_vertex_attributes,
109-
load_all_edge_attributes=load_all_edge_attributes,
134+
load_all_edge_attributes=res_do_load_all_edge_attributes,
110135
is_directed=is_directed,
111136
is_multigraph=is_multigraph,
112137
symmetrize_edges_if_directed=symmetrize_edges_if_directed,
@@ -121,6 +146,7 @@ def get_arangodb_graph(
121146
dst_indices,
122147
edge_indices,
123148
vertex_ids_to_index,
149+
edge_values,
124150
)
125151

126152

nx_arangodb/classes/graph.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
default_node_type: str | None = None,
4949
edge_type_key: str = "_edge_type",
5050
edge_type_func: Callable[[str, str], str] | None = None,
51+
edge_collections_attributes: set[str] | None = None,
5152
db: StandardDatabase | None = None,
5253
read_parallelism: int = 10,
5354
read_batch_size: int = 100000,
@@ -69,6 +70,8 @@ def __init__(
6970
self.read_batch_size = read_batch_size
7071
self.write_batch_size = write_batch_size
7172

73+
self._set_edge_collections_attributes_to_fetch(edge_collections_attributes)
74+
7275
# NOTE: Need to revisit these...
7376
# self.maintain_node_dict_cache = False
7477
# self.maintain_adj_dict_cache = False
@@ -80,6 +83,7 @@ def __init__(
8083
self.dst_indices: npt.NDArray[np.int64] | None = None
8184
self.edge_indices: npt.NDArray[np.int64] | None = None
8285
self.vertex_ids_to_index: dict[str, int] | None = None
86+
self.edge_values: dict[str, list[int | float]] | None = None
8387

8488
# Does not apply to undirected graphs
8589
self.symmetrize_edges = symmetrize_edges
@@ -236,6 +240,17 @@ def _set_factory_methods(self) -> None:
236240
*adj_args, self.symmetrize_edges
237241
)
238242

243+
def _set_edge_collections_attributes_to_fetch(
244+
self, attributes: set[str] | None
245+
) -> None:
246+
if attributes is None:
247+
self._edge_collections_attributes = set()
248+
return
249+
if len(attributes) > 0:
250+
self._edge_collections_attributes = attributes
251+
if "_id" not in attributes:
252+
self._edge_collections_attributes.add("_id")
253+
239254
###########
240255
# Getters #
241256
###########
@@ -258,6 +273,10 @@ def graph_name(self) -> str:
258273
def graph_exists_in_db(self) -> bool:
259274
return self._graph_exists_in_db
260275

276+
@property
277+
def get_edge_attributes(self) -> set[str]:
278+
return self._edge_collections_attributes
279+
261280
###########
262281
# Setters #
263282
###########

nx_arangodb/classes/multidigraph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
default_node_type: str | None = None,
2828
edge_type_key: str = "_edge_type",
2929
edge_type_func: Callable[[str, str], str] | None = None,
30+
edge_collections_attributes: set[str] | None = None,
3031
db: StandardDatabase | None = None,
3132
read_parallelism: int = 10,
3233
read_batch_size: int = 100000,
@@ -40,6 +41,7 @@ def __init__(
4041
default_node_type,
4142
edge_type_key,
4243
edge_type_func,
44+
edge_collections_attributes,
4345
db,
4446
read_parallelism,
4547
read_batch_size,

nx_arangodb/classes/multigraph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
default_node_type: str | None = None,
2929
edge_type_key: str = "_edge_type",
3030
edge_type_func: Callable[[str, str], str] | None = None,
31+
edge_collections_attributes: set[str] | None = None,
3132
db: StandardDatabase | None = None,
3233
read_parallelism: int = 10,
3334
read_batch_size: int = 100000,
@@ -40,6 +41,7 @@ def __init__(
4041
default_node_type,
4142
edge_type_key,
4243
edge_type_func,
44+
edge_collections_attributes,
4345
db,
4446
read_parallelism,
4547
read_batch_size,

nx_arangodb/convert.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import networkx as nx
77

88
import nx_arangodb as nxadb
9+
from nx_arangodb.classes.function import do_load_all_edge_attributes
910
from nx_arangodb.logger import logger
1011

1112
try:
@@ -126,9 +127,9 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
126127
load_node_dict=True,
127128
load_adj_dict=True,
128129
load_coo=False,
130+
edge_collections_attributes=G.get_edge_attributes,
129131
load_all_vertex_attributes=False,
130-
# TODO: Only return the edge attributes that are needed
131-
load_all_edge_attributes=True,
132+
load_all_edge_attributes=do_load_all_edge_attributes(G.get_edge_attributes),
132133
is_directed=G.is_directed(),
133134
is_multigraph=G.is_multigraph(),
134135
symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False,
@@ -158,27 +159,37 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
158159
and G.dst_indices is not None
159160
and G.edge_indices is not None
160161
and G.vertex_ids_to_index is not None
162+
and G.edge_values is not None
161163
):
162164
m = "**use_coo_cache** is enabled. using cached COO data. no pull required."
163165
logger.debug(m)
164166

165167
else:
166168
start_time = time.time()
167169

168-
_, _, src_indices, dst_indices, edge_indices, vertex_ids_to_index = (
169-
nxadb.classes.function.get_arangodb_graph(
170-
adb_graph=G.adb_graph,
171-
load_node_dict=False,
172-
load_adj_dict=False,
173-
load_coo=True,
174-
load_all_vertex_attributes=False, # not used
175-
load_all_edge_attributes=False, # not used
176-
is_directed=G.is_directed(),
177-
is_multigraph=G.is_multigraph(),
178-
symmetrize_edges_if_directed=(
179-
G.symmetrize_edges if G.is_directed() else False
180-
),
181-
)
170+
(
171+
_,
172+
_,
173+
src_indices,
174+
dst_indices,
175+
edge_indices,
176+
vertex_ids_to_index,
177+
edge_values,
178+
) = nxadb.classes.function.get_arangodb_graph(
179+
adb_graph=G.adb_graph,
180+
load_node_dict=False,
181+
load_adj_dict=False,
182+
load_coo=True,
183+
edge_collections_attributes=G.get_edge_attributes,
184+
load_all_vertex_attributes=False, # not used
185+
load_all_edge_attributes=do_load_all_edge_attributes(
186+
G.get_edge_attributes
187+
),
188+
is_directed=G.is_directed(),
189+
is_multigraph=G.is_multigraph(),
190+
symmetrize_edges_if_directed=(
191+
G.symmetrize_edges if G.is_directed() else False
192+
),
182193
)
183194

184195
print(f"ADB -> COO load took {time.time() - start_time}s")
@@ -187,6 +198,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
187198
G.dst_indices = dst_indices
188199
G.edge_indices = edge_indices
189200
G.vertex_ids_to_index = vertex_ids_to_index
201+
G.edge_values = edge_values
190202

191203
N = len(G.vertex_ids_to_index) # type: ignore
192204
src_indices_cp = cp.array(G.src_indices)
@@ -204,7 +216,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
204216
src_indices=src_indices_cp,
205217
dst_indices=dst_indices_cp,
206218
edge_indices=edge_indices_cp,
207-
# edge_values,
219+
edge_values=G.edge_values,
208220
# edge_masks,
209221
# node_values,
210222
# node_masks,
@@ -222,7 +234,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
222234
N=N,
223235
src_indices=src_indices_cp,
224236
dst_indices=dst_indices_cp,
225-
# edge_values,
237+
edge_values=G.edge_values,
226238
# edge_masks,
227239
# node_values,
228240
# node_masks,

0 commit comments

Comments
 (0)