-
Notifications
You must be signed in to change notification settings - Fork 412
[BugFix] RewardSum
transform for multiple reward keys
#1544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
446ca69
26f2de0
ab1155a
cb8a586
4e5205d
cb53b15
a1d2a4f
06b6e73
9ff8fb4
122dbda
8da298f
71ecabf
fa54dd7
169ee36
14354db
24354c0
1009399
3ba9196
38054ee
179c034
48bcb22
5e3a2c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,7 +40,7 @@ | |
from torchrl.envs.common import _EnvPostInit, EnvBase, make_tensordict | ||
from torchrl.envs.transforms import functional as F | ||
from torchrl.envs.transforms.utils import check_finite | ||
from torchrl.envs.utils import _sort_keys, step_mdp | ||
from torchrl.envs.utils import _replace_last, _sort_keys, step_mdp | ||
from torchrl.objectives.value.functional import reward2go | ||
|
||
try: | ||
|
@@ -242,7 +242,7 @@ def _apply_transform(self, obs: torch.Tensor) -> None: | |
|
||
""" | ||
raise NotImplementedError( | ||
f"{self.__class__.__name__}_apply_transform is not coded. If the transform is coded in " | ||
f"{self.__class__.__name__}._apply_transform is not coded. If the transform is coded in " | ||
"transform._call, make sure that this method is called instead of" | ||
"transform.forward, which is reserved for usage inside nn.Modules" | ||
"or appended to a replay buffer." | ||
|
@@ -4342,74 +4342,140 @@ class RewardSum(Transform): | |
"""Tracks episode cumulative rewards. | ||
|
||
This transform accepts a list of tensordict reward keys (i.e. ´in_keys´) and tracks their cumulative | ||
value along each episode. When called, the transform creates a new tensordict key for each in_key named | ||
´episode_{in_key}´ where the cumulative values are written. All ´in_keys´ should be part of the env | ||
reward and be present in the env reward_spec. | ||
value along the time dimension for each episode. | ||
|
||
If no in_keys are specified, this transform assumes ´reward´ to be the input key. However, multiple rewards | ||
(e.g. reward1 and reward2) can also be specified. If ´in_keys´ are not present in the provided tensordict, | ||
this transform hos no effect. | ||
When called, the transform writes a new tensordict entry for each ``in_key`` named | ||
``episode_{in_key}`` where the cumulative values are written. | ||
|
||
.. note:: :class:`~RewardSum` currently only supports ``"done"`` signal at the root. | ||
Nested ``"done"``, such as those found in MARL settings, are currently not supported. | ||
If this feature is needed, please raise an issue on TorchRL repo. | ||
Args: | ||
in_keys (list of NestedKeys, optional): Input reward keys. | ||
All ´in_keys´ should be part of the environment reward_spec. | ||
If no ``in_keys`` are specified, this transform assumes ``"reward"`` to be the input key. | ||
However, multiple rewards (e.g. ``"reward1"`` and ``"reward2""``) can also be specified. | ||
out_keys (list of NestedKeys, optional): The output sum keys, should be one per each input key. | ||
reset_keys (list of NestedKeys, optional): the list of reset_keys to be | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here i preferred having done keys rather than reset keys, this is because users are familiar with what a done key is and could not know about reset keys. Plus there is a 1:1 matching between the 2 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not the way it was done: if I pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another thought about this: per se most users won't need to pass reset keys. We just support it if someone really wants to do nasty things like summing part of the rewards but not all etc. Advanced usage requires advanced understanding so it's fine to ask for reset_keys even this isn't something that is always user-facing. |
||
used, if the parent environment cannot be found. If provided, this | ||
value will prevail over the environment ``reset_keys``. | ||
|
||
Examples: | ||
>>> from torchrl.envs.transforms import RewardSum, TransformedEnv | ||
>>> from torchrl.envs.libs.gym import GymEnv | ||
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), RewardSum()) | ||
>>> td = env.reset() | ||
>>> print(td["episode_reward"]) | ||
tensor([0.]) | ||
>>> td = env.rollout(3) | ||
>>> print(td["next", "episode_reward"]) | ||
tensor([[-0.5926], | ||
[-1.4578], | ||
[-2.7885]]) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_keys: Optional[Sequence[NestedKey]] = None, | ||
out_keys: Optional[Sequence[NestedKey]] = None, | ||
reset_keys: Optional[Sequence[NestedKey]] = None, | ||
): | ||
"""Initialises the transform. Filters out non-reward input keys and defines output keys.""" | ||
if in_keys is None: | ||
in_keys = ["reward"] | ||
if out_keys is None and in_keys == ["reward"]: | ||
out_keys = ["episode_reward"] | ||
elif out_keys is None: | ||
raise RuntimeError( | ||
"the out_keys must be specified for non-conventional in-keys in RewardSum." | ||
super().__init__(in_keys=in_keys, out_keys=out_keys) | ||
self._reset_keys = reset_keys | ||
|
||
@property | ||
def in_keys(self): | ||
in_keys = self.__dict__.get("_in_keys", None) | ||
if in_keys in (None, []): | ||
# retrieve rewards from parent env | ||
parent = self.parent | ||
if parent is None: | ||
in_keys = ["reward"] | ||
vmoens marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
in_keys = copy(parent.reward_keys) | ||
self._in_keys = in_keys | ||
return in_keys | ||
|
||
@in_keys.setter | ||
def in_keys(self, value): | ||
if value is not None: | ||
if isinstance(value, (str, tuple)): | ||
value = [value] | ||
value = [unravel_key(val) for val in value] | ||
self._in_keys = value | ||
|
||
@property | ||
def out_keys(self): | ||
out_keys = self.__dict__.get("_out_keys", None) | ||
if out_keys in (None, []): | ||
out_keys = [ | ||
_replace_last(in_key, f"episode_{_unravel_key_to_tuple(in_key)[-1]}") | ||
for in_key in self.in_keys | ||
] | ||
self._out_keys = out_keys | ||
return out_keys | ||
|
||
@out_keys.setter | ||
def out_keys(self, value): | ||
# we must access the private attribute because this check occurs before | ||
# the parent env is defined | ||
if value is not None and len(self._in_keys) != len(value): | ||
raise ValueError( | ||
"RewardSum expects the same number of input and output keys" | ||
) | ||
if value is not None: | ||
if isinstance(value, (str, tuple)): | ||
value = [value] | ||
value = [unravel_key(val) for val in value] | ||
self._out_keys = value | ||
|
||
super().__init__(in_keys=in_keys, out_keys=out_keys) | ||
@property | ||
def reset_keys(self): | ||
reset_keys = self.__dict__.get("_reset_keys", None) | ||
if reset_keys is None: | ||
parent = self.parent | ||
if parent is None: | ||
raise TypeError( | ||
"reset_keys not provided but parent env not found. " | ||
"Make sure that the reset_keys are provided during " | ||
"construction if the transform does not have a container env." | ||
) | ||
reset_keys = copy(parent.reset_keys) | ||
self._reset_keys = reset_keys | ||
return reset_keys | ||
matteobettini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@reset_keys.setter | ||
def reset_keys(self, value): | ||
if value is not None: | ||
if isinstance(value, (str, tuple)): | ||
value = [value] | ||
value = [unravel_key(val) for val in value] | ||
self._reset_keys = value | ||
|
||
def reset(self, tensordict: TensorDictBase) -> TensorDictBase: | ||
"""Resets episode rewards.""" | ||
# Non-batched environments | ||
_reset = tensordict.get("_reset", None) | ||
if _reset is None: | ||
_reset = torch.ones( | ||
self.parent.done_spec.shape if self.parent else tensordict.batch_size, | ||
dtype=torch.bool, | ||
device=tensordict.device, | ||
) | ||
for in_key, reset_key, out_key in zip( | ||
self.in_keys, self.reset_keys, self.out_keys | ||
): | ||
_reset = tensordict.get(reset_key, None) | ||
|
||
if _reset.any(): | ||
_reset = _reset.sum( | ||
tuple(range(tensordict.batch_dims, _reset.ndim)), dtype=torch.bool | ||
) | ||
reward_key = self.parent.reward_key if self.parent else "reward" | ||
for in_key, out_key in zip(self.in_keys, self.out_keys): | ||
if out_key in tensordict.keys(True, True): | ||
value = tensordict[out_key] | ||
tensordict[out_key] = value.masked_fill( | ||
expand_as_right(_reset, value), 0.0 | ||
) | ||
elif unravel_key(in_key) == unravel_key(reward_key): | ||
if _reset is None or _reset.any(): | ||
value = tensordict.get(out_key, default=None) | ||
if value is not None: | ||
if _reset is None: | ||
tensordict.set(out_key, torch.zeros_like(value)) | ||
else: | ||
tensordict.set( | ||
out_key, | ||
value.masked_fill( | ||
expand_as_right(_reset.squeeze(-1), value), 0.0 | ||
), | ||
) | ||
else: | ||
# Since the episode reward is not in the tensordict, we need to allocate it | ||
# with zeros entirely (regardless of the _reset mask) | ||
tensordict[out_key] = self.parent.reward_spec.zero() | ||
else: | ||
try: | ||
tensordict[out_key] = self.parent.observation_spec[ | ||
in_key | ||
].zero() | ||
except KeyError as err: | ||
raise KeyError( | ||
f"The key {in_key} was not found in the parent " | ||
f"observation_spec with keys " | ||
f"{list(self.parent.observation_spec.keys(True))}. " | ||
) from err | ||
tensordict.set( | ||
out_key, | ||
self.parent.full_reward_spec[in_key].zero(), | ||
) | ||
return tensordict | ||
|
||
def _step( | ||
|
@@ -4430,76 +4496,48 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: | |
state_spec = input_spec["full_state_spec"] | ||
if state_spec is None: | ||
state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device) | ||
reward_spec = self.parent.output_spec["full_reward_spec"] | ||
reward_spec_keys = list(reward_spec.keys(True, True)) | ||
state_spec.update(self._generate_episode_reward_spec()) | ||
input_spec["full_state_spec"] = state_spec | ||
return input_spec | ||
|
||
def _generate_episode_reward_spec(self) -> CompositeSpec: | ||
episode_reward_spec = CompositeSpec() | ||
reward_spec = self.parent.full_reward_spec | ||
reward_spec_keys = self.parent.reward_keys | ||
# Define episode specs for all out_keys | ||
for in_key, out_key in zip(self.in_keys, self.out_keys): | ||
if ( | ||
in_key in reward_spec_keys | ||
): # if this out_key has a corresponding key in reward_spec | ||
out_key = _unravel_key_to_tuple(out_key) | ||
temp_state_spec = state_spec | ||
temp_episode_reward_spec = episode_reward_spec | ||
temp_rew_spec = reward_spec | ||
for sub_key in out_key[:-1]: | ||
if ( | ||
not isinstance(temp_rew_spec, CompositeSpec) | ||
or sub_key not in temp_rew_spec.keys() | ||
): | ||
break | ||
if sub_key not in temp_state_spec.keys(): | ||
temp_state_spec[sub_key] = temp_rew_spec[sub_key].empty() | ||
if sub_key not in temp_episode_reward_spec.keys(): | ||
temp_episode_reward_spec[sub_key] = temp_rew_spec[ | ||
sub_key | ||
].empty() | ||
temp_rew_spec = temp_rew_spec[sub_key] | ||
temp_state_spec = temp_state_spec[sub_key] | ||
state_spec[out_key] = reward_spec[in_key].clone() | ||
temp_episode_reward_spec = temp_episode_reward_spec[sub_key] | ||
episode_reward_spec[out_key] = reward_spec[in_key].clone() | ||
else: | ||
raise ValueError( | ||
f"The in_key: {in_key} is not present in the reward spec {reward_spec}." | ||
) | ||
input_spec["full_state_spec"] = state_spec | ||
return input_spec | ||
return episode_reward_spec | ||
|
||
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: | ||
"""Transforms the observation spec, adding the new keys generated by RewardSum.""" | ||
# Retrieve parent reward spec | ||
reward_spec = self.parent.reward_spec | ||
reward_key = self.parent.reward_key if self.parent else "reward" | ||
|
||
episode_specs = {} | ||
if isinstance(reward_spec, CompositeSpec): | ||
# If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec | ||
if not all(k in reward_spec.keys(True, True) for k in self.in_keys): | ||
raise KeyError("Not all in_keys are present in ´reward_spec´") | ||
|
||
# Define episode specs for all out_keys | ||
for out_key in self.out_keys: | ||
episode_spec = UnboundedContinuousTensorSpec( | ||
shape=reward_spec.shape, | ||
device=reward_spec.device, | ||
dtype=reward_spec.dtype, | ||
) | ||
episode_specs.update({out_key: episode_spec}) | ||
|
||
else: | ||
# If reward_spec is not a CompositeSpec, the only in_key should be ´reward´ | ||
if set(unravel_key_list(self.in_keys)) != {unravel_key(reward_key)}: | ||
raise KeyError( | ||
"reward_spec is not a CompositeSpec class, in_keys should only include ´reward´" | ||
) | ||
|
||
# Define episode spec | ||
episode_spec = UnboundedContinuousTensorSpec( | ||
device=reward_spec.device, | ||
dtype=reward_spec.dtype, | ||
shape=reward_spec.shape, | ||
) | ||
episode_specs.update({self.out_keys[0]: episode_spec}) | ||
|
||
# Update observation_spec with episode_specs | ||
if not isinstance(observation_spec, CompositeSpec): | ||
observation_spec = CompositeSpec( | ||
observation=observation_spec, shape=self.parent.batch_size | ||
) | ||
observation_spec.update(episode_specs) | ||
observation_spec.update(self._generate_episode_reward_spec()) | ||
return observation_spec | ||
|
||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | ||
|
Uh oh!
There was an error while loading. Please reload this page.