Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0e38ed4
Changed flag in EnvBase
matteobettini Jan 6, 2023
7ef1b47
Changed flag in Vec Envs
matteobettini Jan 6, 2023
4cf67cb
Fixed collectors
matteobettini Jan 6, 2023
197c354
Fixed transforms
matteobettini Jan 6, 2023
9d191e1
Fixed env tests and lint
matteobettini Jan 6, 2023
5320414
Fixed env tests and lint
matteobettini Jan 6, 2023
2cf04f9
Linting
matteobettini Jan 6, 2023
155efd6
Added "_reset" kay deletion after use
matteobettini Jan 6, 2023
01a7377
Linting
matteobettini Jan 6, 2023
23a82fc
Added "_reset" kay deletion after use
matteobettini Jan 6, 2023
f8e3d33
removed deletion of reset flag from collector as the flag is deleted …
matteobettini Jan 8, 2023
2c4eb10
Moved _reset check before calling wrapped env
matteobettini Jan 8, 2023
87228e4
refactor how to check if flag is present
matteobettini Jan 8, 2023
34317c0
added tests
matteobettini Jan 8, 2023
9b3d4e7
Linting
matteobettini Jan 8, 2023
a721446
vmas support for _reset flag
matteobettini Jan 8, 2023
e615a52
Merge branch 'main' into reset_flag
matteobettini Jan 8, 2023
ac34b08
refactor
matteobettini Jan 8, 2023
469b6d5
close ParallelEnv
matteobettini Jan 8, 2023
416b066
partial reset mirrored in step_count
matteobettini Jan 9, 2023
9d5754e
Added torch seeding
matteobettini Jan 10, 2023
e1dbaff
Modified tests
matteobettini Jan 10, 2023
d9dc888
Removed last dim of 1
matteobettini Jan 10, 2023
a244c95
removed lid dim of 1 to "step_count"
matteobettini Jan 10, 2023
5daadee
removed another 1
matteobettini Jan 10, 2023
00c351b
change spec of StepCount
matteobettini Jan 10, 2023
6f99133
set the done properly in StepCount
matteobettini Jan 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch.nn as nn
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
BoundedTensorSpec,
Expand Down Expand Up @@ -718,3 +717,53 @@ def __init__(self, in_size, out_size):

def forward(self, observation, action):
return self.linear(torch.cat([observation, action], dim=-1))


class CountingEnv(EnvBase):
def __init__(self, max_steps: int = 5, **kwargs):
super().__init__(**kwargs)
self.max_steps = max_steps

self.observation_spec = CompositeSpec(
observation=UnboundedContinuousTensorSpec((1,))
)
self.reward_spec = UnboundedContinuousTensorSpec((1,))
self.input_spec = CompositeSpec(action=BinaryDiscreteTensorSpec(1))

self.count = torch.zeros(
(*self.batch_size, 1), device=self.device, dtype=torch.int
)

def _set_seed(self, seed: Optional[int]):
torch.manual_seed(seed)

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
if tensordict is not None and "_reset" in tensordict.keys():
_reset = tensordict.get("_reset")
self.count[_reset] = 0
else:
self.count[:] = 0
return TensorDict(
source={
"observation": self.count.clone(),
"done": self.count > self.max_steps,
},
batch_size=self.batch_size,
device=self.device,
)

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict.get("action")
self.count += action.to(torch.int)
return TensorDict(
source={
"observation": self.count,
"done": self.count > self.max_steps,
"reward": torch.zeros_like(self.count, dtype=torch.float),
},
batch_size=self.batch_size,
device=self.device,
)
77 changes: 75 additions & 2 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from mocking_classes import (
ActionObsMergeLinear,
CountingEnv,
DiscreteActionConvMockEnv,
DiscreteActionVecMockEnv,
DummyModelBasedEnvBase,
Expand Down Expand Up @@ -511,7 +512,7 @@ def test_parallel_env(
_ = env_parallel.step(td)

td_reset = TensorDict(
source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()},
source={"_reset": torch.zeros(N, dtype=torch.bool).bernoulli_()},
batch_size=[
N,
],
Expand Down Expand Up @@ -595,7 +596,7 @@ def test_parallel_env_with_policy(
_ = env_parallel.step(td)

td_reset = TensorDict(
source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()},
source={"_reset": torch.zeros(N, dtype=torch.bool).bernoulli_()},
batch_size=[
N,
],
Expand Down Expand Up @@ -900,6 +901,78 @@ def env_fn2(seed):
env1.close()
env2.close()

@pytest.mark.parametrize("batch_size", [(), (1,), (4,), (32, 5)])
@pytest.mark.parametrize("n_workers", [1, 2])
def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3):
torch.manual_seed(1)
env = ParallelEnv(
n_workers, lambda: CountingEnv(max_steps=max_steps, batch_size=batch_size)
)
env.set_seed(1)
action = env.action_spec.rand(env.batch_size)
action[:] = 1

for i in range(max_steps):
td = env.step(
TensorDict(
{"action": action}, batch_size=env.batch_size, device=env.device
)
)
assert (td["done"] == 0).all()
assert (td["next"]["observation"] == i + 1).all()

td = env.step(
TensorDict({"action": action}, batch_size=env.batch_size, device=env.device)
)
assert (td["done"] == 1).all()
assert (td["next"]["observation"] == max_steps + 1).all()

_reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool)
while not _reset.any():
_reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool)

