1
1
from __future__ import annotations
2
2
3
+ import warnings
3
4
from collections import UserDict
4
5
from collections .abc import Iterator
5
6
from itertools import islice
8
9
from arango .database import StandardDatabase
9
10
from arango .exceptions import DocumentDeleteError
10
11
from arango .graph import Graph
12
+ from phenolrs .networkx .typings import (
13
+ DiGraphAdjDict ,
14
+ GraphAdjDict ,
15
+ MultiDiGraphAdjDict ,
16
+ MultiGraphAdjDict ,
17
+ )
11
18
12
19
from nx_arangodb .exceptions import EdgeTypeAmbiguity , MultipleEdgesFound
13
20
from nx_arangodb .logger import logger
14
21
15
22
from ..enum import DIRECTED_GRAPH_TYPES , MULTIGRAPH_TYPES , GraphType , TraversalDirection
16
23
from ..function import (
24
+ ArangoDBBatchError ,
17
25
aql ,
18
26
aql_doc_get_key ,
19
27
aql_doc_has_key ,
23
31
aql_edge_get ,
24
32
aql_edge_id ,
25
33
aql_fetch_data_edge ,
34
+ check_list_for_errors ,
26
35
doc_insert ,
27
36
doc_update ,
28
37
get_arangodb_graph ,
36
45
keys_are_not_reserved ,
37
46
keys_are_strings ,
38
47
logger_debug ,
48
+ separate_edges_by_collections ,
49
+ upsert_collection_edges ,
39
50
)
40
51
41
52
#############
@@ -169,7 +180,7 @@ def __init__(
169
180
self .graph = graph
170
181
self .edge_id : str | None = None
171
182
172
- # NodeAttrDict may be a child of another NodeAttrDict
183
+ # EdgeAttrDict may be a child of another EdgeAttrDict
173
184
# e.g G._adj['node/1']['node/2']['object']['foo'] = 'bar'
174
185
# In this case, **parent_keys** would be ['object']
175
186
# and **root** would be G._adj['node/1']['node/2']
@@ -1482,8 +1493,31 @@ def clear(self) -> None:
1482
1493
@keys_are_strings
1483
1494
@logger_debug
1484
1495
def update (self , edges : Any ) -> None :
1485
- """g._adj.update({'node/1': {'node/2': {'foo': 'bar'}})"""
1486
- raise NotImplementedError ("AdjListOuterDict.update()" )
1496
+ """g._adj.update({'node/1': {'node/2': {'_id': 'foo/bar', 'foo': "bar"}})"""
1497
+ separated_by_edge_collection = separate_edges_by_collections (
1498
+ edges , graph_type = self .graph_type , default_node_type = self .default_node_type
1499
+ )
1500
+ result = upsert_collection_edges (self .db , separated_by_edge_collection )
1501
+
1502
+ all_good = check_list_for_errors (result )
1503
+ if all_good :
1504
+ # Means no single operation failed, in this case we update the local cache
1505
+ self .__set_adj_elements (edges )
1506
+ else :
1507
+ # In this case some or all documents failed. Right now we will not
1508
+ # update the local cache, but raise an error instead.
1509
+ # Reason: We cannot set silent to True, because we need as it does
1510
+ # not report errors then. We need to update the driver to also pass
1511
+ # the errors back to the user, then we can adjust the behavior here.
1512
+ # This will also save network traffic and local computation time.
1513
+ errors = []
1514
+ for collections_results in result :
1515
+ for collection_result in collections_results :
1516
+ errors .append (collection_result )
1517
+ warnings .warn (
1518
+ "Failed to insert at least one node. Will not update local cache."
1519
+ )
1520
+ raise ArangoDBBatchError (errors )
1487
1521
1488
1522
@logger_debug
1489
1523
def values (self ) -> Any :
@@ -1507,22 +1541,44 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:
1507
1541
yield from result
1508
1542
1509
1543
@logger_debug
1510
- def _fetch_all (self ) -> None :
1511
- self .clear ()
1544
+ def __set_adj_elements (
1545
+ self ,
1546
+ edges_dict : (
1547
+ GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict
1548
+ ),
1549
+ ) -> None :
1550
+ def set_edge_graph (
1551
+ src_node_id : str , dst_node_id : str , edge : dict [str , Any ]
1552
+ ) -> EdgeAttrDict :
1553
+ adjlist_inner_dict = self .data [src_node_id ]
1554
+
1555
+ edge_attr_dict : EdgeAttrDict
1556
+ edge_attr_dict = adjlist_inner_dict ._create_edge_attr_dict (edge )
1557
+
1558
+ adjlist_inner_dict .data [dst_node_id ] = edge_attr_dict
1559
+
1560
+ return edge_attr_dict
1561
+
1562
+ def set_edge_multigraph (
1563
+ src_node_id : str , dst_node_id : str , edges : dict [int , dict [str , Any ]]
1564
+ ) -> EdgeKeyDict :
1565
+ adjlist_inner_dict = self .data [src_node_id ]
1566
+
1567
+ edge_key_dict = adjlist_inner_dict .edge_key_dict_factory ()
1568
+ edge_key_dict .src_node_id = src_node_id
1569
+ edge_key_dict .dst_node_id = dst_node_id
1570
+ edge_key_dict .FETCHED_ALL_DATA = True
1571
+ edge_key_dict .FETCHED_ALL_IDS = True
1572
+
1573
+ for edge in edges .values ():
1574
+ edge_attr_dict = adjlist_inner_dict ._create_edge_attr_dict (edge )
1575
+ edge_key_dict .data [edge ["_id" ]] = edge_attr_dict
1512
1576
1513
- def set_adj_inner_dict (
1514
- adj_outer_dict : AdjListOuterDict , node_id : str
1515
- ) -> AdjListInnerDict :
1516
- if node_id in adj_outer_dict .data :
1517
- return adj_outer_dict .data [node_id ]
1577
+ adjlist_inner_dict .data [dst_node_id ] = edge_key_dict
1518
1578
1519
- adj_inner_dict = self .adjlist_inner_dict_factory ()
1520
- adj_inner_dict .src_node_id = node_id
1521
- adj_inner_dict .FETCHED_ALL_DATA = True
1522
- adj_inner_dict .FETCHED_ALL_IDS = True
1523
- adj_outer_dict .data [node_id ] = adj_inner_dict
1579
+ return edge_key_dict
1524
1580
1525
- return adj_inner_dict
1581
+ set_edge_func = set_edge_multigraph if self . is_multigraph else set_edge_graph
1526
1582
1527
1583
def propagate_edge_undirected (
1528
1584
src_node_id : str ,
@@ -1536,7 +1592,7 @@ def propagate_edge_directed(
1536
1592
dst_node_id : str ,
1537
1593
edge_key_or_attr_dict : EdgeKeyDict | EdgeAttrDict ,
1538
1594
) -> None :
1539
- set_adj_inner_dict (self .mirror , dst_node_id )
1595
+ self . __set_adj_inner_dict (self .mirror , dst_node_id )
1540
1596
self .mirror .data [dst_node_id ].data [src_node_id ] = edge_key_or_attr_dict
1541
1597
1542
1598
def propagate_edge_directed_symmetric (
@@ -1546,7 +1602,7 @@ def propagate_edge_directed_symmetric(
1546
1602
) -> None :
1547
1603
propagate_edge_directed (src_node_id , dst_node_id , edge_key_or_attr_dict )
1548
1604
propagate_edge_undirected (src_node_id , dst_node_id , edge_key_or_attr_dict )
1549
- set_adj_inner_dict (self .mirror , src_node_id )
1605
+ self . __set_adj_inner_dict (self .mirror , src_node_id )
1550
1606
self .mirror .data [src_node_id ].data [dst_node_id ] = edge_key_or_attr_dict
1551
1607
1552
1608
propagate_edge_func = (
@@ -1559,38 +1615,39 @@ def propagate_edge_directed_symmetric(
1559
1615
)
1560
1616
)
1561
1617
1562
- def set_edge_graph (
1563
- src_node_id : str , dst_node_id : str , edge : dict [str , Any ]
1564
- ) -> EdgeAttrDict :
1565
- adjlist_inner_dict = self .data [src_node_id ]
1566
-
1567
- edge_attr_dict : EdgeAttrDict
1568
- edge_attr_dict = adjlist_inner_dict ._create_edge_attr_dict (edge )
1618
+ for src_node_id , inner_dict in edges_dict .items ():
1619
+ for dst_node_id , edge_or_edges in inner_dict .items ():
1569
1620
1570
- adjlist_inner_dict .data [dst_node_id ] = edge_attr_dict
1621
+ if not self .is_directed :
1622
+ if src_node_id in self .data :
1623
+ if dst_node_id in self .data [src_node_id ].data :
1624
+ continue # can skip due not directed
1571
1625
1572
- return edge_attr_dict
1626
+ self .__set_adj_inner_dict (self , src_node_id )
1627
+ self .__set_adj_inner_dict (self , dst_node_id )
1628
+ edge_attr_or_key_dict = set_edge_func ( # type: ignore[operator]
1629
+ src_node_id , dst_node_id , edge_or_edges
1630
+ )
1573
1631
1574
- def set_edge_multigraph (
1575
- src_node_id : str , dst_node_id : str , edges : dict [int , dict [str , Any ]]
1576
- ) -> EdgeKeyDict :
1577
- adjlist_inner_dict = self .data [src_node_id ]
1632
+ propagate_edge_func (src_node_id , dst_node_id , edge_attr_or_key_dict )
1578
1633
1579
- edge_key_dict = adjlist_inner_dict . edge_key_dict_factory ()
1580
- edge_key_dict . src_node_id = src_node_id
1581
- edge_key_dict . dst_node_id = dst_node_id
1582
- edge_key_dict . FETCHED_ALL_DATA = True
1583
- edge_key_dict . FETCHED_ALL_IDS = True
1634
+ def __set_adj_inner_dict (
1635
+ self , adj_outer_dict : AdjListOuterDict , node_id : str
1636
+ ) -> AdjListInnerDict :
1637
+ if node_id in adj_outer_dict . data :
1638
+ return adj_outer_dict . data [ node_id ]
1584
1639
1585
- for edge in edges . values ():
1586
- edge_attr_dict = adjlist_inner_dict . _create_edge_attr_dict ( edge )
1587
- edge_key_dict . data [ edge [ "_id" ]] = edge_attr_dict
1588
-
1589
- adjlist_inner_dict .data [dst_node_id ] = edge_key_dict
1640
+ adj_inner_dict = self . adjlist_inner_dict_factory ()
1641
+ adj_inner_dict . src_node_id = node_id
1642
+ adj_inner_dict . FETCHED_ALL_DATA = True
1643
+ adj_inner_dict . FETCHED_ALL_IDS = True
1644
+ adj_outer_dict .data [node_id ] = adj_inner_dict
1590
1645
1591
- return edge_key_dict
1646
+ return adj_inner_dict
1592
1647
1593
- set_edge_func = set_edge_multigraph if self .is_multigraph else set_edge_graph
1648
+ @logger_debug
1649
+ def _fetch_all (self ) -> None :
1650
+ self .clear ()
1594
1651
1595
1652
(
1596
1653
_ ,
@@ -1613,21 +1670,7 @@ def set_edge_multigraph(
1613
1670
if self .is_directed :
1614
1671
adj_dict = adj_dict ["succ" ]
1615
1672
1616
- for src_node_id , inner_dict in adj_dict .items ():
1617
- for dst_node_id , edge_or_edges in inner_dict .items ():
1618
-
1619
- if not self .is_directed :
1620
- if src_node_id in self .data :
1621
- if dst_node_id in self .data [src_node_id ].data :
1622
- continue # can skip due not directed
1623
-
1624
- set_adj_inner_dict (self , src_node_id )
1625
- set_adj_inner_dict (self , dst_node_id )
1626
- edge_attr_or_key_dict = set_edge_func ( # type: ignore[operator]
1627
- src_node_id , dst_node_id , edge_or_edges
1628
- )
1629
-
1630
- propagate_edge_func (src_node_id , dst_node_id , edge_attr_or_key_dict )
1673
+ self .__set_adj_elements (adj_dict )
1631
1674
1632
1675
self .FETCHED_ALL_DATA = True
1633
1676
self .FETCHED_ALL_IDS = True
0 commit comments