diff --git a/test/mocking_classes.py b/test/mocking_classes.py index f15ca5d96b2..588f70ffe00 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -892,7 +892,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self.count[:] = self.start_val return TensorDict( source={ - "observation": self.count.clone(), + "observation": self.count.float().clone(), "done": self.count > self.max_steps, }, batch_size=self.batch_size, @@ -907,7 +907,7 @@ def _step( self.count += action.to(torch.int) return TensorDict( source={ - "observation": self.count, + "observation": self.count.float(), "done": self.count > self.max_steps, "reward": torch.zeros_like(self.count, dtype=torch.float), }, diff --git a/test/test_collector.py b/test/test_collector.py index 123965465c5..26157d997dc 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -9,6 +9,7 @@ import numpy as np import pytest import torch + from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( ContinuousActionVecMockEnv, @@ -23,7 +24,7 @@ from tensordict.nn import TensorDictModule from tensordict.tensordict import assert_allclose_td, TensorDict from torch import nn -from torchrl._utils import seed_generator +from torchrl._utils import prod, seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( MultiaSyncDataCollector, @@ -318,16 +319,18 @@ def make_env(): ) for _data in collector: continue - steps = _data["collector", "step_count"][..., 1:] - done = _data["done"][..., :-1, :].squeeze(-1) + steps = _data["collector", "step_count"] + done = _data["done"].squeeze(-1) + traj_ids = _data["collector", "traj_ids"] # we don't want just one done assert done.sum() > 3 - # check that after a done, the next step count is always 1 - assert (steps[done] == 1).all() - # check that if the env is not done, the next step count is > 1 - assert (steps[~done] > 1).all() - # check that if step is 1, then the env was done before - assert (steps == 1)[done].all() + for i in traj_ids.unique(sorted=False): + # check that after a done, the next step count is always 1 + assert (steps[traj_ids == i][0] == 1).all() + # check that step counts are positive for not first elements of traj + assert (steps[traj_ids == i][1:] > 1).all() + # check that non-last elements of trajectories are not done + assert (done[traj_ids == i][:-1] == 0).all() # check that split traj has a minimum total reward of -21 (for pong only) _data = split_trajectories(_data) assert _data["reward"].sum(-2).min() == -21 @@ -375,9 +378,8 @@ def make_env(seed): ) for _, d in enumerate(collector): # noqa break - - assert (d["done"].sum(-2) >= 1).all() - assert torch.unique(d["collector", "traj_ids"], dim=-1).shape[-1] == 1 + assert (d["done"].sum() >= 1).all() + assert torch.unique(d["collector", "traj_ids"]).shape[0] == num_env del collector @@ -536,6 +538,146 @@ def env_fn(): ccollector.shutdown() +@pytest.mark.parametrize("n_env_workers", [1, 3]) +@pytest.mark.parametrize("batch_size", [(), (2, 4)]) +@pytest.mark.parametrize("mask_env_batch_size", [None, (True, False, True)]) +def test_collector_batch_size_with_env_batch_size( + n_env_workers, + batch_size, + mask_env_batch_size, + max_steps=5, + n_collector_workers=4, + seed=100, +): + if n_env_workers == 3 and _os_is_windows: + pytest.skip("Test timeout (> 10 min) on CI pipeline Windows machine with GPU") + if n_env_workers == 1: + env = lambda: CountingEnv(max_steps=max_steps, batch_size=batch_size) + if mask_env_batch_size is not None: + mask_env_batch_size = mask_env_batch_size[1:] + else: + env = lambda: ParallelEnv( + num_workers=n_env_workers, + create_env_fn=lambda: CountingEnv( + max_steps=max_steps, batch_size=batch_size + ), + ) + new_batch_size = env().batch_size + policy = TensorDictModule( + nn.Linear(1, 1), in_keys=["observation"], out_keys=["action"] + ) + torch.manual_seed(0) + np.random.seed(0) + + env_unmasked_dims = [ + dim + for i, dim in enumerate(new_batch_size) + if mask_env_batch_size is not None and not mask_env_batch_size[i] + ] + n_batch_envs = max( + 1, + prod( + [ + dim + for i, dim in enumerate(new_batch_size) + if mask_env_batch_size is None or mask_env_batch_size[i] + ] + ), + ) + frames_per_batch = n_collector_workers * n_batch_envs * n_env_workers * 5 + + if mask_env_batch_size is not None and len(mask_env_batch_size) != len( + new_batch_size + ): + with pytest.raises(RuntimeError): + SyncDataCollector( + create_env_fn=env, + policy=policy, + frames_per_batch=frames_per_batch, + mask_env_batch_size=mask_env_batch_size, + pin_memory=False, + ) + return + + # Multi async no split traj + ccollector = MultiaSyncDataCollector( + create_env_fn=[env for _ in range(n_collector_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + mask_env_batch_size=mask_env_batch_size, + pin_memory=False, + split_trajs=False, + ) + ccollector.set_seed(seed) + for i, b in enumerate(ccollector): + assert b.batch_size == torch.Size([frames_per_batch, *env_unmasked_dims]) + if i == 1: + break + ccollector.shutdown() + + # Multi async split traj + ccollector = MultiaSyncDataCollector( + create_env_fn=[env for _ in range(n_collector_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + max_frames_per_traj=max_steps, + mask_env_batch_size=mask_env_batch_size, + pin_memory=False, + split_trajs=True, + ) + ccollector.set_seed(seed) + for i, b in enumerate(ccollector): + assert b.batch_size == torch.Size( + [ + b["collector", "traj_ids"].unique(sorted=False).shape[0], + *env_unmasked_dims, + max_steps, + ] + ) + if i == 1: + break + ccollector.shutdown() + + # Multi sync no split traj + ccollector = MultiSyncDataCollector( + create_env_fn=[env for _ in range(n_collector_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + mask_env_batch_size=mask_env_batch_size, + pin_memory=False, + split_trajs=False, + ) + ccollector.set_seed(seed) + for i, b in enumerate(ccollector): + assert b.batch_size == torch.Size([frames_per_batch, *env_unmasked_dims]) + if i == 1: + break + ccollector.shutdown() + + # Multi sync split traj + ccollector = MultiSyncDataCollector( + create_env_fn=[env for _ in range(n_collector_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + max_frames_per_traj=max_steps, + mask_env_batch_size=mask_env_batch_size, + pin_memory=False, + split_trajs=True, + ) + ccollector.set_seed(seed) + for i, b in enumerate(ccollector): + assert b.batch_size == torch.Size( + [ + b["collector", "traj_ids"].unique(sorted=False).shape[0], + *env_unmasked_dims, + max_steps, + ] + ) + if i == 1: + break + ccollector.shutdown() + + @pytest.mark.parametrize("num_env", [1, 2]) @pytest.mark.parametrize("env_name", ["vec", "conv"]) def test_concurrent_collector_seed(num_env, env_name, seed=100): diff --git a/test/test_postprocs.py b/test/test_postprocs.py index be09fdd0c9e..ace2b9aa1eb 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -6,6 +6,7 @@ import pytest import torch + from _utils_internal import get_available_devices from tensordict.tensordict import assert_allclose_td, TensorDict from torchrl.collectors.utils import split_trajectories @@ -121,7 +122,7 @@ def create_fake_trajs( traj_ids[done] = traj_ids.max() + torch.arange(1, done.sum() + 1) step_count[done] = 0 - out = torch.stack(out, 1).contiguous() + out = torch.stack(out, 1).view(-1).contiguous() return out @pytest.mark.parametrize("num_workers", range(3, 34, 3)) @@ -129,8 +130,8 @@ def create_fake_trajs( def test_splits(self, num_workers, traj_len): trajs = TestSplits.create_fake_trajs(num_workers, traj_len) - assert trajs.shape[0] == num_workers - assert trajs.shape[1] == traj_len + assert trajs.shape[0] == num_workers * traj_len + assert len(trajs.shape) == 1 split_trajs = split_trajectories(trajs) assert ( split_trajs.shape[0] == split_trajs.get(("collector", "traj_ids")).max() + 1 diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 7dbe3991e6e..55d41e6ec68 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -9,6 +9,7 @@ import os import queue import time +import warnings from collections import OrderedDict from copy import deepcopy from multiprocessing import connection, queues @@ -22,9 +23,12 @@ from tensordict.tensordict import TensorDictBase from torch import multiprocessing as mp from torch.utils.data import IterableDataset - -from torchrl._utils import _check_for_faulty_process, prod -from torchrl.collectors.utils import split_trajectories +from torchrl._utils import _check_for_faulty_process +from torchrl.collectors.utils import ( + bring_forward_and_squash_batch_sizes, + get_batch_size_masked, + split_trajectories, +) from torchrl.data import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import EnvBase @@ -293,6 +297,11 @@ class SyncDataCollector(_DataCollector): updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. Default is False. + mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size. + A value of ``True`` indicates that the corresponding dimension of ``env.batch_size`` is to be included in + the computation of the number of frames collected. A value of ``False`` indicates NOT to consider this particular dimension + as part of the batch of environments used to collect frames (e.g. used for agent dimension in multi-agent settings). + Defaults to ``True`` for all dims. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -369,6 +378,7 @@ def __init__( init_with_lag: bool = False, return_same_td: bool = False, reset_when_done: bool = True, + mask_env_batch_size: Optional[Sequence[bool]] = None, ): self.closed = True if seed is not None: @@ -405,7 +415,52 @@ def __init__( self.env: EnvBase = env self.closed = False self.reset_when_done = reset_when_done - self.n_env = self.env.numel() + + # Batch sizes and masks + if mask_env_batch_size is None: + mask_env_batch_size = [True for _ in self.env.batch_size] + else: + mask_env_batch_size = list(mask_env_batch_size) + self.mask_env_batch_size = mask_env_batch_size + self.mask_out_batch_size = mask_env_batch_size + [True] + # env.batch_size with masked dimensions set to 1. + # Also returns error in case the input mask is malformed + self.env_batch_size_masked = get_batch_size_masked( + self.env.batch_size, self.mask_env_batch_size + ) + + # Indices of env.batch_size dims not in the batch + env_batch_size_unmasked_indeces = [ + i for i, is_batch in enumerate(self.mask_env_batch_size) if not is_batch + ] + # env.batch_size dims not in the batch + self.env_batch_size_unmasked = [ + env.batch_size[i] + for i, is_batch in enumerate(self.mask_env_batch_size) + if not is_batch + ] + # Permutation indices: fist all batch dims and then all masked out dims + self.permute_out_batch_size = [ + i for i, is_batch in enumerate(self.mask_out_batch_size) if is_batch + ] + env_batch_size_unmasked_indeces + self.permute_env_batch_size = [ + i for i, is_batch in enumerate(self.mask_env_batch_size) if is_batch + ] + env_batch_size_unmasked_indeces + + # Number of batched environments used for collection + self.n_env = max(1, self.env_batch_size_masked.numel()) + + if len(self.env_batch_size_unmasked): + # Mask used to only consider batch dimensions in trajectories + self.mask_tensor = torch.ones( + *self.env.batch_size, + dtype=torch.bool, + device=self.env.device, + ) + for dim in env_batch_size_unmasked_indeces: + self.mask_tensor.index_fill_( + dim, torch.arange(1, self.env.batch_size[dim]), 0 + ) (self.policy, self.device, self.get_weights_fn,) = self._get_policy_and_device( policy=policy, @@ -423,7 +478,13 @@ def __init__( if self.postproc is not None: self.postproc.to(self.storing_device) self.max_frames_per_traj = max_frames_per_traj - self.frames_per_batch = -(-frames_per_batch // self.n_env) + if frames_per_batch % self.n_env != 0: + warnings.warn( + f"frames_per_batch {frames_per_batch} is not exactly divisible by the number of batched environments {self.n_env}, " + f" this results in more frames_per_batch per iteration that requested" + ) + self.batched_frames_per_batch = -(-frames_per_batch // self.n_env) + self.frames_per_batch = self.batched_frames_per_batch * self.n_env self.pin_memory = pin_memory self.exploration_mode = ( exploration_mode if exploration_mode else DEFAULT_EXPLORATION_MODE @@ -455,7 +516,7 @@ def __init__( self._tensordict_out.update(self.policy.spec.zero()) self._tensordict_out = ( self._tensordict_out.unsqueeze(-1) - .expand(*env.batch_size, self.frames_per_batch) + .expand(*env.batch_size, self.batched_frames_per_batch) .to_tensordict() ) else: @@ -467,7 +528,9 @@ def __init__( self._tensordict_out = self._tensordict_out.to(self.device) self._tensordict_out = self.policy(self._tensordict_out).unsqueeze(-1) self._tensordict_out = ( - self._tensordict_out.expand(*env.batch_size, self.frames_per_batch) + self._tensordict_out.expand( + *env.batch_size, self.batched_frames_per_batch + ) .to_tensordict() .zero_() ) @@ -490,6 +553,11 @@ def __init__( device=self.storing_device, ), ) + self._tensordict_out = bring_forward_and_squash_batch_sizes( + self._tensordict_out, + self.permute_out_batch_size, + self.env_batch_size_unmasked, + ) if split_trajs is None: if not self.reset_when_done: @@ -533,17 +601,16 @@ def iterator(self) -> Iterator[TensorDictBase]: Yields: TensorDictBase objects containing (chunks of) trajectories """ - total_frames = self.total_frames i = -1 self._frames = 0 while True: i += 1 self._iter = i tensordict_out = self.rollout() - self._frames += tensordict_out.numel() - if self._frames >= total_frames: - self.env.close() + self._frames += tensordict_out.batch_size[0] + if self._frames >= self.total_frames: + self.env.close() if self.split_trajs: tensordict_out = split_trajectories(tensordict_out) if self.postproc is not None: @@ -614,10 +681,23 @@ def _step_and_maybe_reset(self) -> None: raise RuntimeError( f"Env {self.env} was done after reset on specified '_reset' dimensions. This is (currently) not allowed." ) + steps[done_or_terminated] = 0 + + if len(self.env_batch_size_unmasked): + traj_ids = traj_ids[self.mask_tensor] + done_or_terminated = done_or_terminated[self.mask_tensor] + traj_ids[done_or_terminated] = traj_ids.max() + torch.arange( 1, done_or_terminated.sum() + 1, device=traj_ids.device ) - steps[done_or_terminated] = 0 + + if len(self.env_batch_size_unmasked): + traj_ids = ( + traj_ids.view(self.env_batch_size_masked) + .expand(self.env.batch_size) + .clone() + ) + self._tensordict.set_( ("collector", "traj_ids"), traj_ids ) # no ops if they already match @@ -637,8 +717,16 @@ def rollout(self) -> TensorDictBase: self._tensordict.update(self.env.reset(), inplace=True) self._tensordict.fill_(("collector", "step_count"), 0) + self._tensordict.set( + ("collector", "traj_ids"), + torch.arange(self.n_env) + .view(self.env_batch_size_masked) + .expand(self.env.batch_size) + .clone(), + ) + with set_exploration_mode(self.exploration_mode): - for j in range(self.frames_per_batch): + for j in range(self.batched_frames_per_batch): if self._frames < self.init_random_frames: self.env.rand_step(self._tensordict) else: @@ -646,15 +734,25 @@ def rollout(self) -> TensorDictBase: step_count = self._tensordict.get(("collector", "step_count")) self._tensordict.set_(("collector", "step_count"), step_count + 1) + + tensordict = bring_forward_and_squash_batch_sizes( + self._tensordict, + self.permute_env_batch_size, + self.env_batch_size_unmasked, + ) # we must clone all the values, since the step / traj_id updates are done in-place try: - self._tensordict_out[..., j] = self._tensordict + self._tensordict_out[ + j * self.n_env : (j + 1) * self.n_env + ] = tensordict except RuntimeError: # unlock the output tensordict to allow for new keys to be written # these will be missed during the sync but at least we won't get an error during the update is_shared = self._tensordict_out.is_shared() self._tensordict_out.unlock() - self._tensordict_out[..., j] = self._tensordict + self._tensordict_out[ + j * self.n_env : (j + 1) * self.n_env + ] = tensordict if is_shared: self._tensordict_out.share_memory_() @@ -666,7 +764,7 @@ def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state.""" if index is not None: # check that the env supports partial reset - if prod(self.env.batch_size) == 0: + if self.n_env == 0: raise RuntimeError("resetting unique env with index is not permitted.") _reset = torch.zeros( self.env.batch_size, @@ -809,7 +907,7 @@ class _MultiDataCollector(_DataCollector): init_with_lag (bool, optional): if True, the first trajectory will be truncated earlier at a random step. This is helpful to desynchronize the environments, such that steps do no match in all collected rollouts. default = True - exploration_mode (str, optional): interaction mode to be used when collecting data. Must be one of "random", + exploration_mode (str, optional): interaction mode to be used when collecting data. Must be one of "random", "mode" or "mean". default = "random" reset_when_done (bool, optional): if True, the contained environment will be reset @@ -819,6 +917,11 @@ class _MultiDataCollector(_DataCollector): in other words, if the env is a multi-agent env, all agents will be reset once one of them is done. Defaults to `True`. + mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size. + A value of ``True`` indicates that the corresponding dimension of ``env.batch_size`` is to be included in + the computation of the number of frames collected. A value of ``False`` indicates NOT to consider this particular dimension + as part of the batch of environments used to collect frames (e.g. used for agent dimension in multi-agent settings). + Defaults to ``True`` for all dims. """ @@ -847,6 +950,7 @@ def __init__( init_with_lag: bool = False, exploration_mode: str = DEFAULT_EXPLORATION_MODE, reset_when_done: bool = True, + mask_env_batch_size: Optional[Sequence[bool]] = None, ): self.closed = True self.create_env_fn = create_env_fn @@ -937,6 +1041,7 @@ def device_err_msg(device_name, devices_list): f"Found {type(storing_devices)} instead." ) + self.mask_env_batch_size = mask_env_batch_size self.total_frames = total_frames if total_frames > 0 else float("inf") self.reset_at_each_iter = reset_at_each_iter self.postprocs = postproc @@ -1012,6 +1117,7 @@ def _run_processes(self) -> None: "exploration_mode": self.exploration_mode, "reset_when_done": self.reset_when_done, "idx": i, + "mask_env_batch_size": self.mask_env_batch_size, } proc = mp.Process(target=_main_async_collector, kwargs=kwargs) # proc.daemon can't be set as daemonic processes may be launched by the process itself @@ -1216,7 +1322,14 @@ class MultiSyncDataCollector(_MultiDataCollector): @property def frames_per_batch_worker(self): - return -(-self.frames_per_batch // self.num_workers) + if self.frames_per_batch % self.num_workers != 0: + warnings.warn( + f"frames_per_batch {self.frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}," + f" this results in more frames_per_batch per iteration that requested" + ) + frames_per_batch_worker = -(-self.frames_per_batch // self.num_workers) + self.frames_per_batch = frames_per_batch_worker * self.num_workers + return frames_per_batch_worker @property def _queue_len(self) -> int: @@ -1252,7 +1365,7 @@ def iterator(self) -> Iterator[TensorDictBase]: else: idx = new_data workers_frames[idx] = ( - workers_frames[idx] + out_tensordicts_shared[idx].numel() + workers_frames[idx] + out_tensordicts_shared[idx].batch_size[0] ) if workers_frames[idx] >= self.total_frames: @@ -1285,12 +1398,11 @@ def iterator(self) -> Iterator[TensorDictBase]: out=out_buffer, ) + frames += out_buffer.batch_size[0] if self.split_trajs: out = split_trajectories(out_buffer) - frames += out.get(("collector", "mask")).sum().item() else: out = out_buffer.clone() - frames += prod(out.shape) if self.postprocs: self.postprocs = self.postprocs.to(out.device) out = self.postprocs(out) @@ -1415,7 +1527,7 @@ def iterator(self) -> Iterator[TensorDictBase]: i += 1 idx, j, out = self._get_from_queue() - worker_frames = out.numel() + worker_frames = out.batch_size[0] if self.split_trajs: out = split_trajectories(out) self._frames += worker_frames @@ -1521,6 +1633,11 @@ class aSyncDataCollector(MultiaSyncDataCollector): init_with_lag (bool, optional): if True, the first trajectory will be truncated earlier at a random step. This is helpful to desynchronize the environments, such that steps do no match in all collected rollouts. default = True + mask_env_batch_size (Sequence[bool], optional): a sequence of bools of the same length as env.batch_size. + A value of ``True`` indicates that the corresponding dimension of ``env.batch_size`` is to be included in + the computation of the number of frames collected. A value of ``False`` indicates NOT to consider this particular dimension + as part of the batch of environments used to collect frames (e.g. used for agent dimension in multi-agent settings). + Defaults to ``True`` for all dims. """ @@ -1544,6 +1661,7 @@ def __init__( device: Optional[Union[int, str, torch.device]] = None, storing_device: Optional[Union[int, str, torch.device]] = None, seed: Optional[int] = None, + mask_env_batch_size: Optional[Sequence[bool]] = None, pin_memory: bool = False, **kwargs, ): @@ -1562,6 +1680,7 @@ def __init__( storing_devices=[storing_device] if storing_device is not None else None, seed=seed, pin_memory=pin_memory, + mask_env_batch_size=mask_env_batch_size, **kwargs, ) @@ -1585,6 +1704,7 @@ def _main_async_collector( init_with_lag: bool = False, exploration_mode: str = DEFAULT_EXPLORATION_MODE, reset_when_done: bool = True, + mask_env_batch_size: Optional[Sequence[bool]] = None, verbose: bool = False, ) -> None: pipe_parent.close() @@ -1609,6 +1729,7 @@ def _main_async_collector( exploration_mode=exploration_mode, reset_when_done=reset_when_done, return_same_td=True, + mask_env_batch_size=mask_env_batch_size, ) if verbose: print("Sync data collector created") diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 1a5afa42fcc..9af9a927e61 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable +from typing import Callable, Optional, Sequence import torch from tensordict.tensordict import pad, TensorDictBase @@ -25,49 +25,85 @@ def stacked_output_fun(*args, **kwargs): return stacked_output_fun -def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: +def split_trajectories( + rollout_tensordict: TensorDictBase, sort: bool = True +) -> TensorDictBase: """A util function for trajectory separation. Takes a tensordict with a key traj_ids that indicates the id of each trajectory. + The input tensordict has batch_size = (B x *masked_dims) - From there, builds a B x T x ... zero-padded tensordict with B batches on max duration T + From there, builds a (number_of_trajectories x *masked_dims x T) zero-padded tensordict + with number_of_trajectories batches of shape ( *masked_dims, T) with max duration T + + If sorted=True the trajectories are also sorted based on traj_id. """ + # TODO: incorporate tensordict.split once it's implemented + mask = torch.ones( + rollout_tensordict.batch_size, + device=rollout_tensordict.device, + dtype=torch.bool, + ) + for dim in range(1, len(rollout_tensordict.batch_size)): + mask.index_fill_(dim, torch.arange(1, rollout_tensordict.batch_size[dim]), 0) + sep = ".-|-." rollout_tensordict = rollout_tensordict.flatten_keys(sep) - traj_ids = rollout_tensordict.get(sep.join(["collector", "traj_ids"])) - splits = traj_ids.view(-1) - splits = [(splits == i).sum().item() for i in splits.unique_consecutive()] - # if all splits are identical then we can skip this function - mask_key = sep.join(("collector", "mask")) - if len(set(splits)) == 1 and splits[0] == traj_ids.shape[-1]: - rollout_tensordict.set( - mask_key, - torch.ones( - rollout_tensordict.shape, - device=rollout_tensordict.device, - dtype=torch.bool, - ), - ) - if rollout_tensordict.ndimension() == 1: - rollout_tensordict = rollout_tensordict.unsqueeze(0).to_tensordict() - return rollout_tensordict.unflatten_keys(sep) - out_splits = rollout_tensordict.view(-1).split(splits, 0) + traj_ids = rollout_tensordict.get(sep.join(["collector", "traj_ids"]))[mask].view( + -1 + ) + unique_traj_ids = traj_ids.unique(sorted=sort) + MAX = max([(traj_ids == i).count_nonzero() for i in unique_traj_ids]) - for out_split in out_splits: + out_splits = [] + for i in unique_traj_ids: + out_split = rollout_tensordict[traj_ids == i] out_split.set( - mask_key, + sep.join(("collector", "mask")), torch.ones( out_split.shape, dtype=torch.bool, device=out_split.get("done").device, ), ) - if len(out_splits) > 1: - MAX = max(*[out_split.shape[0] for out_split in out_splits]) - else: - MAX = out_splits[0].shape[0] - td = torch.stack( - [pad(out_split, [0, MAX - out_split.shape[0]]) for out_split in out_splits], 0 - ).contiguous() + out_split = pad(out_split, [0, MAX - out_split.shape[0]]) + out_split = out_split.permute(-1, *range(len(out_split.batch_size) - 1)) + out_splits.append(out_split) + + td = torch.stack(out_splits, 0).contiguous() td = td.unflatten_keys(sep) return td + + +def get_batch_size_masked( + batch_size: torch.Size, mask: Optional[Sequence[bool]] = None +) -> torch.Size: + """Returns a size with the masked dimensions equal to 1.""" + if mask is None: + return batch_size + if mask is not None and len(mask) != len(batch_size): + raise RuntimeError( + f"Batch size mask and env batch size have different lengths: mask={mask}, env.batch_size={batch_size}" + ) + return torch.Size( + [ + (dim if is_in else 1) + for dim, is_in in zip( + batch_size, + mask, + ) + ] + ) + + +def bring_forward_and_squash_batch_sizes( + tensordict: TensorDictBase, + permute: Sequence[int], + batch_size_unmasked: Sequence[int], +) -> TensorDictBase: + """Permutes the batch dimensions according to the permute indices, then squashes all the leading dimensions except ``batch_size_unmasked``.""" + # Bring all batch dimensions to the front (only performs computation if it is not already the case) + tensordict = tensordict.permute(permute) + # Flatten all batch dimensions into first one and leave unmasked dimensions untouched + tensordict = tensordict.reshape(-1, *batch_size_unmasked) + return tensordict