From 98c0e766000975beb084d33f4a5b27284b31bf21 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 29 May 2023 21:03:59 +0100 Subject: [PATCH 1/3] fix --- tensordict/tensordict.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index d3e6d35f2..8cef0448d 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -5805,6 +5805,41 @@ def entry_class(self, key: NestedKey) -> type: return LazyStackedTensorDict return data_type + def apply_(self, fn: Callable, *others): + if len(others): + raise NotImplementedError( + "LazyStackedTensorDict.apply_(*other) is not implemented yet." + ) + for td in self.tensordicts: + td.apply_(fn) + return self + + def apply( + self, + fn: Callable, + *others: TensorDictBase, + batch_size: Sequence[int] | None = None, + device: torch.device | None = None, + names: Sequence[str] | None = None, + inplace: bool = False, + **constructor_kwargs, + ) -> TensorDictBase: + if inplace: + if any(arg for arg in (batch_size, device, names, constructor_kwargs)): + raise ValueError( + "Cannot pass other arguments to LazyStackedTensorDict.apply when inplace=True." + ) + return self.apply_(fn, *others) + else: + return super().apply( + fn, + *others, + batch_size=batch_size, + device=device, + names=names, + **constructor_kwargs, + ) + def select( self, *keys: str, inplace: bool = False, strict: bool = False ) -> LazyStackedTensorDict: From cbe1e914cd69c7d8cfa5f53e9502fab5fbbe4df0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 2 Jun 2023 18:09:48 +0100 Subject: [PATCH 2/3] amend --- tensordict/tensordict.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 5a5ad538e..3ecdddccf 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -5845,12 +5845,9 @@ def entry_class(self, key: NestedKey) -> type: return data_type def apply_(self, fn: Callable, *others): - if len(others): - raise NotImplementedError( - "LazyStackedTensorDict.apply_(*other) is not implemented yet." - ) - for td in self.tensordicts: - td.apply_(fn) + for i, td in enumerate(self.tensordicts): + idx = (slice(None),) * self.stack_dim + (i,) + td.apply_(fn, *[other[idx] for other in others]) return self def apply( From 8dab80297059053c149c9a1e1ee34271289508e3 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 2 Jun 2023 18:34:28 +0100 Subject: [PATCH 3/3] amend --- tensordict/tensordict.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 3ecdddccf..d41fc6143 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -1249,18 +1249,20 @@ def pop( ) from err return out - def apply_(self, fn: Callable) -> TensorDictBase: + def apply_(self, fn: Callable, *others) -> TensorDictBase: """Applies a callable to all values stored in the tensordict and re-writes them in-place. Args: fn (Callable): function to be applied to the tensors in the tensordict. + *others (sequence of TensorDictBase, optional): the other + tensordicts to be used. Returns: self or a copy of self with the function applied """ - return self.apply(fn, inplace=True) + return self.apply(fn, *others, inplace=True) def apply( self,