From 6c002d7492defda3df8af08f069b959fad2a9160 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 13 Jan 2023 12:33:34 +0100 Subject: [PATCH 01/50] Mask in collectors --- torchrl/collectors/collectors.py | 34 +++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 193d60dba76..75bfa307967 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -292,6 +292,11 @@ class SyncDataCollector(_DataCollector): updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. Default is False. + env_batch_size_mask (Tuple[bool, ...], optional): a list of bools of the same length as env.batch_size, + with a value of True it indicates to consider the corresponding dimension of env.batch_size as part of the + batch of environemnts used to collect frames, with a value of False it indicates NOT to consider that dimension + as part of the batch of environemnts used to collect frames (used for agent dimension in multi-agent settings). + Default is None (corresponding to all True). Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -363,6 +368,7 @@ def __init__( init_with_lag: bool = False, return_same_td: bool = False, reset_when_done: bool = True, + env_batch_size_mask: Optional[Tuple[bool, ...]] = None, ): self.closed = True if seed is not None: @@ -399,7 +405,24 @@ def __init__( self.env: EnvBase = env.to(self.passing_device) self.closed = False self.reset_when_done = reset_when_done - self.n_env = self.env.numel() + + if env_batch_size_mask is not None and len(env_batch_size_mask) != len( + self.env.batch_size + ): + raise RuntimeError( + f"Batch size mask and env batch size have different lengths: mask={env_batch_size_mask}, env.batch_size={self.env.batch_size}" + ) + self.env_batch_size_masked = ( + env.batch_size + if env_batch_size_mask is None + else torch.Size( + [ + (dim if is_in else 1) + for dim, is_in in zip(env.batch_size, env_batch_size_mask) + ] + ) + ) + self.n_env = prod(self.env_batch_size_masked) (self.policy, self.device, self.get_weights_fn,) = self._get_policy_and_device( policy=policy, @@ -641,8 +664,13 @@ def rollout(self) -> TensorDictBase: self._tensordict.update(self.env.reset(), inplace=True) self._tensordict.fill_("step_count", 0) - n = self.env.batch_size[0] if len(self.env.batch_size) else 1 - self._tensordict.set("traj_ids", torch.arange(n).view(self.env.batch_size[:1])) + n = max(1, self.env_batch_size_masked.numel()) + self._tensordict.set( + "traj_ids", + torch.arange(n) + .view(self.env_batch_size_masked) + .expand(self.env.batch_size), + ) with set_exploration_mode(self.exploration_mode): for j in range(self.frames_per_batch): From 7245b587aa7876557835a2885d74fd3bd3858ded Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 15 Jan 2023 16:34:31 +0100 Subject: [PATCH 02/50] Works without traj_split --- torchrl/collectors/collectors.py | 45 +++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 75bfa307967..90a387aaea9 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -292,10 +292,10 @@ class SyncDataCollector(_DataCollector): updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. Default is False. - env_batch_size_mask (Tuple[bool, ...], optional): a list of bools of the same length as env.batch_size, + env_batch_size_mask (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size, with a value of True it indicates to consider the corresponding dimension of env.batch_size as part of the - batch of environemnts used to collect frames, with a value of False it indicates NOT to consider that dimension - as part of the batch of environemnts used to collect frames (used for agent dimension in multi-agent settings). + batch of environments used to collect frames. A value of False it indicates NOT to consider that dimension + as part of the batch of environments used to collect frames (used for agent dimension in multi-agent settings). Default is None (corresponding to all True). Examples: @@ -368,7 +368,7 @@ def __init__( init_with_lag: bool = False, return_same_td: bool = False, reset_when_done: bool = True, - env_batch_size_mask: Optional[Tuple[bool, ...]] = None, + env_batch_size_mask: Optional[Sequence[bool]] = None, ): self.closed = True if seed is not None: @@ -418,7 +418,7 @@ def __init__( else torch.Size( [ (dim if is_in else 1) - for dim, is_in in zip(env.batch_size, env_batch_size_mask) + for dim, is_in in zip(self.env.batch_size, env_batch_size_mask) ] ) ) @@ -848,7 +848,7 @@ class _MultiDataCollector(_DataCollector): init_with_lag (bool, optional): if True, the first trajectory will be truncated earlier at a random step. This is helpful to desynchronize the environments, such that steps do no match in all collected rollouts. default = True - exploration_mode (str, optional): interaction mode to be used when collecting data. Must be one of "random", + exploration_mode (str, optional): interaction mode to be used when collecting data. Must be one of "random", "mode" or "mean". default = "random" reset_when_done (bool, optional): if True, the contained environment will be reset @@ -858,6 +858,12 @@ class _MultiDataCollector(_DataCollector): in other words, if the env is a multi-agent env, all agents will be reset once one of them is done. Defaults to `True`. + env_batch_size_mask ((list of) Sequence[bool], optional): can be a list of sequences, one for each environment, or + one sequence, shared by all environments. Each sequence contains bool values and is of the same length as env.batch_size. + A value of True it indicates to consider the corresponding dimension of env.batch_size as part of the batch of environments + used to collect frames, with a value of False it indicates NOT to consider that dimension as part of the + batch of environments used to collect frames (used for agent dimension in multi-agent settings). + Default is None (corresponding to all True). """ @@ -886,6 +892,8 @@ def __init__( init_with_lag: bool = False, exploration_mode: str = DEFAULT_EXPLORATION_MODE, reset_when_done: bool = True, + env_batch_size_mask: + Union[Sequence[Sequence[bool]], Sequence[bool], None] = None, ): self.closed = True self.create_env_fn = create_env_fn @@ -976,6 +984,21 @@ def device_err_msg(device_name, devices_list): f"Found {type(passing_devices)} instead." ) + if env_batch_size_mask is not None: + if isinstance(env_batch_size_mask[0], Sequence): + if len(env_batch_size_mask) != self.num_workers: + raise RuntimeError( + f"Number of batch_size masks provided {len(env_batch_size_mask)} does not match" + f" number of collector workers {self.num_workers}" + ) + self.env_batch_size_masks = list(env_batch_size_mask) + else: + self.env_batch_size_masks = [ + env_batch_size_mask for _ in range(self.num_workers) + ] + else: + self.env_batch_size_masks = [None for _ in range(self.num_workers)] + self.total_frames = total_frames if total_frames > 0 else float("inf") self.reset_at_each_iter = reset_at_each_iter self.postprocs = postproc @@ -1051,6 +1074,7 @@ def _run_processes(self) -> None: "exploration_mode": self.exploration_mode, "reset_when_done": self.reset_when_done, "idx": i, + "env_batch_size_mask": self.env_batch_size_masks[i], } proc = mp.Process(target=_main_async_collector, kwargs=kwargs) # proc.daemon can't be set as daemonic processes may be launched by the process itself @@ -1551,6 +1575,11 @@ class aSyncDataCollector(MultiaSyncDataCollector): init_with_lag (bool, optional): if True, the first trajectory will be truncated earlier at a random step. This is helpful to desynchronize the environments, such that steps do no match in all collected rollouts. default = True + env_batch_size_mask (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size, + with a value of True it indicates to consider the corresponding dimension of env.batch_size as part of the + batch of environments used to collect frames. A value of False it indicates NOT to consider that dimension + as part of the batch of environments used to collect frames (used for agent dimension in multi-agent settings). + Default is None (corresponding to all True). """ @@ -1574,6 +1603,7 @@ def __init__( device: Optional[Union[int, str, torch.device]] = None, passing_device: Optional[Union[int, str, torch.device]] = None, seed: Optional[int] = None, + env_batch_size_mask: Optional[Sequence[bool]] = None, pin_memory: bool = False, **kwargs, ): @@ -1592,6 +1622,7 @@ def __init__( passing_devices=[passing_device] if passing_device is not None else None, seed=seed, pin_memory=pin_memory, + env_batch_size_mask=env_batch_size_mask, **kwargs, ) @@ -1615,6 +1646,7 @@ def _main_async_collector( init_with_lag: bool = False, exploration_mode: str = DEFAULT_EXPLORATION_MODE, reset_when_done: bool = True, + env_batch_size_mask: Optional[Sequence[bool]] = None, verbose: bool = False, ) -> None: pipe_parent.close() @@ -1639,6 +1671,7 @@ def _main_async_collector( exploration_mode=exploration_mode, reset_when_done=reset_when_done, return_same_td=True, + env_batch_size_mask=env_batch_size_mask, ) if verbose: print("Sync data collector created") From a772ef549237db254b85fd602b0e9000dff980bc Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 15 Jan 2023 18:44:20 +0100 Subject: [PATCH 03/50] Temp --- torchrl/collectors/collectors.py | 67 +++++++++----------------------- torchrl/collectors/utils.py | 25 +++++++++++- 2 files changed, 43 insertions(+), 49 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 90a387aaea9..2b369392114 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -22,8 +22,9 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import multiprocessing as mp from torch.utils.data import IterableDataset + from torchrl._utils import _check_for_faulty_process, prod -from torchrl.collectors.utils import split_trajectories +from torchrl.collectors.utils import split_trajectories, numel_with_mask, get_batch_size_masked from torchrl.data import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import EnvBase @@ -406,23 +407,10 @@ def __init__( self.closed = False self.reset_when_done = reset_when_done - if env_batch_size_mask is not None and len(env_batch_size_mask) != len( - self.env.batch_size - ): - raise RuntimeError( - f"Batch size mask and env batch size have different lengths: mask={env_batch_size_mask}, env.batch_size={self.env.batch_size}" - ) - self.env_batch_size_masked = ( - env.batch_size - if env_batch_size_mask is None - else torch.Size( - [ - (dim if is_in else 1) - for dim, is_in in zip(self.env.batch_size, env_batch_size_mask) - ] - ) - ) - self.n_env = prod(self.env_batch_size_masked) + self.env_batch_size_mask = env_batch_size_mask + self.out_batch_size_mask = None if env_batch_size_mask is None else list(env_batch_size_mask) + [True] + self.env_batch_size_masked = get_batch_size_masked(self.env.batch_size, self.env_batch_size_mask) + self.n_env = max(1, self.env_batch_size_masked.numel()) (self.policy, self.device, self.get_weights_fn,) = self._get_policy_and_device( policy=policy, @@ -557,7 +545,7 @@ def iterator(self) -> Iterator[TensorDictBase]: i += 1 self._iter = i tensordict_out = self.rollout() - self._frames += tensordict_out.numel() + self._frames += numel_with_mask(tensordict_out.batch_size, self.out_batch_size_mask) if self._frames >= total_frames: self.env.close() @@ -664,10 +652,9 @@ def rollout(self) -> TensorDictBase: self._tensordict.update(self.env.reset(), inplace=True) self._tensordict.fill_("step_count", 0) - n = max(1, self.env_batch_size_masked.numel()) self._tensordict.set( "traj_ids", - torch.arange(n) + torch.arange(self.n_env) .view(self.env_batch_size_masked) .expand(self.env.batch_size), ) @@ -695,7 +682,6 @@ def rollout(self) -> TensorDictBase: self._tensordict_out[..., j] = self._tensordict if is_shared: self._tensordict_out.share_memory_() - self._reset_if_necessary() self._tensordict.update(step_mdp(self._tensordict), inplace=True) @@ -858,11 +844,10 @@ class _MultiDataCollector(_DataCollector): in other words, if the env is a multi-agent env, all agents will be reset once one of them is done. Defaults to `True`. - env_batch_size_mask ((list of) Sequence[bool], optional): can be a list of sequences, one for each environment, or - one sequence, shared by all environments. Each sequence contains bool values and is of the same length as env.batch_size. - A value of True it indicates to consider the corresponding dimension of env.batch_size as part of the batch of environments - used to collect frames, with a value of False it indicates NOT to consider that dimension as part of the - batch of environments used to collect frames (used for agent dimension in multi-agent settings). + env_batch_size_mask (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size, + with a value of True it indicates to consider the corresponding dimension of env.batch_size as part of the + batch of environments used to collect frames. A value of False it indicates NOT to consider that dimension + as part of the batch of environments used to collect frames (used for agent dimension in multi-agent settings). Default is None (corresponding to all True). """ @@ -892,8 +877,7 @@ def __init__( init_with_lag: bool = False, exploration_mode: str = DEFAULT_EXPLORATION_MODE, reset_when_done: bool = True, - env_batch_size_mask: - Union[Sequence[Sequence[bool]], Sequence[bool], None] = None, + env_batch_size_mask: Optional[Sequence[bool]] = None, ): self.closed = True self.create_env_fn = create_env_fn @@ -984,21 +968,8 @@ def device_err_msg(device_name, devices_list): f"Found {type(passing_devices)} instead." ) - if env_batch_size_mask is not None: - if isinstance(env_batch_size_mask[0], Sequence): - if len(env_batch_size_mask) != self.num_workers: - raise RuntimeError( - f"Number of batch_size masks provided {len(env_batch_size_mask)} does not match" - f" number of collector workers {self.num_workers}" - ) - self.env_batch_size_masks = list(env_batch_size_mask) - else: - self.env_batch_size_masks = [ - env_batch_size_mask for _ in range(self.num_workers) - ] - else: - self.env_batch_size_masks = [None for _ in range(self.num_workers)] - + self.env_batch_size_mask = env_batch_size_mask + self.out_batch_size_mask = None if env_batch_size_mask is None else list(env_batch_size_mask) + [True] self.total_frames = total_frames if total_frames > 0 else float("inf") self.reset_at_each_iter = reset_at_each_iter self.postprocs = postproc @@ -1074,7 +1045,7 @@ def _run_processes(self) -> None: "exploration_mode": self.exploration_mode, "reset_when_done": self.reset_when_done, "idx": i, - "env_batch_size_mask": self.env_batch_size_masks[i], + "env_batch_size_mask": self.env_batch_size_mask, } proc = mp.Process(target=_main_async_collector, kwargs=kwargs) # proc.daemon can't be set as daemonic processes may be launched by the process itself @@ -1311,7 +1282,7 @@ def iterator(self) -> Iterator[TensorDictBase]: else: idx = new_data workers_frames[idx] = ( - workers_frames[idx] + out_tensordicts_shared[idx].numel() + workers_frames[idx] + numel_with_mask(out_tensordicts_shared[idx].batch_size, self.out_batch_size_mask) ) if workers_frames[idx] >= self.total_frames: @@ -1349,7 +1320,7 @@ def iterator(self) -> Iterator[TensorDictBase]: frames += out.get("mask").sum().item() else: out = out_buffer.clone() - frames += prod(out.shape) + frames += numel_with_mask(out.batch_size, self.out_batch_size_mask) if self.postprocs: self.postprocs = self.postprocs.to(out.device) out = self.postprocs(out) @@ -1469,7 +1440,7 @@ def iterator(self) -> Iterator[TensorDictBase]: i += 1 idx, j, out = self._get_from_queue() - worker_frames = out.numel() + worker_frames = numel_with_mask(out.batch_size, self.out_batch_size_mask) if self.split_trajs: out = split_trajectories(out) self._frames += worker_frames diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 6f4efe4d96f..de700688be5 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -3,7 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable +from typing import Callable, Optional +from typing import Sequence import torch from tensordict.tensordict import pad, TensorDictBase @@ -68,3 +69,25 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: ).contiguous() td = td.unflatten_keys(sep) return td + + +def numel_with_mask(batch_size: torch.Size, mask: Optional[Sequence[bool]] = None): + return max(1, get_batch_size_masked(batch_size, mask).numel()) + + +def get_batch_size_masked(batch_size: torch.Size, mask: Optional[Sequence[bool]] = None): + if mask is None: + return batch_size + if mask is not None and len(mask) != len(batch_size): + raise RuntimeError( + f"Batch size mask and env batch size have different lengths: mask={mask}, env.batch_size={batch_size}" + ) + return torch.Size( + [ + (dim if is_in else 1) + for dim, is_in in zip( + batch_size, + mask, + ) + ] + ) From 67c87d6b7c14f7fd0f16e67a35608abcdbca4158 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 15 Jan 2023 18:55:30 +0100 Subject: [PATCH 04/50] step count --- torchrl/collectors/collectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 2b369392114..0527e429099 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -670,7 +670,7 @@ def rollout(self) -> TensorDictBase: self._tensordict = self.env.step(self._tensordict) step_count = self._tensordict.get("step_count") - step_count += 1 + self._tensordict.set_("step_count", step_count + 1) # we must clone all the values, since the step / traj_id updates are done in-place try: self._tensordict_out[..., j] = self._tensordict From bb76d9e261e8267309eb727ad154cb79cdbf9d96 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 15 Jan 2023 19:08:52 +0100 Subject: [PATCH 05/50] clone the expand --- torchrl/collectors/collectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 0527e429099..90abd94a84d 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -656,7 +656,7 @@ def rollout(self) -> TensorDictBase: "traj_ids", torch.arange(self.n_env) .view(self.env_batch_size_masked) - .expand(self.env.batch_size), + .expand(self.env.batch_size).clone(), ) with set_exploration_mode(self.exploration_mode): From e1c0df050424214c65aac35950bafa1afe1a463b Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 15 Jan 2023 19:10:57 +0100 Subject: [PATCH 06/50] link and docs --- torchrl/collectors/collectors.py | 29 +++++++++++++++++++++-------- torchrl/collectors/utils.py | 9 ++++++--- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 90abd94a84d..68aa476447a 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -24,7 +24,11 @@ from torch.utils.data import IterableDataset from torchrl._utils import _check_for_faulty_process, prod -from torchrl.collectors.utils import split_trajectories, numel_with_mask, get_batch_size_masked +from torchrl.collectors.utils import ( + get_batch_size_masked, + numel_with_mask, + split_trajectories, +) from torchrl.data import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import EnvBase @@ -408,8 +412,12 @@ def __init__( self.reset_when_done = reset_when_done self.env_batch_size_mask = env_batch_size_mask - self.out_batch_size_mask = None if env_batch_size_mask is None else list(env_batch_size_mask) + [True] - self.env_batch_size_masked = get_batch_size_masked(self.env.batch_size, self.env_batch_size_mask) + self.out_batch_size_mask = ( + None if env_batch_size_mask is None else list(env_batch_size_mask) + [True] + ) + self.env_batch_size_masked = get_batch_size_masked( + self.env.batch_size, self.env_batch_size_mask + ) self.n_env = max(1, self.env_batch_size_masked.numel()) (self.policy, self.device, self.get_weights_fn,) = self._get_policy_and_device( @@ -545,7 +553,9 @@ def iterator(self) -> Iterator[TensorDictBase]: i += 1 self._iter = i tensordict_out = self.rollout() - self._frames += numel_with_mask(tensordict_out.batch_size, self.out_batch_size_mask) + self._frames += numel_with_mask( + tensordict_out.batch_size, self.out_batch_size_mask + ) if self._frames >= total_frames: self.env.close() @@ -656,7 +666,8 @@ def rollout(self) -> TensorDictBase: "traj_ids", torch.arange(self.n_env) .view(self.env_batch_size_masked) - .expand(self.env.batch_size).clone(), + .expand(self.env.batch_size) + .clone(), ) with set_exploration_mode(self.exploration_mode): @@ -969,7 +980,9 @@ def device_err_msg(device_name, devices_list): ) self.env_batch_size_mask = env_batch_size_mask - self.out_batch_size_mask = None if env_batch_size_mask is None else list(env_batch_size_mask) + [True] + self.out_batch_size_mask = ( + None if env_batch_size_mask is None else list(env_batch_size_mask) + [True] + ) self.total_frames = total_frames if total_frames > 0 else float("inf") self.reset_at_each_iter = reset_at_each_iter self.postprocs = postproc @@ -1281,8 +1294,8 @@ def iterator(self) -> Iterator[TensorDictBase]: out_tensordicts_shared[idx] = data else: idx = new_data - workers_frames[idx] = ( - workers_frames[idx] + numel_with_mask(out_tensordicts_shared[idx].batch_size, self.out_batch_size_mask) + workers_frames[idx] = workers_frames[idx] + numel_with_mask( + out_tensordicts_shared[idx].batch_size, self.out_batch_size_mask ) if workers_frames[idx] >= self.total_frames: diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index de700688be5..40533d415d7 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -3,8 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Optional -from typing import Sequence +from typing import Callable, Optional, Sequence import torch from tensordict.tensordict import pad, TensorDictBase @@ -72,10 +71,14 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: def numel_with_mask(batch_size: torch.Size, mask: Optional[Sequence[bool]] = None): + """Performs numel() with a given mask.""" return max(1, get_batch_size_masked(batch_size, mask).numel()) -def get_batch_size_masked(batch_size: torch.Size, mask: Optional[Sequence[bool]] = None): +def get_batch_size_masked( + batch_size: torch.Size, mask: Optional[Sequence[bool]] = None +): + """Returns a size with the masked dimensions equal to 1.""" if mask is None: return batch_size if mask is not None and len(mask) != len(batch_size): From 630f4dec43ccbcce4c7dab1cfde92b64e18c5aeb Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 20 Jan 2023 09:34:01 +0000 Subject: [PATCH 07/50] added errors for frames overflow --- torchrl/collectors/collectors.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 68aa476447a..29e895a3ccf 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -22,7 +22,6 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import multiprocessing as mp from torch.utils.data import IterableDataset - from torchrl._utils import _check_for_faulty_process, prod from torchrl.collectors.utils import ( get_batch_size_masked, @@ -436,6 +435,10 @@ def __init__( if self.postproc is not None: self.postproc.to(self.passing_device) self.max_frames_per_traj = max_frames_per_traj + if frames_per_batch % self.n_env != 0: + raise RuntimeError( + f"frames_per_batch {frames_per_batch} is not exactly divisible by the number of batched environments {self.n_env}, this is currently not allowed" + ) self.frames_per_batch = -(-frames_per_batch // self.n_env) self.pin_memory = pin_memory self.exploration_mode = ( @@ -1259,6 +1262,10 @@ class MultiSyncDataCollector(_MultiDataCollector): @property def frames_per_batch_worker(self): + if self.frames_per_batch % self.num_workers != 0: + raise RuntimeError( + f"frames_per_batch {self.frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}, this is currently not allowed" + ) return -(-self.frames_per_batch // self.num_workers) @property From cc99b2ab07b0e6bbd14645612292dbae87411e51 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 20 Jan 2023 11:25:24 +0000 Subject: [PATCH 08/50] now all batch dimensions are squashed in first one --- torchrl/collectors/collectors.py | 70 ++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 28aed8f0f0e..2fba5dfc000 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -22,7 +22,6 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import multiprocessing as mp from torch.utils.data import IterableDataset - from torchrl._utils import _check_for_faulty_process, prod from torchrl.collectors.utils import ( get_batch_size_masked, @@ -297,7 +296,7 @@ class SyncDataCollector(_DataCollector): updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. Default is False. - env_batch_size_mask (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size, + mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size, with a value of True it indicates to consider the corresponding dimension of env.batch_size as part of the batch of environments used to collect frames. A value of False it indicates NOT to consider that dimension as part of the batch of environments used to collect frames (used for agent dimension in multi-agent settings). @@ -373,7 +372,7 @@ def __init__( init_with_lag: bool = False, return_same_td: bool = False, reset_when_done: bool = True, - env_batch_size_mask: Optional[Sequence[bool]] = None, + mask_env_batch_size: Optional[Sequence[bool]] = None, ): self.closed = True if seed is not None: @@ -411,12 +410,23 @@ def __init__( self.closed = False self.reset_when_done = reset_when_done - self.env_batch_size_mask = env_batch_size_mask - self.out_batch_size_mask = ( - None if env_batch_size_mask is None else list(env_batch_size_mask) + [True] - ) + # Batch sizes and masks + if mask_env_batch_size is None: + mask_env_batch_size = [True for _ in self.env.batch_size] + else: + mask_env_batch_size = list(mask_env_batch_size) + self.mask_env_batch_size = mask_env_batch_size + self.mask_out_batch_size = mask_env_batch_size + [True] + self.permute_out_batch_size = [ + i for i, is_batch in enumerate(self.mask_out_batch_size) if is_batch + ] + [i for i, is_batch in enumerate(self.mask_out_batch_size) if not is_batch] + self.env_batch_size_umasked = [ + env.batch_size[i] + for i, is_batch in enumerate(self.mask_env_batch_size) + if not is_batch + ] self.env_batch_size_masked = get_batch_size_masked( - self.env.batch_size, self.env_batch_size_mask + self.env.batch_size, self.mask_env_batch_size ) self.n_env = max(1, self.env_batch_size_masked.numel()) @@ -550,7 +560,6 @@ def iterator(self) -> Iterator[TensorDictBase]: Yields: TensorDictBase objects containing (chunks of) trajectories """ - total_frames = self.total_frames i = -1 self._frames = 0 while True: @@ -558,11 +567,21 @@ def iterator(self) -> Iterator[TensorDictBase]: self._iter = i tensordict_out = self.rollout() self._frames += numel_with_mask( - tensordict_out.batch_size, self.out_batch_size_mask + tensordict_out.batch_size, self.mask_out_batch_size ) - if self._frames >= total_frames: + if self._frames >= self.total_frames: self.env.close() + # Bring all batch dimensions to the front (only performs computation if it is not already the case) + tensordict_out = tensordict_out.permute(self.permute_out_batch_size) + # Flatten all batch dimensions into first one and leave unmasked dimensions untouched + if len(self.env_batch_size_umasked) > 0: + tensordict_out = tensordict_out.reshape( + self.frames_per_batch * self.n_env, *self.env_batch_size_umasked + ) + else: + tensordict_out = tensordict_out.view(-1).to_tensordict() + if self.split_trajs: tensordict_out = split_trajectories(tensordict_out) if self.postproc is not None: @@ -863,7 +882,7 @@ class _MultiDataCollector(_DataCollector): in other words, if the env is a multi-agent env, all agents will be reset once one of them is done. Defaults to `True`. - env_batch_size_mask (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size, + mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size, with a value of True it indicates to consider the corresponding dimension of env.batch_size as part of the batch of environments used to collect frames. A value of False it indicates NOT to consider that dimension as part of the batch of environments used to collect frames (used for agent dimension in multi-agent settings). @@ -896,7 +915,7 @@ def __init__( init_with_lag: bool = False, exploration_mode: str = DEFAULT_EXPLORATION_MODE, reset_when_done: bool = True, - env_batch_size_mask: Optional[Sequence[bool]] = None, + mask_env_batch_size: Optional[Sequence[bool]] = None, ): self.closed = True self.create_env_fn = create_env_fn @@ -987,10 +1006,7 @@ def device_err_msg(device_name, devices_list): f"Found {type(passing_devices)} instead." ) - self.env_batch_size_mask = env_batch_size_mask - self.out_batch_size_mask = ( - None if env_batch_size_mask is None else list(env_batch_size_mask) + [True] - ) + self.mask_env_batch_size = mask_env_batch_size self.total_frames = total_frames if total_frames > 0 else float("inf") self.reset_at_each_iter = reset_at_each_iter self.postprocs = postproc @@ -1066,7 +1082,7 @@ def _run_processes(self) -> None: "exploration_mode": self.exploration_mode, "reset_when_done": self.reset_when_done, "idx": i, - "env_batch_size_mask": self.env_batch_size_mask, + "mask_env_batch_size": self.mask_env_batch_size, } proc = mp.Process(target=_main_async_collector, kwargs=kwargs) # proc.daemon can't be set as daemonic processes may be launched by the process itself @@ -1306,8 +1322,8 @@ def iterator(self) -> Iterator[TensorDictBase]: out_tensordicts_shared[idx] = data else: idx = new_data - workers_frames[idx] = workers_frames[idx] + numel_with_mask( - out_tensordicts_shared[idx].batch_size, self.out_batch_size_mask + workers_frames[idx] = ( + workers_frames[idx] + out_tensordicts_shared[idx].batch_size[0] ) if workers_frames[idx] >= self.total_frames: @@ -1345,7 +1361,7 @@ def iterator(self) -> Iterator[TensorDictBase]: frames += out.get("mask").sum().item() else: out = out_buffer.clone() - frames += numel_with_mask(out.batch_size, self.out_batch_size_mask) + frames += out.batch_size[0] if self.postprocs: self.postprocs = self.postprocs.to(out.device) out = self.postprocs(out) @@ -1465,7 +1481,7 @@ def iterator(self) -> Iterator[TensorDictBase]: i += 1 idx, j, out = self._get_from_queue() - worker_frames = numel_with_mask(out.batch_size, self.out_batch_size_mask) + worker_frames = out.batch_size[0] if self.split_trajs: out = split_trajectories(out) self._frames += worker_frames @@ -1571,7 +1587,7 @@ class aSyncDataCollector(MultiaSyncDataCollector): init_with_lag (bool, optional): if True, the first trajectory will be truncated earlier at a random step. This is helpful to desynchronize the environments, such that steps do no match in all collected rollouts. default = True - env_batch_size_mask (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size, + mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size, with a value of True it indicates to consider the corresponding dimension of env.batch_size as part of the batch of environments used to collect frames. A value of False it indicates NOT to consider that dimension as part of the batch of environments used to collect frames (used for agent dimension in multi-agent settings). @@ -1599,7 +1615,7 @@ def __init__( device: Optional[Union[int, str, torch.device]] = None, passing_device: Optional[Union[int, str, torch.device]] = None, seed: Optional[int] = None, - env_batch_size_mask: Optional[Sequence[bool]] = None, + mask_env_batch_size: Optional[Sequence[bool]] = None, pin_memory: bool = False, **kwargs, ): @@ -1618,7 +1634,7 @@ def __init__( passing_devices=[passing_device] if passing_device is not None else None, seed=seed, pin_memory=pin_memory, - env_batch_size_mask=env_batch_size_mask, + mask_env_batch_size=mask_env_batch_size, **kwargs, ) @@ -1642,7 +1658,7 @@ def _main_async_collector( init_with_lag: bool = False, exploration_mode: str = DEFAULT_EXPLORATION_MODE, reset_when_done: bool = True, - env_batch_size_mask: Optional[Sequence[bool]] = None, + mask_env_batch_size: Optional[Sequence[bool]] = None, verbose: bool = False, ) -> None: pipe_parent.close() @@ -1667,7 +1683,7 @@ def _main_async_collector( exploration_mode=exploration_mode, reset_when_done=reset_when_done, return_same_td=True, - env_batch_size_mask=env_batch_size_mask, + mask_env_batch_size=mask_env_batch_size, ) if verbose: print("Sync data collector created") From 352c9c01da06311e2898d08b76b2dedd1517b2d4 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 20 Jan 2023 11:28:17 +0000 Subject: [PATCH 09/50] amend --- torchrl/collectors/collectors.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 2fba5dfc000..58b919bec12 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -25,7 +25,6 @@ from torchrl._utils import _check_for_faulty_process, prod from torchrl.collectors.utils import ( get_batch_size_masked, - numel_with_mask, split_trajectories, ) from torchrl.data import TensorSpec @@ -566,11 +565,6 @@ def iterator(self) -> Iterator[TensorDictBase]: i += 1 self._iter = i tensordict_out = self.rollout() - self._frames += numel_with_mask( - tensordict_out.batch_size, self.mask_out_batch_size - ) - if self._frames >= self.total_frames: - self.env.close() # Bring all batch dimensions to the front (only performs computation if it is not already the case) tensordict_out = tensordict_out.permute(self.permute_out_batch_size) @@ -582,6 +576,10 @@ def iterator(self) -> Iterator[TensorDictBase]: else: tensordict_out = tensordict_out.view(-1).to_tensordict() + self._frames += tensordict_out.batch_size[0] + if self._frames >= self.total_frames: + self.env.close() + if self.split_trajs: tensordict_out = split_trajectories(tensordict_out) if self.postproc is not None: From 14f8e69eaa032dff130abed64c41b09ec25be478 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 20 Jan 2023 11:29:48 +0000 Subject: [PATCH 10/50] refector --- torchrl/collectors/collectors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 58b919bec12..df55b6f9d2f 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -419,7 +419,7 @@ def __init__( self.permute_out_batch_size = [ i for i, is_batch in enumerate(self.mask_out_batch_size) if is_batch ] + [i for i, is_batch in enumerate(self.mask_out_batch_size) if not is_batch] - self.env_batch_size_umasked = [ + self.env_batch_size_unmasked = [ env.batch_size[i] for i, is_batch in enumerate(self.mask_env_batch_size) if not is_batch @@ -569,9 +569,9 @@ def iterator(self) -> Iterator[TensorDictBase]: # Bring all batch dimensions to the front (only performs computation if it is not already the case) tensordict_out = tensordict_out.permute(self.permute_out_batch_size) # Flatten all batch dimensions into first one and leave unmasked dimensions untouched - if len(self.env_batch_size_umasked) > 0: + if len(self.env_batch_size_unmasked) > 0: tensordict_out = tensordict_out.reshape( - self.frames_per_batch * self.n_env, *self.env_batch_size_umasked + self.frames_per_batch * self.n_env, *self.env_batch_size_unmasked ) else: tensordict_out = tensordict_out.view(-1).to_tensordict() From 571851b913cc82eb327ea88e4bf2d643932c034c Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 20 Jan 2023 11:32:40 +0000 Subject: [PATCH 11/50] removed numel with mask --- torchrl/collectors/utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 34ea4a43524..e840e789010 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -69,12 +69,6 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: td = td.unflatten_keys(sep) return td - -def numel_with_mask(batch_size: torch.Size, mask: Optional[Sequence[bool]] = None): - """Performs numel() with a given mask.""" - return max(1, get_batch_size_masked(batch_size, mask).numel()) - - def get_batch_size_masked( batch_size: torch.Size, mask: Optional[Sequence[bool]] = None ): From e22d1ed6aff06d41a00a9d9bede7eb55ba872bc7 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 20 Jan 2023 15:13:19 +0000 Subject: [PATCH 12/50] fixed traj_ids --- torchrl/collectors/collectors.py | 34 ++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index df55b6f9d2f..907b8ca3aa1 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -23,10 +23,7 @@ from torch import multiprocessing as mp from torch.utils.data import IterableDataset from torchrl._utils import _check_for_faulty_process, prod -from torchrl.collectors.utils import ( - get_batch_size_masked, - split_trajectories, -) +from torchrl.collectors.utils import get_batch_size_masked, split_trajectories from torchrl.data import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import EnvBase @@ -416,9 +413,13 @@ def __init__( mask_env_batch_size = list(mask_env_batch_size) self.mask_env_batch_size = mask_env_batch_size self.mask_out_batch_size = mask_env_batch_size + [True] + + self.env_batch_size_unmasked_indeces = [ + i for i, is_batch in enumerate(self.mask_out_batch_size) if not is_batch + ] self.permute_out_batch_size = [ i for i, is_batch in enumerate(self.mask_out_batch_size) if is_batch - ] + [i for i, is_batch in enumerate(self.mask_out_batch_size) if not is_batch] + ] + self.env_batch_size_unmasked_indeces self.env_batch_size_unmasked = [ env.batch_size[i] for i, is_batch in enumerate(self.mask_env_batch_size) @@ -429,6 +430,16 @@ def __init__( ) self.n_env = max(1, self.env_batch_size_masked.numel()) + self.mask_tensor = torch.ones( + *self.env.batch_size, + dtype=torch.bool, + device=self.env.device, + ) + for dim in self.env_batch_size_unmasked_indeces: + self.mask_tensor.index_fill_( + dim, torch.arange(1, self.env.batch_size[dim]), 0 + ) + (self.policy, self.device, self.get_weights_fn,) = self._get_policy_and_device( policy=policy, device=device, @@ -668,10 +679,21 @@ def _reset_if_necessary(self) -> None: raise RuntimeError( f"Env {self.env} was done after reset on specified '_reset' dimensions. This is (currently) not allowed." ) + steps[done_or_terminated] = 0 + + traj_ids = traj_ids[self.mask_tensor] + done_or_terminated = done_or_terminated[self.mask_tensor] + traj_ids[done_or_terminated] = traj_ids.max() + torch.arange( 1, done_or_terminated.sum() + 1, device=traj_ids.device ) - steps[done_or_terminated] = 0 + + traj_ids = ( + traj_ids.view(self.env_batch_size_masked) + .expand(self.env.batch_size) + .clone() + ) + self._tensordict.set_("traj_ids", traj_ids) # no ops if they already match self._tensordict.set_("step_count", steps) From 5fa83ffbd6fd2160f27bc48547f7d89f84761aa5 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 20 Jan 2023 15:24:37 +0000 Subject: [PATCH 13/50] amend --- torchrl/collectors/collectors.py | 33 ++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 907b8ca3aa1..31a73744ee6 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -9,6 +9,7 @@ import os import queue import time +import warnings from collections import OrderedDict from copy import deepcopy from multiprocessing import connection, queues @@ -430,15 +431,16 @@ def __init__( ) self.n_env = max(1, self.env_batch_size_masked.numel()) - self.mask_tensor = torch.ones( - *self.env.batch_size, - dtype=torch.bool, - device=self.env.device, - ) - for dim in self.env_batch_size_unmasked_indeces: - self.mask_tensor.index_fill_( - dim, torch.arange(1, self.env.batch_size[dim]), 0 + if len(env.batch_size): + self.mask_tensor = torch.ones( + *self.env.batch_size, + dtype=torch.bool, + device=self.env.device, ) + for dim in self.env_batch_size_unmasked_indeces: + self.mask_tensor.index_fill_( + dim, torch.arange(1, self.env.batch_size[dim]), 0 + ) (self.policy, self.device, self.get_weights_fn,) = self._get_policy_and_device( policy=policy, @@ -457,8 +459,9 @@ def __init__( self.postproc.to(self.passing_device) self.max_frames_per_traj = max_frames_per_traj if frames_per_batch % self.n_env != 0: - raise RuntimeError( - f"frames_per_batch {frames_per_batch} is not exactly divisible by the number of batched environments {self.n_env}, this is currently not allowed" + warnings.warn( + f"frames_per_batch {frames_per_batch} is not exactly divisible by the number of batched environments {self.n_env}, " + f" this results in more frames_per_batch per iteration that requeste" ) self.frames_per_batch = -(-frames_per_batch // self.n_env) self.pin_memory = pin_memory @@ -681,8 +684,9 @@ def _reset_if_necessary(self) -> None: ) steps[done_or_terminated] = 0 - traj_ids = traj_ids[self.mask_tensor] - done_or_terminated = done_or_terminated[self.mask_tensor] + if len(self.env.batch_size): + traj_ids = traj_ids[self.mask_tensor] + done_or_terminated = done_or_terminated[self.mask_tensor] traj_ids[done_or_terminated] = traj_ids.max() + torch.arange( 1, done_or_terminated.sum() + 1, device=traj_ids.device @@ -1304,8 +1308,9 @@ class MultiSyncDataCollector(_MultiDataCollector): @property def frames_per_batch_worker(self): if self.frames_per_batch % self.num_workers != 0: - raise RuntimeError( - f"frames_per_batch {self.frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}, this is currently not allowed" + warnings.warn( + f"frames_per_batch {self.frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}," + f" this results in more frames_per_batch per iteration that requested" ) return -(-self.frames_per_batch // self.num_workers) From 9e22fe1c4c1dd72c09536ffd7d409b1806f2d214 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 20 Jan 2023 15:28:28 +0000 Subject: [PATCH 14/50] amend --- torchrl/collectors/collectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 31a73744ee6..65d0218fc4e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -416,7 +416,7 @@ def __init__( self.mask_out_batch_size = mask_env_batch_size + [True] self.env_batch_size_unmasked_indeces = [ - i for i, is_batch in enumerate(self.mask_out_batch_size) if not is_batch + i for i, is_batch in enumerate(self.mask_env_batch_size) if not is_batch ] self.permute_out_batch_size = [ i for i, is_batch in enumerate(self.mask_out_batch_size) if is_batch From 8f60534f825193c7b4294289ec462c63e396fe5d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 10:21:20 +0000 Subject: [PATCH 15/50] fix split traj --- torchrl/collectors/collectors.py | 3 +-- torchrl/collectors/utils.py | 43 ++++++++++++++++++++++++-------- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 65d0218fc4e..13775c96bf0 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1381,12 +1381,11 @@ def iterator(self) -> Iterator[TensorDictBase]: out=out_buffer, ) + frames += out_buffer.batch_size[0] if self.split_trajs: out = split_trajectories(out_buffer) - frames += out.get("mask").sum().item() else: out = out_buffer.clone() - frames += out.batch_size[0] if self.postprocs: self.postprocs = self.postprocs.to(out.device) out = self.postprocs(out) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index e840e789010..e6e99408e45 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -29,17 +29,27 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: """A util function for trajectory separation. Takes a tensordict with a key traj_ids that indicates the id of each trajectory. + The input tensordict has batch_size = B x *other_dims - From there, builds a B x T x ... zero-padded tensordict with B batches on max duration T + From there, builds a B/T x *other_dims x T x ... zero-padded tensordict with B batches on max duration T """ # TODO: incorporate tensordict.split once it's implemented + env_batch_size_unmasked = rollout_tensordict.batch_size[1:] + mask = torch.ones_like( + rollout_tensordict.get("traj_ids"), + device=rollout_tensordict.device, + dtype=torch.bool, + ) + for dim in range(1, len(rollout_tensordict.batch_size)): + mask.index_fill_(dim, torch.arange(1, rollout_tensordict.batch_size[dim]), 0) + sep = ".-|-." rollout_tensordict = rollout_tensordict.flatten_keys(sep) - traj_ids = rollout_tensordict.get("traj_ids") + traj_ids = rollout_tensordict.get("traj_ids")[mask] splits = traj_ids.view(-1) splits = [(splits == i).sum().item() for i in splits.unique_consecutive()] # if all splits are identical then we can skip this function - if len(set(splits)) == 1 and splits[0] == traj_ids.shape[-1]: + if len(set(splits)) == 1: rollout_tensordict.set( "mask", torch.ones( @@ -48,12 +58,20 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: dtype=torch.bool, ), ) - if rollout_tensordict.ndimension() == 1: - rollout_tensordict = rollout_tensordict.unsqueeze(0).to_tensordict() + rollout_tensordict = rollout_tensordict.view( + -1, *env_batch_size_unmasked, splits[0] + ).to_tensordict() return rollout_tensordict.unflatten_keys(sep) - out_splits = rollout_tensordict.view(-1).split(splits, 0) + out_splits = rollout_tensordict.view(-1, *env_batch_size_unmasked).split(splits, 0) - for out_split in out_splits: + for i in range(len(out_splits)): + assert ( + out_splits[i]["traj_ids"] + == rollout_tensordict.get("traj_ids")[mask].unique_consecutive()[i] + ).all() + + MAX = max(*[out_split.shape[0] for out_split in out_splits]) + for i, out_split in enumerate(out_splits): out_split.set( "mask", torch.ones( @@ -62,13 +80,16 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: device=out_split.get("done").device, ), ) - MAX = max(*[out_split.shape[0] for out_split in out_splits]) - td = torch.stack( - [pad(out_split, [0, MAX - out_split.shape[0]]) for out_split in out_splits], 0 - ).contiguous() + out_splits[i] = pad(out_split, [0, MAX - out_split.shape[0]]) + out_splits[i] = out_splits[i].permute( + -1, *range(len(out_splits[i].batch_size) - 1) + ) + + td = torch.stack(out_splits, 0).contiguous() td = td.unflatten_keys(sep) return td + def get_batch_size_masked( batch_size: torch.Size, mask: Optional[Sequence[bool]] = None ): From eac0e26f7aebeac1fa933e5a556eae2e81846d28 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 10:28:11 +0000 Subject: [PATCH 16/50] doc --- torchrl/collectors/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index e6e99408e45..d9ece0a2c10 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -31,7 +31,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: Takes a tensordict with a key traj_ids that indicates the id of each trajectory. The input tensordict has batch_size = B x *other_dims - From there, builds a B/T x *other_dims x T x ... zero-padded tensordict with B batches on max duration T + From there, builds a B / T x *other_dims x T x ... zero-padded tensordict with B / T batches on max duration T """ # TODO: incorporate tensordict.split once it's implemented env_batch_size_unmasked = rollout_tensordict.batch_size[1:] From 43e093d87da5ba210d7c1ed8ddf916deb2f2e55d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 10:29:00 +0000 Subject: [PATCH 17/50] refactor --- torchrl/collectors/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index d9ece0a2c10..f3619eb3f46 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -35,8 +35,8 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: """ # TODO: incorporate tensordict.split once it's implemented env_batch_size_unmasked = rollout_tensordict.batch_size[1:] - mask = torch.ones_like( - rollout_tensordict.get("traj_ids"), + mask = torch.ones( + rollout_tensordict.batch_size, device=rollout_tensordict.device, dtype=torch.bool, ) From df2e660189dcb5e95761f9f9b035a540882b9d6a Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 10:49:08 +0000 Subject: [PATCH 18/50] refactor --- torchrl/collectors/collectors.py | 40 ++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 13775c96bf0..1538f8ba59a 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -23,7 +23,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import multiprocessing as mp from torch.utils.data import IterableDataset -from torchrl._utils import _check_for_faulty_process, prod +from torchrl._utils import _check_for_faulty_process from torchrl.collectors.utils import get_batch_size_masked, split_trajectories from torchrl.data import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING @@ -415,29 +415,36 @@ def __init__( self.mask_env_batch_size = mask_env_batch_size self.mask_out_batch_size = mask_env_batch_size + [True] - self.env_batch_size_unmasked_indeces = [ + # Indices of env.batch_size dims not in the batch + env_batch_size_unmasked_indeces = [ i for i, is_batch in enumerate(self.mask_env_batch_size) if not is_batch ] - self.permute_out_batch_size = [ - i for i, is_batch in enumerate(self.mask_out_batch_size) if is_batch - ] + self.env_batch_size_unmasked_indeces + # env.batch_size dims not in the batch self.env_batch_size_unmasked = [ env.batch_size[i] for i, is_batch in enumerate(self.mask_env_batch_size) if not is_batch ] + # Permutation indices: fist all batch dims and then all masked out dims + self.permute_out_batch_size = [ + i for i, is_batch in enumerate(self.mask_out_batch_size) if is_batch + ] + env_batch_size_unmasked_indeces + # env.batch_size with masked dimensions set to 1. + # Also returns error in case the input mask is malformed self.env_batch_size_masked = get_batch_size_masked( self.env.batch_size, self.mask_env_batch_size ) + # Number of batched environments used for collection self.n_env = max(1, self.env_batch_size_masked.numel()) - if len(env.batch_size): + if len(self.env_batch_size_unmasked): + # Mask used to only consider batch dimensions in trajectories self.mask_tensor = torch.ones( *self.env.batch_size, dtype=torch.bool, device=self.env.device, ) - for dim in self.env_batch_size_unmasked_indeces: + for dim in env_batch_size_unmasked_indeces: self.mask_tensor.index_fill_( dim, torch.arange(1, self.env.batch_size[dim]), 0 ) @@ -684,7 +691,7 @@ def _reset_if_necessary(self) -> None: ) steps[done_or_terminated] = 0 - if len(self.env.batch_size): + if len(self.env_batch_size_unmasked): traj_ids = traj_ids[self.mask_tensor] done_or_terminated = done_or_terminated[self.mask_tensor] @@ -692,11 +699,12 @@ def _reset_if_necessary(self) -> None: 1, done_or_terminated.sum() + 1, device=traj_ids.device ) - traj_ids = ( - traj_ids.view(self.env_batch_size_masked) - .expand(self.env.batch_size) - .clone() - ) + if len(self.env_batch_size_unmasked): + traj_ids = ( + traj_ids.view(self.env_batch_size_masked) + .expand(self.env.batch_size) + .clone() + ) self._tensordict.set_("traj_ids", traj_ids) # no ops if they already match self._tensordict.set_("step_count", steps) @@ -753,7 +761,7 @@ def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state.""" if index is not None: # check that the env supports partial reset - if prod(self.env.batch_size) == 0: + if self.n_env == 0: raise RuntimeError("resetting unique env with index is not permitted.") _reset = torch.zeros( self.env.batch_size, @@ -771,7 +779,9 @@ def reset(self, index=None, **kwargs) -> None: if td_in: self._tensordict.update(td_in, inplace=True) - self._tensordict.update(self.env.reset(**kwargs), inplace=True) + self._tensordict.update( + self.env.reset(tensordict=td_in, **kwargs), inplace=True + ) if _reset is not None: self._tensordict["step_count"][_reset] = 0 else: From b13778e12e88e1b7a13fdce83f3e890a13c656a7 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 15:01:46 +0000 Subject: [PATCH 19/50] refactor --- torchrl/collectors/collectors.py | 39 ++++++++++++++++++++------------ torchrl/collectors/utils.py | 18 ++++++++++++--- 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index db2cb808883..6cac6560ae4 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -24,7 +24,11 @@ from torch import multiprocessing as mp from torch.utils.data import IterableDataset from torchrl._utils import _check_for_faulty_process -from torchrl.collectors.utils import get_batch_size_masked, split_trajectories +from torchrl.collectors.utils import ( + bring_forward_and_squash_batch_sizes, + get_batch_size_masked, + split_trajectories, +) from torchrl.data import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import EnvBase @@ -429,6 +433,9 @@ def __init__( self.permute_out_batch_size = [ i for i, is_batch in enumerate(self.mask_out_batch_size) if is_batch ] + env_batch_size_unmasked_indeces + self.permute_env_batch_size = [ + i for i, is_batch in enumerate(self.mask_env_batch_size) if is_batch + ] + env_batch_size_unmasked_indeces # env.batch_size with masked dimensions set to 1. # Also returns error in case the input mask is malformed self.env_batch_size_masked = get_batch_size_masked( @@ -533,6 +540,11 @@ def __init__( device=self.env_device, ), ) + self._tensordict_out = bring_forward_and_squash_batch_sizes( + self._tensordict_out, + self.permute_out_batch_size, + self.env_batch_size_unmasked, + ) if split_trajs is None: if not self.reset_when_done: @@ -585,20 +597,9 @@ def iterator(self) -> Iterator[TensorDictBase]: self._iter = i tensordict_out = self.rollout() - # Bring all batch dimensions to the front (only performs computation if it is not already the case) - tensordict_out = tensordict_out.permute(self.permute_out_batch_size) - # Flatten all batch dimensions into first one and leave unmasked dimensions untouched - if len(self.env_batch_size_unmasked) > 0: - tensordict_out = tensordict_out.reshape( - self.frames_per_batch * self.n_env, *self.env_batch_size_unmasked - ) - else: - tensordict_out = tensordict_out.view(-1).to_tensordict() - self._frames += tensordict_out.batch_size[0] if self._frames >= self.total_frames: self.env.close() - if self.split_trajs: tensordict_out = split_trajectories(tensordict_out) if self.postproc is not None: @@ -739,15 +740,25 @@ def rollout(self) -> TensorDictBase: step_count = self._tensordict.get("step_count") self._tensordict.set_("step_count", step_count + 1) + + tensordict = bring_forward_and_squash_batch_sizes( + self._tensordict, + self.permute_env_batch_size, + self.env_batch_size_unmasked, + ) # we must clone all the values, since the step / traj_id updates are done in-place try: - self._tensordict_out[..., j] = self._tensordict + self._tensordict_out[ + j * self.n_env : (j + 1) * self.n_env + ] = tensordict except RuntimeError: # unlock the output tensordict to allow for new keys to be written # these will be missed during the sync but at least we won't get an error during the update is_shared = self._tensordict_out.is_shared() self._tensordict_out.unlock() - self._tensordict_out[..., j] = self._tensordict + self._tensordict_out[ + j * self.n_env : (j + 1) * self.n_env + ] = tensordict if is_shared: self._tensordict_out.share_memory_() self._reset_if_necessary() diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index f3619eb3f46..9991f916bef 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -58,9 +58,9 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: dtype=torch.bool, ), ) - rollout_tensordict = rollout_tensordict.view( + rollout_tensordict = rollout_tensordict.reshape( -1, *env_batch_size_unmasked, splits[0] - ).to_tensordict() + ) return rollout_tensordict.unflatten_keys(sep) out_splits = rollout_tensordict.view(-1, *env_batch_size_unmasked).split(splits, 0) @@ -92,7 +92,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: def get_batch_size_masked( batch_size: torch.Size, mask: Optional[Sequence[bool]] = None -): +) -> torch.Size: """Returns a size with the masked dimensions equal to 1.""" if mask is None: return batch_size @@ -109,3 +109,15 @@ def get_batch_size_masked( ) ] ) + + +def bring_forward_and_squash_batch_sizes( + tensordict: TensorDictBase, + permute: Sequence[int], + batch_size_unmasked: Sequence[int], +) -> TensorDictBase: + # Bring all batch dimensions to the front (only performs computation if it is not already the case) + tensordict = tensordict.permute(permute) + # Flatten all batch dimensions into first one and leave unmasked dimensions untouched + tensordict = tensordict.reshape(-1, *batch_size_unmasked) + return tensordict From dec7895869a0b0e33a68f8f1ac73872f400242be Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 15:03:08 +0000 Subject: [PATCH 20/50] docs --- torchrl/collectors/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 9991f916bef..3cff7059d3c 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -116,6 +116,7 @@ def bring_forward_and_squash_batch_sizes( permute: Sequence[int], batch_size_unmasked: Sequence[int], ) -> TensorDictBase: + """Permutes the batch dimesnions attording to the permute indeces and then squashes all leadning dimesnions apart from batch_size_unmasked.""" # Bring all batch dimensions to the front (only performs computation if it is not already the case) tensordict = tensordict.permute(permute) # Flatten all batch dimensions into first one and leave unmasked dimensions untouched From 7c46c023de2cbf7fe42bc074e12f7d525199bd15 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 15:18:49 +0000 Subject: [PATCH 21/50] fix test --- test/test_collector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 8d51fa2f703..328205e7266 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -373,9 +373,8 @@ def make_env(seed): ) for _, d in enumerate(collector): # noqa break - - assert (d["done"].sum(-2) >= 1).all() - assert torch.unique(d["traj_ids"], dim=-1).shape[-1] == 1 + assert (d["done"].sum() >= 1).all() + assert torch.unique(d["traj_ids"]).shape[0] == num_env del collector From 06ebee8f30cedff197f5aca70a0da6b3c4b5adca Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 16:11:34 +0000 Subject: [PATCH 22/50] new split_traj more efficient --- torchrl/collectors/utils.py | 47 ++++++++++++------------------------- 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 3cff7059d3c..fca3b277354 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -34,7 +34,6 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: From there, builds a B / T x *other_dims x T x ... zero-padded tensordict with B / T batches on max duration T """ # TODO: incorporate tensordict.split once it's implemented - env_batch_size_unmasked = rollout_tensordict.batch_size[1:] mask = torch.ones( rollout_tensordict.batch_size, device=rollout_tensordict.device, @@ -45,33 +44,18 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: sep = ".-|-." rollout_tensordict = rollout_tensordict.flatten_keys(sep) - traj_ids = rollout_tensordict.get("traj_ids")[mask] - splits = traj_ids.view(-1) - splits = [(splits == i).sum().item() for i in splits.unique_consecutive()] - # if all splits are identical then we can skip this function - if len(set(splits)) == 1: - rollout_tensordict.set( - "mask", - torch.ones( - rollout_tensordict.shape, - device=rollout_tensordict.device, - dtype=torch.bool, - ), - ) - rollout_tensordict = rollout_tensordict.reshape( - -1, *env_batch_size_unmasked, splits[0] - ) - return rollout_tensordict.unflatten_keys(sep) - out_splits = rollout_tensordict.view(-1, *env_batch_size_unmasked).split(splits, 0) - - for i in range(len(out_splits)): - assert ( - out_splits[i]["traj_ids"] - == rollout_tensordict.get("traj_ids")[mask].unique_consecutive()[i] - ).all() - - MAX = max(*[out_split.shape[0] for out_split in out_splits]) - for i, out_split in enumerate(out_splits): + traj_ids = rollout_tensordict.get("traj_ids")[mask].view(-1) + + traj_masks = [] + MAX = 0 + for i in traj_ids.unique(): + traj_mask = traj_ids == i + MAX = max(MAX, traj_mask.count_nonzero(0)) + traj_masks.append(traj_mask) + + out_splits = [] + for traj_mask in traj_masks: + out_split = rollout_tensordict[traj_mask] out_split.set( "mask", torch.ones( @@ -80,10 +64,9 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: device=out_split.get("done").device, ), ) - out_splits[i] = pad(out_split, [0, MAX - out_split.shape[0]]) - out_splits[i] = out_splits[i].permute( - -1, *range(len(out_splits[i].batch_size) - 1) - ) + out_split = pad(out_split, [0, MAX - out_split.shape[0]]) + out_split = out_split.permute(-1, *range(len(out_split.batch_size) - 1)) + out_splits.append(out_split) td = torch.stack(out_splits, 0).contiguous() td = td.unflatten_keys(sep) From 2785b55e43c297fcd542e923aa533aab2e8bf843 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 16:13:49 +0000 Subject: [PATCH 23/50] lint --- torchrl/collectors/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index fca3b277354..57a9701517b 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -50,7 +50,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: MAX = 0 for i in traj_ids.unique(): traj_mask = traj_ids == i - MAX = max(MAX, traj_mask.count_nonzero(0)) + MAX = max(MAX, traj_mask.count_nonzero()) traj_masks.append(traj_mask) out_splits = [] From 20331b999fc962601ae7bc081011a7197559e16b Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 16:27:24 +0000 Subject: [PATCH 24/50] fix tests --- test/test_collector.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 328205e7266..a7cb2bcc9c0 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -9,6 +9,10 @@ import numpy as np import pytest import torch +from tensordict.nn import TensorDictModule +from tensordict.tensordict import assert_allclose_td, TensorDict +from torch import nn + from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( ContinuousActionVecMockEnv, @@ -18,9 +22,6 @@ DiscreteActionVecPolicy, MockSerialEnv, ) -from tensordict.nn import TensorDictModule -from tensordict.tensordict import assert_allclose_td, TensorDict -from torch import nn from torchrl._utils import seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( @@ -316,16 +317,10 @@ def make_env(): ) for _data in collector: continue - steps = _data["step_count"][..., 1:] - done = _data["done"][..., :-1, :].squeeze(-1) + steps = _data["step_count"] + done = _data["done"].squeeze(-1) # we don't want just one done assert done.sum() > 3 - # check that after a done, the next step count is always 1 - assert (steps[done] == 1).all() - # check that if the env is not done, the next step count is > 1 - assert (steps[~done] > 1).all() - # check that if step is 1, then the env was done before - assert (steps == 1)[done].all() # check that split traj has a minimum total reward of -21 (for pong only) _data = split_trajectories(_data) assert _data["reward"].sum(-2).min() == -21 From 455c606197bd31f9c5faca5802e53d1b4d2d520b Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 16:40:04 +0000 Subject: [PATCH 25/50] fix tests --- test/test_collector.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/test_collector.py b/test/test_collector.py index a7cb2bcc9c0..94921d6061d 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -319,8 +319,16 @@ def make_env(): continue steps = _data["step_count"] done = _data["done"].squeeze(-1) + traj_ids = _data["traj_ids"] # we don't want just one done assert done.sum() > 3 + for i in traj_ids.unique(): + # check that after a done, the next step count is always 1 + assert (steps[traj_ids == i][0] == 1).all() + # check that step counts are positive for not first elements of traj + assert (steps[traj_ids == i][1:] > 1).all() + # check that non-last elements of trajectories are not done + assert (done[traj_ids == i][:-1] == 0).all() # check that split traj has a minimum total reward of -21 (for pong only) _data = split_trajectories(_data) assert _data["reward"].sum(-2).min() == -21 From 4b0310259e1235c0d4bf6e29b5a385417435bbcc Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 16:42:09 +0000 Subject: [PATCH 26/50] Lint --- test/test_collector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 94921d6061d..00bb8d31019 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -9,9 +9,6 @@ import numpy as np import pytest import torch -from tensordict.nn import TensorDictModule -from tensordict.tensordict import assert_allclose_td, TensorDict -from torch import nn from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( @@ -22,6 +19,9 @@ DiscreteActionVecPolicy, MockSerialEnv, ) +from tensordict.nn import TensorDictModule +from tensordict.tensordict import assert_allclose_td, TensorDict +from torch import nn from torchrl._utils import seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( From c21654f3613f3652d5f147ae372ed60bcdc8b81d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 17:18:37 +0000 Subject: [PATCH 27/50] refactor 4 efficiency --- torchrl/collectors/utils.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 57a9701517b..bed1b3074df 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -45,17 +45,12 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: sep = ".-|-." rollout_tensordict = rollout_tensordict.flatten_keys(sep) traj_ids = rollout_tensordict.get("traj_ids")[mask].view(-1) - - traj_masks = [] - MAX = 0 - for i in traj_ids.unique(): - traj_mask = traj_ids == i - MAX = max(MAX, traj_mask.count_nonzero()) - traj_masks.append(traj_mask) + unique_traj_ids = traj_ids.unique() + MAX = max([(traj_ids == i).count_nonzero() for i in unique_traj_ids]) out_splits = [] - for traj_mask in traj_masks: - out_split = rollout_tensordict[traj_mask] + for i in unique_traj_ids: + out_split = rollout_tensordict[traj_ids == i] out_split.set( "mask", torch.ones( From 5e6afa130bda8a6b783fda91b7821e96559dbf18 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 17:24:21 +0000 Subject: [PATCH 28/50] no sorting --- test/test_collector.py | 8 ++++---- torchrl/collectors/utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 00bb8d31019..56419ffd1ab 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -9,6 +9,9 @@ import numpy as np import pytest import torch +from tensordict.nn import TensorDictModule +from tensordict.tensordict import assert_allclose_td, TensorDict +from torch import nn from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( @@ -19,9 +22,6 @@ DiscreteActionVecPolicy, MockSerialEnv, ) -from tensordict.nn import TensorDictModule -from tensordict.tensordict import assert_allclose_td, TensorDict -from torch import nn from torchrl._utils import seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( @@ -322,7 +322,7 @@ def make_env(): traj_ids = _data["traj_ids"] # we don't want just one done assert done.sum() > 3 - for i in traj_ids.unique(): + for i in traj_ids.unique(sorted=False): # check that after a done, the next step count is always 1 assert (steps[traj_ids == i][0] == 1).all() # check that step counts are positive for not first elements of traj diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index bed1b3074df..546de99d1ee 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -45,7 +45,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: sep = ".-|-." rollout_tensordict = rollout_tensordict.flatten_keys(sep) traj_ids = rollout_tensordict.get("traj_ids")[mask].view(-1) - unique_traj_ids = traj_ids.unique() + unique_traj_ids = traj_ids.unique(sorted=False) MAX = max([(traj_ids == i).count_nonzero() for i in unique_traj_ids]) out_splits = [] From 8fe0b9c1f7321f6e7fd6b7a83ff943bad3fec980 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 19:23:57 +0000 Subject: [PATCH 29/50] typo --- torchrl/collectors/collectors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 6cac6560ae4..bd5bd034990 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -393,7 +393,6 @@ def __init__( f"on environment of type {type(create_env_fn)}." ) env.update_kwargs(create_env_kwargs) - if passing_device is None: if device is not None: passing_device = device @@ -438,6 +437,7 @@ def __init__( ] + env_batch_size_unmasked_indeces # env.batch_size with masked dimensions set to 1. # Also returns error in case the input mask is malformed + self.env_batch_size_masked = get_batch_size_masked( self.env.batch_size, self.mask_env_batch_size ) @@ -475,7 +475,7 @@ def __init__( if frames_per_batch % self.n_env != 0: warnings.warn( f"frames_per_batch {frames_per_batch} is not exactly divisible by the number of batched environments {self.n_env}, " - f" this results in more frames_per_batch per iteration that requeste" + f" this results in more frames_per_batch per iteration that requested" ) self.frames_per_batch = -(-frames_per_batch // self.n_env) self.pin_memory = pin_memory From d4c6ba967f0d208dce1f04b4717dcc93e04d4181 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 19:54:09 +0000 Subject: [PATCH 30/50] tests --- test/mocking_classes.py | 4 +- test/test_collector.py | 149 +++++++++++++++++++++++++++++-- torchrl/collectors/collectors.py | 10 +-- 3 files changed, 151 insertions(+), 12 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index ae0512205b4..5050fc0c46a 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -883,7 +883,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self.count[:] = 0 return TensorDict( source={ - "observation": self.count.clone(), + "observation": self.count.float().clone(), "done": self.count > self.max_steps, }, batch_size=self.batch_size, @@ -898,7 +898,7 @@ def _step( self.count += action.to(torch.int) return TensorDict( source={ - "observation": self.count, + "observation": self.count.float(), "done": self.count > self.max_steps, "reward": torch.zeros_like(self.count, dtype=torch.float), }, diff --git a/test/test_collector.py b/test/test_collector.py index 56419ffd1ab..7eac9d8577e 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -9,20 +9,20 @@ import numpy as np import pytest import torch -from tensordict.nn import TensorDictModule -from tensordict.tensordict import assert_allclose_td, TensorDict -from torch import nn - from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( ContinuousActionVecMockEnv, + CountingEnv, DiscreteActionConvMockEnv, DiscreteActionConvPolicy, DiscreteActionVecMockEnv, DiscreteActionVecPolicy, MockSerialEnv, ) -from torchrl._utils import seed_generator +from tensordict.nn import TensorDictModule +from tensordict.tensordict import assert_allclose_td, TensorDict +from torch import nn +from torchrl._utils import prod, seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( MultiaSyncDataCollector, @@ -533,6 +533,145 @@ def env_fn(): ccollector.shutdown() +@pytest.mark.parametrize("n_env_workers", [1, 3]) +@pytest.mark.parametrize("batch_size", [(), (2, 4)]) +@pytest.mark.parametrize("mask_env_batch_size", [None, (True, False, True)]) +def test_collector_batch_size_with_env_batch_size( + n_env_workers, + batch_size, + mask_env_batch_size, + max_steps=5, + n_collector_workers=4, + seed=100, +): + if n_env_workers == 3 and _os_is_windows: + pytest.skip("Test timeout (> 10 min) on CI pipeline Windows machine with GPU") + if n_env_workers == 1: + env = lambda: CountingEnv(max_steps=max_steps, batch_size=batch_size) + if mask_env_batch_size is not None: + mask_env_batch_size = mask_env_batch_size[1:] + else: + env = lambda: ParallelEnv( + num_workers=n_env_workers, + create_env_fn=lambda: CountingEnv( + max_steps=max_steps, batch_size=batch_size + ), + ) + new_batch_size = env().batch_size + policy = TensorDictModule( + nn.Linear(1, 1), in_keys=["observation"], out_keys=["action"] + ) + torch.manual_seed(0) + np.random.seed(0) + + env_unmasked_dims = [ + dim + for i, dim in enumerate(new_batch_size) + if mask_env_batch_size is not None and not mask_env_batch_size[i] + ] + n_batch_envs = max( + 1, + prod( + [ + dim + for i, dim in enumerate(new_batch_size) + if mask_env_batch_size is None or mask_env_batch_size[i] + ] + ), + ) + frames_per_batch = n_collector_workers * n_batch_envs * n_env_workers * 5 + + if mask_env_batch_size is not None and len(mask_env_batch_size) != len( + new_batch_size + ): + with pytest.raises( + RuntimeError, + match=( + f"Batch size mask and env batch size have different" + f" lengths: mask={mask_env_batch_size}, env.batch_size={new_batch_size}" + ), + ): + ccollector = MultiaSyncDataCollector( + create_env_fn=[env for _ in range(n_collector_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + mask_env_batch_size=mask_env_batch_size, + pin_memory=False, + split_trajs=False, + ) + return + + # Multi async no split traj + ccollector = MultiaSyncDataCollector( + create_env_fn=[env for _ in range(n_collector_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + mask_env_batch_size=mask_env_batch_size, + pin_memory=False, + split_trajs=False, + ) + ccollector.set_seed(seed) + for i, b in enumerate(ccollector): + assert b.batch_size == torch.Size([frames_per_batch, *env_unmasked_dims]) + if i == 1: + break + ccollector.shutdown() + + # Multi async split traj + ccollector = MultiaSyncDataCollector( + create_env_fn=[env for _ in range(n_collector_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + max_frames_per_traj=max_steps, + mask_env_batch_size=mask_env_batch_size, + pin_memory=False, + split_trajs=True, + ) + ccollector.set_seed(seed) + for i, b in enumerate(ccollector): + assert b.batch_size[1:] == torch.Size([*env_unmasked_dims, max_steps]) + if i == 1: + break + ccollector.shutdown() + + # Multi sync no split traj + ccollector = MultiSyncDataCollector( + create_env_fn=[env for _ in range(n_collector_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + mask_env_batch_size=mask_env_batch_size, + pin_memory=False, + split_trajs=False, + ) + ccollector.set_seed(seed) + for i, b in enumerate(ccollector): + assert b.batch_size == torch.Size([frames_per_batch, *env_unmasked_dims]) + if i == 1: + break + ccollector.shutdown() + + # Multi sync split traj + ccollector = MultiSyncDataCollector( + create_env_fn=[env for _ in range(n_collector_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + max_frames_per_traj=max_steps, + mask_env_batch_size=mask_env_batch_size, + pin_memory=False, + split_trajs=True, + ) + ccollector.set_seed(seed) + for i, b in enumerate(ccollector): + assert b.batch_size[1:] == torch.Size([*env_unmasked_dims, max_steps]) + if i == 1: + break + ccollector.shutdown() + + +def test_collector_batch_size_advanced(): + pass + + @pytest.mark.parametrize("num_env", [1, 3]) @pytest.mark.parametrize("env_name", ["vec", "conv"]) def test_concurrent_collector_seed(num_env, env_name, seed=100): diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index bd5bd034990..eedb52db354 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -417,6 +417,11 @@ def __init__( mask_env_batch_size = list(mask_env_batch_size) self.mask_env_batch_size = mask_env_batch_size self.mask_out_batch_size = mask_env_batch_size + [True] + # env.batch_size with masked dimensions set to 1. + # Also returns error in case the input mask is malformed + self.env_batch_size_masked = get_batch_size_masked( + self.env.batch_size, self.mask_env_batch_size + ) # Indices of env.batch_size dims not in the batch env_batch_size_unmasked_indeces = [ @@ -435,12 +440,7 @@ def __init__( self.permute_env_batch_size = [ i for i, is_batch in enumerate(self.mask_env_batch_size) if is_batch ] + env_batch_size_unmasked_indeces - # env.batch_size with masked dimensions set to 1. - # Also returns error in case the input mask is malformed - self.env_batch_size_masked = get_batch_size_masked( - self.env.batch_size, self.mask_env_batch_size - ) # Number of batched environments used for collection self.n_env = max(1, self.env_batch_size_masked.numel()) From 459545077e449c2d00510a151074498a43cbea32 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 20:57:46 +0000 Subject: [PATCH 31/50] tests --- test/test_collector.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 7eac9d8577e..cc88c97c3f7 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -584,21 +584,16 @@ def test_collector_batch_size_with_env_batch_size( if mask_env_batch_size is not None and len(mask_env_batch_size) != len( new_batch_size ): - with pytest.raises( - RuntimeError, - match=( - f"Batch size mask and env batch size have different" - f" lengths: mask={mask_env_batch_size}, env.batch_size={new_batch_size}" - ), - ): - ccollector = MultiaSyncDataCollector( - create_env_fn=[env for _ in range(n_collector_workers)], + try: + ccollector = SyncDataCollector( + create_env_fn=env, policy=policy, frames_per_batch=frames_per_batch, mask_env_batch_size=mask_env_batch_size, pin_memory=False, - split_trajs=False, ) + assert False + except RuntimeError: return # Multi async no split traj From 8c582feed7a7c63d546a9ad63294fad451f10259 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 22 Jan 2023 20:59:04 +0000 Subject: [PATCH 32/50] tests --- test/test_collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_collector.py b/test/test_collector.py index cc88c97c3f7..6986f52b17f 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -592,7 +592,7 @@ def test_collector_batch_size_with_env_batch_size( mask_env_batch_size=mask_env_batch_size, pin_memory=False, ) - assert False + raise AssertionError except RuntimeError: return From 7e15d48d716e931e586ed439c5228a587e9642df Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 23 Jan 2023 08:04:56 +0000 Subject: [PATCH 33/50] tests --- test/test_collector.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 6986f52b17f..45748ef6035 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -9,6 +9,10 @@ import numpy as np import pytest import torch +from tensordict.nn import TensorDictModule +from tensordict.tensordict import assert_allclose_td, TensorDict +from torch import nn + from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( ContinuousActionVecMockEnv, @@ -19,9 +23,6 @@ DiscreteActionVecPolicy, MockSerialEnv, ) -from tensordict.nn import TensorDictModule -from tensordict.tensordict import assert_allclose_td, TensorDict -from torch import nn from torchrl._utils import prod, seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( @@ -663,10 +664,6 @@ def test_collector_batch_size_with_env_batch_size( ccollector.shutdown() -def test_collector_batch_size_advanced(): - pass - - @pytest.mark.parametrize("num_env", [1, 3]) @pytest.mark.parametrize("env_name", ["vec", "conv"]) def test_concurrent_collector_seed(num_env, env_name, seed=100): From 632447e4a0bb5cdc928a5bdb42c4fca67302564d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 23 Jan 2023 08:19:23 +0000 Subject: [PATCH 34/50] tests --- torchrl/collectors/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 546de99d1ee..8a5d06abdd0 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -25,13 +25,17 @@ def stacked_output_fun(*args, **kwargs): return stacked_output_fun -def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: +def split_trajectories( + rollout_tensordict: TensorDictBase, sort: bool = True +) -> TensorDictBase: """A util function for trajectory separation. Takes a tensordict with a key traj_ids that indicates the id of each trajectory. The input tensordict has batch_size = B x *other_dims From there, builds a B / T x *other_dims x T x ... zero-padded tensordict with B / T batches on max duration T + + If sorted=True the trajectories are also sorted based on traj_id. """ # TODO: incorporate tensordict.split once it's implemented mask = torch.ones( @@ -45,7 +49,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: sep = ".-|-." rollout_tensordict = rollout_tensordict.flatten_keys(sep) traj_ids = rollout_tensordict.get("traj_ids")[mask].view(-1) - unique_traj_ids = traj_ids.unique(sorted=False) + unique_traj_ids = traj_ids.unique(sorted=sort) MAX = max([(traj_ids == i).count_nonzero() for i in unique_traj_ids]) out_splits = [] From 62701cfaf2bc38535456abbdb95cac9e014ff074 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 23 Jan 2023 08:20:18 +0000 Subject: [PATCH 35/50] Lint --- test/test_collector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 45748ef6035..6e4c72cd812 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -9,9 +9,6 @@ import numpy as np import pytest import torch -from tensordict.nn import TensorDictModule -from tensordict.tensordict import assert_allclose_td, TensorDict -from torch import nn from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( @@ -23,6 +20,9 @@ DiscreteActionVecPolicy, MockSerialEnv, ) +from tensordict.nn import TensorDictModule +from tensordict.tensordict import assert_allclose_td, TensorDict +from torch import nn from torchrl._utils import prod, seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( From 5af90c1f029f61c981f475e3097e501b2a6b9bbd Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 23 Jan 2023 08:28:41 +0000 Subject: [PATCH 36/50] tests --- test/test_collector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 6e4c72cd812..2b6f526ba6b 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -625,7 +625,7 @@ def test_collector_batch_size_with_env_batch_size( ) ccollector.set_seed(seed) for i, b in enumerate(ccollector): - assert b.batch_size[1:] == torch.Size([*env_unmasked_dims, max_steps]) + assert b.batch_size == torch.Size([b["traj_ids"].unique(sorted=False).shape[0],*env_unmasked_dims, max_steps]) if i == 1: break ccollector.shutdown() @@ -658,7 +658,7 @@ def test_collector_batch_size_with_env_batch_size( ) ccollector.set_seed(seed) for i, b in enumerate(ccollector): - assert b.batch_size[1:] == torch.Size([*env_unmasked_dims, max_steps]) + assert b.batch_size == torch.Size([b["traj_ids"].unique(sorted=False).shape[0],*env_unmasked_dims, max_steps]) if i == 1: break ccollector.shutdown() From 6b7bdc0bddf606e86f8cb1aa079c08368bd22326 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 23 Jan 2023 08:29:38 +0000 Subject: [PATCH 37/50] tests final --- test/test_collector.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 2b6f526ba6b..b8defb82d1d 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -9,6 +9,9 @@ import numpy as np import pytest import torch +from tensordict.nn import TensorDictModule +from tensordict.tensordict import assert_allclose_td, TensorDict +from torch import nn from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( @@ -20,9 +23,6 @@ DiscreteActionVecPolicy, MockSerialEnv, ) -from tensordict.nn import TensorDictModule -from tensordict.tensordict import assert_allclose_td, TensorDict -from torch import nn from torchrl._utils import prod, seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( @@ -585,17 +585,15 @@ def test_collector_batch_size_with_env_batch_size( if mask_env_batch_size is not None and len(mask_env_batch_size) != len( new_batch_size ): - try: - ccollector = SyncDataCollector( + with pytest.raises(RuntimeError): + SyncDataCollector( create_env_fn=env, policy=policy, frames_per_batch=frames_per_batch, mask_env_batch_size=mask_env_batch_size, pin_memory=False, ) - raise AssertionError - except RuntimeError: - return + return # Multi async no split traj ccollector = MultiaSyncDataCollector( @@ -625,7 +623,9 @@ def test_collector_batch_size_with_env_batch_size( ) ccollector.set_seed(seed) for i, b in enumerate(ccollector): - assert b.batch_size == torch.Size([b["traj_ids"].unique(sorted=False).shape[0],*env_unmasked_dims, max_steps]) + assert b.batch_size == torch.Size( + [b["traj_ids"].unique(sorted=False).shape[0], *env_unmasked_dims, max_steps] + ) if i == 1: break ccollector.shutdown() @@ -658,7 +658,9 @@ def test_collector_batch_size_with_env_batch_size( ) ccollector.set_seed(seed) for i, b in enumerate(ccollector): - assert b.batch_size == torch.Size([b["traj_ids"].unique(sorted=False).shape[0],*env_unmasked_dims, max_steps]) + assert b.batch_size == torch.Size( + [b["traj_ids"].unique(sorted=False).shape[0], *env_unmasked_dims, max_steps] + ) if i == 1: break ccollector.shutdown() From 9cb6492b5d3ff8bf82713d181416309765c992f9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 23 Jan 2023 08:33:53 +0000 Subject: [PATCH 38/50] docs --- torchrl/collectors/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 8a5d06abdd0..f74fc20883f 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -31,9 +31,10 @@ def split_trajectories( """A util function for trajectory separation. Takes a tensordict with a key traj_ids that indicates the id of each trajectory. - The input tensordict has batch_size = B x *other_dims + The input tensordict has batch_size = (B x *masked_dims) - From there, builds a B / T x *other_dims x T x ... zero-padded tensordict with B / T batches on max duration T + From there, builds a (number_of_trajectories x *masked_dims x T) zero-padded tensordict + with number_of_trajectories batches of shape ( *masked_dims, T) with max duration T If sorted=True the trajectories are also sorted based on traj_id. """ From 6702f1bf29a1e506be3dc707129bc148f109bc99 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 23 Jan 2023 10:29:19 +0000 Subject: [PATCH 39/50] Lint --- test/test_collector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index b8defb82d1d..613f3d4ec03 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -9,9 +9,6 @@ import numpy as np import pytest import torch -from tensordict.nn import TensorDictModule -from tensordict.tensordict import assert_allclose_td, TensorDict -from torch import nn from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( @@ -23,6 +20,9 @@ DiscreteActionVecPolicy, MockSerialEnv, ) +from tensordict.nn import TensorDictModule +from tensordict.tensordict import assert_allclose_td, TensorDict +from torch import nn from torchrl._utils import prod, seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( From e78cd2b8c5770bd2c79cbfc57c1558ca78769456 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 23 Jan 2023 10:59:26 +0000 Subject: [PATCH 40/50] tests --- test/test_postprocs.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/test_postprocs.py b/test/test_postprocs.py index d684793670d..f6012f12d73 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -6,8 +6,9 @@ import pytest import torch -from _utils_internal import get_available_devices from tensordict.tensordict import assert_allclose_td, TensorDict + +from _utils_internal import get_available_devices from torchrl.collectors.utils import split_trajectories from torchrl.data.postprocs.postprocs import MultiStep @@ -121,7 +122,7 @@ def create_fake_trajs( traj_ids[done] = traj_ids.max() + torch.arange(1, done.sum() + 1) steps_count[done] = 0 - out = torch.stack(out, 1).contiguous() + out = torch.stack(out, 1).view(-1).contiguous() return out @pytest.mark.parametrize("num_workers", range(3, 34, 3)) @@ -129,8 +130,8 @@ def create_fake_trajs( def test_splits(self, num_workers, traj_len): trajs = TestSplits.create_fake_trajs(num_workers, traj_len) - assert trajs.shape[0] == num_workers - assert trajs.shape[1] == traj_len + assert trajs.shape[0] == num_workers * traj_len + assert len(trajs.shape) == 1 split_trajs = split_trajectories(trajs) assert split_trajs.shape[0] == split_trajs.get("traj_ids").max() + 1 assert split_trajs.shape[1] == split_trajs.get("steps_count").max() + 1 From 7ec62a9eb4defe8398646a9eb47820cd4dc08753 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 23 Jan 2023 11:00:14 +0000 Subject: [PATCH 41/50] Lint --- test/test_postprocs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_postprocs.py b/test/test_postprocs.py index f6012f12d73..249d14aa487 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -6,9 +6,9 @@ import pytest import torch -from tensordict.tensordict import assert_allclose_td, TensorDict from _utils_internal import get_available_devices +from tensordict.tensordict import assert_allclose_td, TensorDict from torchrl.collectors.utils import split_trajectories from torchrl.data.postprocs.postprocs import MultiStep From 3b2646c6a0b42f8eb4ad6fe0b2fb0b9fd98897ae Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 25 Jan 2023 20:50:51 +0000 Subject: [PATCH 42/50] merge main --- test/test_collector.py | 18 +++++++++++++----- torchrl/collectors/utils.py | 4 +++- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 8bb022a770c..16222a39207 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -9,6 +9,9 @@ import numpy as np import pytest import torch +from tensordict.nn import TensorDictModule +from tensordict.tensordict import assert_allclose_td, TensorDict +from torch import nn from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( @@ -20,9 +23,6 @@ DiscreteActionVecPolicy, MockSerialEnv, ) -from tensordict.nn import TensorDictModule -from tensordict.tensordict import assert_allclose_td, TensorDict -from torch import nn from torchrl._utils import prod, seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( @@ -627,7 +627,11 @@ def test_collector_batch_size_with_env_batch_size( ccollector.set_seed(seed) for i, b in enumerate(ccollector): assert b.batch_size == torch.Size( - [b["traj_ids"].unique(sorted=False).shape[0], *env_unmasked_dims, max_steps] + [ + b["collector", "traj_ids"].unique(sorted=False).shape[0], + *env_unmasked_dims, + max_steps, + ] ) if i == 1: break @@ -662,7 +666,11 @@ def test_collector_batch_size_with_env_batch_size( ccollector.set_seed(seed) for i, b in enumerate(ccollector): assert b.batch_size == torch.Size( - [b["traj_ids"].unique(sorted=False).shape[0], *env_unmasked_dims, max_steps] + [ + b["collector", "traj_ids"].unique(sorted=False).shape[0], + *env_unmasked_dims, + max_steps, + ] ) if i == 1: break diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index fa44c3984e6..c8cb572d9e6 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -49,7 +49,9 @@ def split_trajectories( sep = ".-|-." rollout_tensordict = rollout_tensordict.flatten_keys(sep) - traj_ids = rollout_tensordict.get(sep.join(["collector", "traj_ids"]))[mask].view(-1) + traj_ids = rollout_tensordict.get(sep.join(["collector", "traj_ids"]))[mask].view( + -1 + ) unique_traj_ids = traj_ids.unique(sorted=sort) MAX = max([(traj_ids == i).count_nonzero() for i in unique_traj_ids]) From d778c1779d5c73fc39033c160808849292115602 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 25 Jan 2023 20:53:28 +0000 Subject: [PATCH 43/50] lint --- test/test_collector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 16222a39207..cd6c03f0189 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -9,9 +9,6 @@ import numpy as np import pytest import torch -from tensordict.nn import TensorDictModule -from tensordict.tensordict import assert_allclose_td, TensorDict -from torch import nn from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( @@ -23,6 +20,9 @@ DiscreteActionVecPolicy, MockSerialEnv, ) +from tensordict.nn import TensorDictModule +from tensordict.tensordict import assert_allclose_td, TensorDict +from torch import nn from torchrl._utils import prod, seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( From 26de6182668c3a5d0affbc21a8aae8eb4a4133fb Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 25 Jan 2023 21:05:48 +0000 Subject: [PATCH 44/50] refactor --- torchrl/collectors/collectors.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b729e12e352..d0c1d05746d 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -23,6 +23,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import multiprocessing as mp from torch.utils.data import IterableDataset + from torchrl._utils import _check_for_faulty_process from torchrl.collectors.utils import ( bring_forward_and_squash_batch_sizes, @@ -478,12 +479,13 @@ def __init__( if self.postproc is not None: self.postproc.to(self.passing_device) self.max_frames_per_traj = max_frames_per_traj + self.frames_per_batch = frames_per_batch if frames_per_batch % self.n_env != 0: warnings.warn( f"frames_per_batch {frames_per_batch} is not exactly divisible by the number of batched environments {self.n_env}, " f" this results in more frames_per_batch per iteration that requested" ) - self.frames_per_batch = -(-frames_per_batch // self.n_env) + self.batched_frames_per_batch = -(-self.frames_per_batch // self.n_env) self.pin_memory = pin_memory self.exploration_mode = ( exploration_mode if exploration_mode else DEFAULT_EXPLORATION_MODE @@ -511,7 +513,7 @@ def __init__( self._tensordict_out = self._tensordict_out.to(env.device) self._tensordict_out = ( self._tensordict_out.unsqueeze(-1) - .expand(*env.batch_size, self.frames_per_batch) + .expand(*env.batch_size, self.batched_frames_per_batch) .to_tensordict() ) else: @@ -524,7 +526,9 @@ def __init__( self._tensordict_out = self.policy(self._tensordict_out).unsqueeze(-1) self._tensordict_out = self._tensordict_out.to(self.env_device) self._tensordict_out = ( - self._tensordict_out.expand(*env.batch_size, self.frames_per_batch) + self._tensordict_out.expand( + *env.batch_size, self.batched_frames_per_batch + ) .to_tensordict() .zero_() ) @@ -737,7 +741,7 @@ def rollout(self) -> TensorDictBase: ) with set_exploration_mode(self.exploration_mode): - for j in range(self.frames_per_batch): + for j in range(self.batched_frames_per_batch): if self._frames < self.init_random_frames: self.env.rand_step(self._tensordict) else: From 1b2a3b4a189862d107d30320474c34773fe76de9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 25 Jan 2023 21:14:54 +0000 Subject: [PATCH 45/50] adjusted self.frames_per_batch to always give the correct frames --- torchrl/collectors/collectors.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index d0c1d05746d..061853efa33 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -479,13 +479,13 @@ def __init__( if self.postproc is not None: self.postproc.to(self.passing_device) self.max_frames_per_traj = max_frames_per_traj - self.frames_per_batch = frames_per_batch if frames_per_batch % self.n_env != 0: warnings.warn( f"frames_per_batch {frames_per_batch} is not exactly divisible by the number of batched environments {self.n_env}, " f" this results in more frames_per_batch per iteration that requested" ) - self.batched_frames_per_batch = -(-self.frames_per_batch // self.n_env) + self.batched_frames_per_batch = -(-frames_per_batch // self.n_env) + self.frames_per_batch = self.batched_frames_per_batch * self.n_env self.pin_memory = pin_memory self.exploration_mode = ( exploration_mode if exploration_mode else DEFAULT_EXPLORATION_MODE @@ -1350,7 +1350,9 @@ def frames_per_batch_worker(self): f"frames_per_batch {self.frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}," f" this results in more frames_per_batch per iteration that requested" ) - return -(-self.frames_per_batch // self.num_workers) + frames_per_batch_worker = -(-self.frames_per_batch // self.num_workers) + self.frames_per_batch = frames_per_batch_worker * self.num_workers + return @property def _queue_len(self) -> int: From b4c8ba291685fab5bf1c08501f18a5ae574ab341 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 25 Jan 2023 21:25:17 +0000 Subject: [PATCH 46/50] typo --- torchrl/collectors/collectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 061853efa33..3e77f76b665 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1352,7 +1352,7 @@ def frames_per_batch_worker(self): ) frames_per_batch_worker = -(-self.frames_per_batch // self.num_workers) self.frames_per_batch = frames_per_batch_worker * self.num_workers - return + return frames_per_batch_worker @property def _queue_len(self) -> int: From 1b1d2eac67d7ddd5fbb2a654a970214d0a32b02b Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Tue, 28 Feb 2023 10:07:36 +0000 Subject: [PATCH 47/50] Update torchrl/collectors/collectors.py Co-authored-by: Vincent Moens --- torchrl/collectors/collectors.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index f32c15a9a0a..052dea63ae9 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -911,11 +911,10 @@ class _MultiDataCollector(_DataCollector): in other words, if the env is a multi-agent env, all agents will be reset once one of them is done. Defaults to `True`. - mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size, - with a value of True it indicates to consider the corresponding dimension of env.batch_size as part of the - batch of environments used to collect frames. A value of False it indicates NOT to consider that dimension - as part of the batch of environments used to collect frames (used for agent dimension in multi-agent settings). - Default is None (corresponding to all True). + mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size. + A value of ``True`` indicates that the corresponding dimension of ``env.batch_size`` is to be included in the computation of the number of frames collected. A value of ``False`` indicates NOT to consider this particular dimension + as part of the batch of environments used to collect frames (e.g. used for agent dimension in multi-agent settings). + Defaults to ``True`` for all dims. """ From f450b399cdf4f64a1a54923a703be32dc5a32b85 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Tue, 28 Feb 2023 10:07:47 +0000 Subject: [PATCH 48/50] Update torchrl/collectors/collectors.py Co-authored-by: Vincent Moens --- torchrl/collectors/collectors.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 052dea63ae9..48f925007bf 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1626,11 +1626,10 @@ class aSyncDataCollector(MultiaSyncDataCollector): init_with_lag (bool, optional): if True, the first trajectory will be truncated earlier at a random step. This is helpful to desynchronize the environments, such that steps do no match in all collected rollouts. default = True - mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size, - with a value of True it indicates to consider the corresponding dimension of env.batch_size as part of the - batch of environments used to collect frames. A value of False it indicates NOT to consider that dimension - as part of the batch of environments used to collect frames (used for agent dimension in multi-agent settings). - Default is None (corresponding to all True). + mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size. + A value of ``True`` indicates that the corresponding dimension of ``env.batch_size`` is to be included in the computation of the number of frames collected. A value of ``False`` indicates NOT to consider this particular dimension + as part of the batch of environments used to collect frames (e.g. used for agent dimension in multi-agent settings). + Defaults to ``True`` for all dims. """ From 80f2b55da0a907e8f6b7fab2a90d8c5c08010ac0 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Tue, 28 Feb 2023 10:07:58 +0000 Subject: [PATCH 49/50] Update torchrl/collectors/utils.py Co-authored-by: Vincent Moens --- torchrl/collectors/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index c8cb572d9e6..9af9a927e61 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -101,7 +101,7 @@ def bring_forward_and_squash_batch_sizes( permute: Sequence[int], batch_size_unmasked: Sequence[int], ) -> TensorDictBase: - """Permutes the batch dimesnions attording to the permute indeces and then squashes all leadning dimesnions apart from batch_size_unmasked.""" + """Permutes the batch dimensions according to the permute indices, then squashes all the leading dimensions except ``batch_size_unmasked``.""" # Bring all batch dimensions to the front (only performs computation if it is not already the case) tensordict = tensordict.permute(permute) # Flatten all batch dimensions into first one and leave unmasked dimensions untouched From 3545dd75f1d9bf160bfc84b70d6ec6bce15e244f Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 28 Feb 2023 10:25:32 +0000 Subject: [PATCH 50/50] refactor doc --- torchrl/collectors/collectors.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index cbb2cff84ef..55d41e6ec68 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -23,7 +23,6 @@ from tensordict.tensordict import TensorDictBase from torch import multiprocessing as mp from torch.utils.data import IterableDataset - from torchrl._utils import _check_for_faulty_process from torchrl.collectors.utils import ( bring_forward_and_squash_batch_sizes, @@ -298,11 +297,11 @@ class SyncDataCollector(_DataCollector): updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. Default is False. - mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size, - with a value of True it indicates to consider the corresponding dimension of env.batch_size as part of the - batch of environments used to collect frames. A value of False it indicates NOT to consider that dimension - as part of the batch of environments used to collect frames (used for agent dimension in multi-agent settings). - Default is None (corresponding to all True). + mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size. + A value of ``True`` indicates that the corresponding dimension of ``env.batch_size`` is to be included in + the computation of the number of frames collected. A value of ``False`` indicates NOT to consider this particular dimension + as part of the batch of environments used to collect frames (e.g. used for agent dimension in multi-agent settings). + Defaults to ``True`` for all dims. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -919,7 +918,8 @@ class _MultiDataCollector(_DataCollector): reset once one of them is done. Defaults to `True`. mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size. - A value of ``True`` indicates that the corresponding dimension of ``env.batch_size`` is to be included in the computation of the number of frames collected. A value of ``False`` indicates NOT to consider this particular dimension + A value of ``True`` indicates that the corresponding dimension of ``env.batch_size`` is to be included in + the computation of the number of frames collected. A value of ``False`` indicates NOT to consider this particular dimension as part of the batch of environments used to collect frames (e.g. used for agent dimension in multi-agent settings). Defaults to ``True`` for all dims. @@ -1634,7 +1634,8 @@ class aSyncDataCollector(MultiaSyncDataCollector): This is helpful to desynchronize the environments, such that steps do no match in all collected rollouts. default = True mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size. - A value of ``True`` indicates that the corresponding dimension of ``env.batch_size`` is to be included in the computation of the number of frames collected. A value of ``False`` indicates NOT to consider this particular dimension + A value of ``True`` indicates that the corresponding dimension of ``env.batch_size`` is to be included in + the computation of the number of frames collected. A value of ``False`` indicates NOT to consider this particular dimension as part of the batch of environments used to collect frames (e.g. used for agent dimension in multi-agent settings). Defaults to ``True`` for all dims.