Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@
IncrementingEnv,
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
MultiKeyCountingEnv,
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
)
from tensordict import unravel_key
from tensordict.nn import TensorDictSequential
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import _unravel_key_to_tuple
from torch import multiprocessing as mp, nn, Tensor
from torchrl._utils import prod
from torchrl.data import (
Expand Down Expand Up @@ -104,7 +107,7 @@
from torchrl.envs.transforms.transforms import _has_tv
from torchrl.envs.transforms.vc1 import _has_vc
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
from torchrl.envs.utils import check_env_specs, step_mdp
from torchrl.envs.utils import _replace_last, check_env_specs, step_mdp
from torchrl.modules import LSTMModule, MLP, ProbabilisticActor, TanhNormal

TIMEOUT = 100.0
Expand Down Expand Up @@ -4527,6 +4530,36 @@ def test_trans_parallel_env_check(self):
r = env.rollout(4)
assert r["next", "episode_reward"].unique().numel() > 1

@pytest.mark.parametrize("has_in_keys,", [True, False])
def test_trans_multi_key(
self, has_in_keys, n_workers=2, batch_size=(3, 2), max_steps=5
):
torch.manual_seed(0)
env_fun = lambda: MultiKeyCountingEnv(batch_size=batch_size)
base_env = SerialEnv(n_workers, env_fun)
if has_in_keys:
t = RewardSum(in_keys=base_env.reward_keys, reset_keys=base_env.reset_keys)
else:
t = RewardSum()
env = TransformedEnv(
base_env,
Compose(t),
)
policy = MultiKeyCountingEnvPolicy(
full_action_spec=env.action_spec, deterministic=True
)

check_env_specs(env)
td = env.rollout(max_steps, policy=policy)
for reward_key in env.reward_keys:
reward_key = _unravel_key_to_tuple(reward_key)
assert (
td.get(
("next", _replace_last(reward_key, f"episode_{reward_key[-1]}"))
)[(0,) * (len(batch_size) + 1)][-1]
== max_steps
).all()

@pytest.mark.parametrize("in_key", ["reward", ("some", "nested")])
def test_transform_no_env(self, in_key):
t = RewardSum(in_keys=[in_key], out_keys=[("some", "nested_sum")])
Expand All @@ -4550,7 +4583,8 @@ def test_transform_no_env(self, in_key):
def test_transform_compose(
self,
):
t = Compose(RewardSum())
# reset keys should not be needed for offline run
t = Compose(RewardSum(in_keys=["reward"], out_keys=["episode_reward"]))
reward = torch.randn(10)
td = TensorDict({("next", "reward"): reward}, [])
with pytest.raises(
Expand Down Expand Up @@ -4649,6 +4683,9 @@ def test_sum_reward(self, keys, device):

# reset environments
td.set("_reset", torch.ones(batch, dtype=torch.bool, device=device))
with pytest.raises(TypeError, match="reset_keys not provided but parent"):
rs.reset(td)
rs._reset_keys = ["_reset"]
rs.reset(td)

# apply a third time, episode_reward should be equal to reward again
Expand Down
230 changes: 134 additions & 96 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from torchrl.envs.common import _EnvPostInit, EnvBase, make_tensordict
from torchrl.envs.transforms import functional as F
from torchrl.envs.transforms.utils import check_finite
from torchrl.envs.utils import _sort_keys, step_mdp
from torchrl.envs.utils import _replace_last, _sort_keys, step_mdp
from torchrl.objectives.value.functional import reward2go

try:
Expand Down Expand Up @@ -242,7 +242,7 @@ def _apply_transform(self, obs: torch.Tensor) -> None:

"""
raise NotImplementedError(
f"{self.__class__.__name__}_apply_transform is not coded. If the transform is coded in "
f"{self.__class__.__name__}._apply_transform is not coded. If the transform is coded in "
"transform._call, make sure that this method is called instead of"
"transform.forward, which is reserved for usage inside nn.Modules"
"or appended to a replay buffer."
Expand Down Expand Up @@ -4342,74 +4342,140 @@ class RewardSum(Transform):
"""Tracks episode cumulative rewards.

This transform accepts a list of tensordict reward keys (i.e. ´in_keys´) and tracks their cumulative
value along each episode. When called, the transform creates a new tensordict key for each in_key named
´episode_{in_key}´ where the cumulative values are written. All ´in_keys´ should be part of the env
reward and be present in the env reward_spec.
value along the time dimension for each episode.

If no in_keys are specified, this transform assumes ´reward´ to be the input key. However, multiple rewards
(e.g. reward1 and reward2) can also be specified. If ´in_keys´ are not present in the provided tensordict,
this transform hos no effect.
When called, the transform writes a new tensordict entry for each ``in_key`` named
``episode_{in_key}`` where the cumulative values are written.

.. note:: :class:`~RewardSum` currently only supports ``"done"`` signal at the root.
Nested ``"done"``, such as those found in MARL settings, are currently not supported.
If this feature is needed, please raise an issue on TorchRL repo.
Args:
in_keys (list of NestedKeys, optional): Input reward keys.
All ´in_keys´ should be part of the environment reward_spec.
If no ``in_keys`` are specified, this transform assumes ``"reward"`` to be the input key.
However, multiple rewards (e.g. ``"reward1"`` and ``"reward2""``) can also be specified.
out_keys (list of NestedKeys, optional): The output sum keys, should be one per each input key.
reset_keys (list of NestedKeys, optional): the list of reset_keys to be
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here i preferred having done keys rather than reset keys, this is because users are familiar with what a done key is and could not know about reset keys. Plus there is a 1:1 matching between the 2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not the way it was done: if I pass env.done_keys there are some duplicates (eg, truncation / termination / done).
Having a "done_keys" list to me is more dangerous because of this, and eventually the only thing we're pointing is the tree structure where the reset_keys should be found. I personally prefer to pass reset_keys: it's what is needed, ie there is a lower risk that refactoring the done_keys mechanism in the future will break this transform. Per se asking users to pass a list X when we interpolate a list Y that is present within the env already as env.Y seems a convoluted way of doing things.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thought about this: per se most users won't need to pass reset keys. We just support it if someone really wants to do nasty things like summing part of the rewards but not all etc. Advanced usage requires advanced understanding so it's fine to ask for reset_keys even this isn't something that is always user-facing.

used, if the parent environment cannot be found. If provided, this
value will prevail over the environment ``reset_keys``.

Examples:
>>> from torchrl.envs.transforms import RewardSum, TransformedEnv
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), RewardSum())
>>> td = env.reset()
>>> print(td["episode_reward"])
tensor([0.])
>>> td = env.rollout(3)
>>> print(td["next", "episode_reward"])
tensor([[-0.5926],
[-1.4578],
[-2.7885]])
"""

def __init__(
self,
in_keys: Optional[Sequence[NestedKey]] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
reset_keys: Optional[Sequence[NestedKey]] = None,
):
"""Initialises the transform. Filters out non-reward input keys and defines output keys."""
if in_keys is None:
in_keys = ["reward"]
if out_keys is None and in_keys == ["reward"]:
out_keys = ["episode_reward"]
elif out_keys is None:
raise RuntimeError(
"the out_keys must be specified for non-conventional in-keys in RewardSum."
super().__init__(in_keys=in_keys, out_keys=out_keys)
self._reset_keys = reset_keys

@property
def in_keys(self):
in_keys = self.__dict__.get("_in_keys", None)
if in_keys in (None, []):
# retrieve rewards from parent env
parent = self.parent
if parent is None:
in_keys = ["reward"]
else:
in_keys = copy(parent.reward_keys)
self._in_keys = in_keys
return in_keys

@in_keys.setter
def in_keys(self, value):
if value is not None:
if isinstance(value, (str, tuple)):
value = [value]
value = [unravel_key(val) for val in value]
self._in_keys = value

@property
def out_keys(self):
out_keys = self.__dict__.get("_out_keys", None)
if out_keys in (None, []):
out_keys = [
_replace_last(in_key, f"episode_{_unravel_key_to_tuple(in_key)[-1]}")
for in_key in self.in_keys
]
self._out_keys = out_keys
return out_keys

@out_keys.setter
def out_keys(self, value):
# we must access the private attribute because this check occurs before
# the parent env is defined
if value is not None and len(self._in_keys) != len(value):
raise ValueError(
"RewardSum expects the same number of input and output keys"
)
if value is not None:
if isinstance(value, (str, tuple)):
value = [value]
value = [unravel_key(val) for val in value]
self._out_keys = value

super().__init__(in_keys=in_keys, out_keys=out_keys)
@property
def reset_keys(self):
reset_keys = self.__dict__.get("_reset_keys", None)
if reset_keys is None:
parent = self.parent
if parent is None:
raise TypeError(
"reset_keys not provided but parent env not found. "
"Make sure that the reset_keys are provided during "
"construction if the transform does not have a container env."
)
reset_keys = copy(parent.reset_keys)
self._reset_keys = reset_keys
return reset_keys

@reset_keys.setter
def reset_keys(self, value):
if value is not None:
if isinstance(value, (str, tuple)):
value = [value]
value = [unravel_key(val) for val in value]
self._reset_keys = value

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Resets episode rewards."""
# Non-batched environments
_reset = tensordict.get("_reset", None)
if _reset is None:
_reset = torch.ones(
self.parent.done_spec.shape if self.parent else tensordict.batch_size,
dtype=torch.bool,
device=tensordict.device,
)
for in_key, reset_key, out_key in zip(
self.in_keys, self.reset_keys, self.out_keys
):
_reset = tensordict.get(reset_key, None)

if _reset.any():
_reset = _reset.sum(
tuple(range(tensordict.batch_dims, _reset.ndim)), dtype=torch.bool
)
reward_key = self.parent.reward_key if self.parent else "reward"
for in_key, out_key in zip(self.in_keys, self.out_keys):
if out_key in tensordict.keys(True, True):
value = tensordict[out_key]
tensordict[out_key] = value.masked_fill(
expand_as_right(_reset, value), 0.0
)
elif unravel_key(in_key) == unravel_key(reward_key):
if _reset is None or _reset.any():
value = tensordict.get(out_key, default=None)
if value is not None:
if _reset is None:
tensordict.set(out_key, torch.zeros_like(value))
else:
tensordict.set(
out_key,
value.masked_fill(
expand_as_right(_reset.squeeze(-1), value), 0.0
),
)
else:
# Since the episode reward is not in the tensordict, we need to allocate it
# with zeros entirely (regardless of the _reset mask)
tensordict[out_key] = self.parent.reward_spec.zero()
else:
try:
tensordict[out_key] = self.parent.observation_spec[
in_key
].zero()
except KeyError as err:
raise KeyError(
f"The key {in_key} was not found in the parent "
f"observation_spec with keys "
f"{list(self.parent.observation_spec.keys(True))}. "
) from err
tensordict.set(
out_key,
self.parent.full_reward_spec[in_key].zero(),
)
return tensordict

def _step(
Expand All @@ -4430,76 +4496,48 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
state_spec = input_spec["full_state_spec"]
if state_spec is None:
state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device)
reward_spec = self.parent.output_spec["full_reward_spec"]
reward_spec_keys = list(reward_spec.keys(True, True))
state_spec.update(self._generate_episode_reward_spec())
input_spec["full_state_spec"] = state_spec
return input_spec

def _generate_episode_reward_spec(self) -> CompositeSpec:
episode_reward_spec = CompositeSpec()
reward_spec = self.parent.full_reward_spec
reward_spec_keys = self.parent.reward_keys
# Define episode specs for all out_keys
for in_key, out_key in zip(self.in_keys, self.out_keys):
if (
in_key in reward_spec_keys
): # if this out_key has a corresponding key in reward_spec
out_key = _unravel_key_to_tuple(out_key)
temp_state_spec = state_spec
temp_episode_reward_spec = episode_reward_spec
temp_rew_spec = reward_spec
for sub_key in out_key[:-1]:
if (
not isinstance(temp_rew_spec, CompositeSpec)
or sub_key not in temp_rew_spec.keys()
):
break
if sub_key not in temp_state_spec.keys():
temp_state_spec[sub_key] = temp_rew_spec[sub_key].empty()
if sub_key not in temp_episode_reward_spec.keys():
temp_episode_reward_spec[sub_key] = temp_rew_spec[
sub_key
].empty()
temp_rew_spec = temp_rew_spec[sub_key]
temp_state_spec = temp_state_spec[sub_key]
state_spec[out_key] = reward_spec[in_key].clone()
temp_episode_reward_spec = temp_episode_reward_spec[sub_key]
episode_reward_spec[out_key] = reward_spec[in_key].clone()
else:
raise ValueError(
f"The in_key: {in_key} is not present in the reward spec {reward_spec}."
)
input_spec["full_state_spec"] = state_spec
return input_spec
return episode_reward_spec

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
"""Transforms the observation spec, adding the new keys generated by RewardSum."""
# Retrieve parent reward spec
reward_spec = self.parent.reward_spec
reward_key = self.parent.reward_key if self.parent else "reward"

episode_specs = {}
if isinstance(reward_spec, CompositeSpec):
# If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec
if not all(k in reward_spec.keys(True, True) for k in self.in_keys):
raise KeyError("Not all in_keys are present in ´reward_spec´")

# Define episode specs for all out_keys
for out_key in self.out_keys:
episode_spec = UnboundedContinuousTensorSpec(
shape=reward_spec.shape,
device=reward_spec.device,
dtype=reward_spec.dtype,
)
episode_specs.update({out_key: episode_spec})

else:
# If reward_spec is not a CompositeSpec, the only in_key should be ´reward´
if set(unravel_key_list(self.in_keys)) != {unravel_key(reward_key)}:
raise KeyError(
"reward_spec is not a CompositeSpec class, in_keys should only include ´reward´"
)

# Define episode spec
episode_spec = UnboundedContinuousTensorSpec(
device=reward_spec.device,
dtype=reward_spec.dtype,
shape=reward_spec.shape,
)
episode_specs.update({self.out_keys[0]: episode_spec})

# Update observation_spec with episode_specs
if not isinstance(observation_spec, CompositeSpec):
observation_spec = CompositeSpec(
observation=observation_spec, shape=self.parent.batch_size
)
observation_spec.update(episode_specs)
observation_spec.update(self._generate_episode_reward_spec())
return observation_spec

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand Down