diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 027b440e598..29b29850a0d 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -614,14 +614,18 @@ def _reset_if_necessary(self) -> None: steps = steps.clone() if len(self.env.batch_size): self._tensordict.masked_fill_(done_or_terminated, 0) - self._tensordict.set("_reset", done_or_terminated) + _reset = done_or_terminated + self._tensordict.set("_reset", _reset) else: + _reset = None self._tensordict.zero_() self.env.reset(self._tensordict) - if self._tensordict.get("done").any(): + if (_reset is None and self._tensordict.get("done").any()) or ( + _reset is not None and self._tensordict.get("done")[_reset].any() + ): raise RuntimeError( - f"Got {sum(self._tensordict.get('done'))} done envs after reset." + f"Env {self.env} was done after reset on specified '_reset' dimensions. This is (currently) not allowed." ) traj_ids[done_or_terminated] = traj_ids.max() + torch.arange( 1, done_or_terminated.sum() + 1, device=traj_ids.device