-
Notifications
You must be signed in to change notification settings - Fork 412
Description
Describe the bug
Currently in ParallelEnv
the method _run_worker_pipe_shared_mem
has the following line to check if the environment is done:
_td = env._step(_td)
...
if _td.get("done"):
msg = "done"
else:
msg = "step_result"
The problem is that _td.get("done")
is a tensor of shape (*wrapped_env.batch_dim,1)
and thus the if is undefined for a tensor with multiple values.
Furthermore, the ParallelEnv
uses multiple times the funtions all()
and any()
without specified dimensions. This means that it will span over all dimensions (including the environment ones which could be arbitraily many) this has the risk of injecting bias and bugs. This is done for example in the _reset()
function:
self.shared_tensordict_parent.get("done").any()
To Reproduce
Add the instructions parallel_env.rand_step()
to the test introduced in #774
import torchrl
env = MockBatchedLockedEnv(device="cpu", batch_size=torch.Size(env_batch_size))
env.set_seed(1)
parallel_env = ParallelEnv(num_parallel_env, lambda: env)
parallel_env.start()
parallel_env.rand_step()
Traceback (most recent call last):
File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/vec_env.py", line 1032, in _run_worker_pipe_shared_mem
if _td.get("done"):
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
Reason and Possible fixes
The decision on how to handle the done flag in environemnts with non-empty batch sizes is not so trivial.
My suggestion is to remove this logic from the function that is throwing the error and give the done vector to the user as is.
In case that checks have to be done on the done vector we have to keep in mind its arbitrary shape.
I suggest to remove all the logic checks which use all()
and any()
on all dimensions such as
while self.shared_tensordict_parent.get("done").any():
if check_count == 4:
raise RuntimeError("Envs have just been reset but some are still done")
This is because users might need to leave some dims done and reset others.
In favor of this you can check more cleverly that only the dimensions which where reset are actually not done