From a0af2ac69edbf5b7049fd0188a76c433dc276bf1 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Wed, 11 Sep 2024 11:22:31 -0400 Subject: [PATCH 01/11] test from #9196 but on TreeNode --- xarray/tests/test_treenode.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py index d9d581cc314..1baceed9ee5 100644 --- a/xarray/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -64,6 +64,12 @@ def test_forbid_setting_parent_directly(self): ): mary.parent = john + def test_dont_modify_children_inplace(self): + # GH issue 9196 + child = TreeNode() + TreeNode(children={"child": child}) + assert child.parent is None + def test_multi_child_family(self): mary: TreeNode = TreeNode() kate: TreeNode = TreeNode() From 104d80efcee32f8c0377a2e36e779d5373353a8a Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Wed, 11 Sep 2024 13:45:41 -0400 Subject: [PATCH 02/11] move assignment and copying of children to TreeNode constructor --- xarray/core/datatree.py | 8 ++------ xarray/core/treenode.py | 6 ++++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 61c71917008..db8f1f6bae6 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -447,15 +447,11 @@ def __init__( -------- DataTree.from_dict """ - if children is None: - children = {} + # TODO set after setting node data as this will check for name conflicts? + super().__init__(name=name, children=children) - super().__init__(name=name) self._set_node_data(_to_new_dataset(dataset)) - # shallow copy to avoid modifying arguments in-place (see GH issue #9196) - self.children = {name: child.copy() for name, child in children.items()} - def _set_node_data(self, dataset: Dataset): data_vars, coord_vars = _collect_data_and_coord_variables(dataset) self._data_variables = data_vars diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 84ce392ad32..770b6fdcd65 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -78,8 +78,10 @@ def __init__(self, children: Mapping[str, Tree] | None = None): """Create a parentless node.""" self._parent = None self._children = {} - if children is not None: - self.children = children + + if children: + # shallow copy to avoid modifying arguments in-place (see GH issue #9196) + self.children = {name: child.copy() for name, child in children.items()} @property def parent(self) -> Tree | None: From 9e02ae24493eaa8b35c01a5c8b680b0c230a5768 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Wed, 11 Sep 2024 14:29:48 -0400 Subject: [PATCH 03/11] move copy methods over to TreeNode --- xarray/core/datatree.py | 58 ++++------------------------------ xarray/core/treenode.py | 70 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 52 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index db8f1f6bae6..ca4ea5de1ed 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -771,67 +771,21 @@ def _replace_node( self.children = children - def copy( - self: DataTree, - deep: bool = False, - ) -> DataTree: - """ - Returns a copy of this subtree. - - Copies this node and all child nodes. - - If `deep=True`, a deep copy is made of each of the component variables. - Otherwise, a shallow copy of each of the component variable is made, so - that the underlying memory region of the new datatree is the same as in - the original datatree. - - Parameters - ---------- - deep : bool, default: False - Whether each component variable is loaded into memory and copied onto - the new object. Default is False. - - Returns - ------- - object : DataTree - New object with dimensions, attributes, coordinates, name, encoding, - and data of this node and all child nodes copied from original. - - See Also - -------- - xarray.Dataset.copy - pandas.DataFrame.copy - """ - return self._copy_subtree(deep=deep) - - def _copy_subtree( - self: DataTree, - deep: bool = False, - memo: dict[int, Any] | None = None, - ) -> DataTree: - """Copy entire subtree""" - new_tree = self._copy_node(deep=deep) - for node in self.descendants: - path = node.relative_to(self) - new_tree[path] = node._copy_node(deep=deep) - return new_tree - def _copy_node( self: DataTree, deep: bool = False, ) -> DataTree: """Copy just one node of a tree""" + + new_node = super()._copy_node() + new_node._name = self.name + data = self._to_dataset_view(rebuild_dims=False, inherited=False) if deep: data = data.copy(deep=True) - new_node = DataTree(data, name=self.name) - return new_node - - def __copy__(self: DataTree) -> DataTree: - return self._copy_subtree(deep=False) + new_node._set_node_data(data) - def __deepcopy__(self: DataTree, memo: dict[int, Any] | None = None) -> DataTree: - return self._copy_subtree(deep=True, memo=memo) + return new_node def get( # type: ignore[override] self: DataTree, key: str, default: DataTree | DataArray | None = None diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 770b6fdcd65..0bc35d4164d 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -5,6 +5,7 @@ from pathlib import PurePosixPath from typing import ( TYPE_CHECKING, + Any, Generic, TypeVar, ) @@ -237,6 +238,75 @@ def _post_attach_children(self: Tree, children: Mapping[str, Tree]) -> None: """Method call after attaching `children`.""" pass + def copy( + self: Tree, + deep: bool = False, + ) -> Tree: + """ + Returns a copy of this subtree. + + Copies this node and all child nodes. + + If `deep=True`, a deep copy is made of each of the component variables. + Otherwise, a shallow copy of each of the component variable is made, so + that the underlying memory region of the new datatree is the same as in + the original datatree. + + Parameters + ---------- + deep : bool, default: False + Whether each component variable is loaded into memory and copied onto + the new object. Default is False. + + Returns + ------- + object : DataTree + New object with dimensions, attributes, coordinates, name, encoding, + and data of this node and all child nodes copied from original. + + See Also + -------- + xarray.Dataset.copy + pandas.DataFrame.copy + """ + return self._copy_subtree(deep=deep) + + def _copy_subtree( + self: Tree, + deep: bool = False, + memo: dict[int, Any] | None = None, + ) -> Tree: + """Copy entire subtree""" + new_tree = self._copy_node(deep=deep) + + # TODO if I can write this to use recursion then it might not need to use .relative_to + for node in self.descendants: + # TODO these will currently fail because .relative_to is only defined on NamedNode, not TreeNode + path = node.relative_to(self) + # TODO and __setitem__ is only only defined on DataTree (though ._set_item is defined on TreeNode) + new_tree[path] = node._copy_node(deep=deep) + return new_tree + + def _copy_node( + self: Tree, + deep: bool = False, + ) -> Tree: + """Copy just one node of a tree""" + # TODO is this correct?? + new_empty_node = type(self)() + return new_empty_node + + # TODO could I just do + # new_instance = type(self).__new__(type(self)) + # new_instance.__dict__.update(self.__dict__) + # return new_instance + + def __copy__(self: Tree) -> Tree: + return self._copy_subtree(deep=False) + + def __deepcopy__(self: Tree, memo: dict[int, Any] | None = None) -> Tree: + return self._copy_subtree(deep=True, memo=memo) + def _iter_parents(self: Tree) -> Iterator[Tree]: """Iterate up the tree, starting from the current node's parent.""" node: Tree | None = self.parent From d3629bb0fa3183818ca4631b95493ea9579304d5 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Wed, 11 Sep 2024 14:51:49 -0400 Subject: [PATCH 04/11] change copying behaviour to be in line with #9196 --- xarray/tests/test_treenode.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py index 1baceed9ee5..67e0f0c6542 100644 --- a/xarray/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -71,17 +71,21 @@ def test_dont_modify_children_inplace(self): assert child.parent is None def test_multi_child_family(self): - mary: TreeNode = TreeNode() - kate: TreeNode = TreeNode() - john: TreeNode = TreeNode(children={"Mary": mary, "Kate": kate}) - assert john.children["Mary"] is mary - assert john.children["Kate"] is kate + john: TreeNode = TreeNode(children={"Mary": TreeNode(), "Kate": TreeNode()}) + + assert "Mary" in john.children + mary = john.children["Mary"] + assert isinstance(mary, TreeNode) assert mary.parent is john + + assert "Kate" in john.children + kate = john.children["Kate"] + assert isinstance(kate, TreeNode) assert kate.parent is john def test_disown_child(self): - mary: TreeNode = TreeNode() - john: TreeNode = TreeNode(children={"Mary": mary}) + john: TreeNode = TreeNode(children={"Mary": TreeNode()}) + mary = john.children["Mary"] mary.orphan() assert mary.parent is None assert "Mary" not in john.children @@ -102,12 +106,11 @@ def test_doppelganger_child(self): assert john.children["Kate"] is evil_kate def test_sibling_relationships(self): - mary: TreeNode = TreeNode() - kate: TreeNode = TreeNode() - ashley: TreeNode = TreeNode() - TreeNode(children={"Mary": mary, "Kate": kate, "Ashley": ashley}) - assert kate.siblings["Mary"] is mary - assert kate.siblings["Ashley"] is ashley + john = TreeNode( + children={"Mary": TreeNode(), "Kate": TreeNode(), "Ashley": TreeNode()} + ) + kate = john.children["Kate"] + assert list(kate.siblings) == ["Mary", "Ashley"] assert "Kate" not in kate.siblings def test_ancestors(self): From f8ee2a751de9bee81927b8faf9a453ae001ebc6b Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Wed, 11 Sep 2024 14:52:12 -0400 Subject: [PATCH 05/11] explicitly test that ._copy_subtree works for TreeNode --- xarray/tests/test_treenode.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py index 67e0f0c6542..24f3ee10da0 100644 --- a/xarray/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -113,6 +113,15 @@ def test_sibling_relationships(self): assert list(kate.siblings) == ["Mary", "Ashley"] assert "Kate" not in kate.siblings + def test_copy_subtree(self): + tony: TreeNode = TreeNode() + michael: TreeNode = TreeNode(children={"Tony": tony}) + vito = TreeNode(children={"Michael": michael}) + + # check that children of assigned children are also copied (i.e. that ._copy_subtree works) + copied_tony = vito.children["Michael"].children["Tony"] + assert copied_tony is not tony + def test_ancestors(self): tony: TreeNode = TreeNode() michael: TreeNode = TreeNode(children={"Tony": tony}) From 6231c7ceddf1c96255b35cb0facad3af9763dc5b Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Wed, 11 Sep 2024 15:47:23 -0400 Subject: [PATCH 06/11] reimplement ._copy_subtree using recursion --- xarray/core/treenode.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 0bc35d4164d..07ea0183c0b 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -276,15 +276,13 @@ def _copy_subtree( deep: bool = False, memo: dict[int, Any] | None = None, ) -> Tree: - """Copy entire subtree""" + """Copy entire subtree recursively.""" + new_tree = self._copy_node(deep=deep) + for name, child in self.children.items(): + # TODO use `.children[name] = ...` once #9477 is implemented + new_tree._set(name, child._copy_subtree(deep=deep)) - # TODO if I can write this to use recursion then it might not need to use .relative_to - for node in self.descendants: - # TODO these will currently fail because .relative_to is only defined on NamedNode, not TreeNode - path = node.relative_to(self) - # TODO and __setitem__ is only only defined on DataTree (though ._set_item is defined on TreeNode) - new_tree[path] = node._copy_node(deep=deep) return new_tree def _copy_node( @@ -292,15 +290,15 @@ def _copy_node( deep: bool = False, ) -> Tree: """Copy just one node of a tree""" - # TODO is this correct?? - new_empty_node = type(self)() - return new_empty_node - # TODO could I just do + # TODO could I just do this and then not have to override this class? # new_instance = type(self).__new__(type(self)) # new_instance.__dict__.update(self.__dict__) # return new_instance + new_empty_node = type(self)() + return new_empty_node + def __copy__(self: Tree) -> Tree: return self._copy_subtree(deep=False) @@ -691,6 +689,15 @@ def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None: """Ensures child has name attribute corresponding to key under which it has been stored.""" self.name = name + def _copy_node( + self: NamedNode, + deep: bool = False, + ) -> NamedNode: + """Copy just one node of a tree""" + new_node = super()._copy_node() + new_node._name = self.name + return new_node + @property def path(self) -> str: """Return the file-like path from the root to this node.""" From 5ece6e868ca300161486c518a893a23705925f19 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Wed, 11 Sep 2024 16:52:21 -0400 Subject: [PATCH 07/11] change treenode.py tests to match expected non-in-place behaviour --- xarray/tests/test_treenode.py | 57 ++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py index 24f3ee10da0..0596af07c0c 100644 --- a/xarray/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -122,21 +122,29 @@ def test_copy_subtree(self): copied_tony = vito.children["Michael"].children["Tony"] assert copied_tony is not tony - def test_ancestors(self): - tony: TreeNode = TreeNode() - michael: TreeNode = TreeNode(children={"Tony": tony}) - vito = TreeNode(children={"Michael": michael}) + def test_parents(self): + vito = TreeNode( + children={"Michael": TreeNode(children={"Tony": TreeNode()})}, + ) + michael = vito.children["Michael"] + tony = michael.children["Tony"] + assert tony.root is vito assert tony.parents == (michael, vito) - assert tony.ancestors == (vito, michael, tony) class TestGetNodes: def test_get_child(self): - steven: TreeNode = TreeNode() - sue = TreeNode(children={"Steven": steven}) - mary = TreeNode(children={"Sue": sue}) - john = TreeNode(children={"Mary": mary}) + john = TreeNode( + children={ + "Mary": TreeNode( + children={"Sue": TreeNode(children={"Steven": TreeNode()})} + ) + } + ) + mary = john.children["Mary"] + sue = mary.children["Sue"] + steven = sue.children["Steven"] # get child assert john._get_item("Mary") is mary @@ -156,10 +164,14 @@ def test_get_child(self): assert mary._get_item("Sue/Steven") is steven def test_get_upwards(self): - sue: TreeNode = TreeNode() - kate: TreeNode = TreeNode() - mary = TreeNode(children={"Sue": sue, "Kate": kate}) - john = TreeNode(children={"Mary": mary}) + john = TreeNode( + children={ + "Mary": TreeNode(children={"Sue": TreeNode(), "Kate": TreeNode()}) + } + ) + mary = john.children["Mary"] + sue = mary.children["Sue"] + kate = mary.children["Kate"] assert sue._get_item("../") is mary assert sue._get_item("../../") is john @@ -168,9 +180,9 @@ def test_get_upwards(self): assert sue._get_item("../Kate") is kate def test_get_from_root(self): - sue: TreeNode = TreeNode() - mary = TreeNode(children={"Sue": sue}) - john = TreeNode(children={"Mary": mary}) # noqa + john = TreeNode(children={"Mary": TreeNode(children={"Sue": TreeNode()})}) + mary = john.children["Mary"] + sue = mary.children["Sue"] assert sue._get_item("/Mary") is mary @@ -385,11 +397,14 @@ def test_levels(self): class TestRenderTree: def test_render_nodetree(self): - sam: NamedNode = NamedNode() - ben: NamedNode = NamedNode() - mary: NamedNode = NamedNode(children={"Sam": sam, "Ben": ben}) - kate: NamedNode = NamedNode() - john: NamedNode = NamedNode(children={"Mary": mary, "Kate": kate}) + john: NamedNode = NamedNode( + children={ + "Mary": NamedNode(children={"Sam": NamedNode(), "Ben": NamedNode()}), + "Kate": NamedNode(), + } + ) + mary = john.children["Mary"] + expected_nodes = [ "NamedNode()", "\tNamedNode('Mary')", From ea61a95540e2e9ed59f7498fede2940a13b234a5 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Wed, 11 Sep 2024 16:55:47 -0400 Subject: [PATCH 08/11] fix but created in DataTree.__init__ --- xarray/core/datatree.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index ca4ea5de1ed..1404bac1df7 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -447,11 +447,11 @@ def __init__( -------- DataTree.from_dict """ - # TODO set after setting node data as this will check for name conflicts? - super().__init__(name=name, children=children) - self._set_node_data(_to_new_dataset(dataset)) + # comes after setting node data as this will check for clashes between child names and existing variable names + super().__init__(name=name, children=children) + def _set_node_data(self, dataset: Dataset): data_vars, coord_vars = _collect_data_and_coord_variables(dataset) self._data_variables = data_vars From 9c328a02a082b4b6bc1621321d17d9738297157b Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Wed, 11 Sep 2024 17:05:55 -0400 Subject: [PATCH 09/11] add type hints for Generic TreeNode back in --- xarray/tests/test_treenode.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py index 0596af07c0c..22a6a97c3f5 100644 --- a/xarray/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -66,7 +66,7 @@ def test_forbid_setting_parent_directly(self): def test_dont_modify_children_inplace(self): # GH issue 9196 - child = TreeNode() + child: TreeNode = TreeNode() TreeNode(children={"child": child}) assert child.parent is None @@ -106,7 +106,7 @@ def test_doppelganger_child(self): assert john.children["Kate"] is evil_kate def test_sibling_relationships(self): - john = TreeNode( + john: TreeNode = TreeNode( children={"Mary": TreeNode(), "Kate": TreeNode(), "Ashley": TreeNode()} ) kate = john.children["Kate"] @@ -123,7 +123,7 @@ def test_copy_subtree(self): assert copied_tony is not tony def test_parents(self): - vito = TreeNode( + vito: TreeNode = TreeNode( children={"Michael": TreeNode(children={"Tony": TreeNode()})}, ) michael = vito.children["Michael"] @@ -135,7 +135,7 @@ def test_parents(self): class TestGetNodes: def test_get_child(self): - john = TreeNode( + john: TreeNode = TreeNode( children={ "Mary": TreeNode( children={"Sue": TreeNode(children={"Steven": TreeNode()})} @@ -164,7 +164,7 @@ def test_get_child(self): assert mary._get_item("Sue/Steven") is steven def test_get_upwards(self): - john = TreeNode( + john: TreeNode = TreeNode( children={ "Mary": TreeNode(children={"Sue": TreeNode(), "Kate": TreeNode()}) } @@ -180,7 +180,9 @@ def test_get_upwards(self): assert sue._get_item("../Kate") is kate def test_get_from_root(self): - john = TreeNode(children={"Mary": TreeNode(children={"Sue": TreeNode()})}) + john: TreeNode = TreeNode( + children={"Mary": TreeNode(children={"Sue": TreeNode()})} + ) mary = john.children["Mary"] sue = mary.children["Sue"] From b1f66e1c3ceda48998c9a935e30db14fd38f45a9 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Wed, 11 Sep 2024 17:11:11 -0400 Subject: [PATCH 10/11] update typing of ._copy_node --- xarray/core/treenode.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 07ea0183c0b..d74c82178ea 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -290,12 +290,6 @@ def _copy_node( deep: bool = False, ) -> Tree: """Copy just one node of a tree""" - - # TODO could I just do this and then not have to override this class? - # new_instance = type(self).__new__(type(self)) - # new_instance.__dict__.update(self.__dict__) - # return new_instance - new_empty_node = type(self)() return new_empty_node @@ -690,9 +684,9 @@ def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None: self.name = name def _copy_node( - self: NamedNode, + self: AnyNamedNode, deep: bool = False, - ) -> NamedNode: + ) -> AnyNamedNode: """Copy just one node of a tree""" new_node = super()._copy_node() new_node._name = self.name From 3a1f372902fc0a522aef937c85cae7ebcb4f5804 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Thu, 12 Sep 2024 14:55:54 -0400 Subject: [PATCH 11/11] remove redunant setting of _name --- xarray/core/datatree.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 1404bac1df7..5715dca486f 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -778,7 +778,6 @@ def _copy_node( """Copy just one node of a tree""" new_node = super()._copy_node() - new_node._name = self.name data = self._to_dataset_view(rebuild_dims=False, inherited=False) if deep: