diff --git a/datatree/datatree.py b/datatree/datatree.py index 606c7935..0696c90b 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -628,6 +628,66 @@ def _replace( ) return obj + def copy( + self: DataTree, + deep: bool = False, + ) -> DataTree: + """ + Returns a copy of this subtree. + + Copies this node and all child nodes. + + If `deep=True`, a deep copy is made of each of the component variables. + Otherwise, a shallow copy of each of the component variable is made, so + that the underlying memory region of the new datatree is the same as in + the original datatree. + + Parameters + ---------- + deep : bool, default: False + Whether each component variable is loaded into memory and copied onto + the new object. Default is False. + + Returns + ------- + object : DataTree + New object with dimensions, attributes, coordinates, name, encoding, + and data of this node and all child nodes copied from original. + + See Also + -------- + xarray.Dataset.copy + pandas.DataFrame.copy + """ + return self._copy_subtree(deep=deep) + + def _copy_subtree( + self: DataTree, + deep: bool = False, + memo: dict[int, Any] | None = None, + ) -> DataTree: + """Copy entire subtree""" + new_tree = self._copy_node(deep=deep) + for node in self.descendants: + new_tree[node.path] = node._copy_node(deep=deep) + return new_tree + + def _copy_node( + self: DataTree, + deep: bool = False, + ) -> DataTree: + """Copy just one node of a tree""" + new_node: DataTree = DataTree() + new_node.name = self.name + new_node.ds = self.to_dataset().copy(deep=deep) + return new_node + + def __copy__(self: DataTree) -> DataTree: + return self._copy_subtree(deep=False) + + def __deepcopy__(self: DataTree, memo: dict[int, Any] | None = None) -> DataTree: + return self._copy_subtree(deep=True, memo=memo) + def get( self: DataTree, key: str, default: Optional[DataTree | DataArray] = None ) -> Optional[DataTree | DataArray]: @@ -694,8 +754,11 @@ def _set(self, key: str, val: DataTree | CoercibleValue) -> None: Counterpart to the public .get method, and also only works on the immediate node, not other nodes in the tree. """ if isinstance(val, DataTree): - val.name = key - val.parent = self + # TODO shallow copy here so as not to alter name of node in original tree? + # new_node = copy.copy(val, deep=False) + new_node = val + new_node.name = key + new_node.parent = self else: if not isinstance(val, (DataArray, Variable)): # accommodate other types that can be coerced into Variables @@ -792,8 +855,7 @@ def from_dict( # Create and set new node node_name = NodePath(path).name if isinstance(data, cls): - # TODO ignoring type error only needed whilst .copy() method is copied from Dataset.copy(). - new_node = data.copy() # type: ignore[attr-defined] + new_node = data.copy() new_node.orphan() else: new_node = cls(name=node_name, data=data) diff --git a/datatree/ops.py b/datatree/ops.py index bdc931c9..eabc1faf 100644 --- a/datatree/ops.py +++ b/datatree/ops.py @@ -31,9 +31,6 @@ ] _DATASET_METHODS_TO_MAP = [ "as_numpy", - "copy", - "__copy__", - "__deepcopy__", "set_coords", "reset_coords", "info", diff --git a/docs/source/whats-new.rst b/docs/source/whats-new.rst index 5e86ba1c..e13a7143 100644 --- a/docs/source/whats-new.rst +++ b/docs/source/whats-new.rst @@ -35,6 +35,9 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ +- :py:meth:`DataTree.copy` copy method now only copies the subtree, not the parent nodes (:pull:`171`). + By `Tom Nicholas `_. + Deprecations ~~~~~~~~~~~~