td_reset = env.reset(
TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device)
)
env.close()

assert (td_reset["done"][_reset] == 0).all()
assert (td_reset["observation"][_reset] == 0).all()
assert (td_reset["done"][~_reset] == 1).all()
assert (td_reset["observation"][~_reset] == max_steps + 1).all()


@pytest.mark.parametrize("batch_size", [(), (2,), (32, 5)])
def test_env_base_reset_flag(batch_size, max_steps=3):
env = CountingEnv(max_steps=max_steps, batch_size=batch_size)
env.set_seed(1)

action = env.action_spec.rand(env.batch_size)
action[:] = 1

for i in range(max_steps):
td = env.step(
TensorDict({"action": action}, batch_size=env.batch_size, device=env.device)
)
assert (td["done"] == 0).all()
assert (td["next"]["observation"] == i + 1).all()

td = env.step(
TensorDict({"action": action}, batch_size=env.batch_size, device=env.device)
)
assert (td["done"] == 1).all()
assert (td["next"]["observation"] == max_steps + 1).all()

_reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool)
td_reset = env.reset(
TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device)
)

assert (td_reset["done"][_reset] == 0).all()
assert (td_reset["observation"][_reset] == 0).all()
assert (td_reset["done"][~_reset] == 1).all()
assert (td_reset["observation"][~_reset] == max_steps + 1).all()


@pytest.mark.skipif(not _has_gym, reason="no gym")
def test_seed():
Expand Down
8 changes: 4 additions & 4 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ def test_sum_reward(self, keys, device):
assert (td.get("episode_reward") == 2 * td.get("reward")).all()

# reset environments
td.set("reset_workers", torch.ones((batch, 1), dtype=torch.bool, device=device))
td.set("_reset", torch.ones(batch, dtype=torch.bool, device=device))
rs.reset(td)

# apply a third time, episode_reward should be equal to reward again
Expand Down Expand Up @@ -1724,7 +1724,7 @@ def test_step_counter(self, max_steps, device, batch, reset_workers):
{"done": torch.zeros(*batch, 1, dtype=torch.bool)}, batch, device=device
)
if reset_workers:
td.set("reset_workers", torch.randn(*batch, 1) < 0)
td.set("_reset", torch.randn(batch) < 0)
step_counter.reset(td)
assert not torch.all(td.get("step_count"))
i = 0
Expand All @@ -1740,10 +1740,10 @@ def test_step_counter(self, max_steps, device, batch, reset_workers):
step_counter.reset(td)
if reset_workers:
assert torch.all(
torch.masked_select(td.get("step_count"), td.get("reset_workers")) == 0
torch.masked_select(td.get("step_count"), td.get("_reset")) == 0
)
assert torch.all(
torch.masked_select(td.get("step_count"), ~td.get("reset_workers")) == i
torch.masked_select(td.get("step_count"), ~td.get("_reset")) == i
)
else:
assert torch.all(td.get("step_count") == 0)
Expand Down
20 changes: 10 additions & 10 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
from tensordict.tensordict import TensorDict, TensorDictBase
from torch import multiprocessing as mp
from torch.utils.data import IterableDataset

from torchrl._utils import _check_for_faulty_process, prod
from torchrl.collectors.utils import split_trajectories
from torchrl.data import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import EnvBase

