diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index d640b5e05b5..4fdd77c9f00 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -46,6 +46,7 @@ ) from torchrl.envs.utils import _sort_keys, _update_during_reset, step_mdp from torchrl.objectives.value.functional import reward2go +from torchrl.objectives.value.utils import _get_num_per_traj, _split_and_pad_sequence try: from torchvision.transforms.functional import center_crop @@ -5600,7 +5601,6 @@ def reset_key(self, value): def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: - _reset = _get_reset(self.reset_key, tensordict) for in_key in self.in_keys: buffer_name = self._buffer_name(in_key) @@ -6686,7 +6686,6 @@ def _step( raise RuntimeError("BurnInTransform can only be appended to a ReplayBuffer.") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - if self.burn_in == 0: return tensordict @@ -6796,3 +6795,211 @@ def _reset( with _set_missing_tolerance(self, True): tensordict_reset = self._call(tensordict_reset) return tensordict_reset + + +class HERTransform(Transform): + """Hindsight Experience Replay (HER) transform. + + This transform is used in reinforcement learning algorithms that employ + Hindsight Experience Replay (HER). HER is a technique that allows an agent + to learn from failed experiences by replaying them with different goals. + + Args: + samples (Optional[Union[int, torch.Tensor]]): The number of augmented samples + to generate for each original sample. Defaults to 4. + generation_type (str): The type of goal generation to use. Can be one of + "future", "random", or "final". Defaults to "future". + achieved_goal_key (Optional[NestedKey]): The key to access the achieved goal + in the input tensor dictionary. Defaults to "achieved_goal". + desired_goal_key (Optional[NestedKey]): The key to access the desired goal + in the input tensor dictionary. Defaults to "desired_goal". + reward_key (Optional[NestedKey]): The key to access the reward in the output + tensor dictionary. Defaults to "reward". + reward_function (Optional[callable]): The reward function to use for calculating + the rewards of augmented samples. Defaults to None, in which case the + `distance_reward_function` is used. + + Attributes: + ENV_ERR (str): The error message to raise when the transform is applied to + the collector or the environment. + + """ + + ENV_ERR = ( + "The Reward2GoTransform is only an inverse transform and can " + "only be applied to the replay buffer and not to the collector or the environment." + ) + + def __init__( + self, + samples: Optional[Union[int, torch.Tensor]] = 4, + generation_type: str = "future", + achieved_goal_key: Optional[NestedKey] = "achieved_goal", + desired_goal_key: Optional[NestedKey] = "desired_goal", + reward_key: Optional[NestedKey] = "reward", + reward_function: Optional[callable] = None, + ): + super().__init__( + in_keys=None, + in_keys_inv=None, + out_keys_inv=None, + ) + self.achieved_goal_key = achieved_goal_key + self.desired_goal_key = desired_goal_key + self.reward_key = reward_key + self.generation_type = generation_type + + if reward_function is None: + self.reward_function = distance_reward_function + else: + self.reward_function = reward_function + + if not isinstance(samples, torch.Tensor): + samples = torch.tensor(samples) + + self.register_buffer("samples", samples) + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + augmentation_td = self.her_augmentation(tensordict) + return torch.cat([tensordict, augmentation_td], dim=0) + + def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor: + return self.her_augmentation(tensordict) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + return tensordict + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + raise ValueError(self.ENV_ERR) + + def her_augmentation(self, sampled_td: TensorDictBase): + if len(sampled_td.shape) == 1: + sampled_td = sampled_td.unsqueeze(0) + b, t = sampled_td.shape + trajectories = _get_num_per_traj(sampled_td.get("terminated")) + splitted_td = _split_and_pad_sequence(sampled_td, trajectories) + splitted_achieved_goals = splitted_td.get(self.achieved_goal_key) + + # get indices for each trajectory + idxs = self.generate_sample_idxs(trajectories) + + # create new goals based idxs + new_goals = [] + for i, ids in enumerate(idxs): + new_goals.append(splitted_achieved_goals[i][ids]) + + # calculate rewards given new desired goals and old achieved goals + vmap_rewards = torch.vmap(distance_reward_function) + rewards = [] + for ach, des in zip(splitted_achieved_goals, new_goals): + rewards.append(vmap_rewards(ach[: des.shape[0], :], des)) + + cat_rewards = torch.cat(rewards).reshape(b, t, self.samples, -1).squeeze(-1) + cat_new_goals = torch.cat(new_goals).reshape(b, t, self.samples, -1) + + augmentation_td = TensorDict( + { + "observation": sampled_td.get("observation").repeat_interleave( + self.samples, dim=0 + ), + "action": sampled_td.get("action").repeat_interleave( + self.samples, dim=0 + ), + "terminated": sampled_td.get("terminated").repeat_interleave( + self.samples, dim=0 + ), + "truncated": sampled_td.get("truncated").repeat_interleave( + self.samples, dim=0 + ), + self.achieved_goal_key: sampled_td.get( + self.achieved_goal_key + ).repeat_interleave(self.samples, dim=0), + }, + batch_size=(b * self.samples, t), + ) + + augmentation_td.set(self.reward_key, cat_rewards.transpose(1, 2).flatten(0, 1)) + augmentation_td.set( + self.desired_goal_key, cat_new_goals.transpose(1, 2).flatten(0, 1) + ) + + return augmentation_td + + def generate_future_idxs(self, traj_lens): + def generate_for_single_traj_len(traj_len): + idxs = [] + for i in range(traj_len - 1): + idxs.append( + torch.randint(low=i + 1, high=traj_len, size=(1, self.samples)) + ) + # correct for the last idx with last idx + idxs.append(torch.full((1, self.samples), fill_value=traj_len - 1)) + return torch.cat(idxs) + + return [generate_for_single_traj_len(traj_len) for traj_len in traj_lens] + + def generate_random_idxs(self, traj_lens): + def generate_for_single_traj_len(traj_len): + idxs = [] + for _ in range(traj_len): + idxs.append(torch.randint(low=0, high=traj_len, size=(1, self.samples))) + return torch.cat(idxs) + + return [generate_for_single_traj_len(traj_len) for traj_len in traj_lens] + + def generate_final_idx(self, traj_lens): + def generate_for_single_traj_len(traj_len): + return torch.full((traj_len, self.samples), fill_value=traj_len - 1) + + return [generate_for_single_traj_len(traj_len) for traj_len in traj_lens] + + def generate_sample_idxs(self, trajectories): + if self.generation_type == "future": + idxs = self.generate_future_idxs(trajectories) + + elif self.generation_type == "random": + idxs = self.generate_random_idxs(trajectories) + + elif self.generation_type == "final": + idxs = self.generate_final_idx(trajectories) + else: + raise ValueError("Invalid generation type") + return idxs + + +def distance_torch(a, b): + """Calculate the Euclidean distance between two tensors. + + Args: + a (torch.Tensor): The first tensor. + b (torch.Tensor): The second tensor. + + Returns: + torch.Tensor: The Euclidean distance between the two tensors. + """ + return torch.linalg.vector_norm(a - b, dim=-1) + + +def distance_reward_function( + achieved_goal: torch.Tensor, + desired_goal: torch.Tensor, + threshold: float = 0.05, + reward_type: str = "sparse", +) -> torch.Tensor: + """Calculates the distance-based reward for a given achieved goal and desired goal. + + Args: + achieved_goal (torch.Tensor): The achieved goal. + desired_goal (torch.Tensor): The desired goal. + threshold (float, optional): The threshold value for determining success. Defaults to 0.05. + reward_type (str, optional): The type of reward to use. Can be "sparse" or "dense". Defaults to "sparse". + + Returns: + torch.Tensor: The distance-based reward. + + """ + d = distance_torch(achieved_goal, desired_goal) + if reward_type == "sparse": + return -(d > threshold).float() + else: + return -d.float()