diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index a2c1962f7..d41fc6143 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -1249,18 +1249,20 @@ def pop( ) from err return out - def apply_(self, fn: Callable) -> TensorDictBase: + def apply_(self, fn: Callable, *others) -> TensorDictBase: """Applies a callable to all values stored in the tensordict and re-writes them in-place. Args: fn (Callable): function to be applied to the tensors in the tensordict. + *others (sequence of TensorDictBase, optional): the other + tensordicts to be used. Returns: self or a copy of self with the function applied """ - return self.apply(fn, inplace=True) + return self.apply(fn, *others, inplace=True) def apply( self, @@ -5844,6 +5846,38 @@ def entry_class(self, key: NestedKey) -> type: return LazyStackedTensorDict return data_type + def apply_(self, fn: Callable, *others): + for i, td in enumerate(self.tensordicts): + idx = (slice(None),) * self.stack_dim + (i,) + td.apply_(fn, *[other[idx] for other in others]) + return self + + def apply( + self, + fn: Callable, + *others: TensorDictBase, + batch_size: Sequence[int] | None = None, + device: torch.device | None = None, + names: Sequence[str] | None = None, + inplace: bool = False, + **constructor_kwargs, + ) -> TensorDictBase: + if inplace: + if any(arg for arg in (batch_size, device, names, constructor_kwargs)): + raise ValueError( + "Cannot pass other arguments to LazyStackedTensorDict.apply when inplace=True." + ) + return self.apply_(fn, *others) + else: + return super().apply( + fn, + *others, + batch_size=batch_size, + device=device, + names=names, + **constructor_kwargs, + ) + def select( self, *keys: str, inplace: bool = False, strict: bool = False ) -> LazyStackedTensorDict: