Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,12 @@ class WrapModule(TensorDictModuleBase):
Keyword Args:
inplace (bool, optional): If ``True``, the input TensorDict will be modified in-place. Otherwise, a new TensorDict
will be returned (if the function does not modify it in-place and returns it). Defaults to ``False``.
in_keys (list of NestedKey, optional): if provided, indicates what entries are read by the module.
This will not be checked and is provided just for the purpose of informing :class:`~tensordict.nn.TensorDictSequential`
about the input keys of the wrapped module. Defaults to `[]`.
out_keys (list of NestedKey, optional): if provided, indicates what entries are written by the module.
This will not be checked and is provided just for the purpose of informing :class:`~tensordict.nn.TensorDictSequential`
about the output keys of the wrapped module. Defaults to `[]`.

Examples:
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule
Expand All @@ -1320,11 +1326,20 @@ class WrapModule(TensorDictModuleBase):
out_keys = []

def __init__(
self, func: Callable[[TensorDictBase], TensorDictBase], *, inplace: bool = False
self,
func: Callable[[TensorDictBase], TensorDictBase],
*,
inplace: bool = False,
in_keys: List[NestedKey] | None = None,
out_keys: List[NestedKey] | None = None,
) -> None:
super().__init__()
self.func = func
self.inplace = inplace
if in_keys is not None:
self.in_keys = in_keys
if out_keys is not None:
self.out_keys = out_keys

def forward(self, data: TensorDictBase) -> TensorDictBase:
result = self.func(data)
Expand Down
Loading