Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,10 +1127,15 @@ def __init__(
shape=self.batch_size,
)

def _reset(self, td):
if self.nested_done and td is not None and "_reset" in td.keys():
td["_reset"] = td["_reset"].sum(-2, dtype=torch.bool)
td = super()._reset(td)
def _reset(self, tensordict):
if (
self.nested_done
and tensordict is not None
and "_reset" in tensordict.keys()
):
tensordict = tensordict.clone()
tensordict["_reset"] = tensordict["_reset"].sum(-2, dtype=torch.bool)
td = super()._reset(tensordict)
if self.nested_done:
td[self.done_key] = (
td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)
Expand All @@ -1149,6 +1154,7 @@ def _reset(self, td):

def _step(self, td):
if self.nested_obs_action:
td = td.clone()
td["data"].batch_size = self.batch_size
td[self.action_key] = td[self.action_key].max(-2)[0]
td_root = super()._step(td)
Expand Down
60 changes: 60 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,66 @@ def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3):
assert (td_reset["done"][~_reset] == 1).all()
assert (td_reset["observation"][~_reset] == max_steps + 1).all()

@pytest.mark.parametrize("nested_obs_action", [True, False])
@pytest.mark.parametrize("nested_done", [True, False])
@pytest.mark.parametrize("nested_reward", [True, False])
@pytest.mark.parametrize("env_type", ["serial", "parallel"])
def test_parallel_env_nested(
self,
nested_obs_action,
nested_done,
nested_reward,
env_type,
n_envs=2,
batch_size=(32,),
nested_dim=5,
rollout_length=3,
seed=1,
):
env_fn = lambda: NestedCountingEnv(
nest_done=nested_done,
nest_reward=nested_reward,
nest_obs_action=nested_obs_action,
batch_size=batch_size,
nested_dim=nested_dim,
)
if env_type == "serial":
env = SerialEnv(n_envs, env_fn)
else:
env = ParallelEnv(n_envs, env_fn)
env.set_seed(seed)

batch_size = (n_envs, *batch_size)

td = env.reset()
assert td.batch_size == batch_size
if nested_done or nested_obs_action:
assert td["data"].batch_size == (*batch_size, nested_dim)
if not nested_done and not nested_reward and not nested_obs_action:
assert "data" not in td.keys()

policy = CountingEnvCountPolicy(env.action_spec, env.action_key)
td = env.rollout(rollout_length, policy)
assert td.batch_size == (*batch_size, rollout_length)
if nested_done or nested_obs_action:
assert td["data"].batch_size == (*batch_size, rollout_length, nested_dim)
if nested_reward or nested_done or nested_obs_action:
assert td["next", "data"].batch_size == (
*batch_size,
rollout_length,
nested_dim,
)
if not nested_done and not nested_reward and not nested_obs_action:
assert "data" not in td.keys()
assert "data" not in td["next"].keys()

if nested_obs_action:
assert "observation" not in td.keys()
assert (td[..., -1]["data", "states"] == 2).all()
else:
assert ("data", "states") not in td.keys(True, True)
assert (td[..., -1]["observation"] == 2).all()


