From 035f734206bff0ae1c317435572edc45f72f06b2 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 17 Jan 2023 10:18:34 +0000 Subject: [PATCH] fix done checking --- torchrl/collectors/collectors.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 027b440e598..29b29850a0d 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -614,14 +614,18 @@ def _reset_if_necessary(self) -> None: steps = steps.clone() if len(self.env.batch_size): self._tensordict.masked_fill_(done_or_terminated, 0) - self._tensordict.set("_reset", done_or_terminated) + _reset = done_or_terminated + self._tensordict.set("_reset", _reset) else: + _reset = None self._tensordict.zero_() self.env.reset(self._tensordict) - if self._tensordict.get("done").any(): + if (_reset is None and self._tensordict.get("done").any()) or ( + _reset is not None and self._tensordict.get("done")[_reset].any() + ): raise RuntimeError( - f"Got {sum(self._tensordict.get('done'))} done envs after reset." + f"Env {self.env} was done after reset on specified '_reset' dimensions. This is (currently) not allowed." ) traj_ids[done_or_terminated] = traj_ids.max() + torch.arange( 1, done_or_terminated.sum() + 1, device=traj_ids.device