from torchrl.envs.transforms import TransformedEnv
from torchrl.envs.utils import set_exploration_mode, step_mdp
from torchrl.envs.vec_env import _BatchedEnv
Expand Down Expand Up @@ -615,7 +613,7 @@ def _reset_if_necessary(self) -> None:
steps = steps.clone()
if len(self.env.batch_size):
self._tensordict.masked_fill_(done_or_terminated, 0)
self._tensordict.set("reset_workers", done_or_terminated)
self._tensordict.set("_reset", done_or_terminated)
else:
self._tensordict.zero_()
self.env.reset(self._tensordict)
Expand All @@ -624,8 +622,6 @@ def _reset_if_necessary(self) -> None:
raise RuntimeError(
f"Got {sum(self._tensordict.get('done'))} done envs after reset."
)
if len(self.env.batch_size):
self._tensordict.del_("reset_workers")
traj_ids[done_or_terminated] = traj_ids.max() + torch.arange(
1, done_or_terminated.sum() + 1, device=traj_ids.device
)
Expand Down Expand Up @@ -683,23 +679,27 @@ def reset(self, index=None, **kwargs) -> None:
# check that the env supports partial reset
if prod(self.env.batch_size) == 0:
raise RuntimeError("resetting unique env with index is not permitted.")
reset_workers = torch.zeros(
*self.env.batch_size,
_reset = torch.zeros(
self.env.batch_size,
dtype=torch.bool,
device=self.env.device,
)
reset_workers[index] = 1
td_in = TensorDict({"reset_workers": reset_workers}, self.env.batch_size)
_reset[index] = 1
td_in = TensorDict({"_reset": _reset}, self.env.batch_size)
self._tensordict[index].zero_()
else:
_reset = None
td_in = None
self._tensordict.zero_()

if td_in:
self._tensordict.update(td_in, inplace=True)

self._tensordict.update(self.env.reset(**kwargs), inplace=True)
self._tensordict.fill_("step_count", 0)
if _reset is not None:
self._tensordict["step_count"][_reset] = 0
else:
self._tensordict.fill_("step_count", 0)

def shutdown(self) -> None:
"""Shuts down all workers and/or closes the local environment."""
Expand Down
17 changes: 14 additions & 3 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
import torch
import torch.nn as nn
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl.data import CompositeSpec, TensorSpec

from .._utils import prod, seed_generator
from ..data.utils import DEVICE_TYPING

from .utils import get_available_libraries, step_mdp

LIBRARIES = get_available_libraries()
Expand Down Expand Up @@ -428,6 +428,12 @@ def reset(
a tensordict (or the input tensordict, if any), modified in place with the resulting observations.

"""
if tensordict is not None and "_reset" in tensordict.keys():
self._assert_tensordict_shape(tensordict)
_reset = tensordict.get("_reset")
else:
_reset = None

tensordict_reset = self._reset(tensordict, **kwargs)

done = tensordict_reset.get("done", None)
Expand Down Expand Up @@ -457,11 +463,16 @@ def reset(
*tensordict_reset.batch_size, 1, dtype=torch.bool, device=self.device
),
)
if tensordict_reset.get("done").any():

if (_reset is None and tensordict_reset.get("done").any()) or (
_reset is not None and tensordict_reset.get("done")[_reset].any()
):
raise RuntimeError(
f"Env {self} was done after reset. This is (currently) not allowed."
f"Env {self} was done after reset on specified '_reset' dimensions. This is (currently) not allowed."
)
if tensordict is not None:
if "_reset" in tensordict.keys():
tensordict.del_("_reset")
tensordict.update(tensordict_reset)
else:
tensordict = tensordict_reset
Expand Down
14 changes: 12 additions & 2 deletions torchrl/envs/libs/vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
Expand Down Expand Up @@ -203,7 +202,18 @@ def _set_seed(self, seed: Optional[int]):
def _reset(
self, tensordict: Optional[TensorDictBase] = None, **kwargs
) -> TensorDictBase:
obs, infos = self._env.reset(return_info=True)
if tensordict is not None and "_reset" in tensordict.keys():
envs_to_reset = tensordict.get("_reset").any(dim=0)
for env_index, to_reset in enumerate(envs_to_reset):
if to_reset:
self._env.reset_at(env_index)
obs = []
infos = []
for agent in self.agents:
obs.append(self.scenario.observation(agent))
infos.append(self.scenario.info(agent))
else:
obs, infos = self._env.reset(return_info=True)

agent_tds = []
for i in range(self.n_agents):
Expand Down
Loading