@pytest.mark.parametrize("batch_size", [(), (2,), (32, 5)])
def test_env_base_reset_flag(batch_size, max_steps=3):
Expand Down
8 changes: 7 additions & 1 deletion torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,10 @@ def step_mdp(
return out


def _set_single_key(source, dest, key):
def _set_single_key(source, dest, key, clone=False):
# key should be already unraveled
if isinstance(key, str):
key = (key,)
for k in key:
val = source.get(k)
if is_tensor_collection(val):
Expand All @@ -234,6 +237,8 @@ def _set_single_key(source, dest, key):
source = val
dest = new_val
else:
if clone:
val = val.clone()
dest._set(k, val)


Expand Down Expand Up @@ -482,6 +487,7 @@ def __get__(self, owner_self, owner_cls):

def _sort_keys(element):
if isinstance(element, tuple):
element = unravel_keys(element)
return "_-|-_".join(element)
return element

Expand Down
51 changes: 26 additions & 25 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
from tensordict import TensorDict
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from tensordict.utils import unravel_keys
from torch import multiprocessing as mp

from torchrl._utils import _check_for_faulty_process, VERBOSE
Expand All @@ -33,8 +34,7 @@
from torchrl.envs.common import _EnvWrapper, EnvBase
from torchrl.envs.env_creator import get_env_metadata

from torchrl.envs.utils import _sort_keys

from torchrl.envs.utils import _set_single_key, _sort_keys

_has_envpool = importlib.util.find_spec("envpool")

Expand Down Expand Up @@ -324,46 +324,47 @@ def _create_td(self) -> None:

if self._single_task:
self.env_input_keys = sorted(
list(self.input_spec["_action_spec"].keys(True))
+ list(self.state_spec.keys(True)),
list(self.input_spec["_action_spec"].keys(True, True))
+ list(self.state_spec.keys(True, True)),
key=_sort_keys,
)
self.env_output_keys = []
self.env_obs_keys = []
for key in self.output_spec["_observation_spec"].keys(True):
if isinstance(key, str):
key = (key,)
self.env_output_keys.append(("next", *key))
for key in self.output_spec["_observation_spec"].keys(True, True):
self.env_output_keys.append(unravel_keys(("next", key)))
self.env_obs_keys.append(key)
self.env_output_keys.append(("next", "reward"))
self.env_output_keys.append(("next", "done"))
self.env_output_keys.append(unravel_keys(("next", self.reward_key)))
self.env_output_keys.append(unravel_keys(("next", self.done_key)))
else:
env_input_keys = set()
for meta_data in self.meta_data:
if meta_data.specs["input_spec", "_state_spec"] is not None:
env_input_keys = env_input_keys.union(
meta_data.specs["input_spec", "_state_spec"].keys(True)
meta_data.specs["input_spec", "_state_spec"].keys(True, True)
)
env_input_keys = env_input_keys.union(
meta_data.specs["input_spec", "_action_spec"].keys(True)
meta_data.specs["input_spec", "_action_spec"].keys(True, True)
)
env_output_keys = set()
env_obs_keys = set()
for meta_data in self.meta_data:
env_obs_keys = env_obs_keys.union(
key
for key in meta_data.specs["output_spec"]["_observation_spec"].keys(
True
True, True
)
)
env_output_keys = env_output_keys.union(
("next", key) if isinstance(key, str) else ("next", *key)
unravel_keys(("next", key))
for key in meta_data.specs["output_spec"]["_observation_spec"].keys(
True
True, True
)
)
env_output_keys = env_output_keys.union(
{("next", "reward"), ("next", "done")}
{
unravel_keys(("next", self.reward_key)),
unravel_keys(("next", self.done_key)),
}
)
self.env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
self.env_input_keys = sorted(env_input_keys, key=_sort_keys)
Expand All @@ -374,10 +375,10 @@ def _create_td(self) -> None:
.union(self.env_input_keys)
.union(self.env_obs_keys)
)
self._selected_keys.add("done")
self._selected_keys.add(self.done_key)
self._selected_keys.add("_reset")

self._selected_reset_keys = self.env_obs_keys + ["done"] + ["_reset"]
self._selected_reset_keys = self.env_obs_keys + [self.done_key] + ["_reset"]
self._selected_step_keys = self.env_output_keys

if self._single_task:
Expand Down Expand Up @@ -550,7 +551,7 @@ def _step(
if self._single_task:
out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape)
for key in self._selected_step_keys:
out._set(key, self.shared_tensordict_parent.get(key).clone())
_set_single_key(self.shared_tensordict_parent, out, key, clone=True)
else:
# strict=False ensures that non-homogeneous keys are still there
out = self.shared_tensordict_parent.select(
Expand Down Expand Up @@ -619,7 +620,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape)
for key in self._selected_reset_keys:
if key != "_reset":
out._set(key, self.shared_tensordict_parent.get(key).clone())
_set_single_key(self.shared_tensordict_parent, out, key, clone=True)
return out
else:
return self.shared_tensordict_parent.select(
Expand Down Expand Up @@ -790,7 +791,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
if self._single_task:
out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape)
for key in self._selected_step_keys:
out._set(key, self.shared_tensordict_parent.get(key).clone())
_set_single_key(self.shared_tensordict_parent, out, key, clone=True)
else:
# strict=False ensures that non-homogeneous keys are still there
out = self.shared_tensordict_parent.select(
Expand Down Expand Up @@ -853,7 +854,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape)
for key in self._selected_reset_keys:
if key != "_reset":
out._set(key, self.shared_tensordict_parent.get(key).clone())
_set_single_key(self.shared_tensordict_parent, out, key, clone=True)
return out
else:
return self.shared_tensordict_parent.select(
Expand Down Expand Up @@ -1187,7 +1188,7 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:

@torch.no_grad()
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
action = tensordict.get("action")
action = tensordict.get(self.action_key)
# Action needs to be moved to CPU and converted to numpy before being passed to envpool
action = action.to(torch.device("cpu"))
step_output = self._env.step(action.numpy())
Expand Down Expand Up @@ -1285,7 +1286,7 @@ def _transform_reset_output(
)

obs = self.obs.clone(False)
obs.update({"done": self.done_spec.zero()})
obs.update({self.done_key: self.done_spec.zero()})
return obs

def _transform_step_output(
Expand All @@ -1295,7 +1296,7 @@ def _transform_step_output(
obs, reward, done, *_ = envpool_output

obs = self._treevalue_or_numpy_to_tensor_or_dict(obs)
obs.update({"reward": torch.tensor(reward), "done": done})
obs.update({self.reward_key: torch.tensor(reward), self.done_key: done})
self.obs = tensordict_out = TensorDict(
obs,
batch_size=self.batch_size,
Expand Down