Skip to content

[BUG] ParallelEnv handling of done flag when wrapped envs have non-empty batch_size #776

@matteobettini

Description

@matteobettini

Describe the bug

Currently in ParallelEnv the method _run_worker_pipe_shared_mem has the following line to check if the environment is done:

_td = env._step(_td)

...

if _td.get("done"):
    msg = "done"
else:
    msg = "step_result"

The problem is that _td.get("done") is a tensor of shape (*wrapped_env.batch_dim,1) and thus the if is undefined for a tensor with multiple values.

Furthermore, the ParallelEnv uses multiple times the funtions all() and any() without specified dimensions. This means that it will span over all dimensions (including the environment ones which could be arbitraily many) this has the risk of injecting bias and bugs. This is done for example in the _reset() function:

self.shared_tensordict_parent.get("done").any()

To Reproduce

Add the instructions parallel_env.rand_step() to the test introduced in #774

import torchrl
env = MockBatchedLockedEnv(device="cpu", batch_size=torch.Size(env_batch_size))
env.set_seed(1)
parallel_env = ParallelEnv(num_parallel_env, lambda: env)
parallel_env.start()
parallel_env.rand_step()
Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/vec_env.py", line 1032, in _run_worker_pipe_shared_mem
    if _td.get("done"):
RuntimeError: Boolean value of Tensor with more than one value is ambiguous

Reason and Possible fixes

The decision on how to handle the done flag in environemnts with non-empty batch sizes is not so trivial.
My suggestion is to remove this logic from the function that is throwing the error and give the done vector to the user as is.
In case that checks have to be done on the done vector we have to keep in mind its arbitrary shape.

I suggest to remove all the logic checks which use all() and any() on all dimensions such as

 while self.shared_tensordict_parent.get("done").any():
          if check_count == 4:
              raise RuntimeError("Envs have just been reset but some are still done")

This is because users might need to leave some dims done and reset others.

In favor of this you can check more cleverly that only the dimensions which where reset are actually not done

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions