From 5c530de9a873651cbfe0aaf2639b117425c5033c Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 18 Dec 2023 12:42:42 +0000 Subject: [PATCH 01/18] amend --- docs/source/reference/data.rst | 1 + torchrl/data/__init__.py | 1 + torchrl/data/datasets/minari_data.py | 1 - torchrl/data/datasets/openx.py | 344 ++++++++++++++++++++++++ torchrl/data/replay_buffers/__init__.py | 1 + torchrl/data/replay_buffers/storages.py | 11 +- torchrl/data/replay_buffers/writers.py | 21 ++ 7 files changed, 377 insertions(+), 3 deletions(-) create mode 100644 torchrl/data/datasets/openx.py diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index c413f986499..3c487ace03b 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -41,6 +41,7 @@ We also give users the ability to compose a replay buffer using the following co LazyMemmapStorage TensorStorage Writer + ImmutableDatasetWriter RoundRobinWriter TensorDictRoundRobinWriter TensorDictMaxValueWriter diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 9a12749b482..d670bcbbea7 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -6,6 +6,7 @@ from . import datasets from .postprocs import MultiStep from .replay_buffers import ( + ImmutableDatasetWriter, LazyMemmapStorage, LazyTensorStorage, ListStorage, diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 1c89d1a869b..e9fea02ba06 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -170,7 +170,6 @@ def __init__( prefetch: int | None = None, transform: "torchrl.envs.Transform" | None = None, # noqa-F821 split_trajs: bool = False, - **env_kwargs, ): self.dataset_id = dataset_id if root is None: diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py new file mode 100644 index 00000000000..868078aa0d9 --- /dev/null +++ b/torchrl/data/datasets/openx.py @@ -0,0 +1,344 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import importlib.util + +import io +import os +import tempfile +from pathlib import Path +from typing import Any, Callable, Tuple + +import torch +import tqdm + +from tensordict import make_tensordict, pad, TensorDict + +from torchrl.data import ImmutableDatasetWriter, ReplayBuffer, Storage, Writer +from torchrl.data.datasets.utils import _get_root_dir +from torchrl.data.replay_buffers import Sampler +from torchrl.data.replay_buffers.storages import _collate_id, TensorStorage + +_has_datasets = importlib.util.find_spec("datasets", None) is not None +_has_tv = importlib.util.find_spec("torchvision", None) is not None + + +class OpenXExperienceReplay(ReplayBuffer): + available_datasets = [ + "fractal20220817_data", + "kuka", + "bridge", + "taco_play", + "jaco_play", + "berkeley_cable_routing", + "roboturk", + "nyu_door_opening_surprising_effectiveness", + "viola", + "berkeley_autolab_ur5", + "toto", + "language_table", + "columbia_cairlab_pusht_real", + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds", + "nyu_rot_dataset_converted_externally_to_rlds", + "stanford_hydra_dataset_converted_externally_to_rlds", + "austin_buds_dataset_converted_externally_to_rlds", + "nyu_franka_play_dataset_converted_externally_to_rlds", + "maniskill_dataset_converted_externally_to_rlds", + "furniture_bench_dataset_converted_externally_to_rlds", + "cmu_franka_exploration_dataset_converted_externally_to_rlds", + "ucsd_kitchen_dataset_converted_externally_to_rlds", + "ucsd_pick_and_place_dataset_converted_externally_to_rlds", + "austin_sailor_dataset_converted_externally_to_rlds", + "austin_sirius_dataset_converted_externally_to_rlds", + "bc_z", + "usc_cloth_sim_converted_externally_to_rlds", + "utokyo_pr2_opening_fridge_converted_externally_to_rlds", + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", + "utokyo_saytap_converted_externally_to_rlds", + "utokyo_xarm_pick_and_place_converted_externally_to_rlds", + "utokyo_xarm_bimanual_converted_externally_to_rlds", + "robo_net", + "berkeley_mvp_converted_externally_to_rlds", + "berkeley_rpt_converted_externally_to_rlds", + "kaist_nonprehensile_converted_externally_to_rlds", + "stanford_mask_vit_converted_externally_to_rlds", + "tokyo_u_lsmo_converted_externally_to_rlds", + "dlr_sara_pour_converted_externally_to_rlds", + "dlr_sara_grid_clamp_converted_externally_to_rlds", + "dlr_edan_shared_control_converted_externally_to_rlds", + "asu_table_top_converted_externally_to_rlds", + "stanford_robocook_converted_externally_to_rlds", + "eth_agent_affordances", + "imperialcollege_sawyer_wrist_cam", + "iamlab_cmu_pickup_insert_converted_externally_to_rlds", + "uiuc_d3field", + "utaustin_mutex", + "berkeley_fanuc_manipulation", + "cmu_playing_with_food", + "cmu_play_fusion", + "cmu_stretch", + "berkeley_gnm_recon", + "berkeley_gnm_cory_hall", + "berkeley_gnm_sac_son", + ] + + """Open X-Embodiment datasets experience replay. + + The Open X-Embodiment Dataset contains 1M+ real robot trajectories + spanning 22 robot embodiments, collected through a collaboration between + 21 institutions, demonstrating 527 skills (160266 tasks). + + Args: + TODO + + Keyword Args: + TODO + + Examples: + TODO + + """ + + def __init__( + self, + dataset_id, + batch_size: int | None, + *, + streaming: bool = True, + root: str | Path | None = None, + download: bool = False, + sampler: Sampler | None = None, + writer: Writer | None = None, + collate_fn: Callable | None = None, + pin_memory: bool = False, + prefetch: int | None = None, + transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + split_trajs: bool = False, + ): + self.download = download + self.streaming = streaming + self.dataset_id = dataset_id + self.split_trajs = split_trajs + if split_trajs: + raise NotImplementedError + if not streaming: + if root is None: + root = _get_root_dir("openx") + os.makedirs(root, exist_ok=True) + self.root = Path(root) + if self.download == "force" or ( + self.download and not self._is_downloaded() + ): + storage = self._download_and_preproc() + else: + storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) + else: + self.root = None + if download: + raise ValueError( + "download and streaming cannot be set to ``True`` concomitantly." + ) + storage = _StreamingStorage(dataset_id=dataset_id) + if sampler is None: + sampler = _StreamingSampler() + if writer is None: + writer = ImmutableDatasetWriter() + if collate_fn is None: + collate_fn = _collate_id + super().__init__( + storage=storage, + sampler=sampler, + writer=writer, + collate_fn=collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + batch_size=batch_size, + transform=transform, + ) + + @property + def data_path(self): + if self.streaming: + return None + if self.split_trajs: + return Path(self.root) / (self.dataset_id + "_split") + return self.data_path_root + + @property + def data_path_root(self): + if self.streaming: + return None + return self.root / self.dataset_id + + def _is_downloaded(self): + return os.path.exists(self.data_path_root) + + def _download_and_preproc(self): + if not _has_datasets: + raise ImportError( + f"the `datasets` library is required for the dataset {self.dataset_id}." + ) + import datasets + + with tempfile.TemporaryDirectory() as cache_dir: + dataset = datasets.load_dataset( + "jxu124/OpenX-Embodiment", + self.dataset_id, + streaming=False, + split="train", + cache_dir=cache_dir, + ) + # iterate over the dataset a first time to count elements + total_frames = 0 + pbar = tqdm.tqdm(dataset, desc="counting") + for data in pbar: + if total_frames == 0: + for step in data["data.pickle"]["steps"]: + td = _make_tensordict_image_conv(step).zero_() + total_frames += len(data["data.pickle"]["steps"]) + td_data = ( + td.expand(total_frames) + .memmap_like(self.root / self.dataset_id) + .unlock_() + ) + pbar = tqdm.tqdm(dataset, desc="preproc", total=total_frames) + idx0 = 0 + idx1 = 0 + episode = 0 + for data in pbar: + current_ep = torch.stack( + [ + _make_tensordict_image_conv(step) + for step in data["data.pickle"]["steps"] + ] + ).contiguous() + _format_data(current_ep, episode) + episode += 1 + idx1 += len(current_ep) + td_data[idx0:idx1] = current_ep + idx0 = idx1 + pbar.update(current_ep.shape[0]) + print("total episodes", td_data["next", "done"].sum()) + return TensorStorage(td_data.lock_()) + + +class _StreamingStorage(Storage): + def __init__( + self, + dataset_id: str, + repo: str = "jxu124/OpenX-Embodiment", + split="train", + base_path="data.pickle", + shuffle: bool = True, + truncate: bool = True, + ): + if not _has_datasets: + raise ImportError( + f"the `datasets` library is required for the dataset {dataset_id}." + ) + import datasets + + dataset = datasets.load_dataset(repo, dataset_id, streaming=True, split=split) + if shuffle: + dataset = dataset.shuffle() + self.dataset = iter(dataset) + self.base_path = base_path + self.truncate = truncate + + def get(self, index: int) -> Any: + if not isinstance(index, range): + # we use a range to indicate how much data we want + raise RuntimeError("iterable datasets do not support indexing.") + total = 0 + data_list = [] + episode = 0 + while total < index.stop: + data = next(self.dataset) + if self.base_path: + data = data[self.base_path] + data = torch.stack( + [_make_tensordict_image_conv(step) for step in data["steps"]] + ).contiguous() + _format_data(data, episode) + data_list.append(data) + total += data.numel() + episode += 1 + data = torch.cat(data_list) + if self.truncate: + return data[: index.stop] + return data + + def __len__(self): + raise RuntimeError( + f"{type(self)} does not have a length. Use a downloaded dataset to " + f"access this property." + ) + + +class _StreamingSampler(Sampler): + def __init__(self): + ... + + def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: + return range(batch_size), {} + + def _empty(self): + return + + def dumps(self, path): + ... + + def loads(self, path): + ... + + +OPENX_KEY_MAP = { + "is_first": "is_init", + "is_last": ("next", "done"), + "is_terminal": ("next", "terminated"), + "reward": ("next", "reward"), +} + + +def _format_data(data: TensorDict, episode: int): + observation_ = data.get("observation") + observation_pad = pad(observation_[1:], [0, 1]) + data.set(("next", "observation"), observation_pad) + for key, newkey in OPENX_KEY_MAP.items(): + data.rename_key_(key, newkey) + data.set( + ("next", "truncated"), + data.get(("next", "done")) ^ data.get(("next", "terminated")), + ) + + for key in ("done", "terminated", "truncated", "reward"): + data.set(("next", key), data.get(("next", key)).unsqueeze(-1)) + if key != "reward": + data.set(key, torch.zeros_like(data.get(("next", key)))) + + data.set( + "episode", torch.full(data.shape, episode, device=data.device, dtype=torch.int) + ) + + +def _make_tensordict_image_conv(data): + # in some datasets, the images are not well converted. + # before building the tensordict, we load the PIL image and convert it to a tensor + try: + img_bytes = data["observation"]["image"]["bytes"] + if not _has_tv: + raise ImportError( + f"the `torchvision` library is required to read the image observation." + ) + import torchvision.transforms.v2.functional + from PIL import Image + + img = Image.open(io.BytesIO(img_bytes)) + tensor = torchvision.transforms.v2.functional.pil_to_tensor(img) + data["observation"]["image"] = tensor + except KeyError: + pass + return make_tensordict(data) diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 6be80e26c1f..b83195a5102 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -24,6 +24,7 @@ TensorStorage, ) from .writers import ( + ImmutableDatasetWriter, RoundRobinWriter, TensorDictMaxValueWriter, TensorDictRoundRobinWriter, diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 95db36e7b4e..65be9a2b2b6 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -45,9 +45,16 @@ class Storage: def __init__(self, max_size: int) -> None: self.max_size = int(max_size) - # Prototype feature. RBs that use a given instance of Storage should add + + @property + def _attached_entities(self): + # RBs that use a given instance of Storage should add # themselves to this set. - self._attached_entities = set() + _attached_entities = self.__dict__.get("_attached_entities_set", None) + if _attached_entities is None: + _attached_entities = set() + self.__dict__["_attached_entities_set"] = _attached_entities + return _attached_entities @abc.abstractmethod def set(self, cursor: int, data: Any): diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 702898b5292..718bfdfbac1 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -58,6 +58,27 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return +class ImmutableDatasetWriter(Writer): + """A blocking writer for immutable datasets.""" + + WRITING_ERR = "This dataset doesn't allow writing." + + def add(self, data: Any) -> int: + raise RuntimeError(self.WRITING_ERR) + + def extend(self, data: Sequence) -> torch.Tensor: + raise RuntimeError(self.WRITING_ERR) + + def _empty(self): + raise RuntimeError(self.WRITING_ERR) + + def dumps(self, path): + ... + + def loads(self, path): + ... + + class RoundRobinWriter(Writer): """A RoundRobin Writer class for composable replay buffers.""" From aafe640835f86e9115e55e1f896bfd10effbe073 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 18 Dec 2023 12:43:48 +0000 Subject: [PATCH 02/18] amend --- torchrl/data/datasets/openx.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 868078aa0d9..19212acfce8 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -91,6 +91,12 @@ class OpenXExperienceReplay(ReplayBuffer): spanning 22 robot embodiments, collected through a collaboration between 21 institutions, demonstrating 527 skills (160266 tasks). + .. note:: + Images ... + + .. note:: + Text data ... + Args: TODO From 2944851273c9ff9cc4feb80d7afa478c3c5a42a2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 18 Dec 2023 17:38:57 +0000 Subject: [PATCH 03/18] amend --- test/test_libs.py | 76 ++++++++++ torchrl/data/datasets/openx.py | 257 ++++++++++++++++++++++++++++++++- 2 files changed, 325 insertions(+), 8 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 996de85a8f7..40d360feb51 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -52,6 +52,7 @@ from torchrl.data.datasets.d4rl import D4RLExperienceReplay from torchrl.data.datasets.minari_data import MinariExperienceReplay from torchrl.data.datasets.openml import OpenMLExperienceReplay +from torchrl.data.datasets.openx import OpenXExperienceReplay from torchrl.data.datasets.roboset import RobosetExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import ( @@ -2057,6 +2058,81 @@ def test_load(self): break +@pytest.mark.slow +class TestOpenX: + @pytest.mark.parametrize("padding", [None, 0, True, False]) + @pytest.mark.parametrize("download", [False, True]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize( + "batch_size,num_slices,slice_len", + [ + [32, 32, None], + [32, None, 1], + [3000, 2, None], + [3000, None, 1500], + [None, None, 32], + [None, None, 1500], + ], + ) + def test_openx(self, download, shuffle, padding, batch_size, num_slices, slice_len): + torch.manual_seed(0) + np.random.seed(0) + + dataset = OpenXExperienceReplay( + "cmu_stretch", + download=download, + streaming=not download, + batch_size=batch_size, + shuffle=shuffle, + num_slices=num_slices, + slice_len=slice_len, + pad=padding, + ) + # iterating + if padding is None and ((batch_size is not None and batch_size > 1000) or (slice_len is not None and slice_len > 1000)): + with pytest.raises(RuntimeError, match="The trajectory length (.*) is shorter than the slice length"): + for data in dataset: + break + else: + for data in dataset: + break + # check data shape + if batch_size is not None: + assert data.shape[0] == batch_size + elif slice_len is not None: + assert data.shape[0] == slice_len + if batch_size is not None: + if num_slices is not None: + assert data.get(("next", "done")).sum(-2) == num_slices + else: + assert ( + data.get(("next", "done")).sum(-2) + == data.get("episode").unique().numel() + ) + + # sampling + if batch_size is None: + if slice_len is not None: + batch_size = 2 * slice_len + elif num_slices is not None: + batch_size = num_slices * 32 + sample = dataset.sample(batch_size) + else: + if padding is None and (batch_size > 1000): + with pytest.raises( + RuntimeError, + match="The trajectory length (.*) is shorter than the slice length" + ): + sample = dataset.sample() + return + else: + sample = dataset.sample() + assert sample.shape == (batch_size,) + if slice_len is not None: + assert sample.get(("next", "done")).sum() == int(batch_size // slice_len) + elif num_slices is not None: + assert sample.get(("next", "done")).sum() == num_slices + @pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found") @pytest.mark.parametrize( "dataset", diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 19212acfce8..198582ad533 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -98,21 +98,157 @@ class OpenXExperienceReplay(ReplayBuffer): Text data ... Args: - TODO + dataset_id (str): The dataset to be downloaded. + Must be part of OpenXExperienceReplay.available_datasets + batch_size (int): Batch-size used during sampling. + Can be overridden by `data.sample(batch_size)` if necessary. + See `num_slices` and `slice_len` keyword arguments for a refined + sampling strategy. + If the ``batch_size`` is ``None`` (default), iterating over the + dataset will deliver trajectories one at a time __whereas__ calling + :meth:`~.sample` will __still__ require a batch-size to be provided. Keyword Args: - TODO + shuffle (bool, optional): if ``True``, trajectories are delivered in a + random order. If ``False``, the dataset is iterated over + in the pre-defined order. + num_slice (int, optional): the number of slices in a batch. This + corresponds to the number of trajectories present in a batch. + Once collected, the batch is presented as a concatenation of + sub-trajectories that can be recovered through `batch.reshape(num_slices, -1)`. + The `batch_size` must be divisible by `num_slices` if provided. + This argument is exclusive with ``slice_len``. + If the ``num_slices`` argument equates the ``batch_size``, each sample + will belong to a different trajectory. + If neither ``slice_len`` nor ``num_slice`` are provided: + whenever a trajectory has a length shorter than the + batch-size, a contiguous slice of it of length `batch_size` will be + sampled. If the trajectory length is insufficient, an exception will + be raised unless `pad` is not `None`. + slice_len (int, optional): the length of slices in a batch. This + corresponds to the length of trajectories present in a batch. + Once collected, the batch is presented as a concatenation of + sub-trajectories that can be recovered through `batch.reshape(-1, slice_len)`. + The `batch_size` must be divisible by `slice_len` if provided. + This argument is exclusive with ``num_slice``. + If the ``slice_len`` argument equates ``1``, each sample + will belong to a different trajectory. + If neither ``slice_len`` nor ``num_slice`` are provided: + whenever a trajectory has a length shorter than the + batch-size, a contiguous slice of it of length `batch_size` will be + sampled. If the trajectory length is insufficient, an exception will + be raised unless `pad` is not `None`. + + .. note:: + The ``slice_len`` (but not ``num_slices``) can be used when + iterating over a dataset without passing a batch-size in the, + constructor. In these cases, a random sub-sequence of the + trajectory will be chosen. + + pad (bool, float or None): if ``True``, trajectories of insufficient length + given the `slice_len` or `num_slices` arguments will be padded with + 0s. If another value is provided, it will be used for padding. If + ``False`` or ``None`` (default) any encounter with a trajectory of + insufficient length will raise an exception. + root (Path or str, optional): The Minari dataset root directory. + The actual dataset memory-mapped files will be saved under + `/`. If none is provided, it defaults to + ``~/.cache/torchrl/minari`. + download (bool or str, optional): Whether the dataset should be downloaded if + not found. Defaults to ``True``. Download can also be passed as "force", + in which case the downloaded data will be overwritten. + sampler (Sampler, optional): the sampler to be used. If none is provided + a default RandomSampler() will be used. + writer (Writer, optional): the writer to be used. If none is provided + a default RoundRobinWriter() will be used. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched + loading from a map-style dataset. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. + split_trajs (bool, optional): if ``True``, the trajectories will be split + along the first dimension and padded to have a matching shape. + To split the trajectories, the ``"done"`` signal will be used, which + is recovered via ``done = truncated | terminated``. In other words, + it is assumed that any ``truncated`` or ``terminated`` signal is + equivalent to the end of a trajectory. For some datasets from + ``D4RL``, this may not be true. It is up to the user to make + accurate choices regarding this usage of ``split_trajs``. + Defaults to ``False``. Examples: - TODO + >>> from torchrl.data.datasets import OpenXExperienceReplay + >>> # Download the data, and sample 128 elements in each batch out of two trajectories + >>> num_slices = 2 + >>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=128, + ... num_slices=num_slices, download=True, streaming=False, root=root) + >>> for batch in dataset: + ... print(data.reshape(num_slices, -1)) + ... break + TensorDict( + fields={ + action: Tensor(shape=torch.Size([2, 64, 8]), device=cpu, dtype=torch.float64, is_shared=False), + discount: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + episode: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int32, is_shared=False), + is_first: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.bool, is_shared=False), + is_init: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.bool, is_shared=False), + is_last: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.bool, is_shared=False), + is_terminal: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.bool, is_shared=False), + language_embedding: Tensor(shape=torch.Size([2, 64, 512]), device=cpu, dtype=torch.float64, is_shared=False), + language_instruction: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: TensorDict( + fields={ + image: Tensor(shape=torch.Size([2, 64, 3, 2, 64, 2, 64]), device=cpu, dtype=torch.uint8, is_shared=False), + state: Tensor(shape=torch.Size([2, 64, 4]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([2, 64]), + device=cpu, + is_shared=False), + reward: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([2, 64]), + device=cpu, + is_shared=False), + observation: TensorDict( + fields={ + image: Tensor(shape=torch.Size([2, 64, 3, 2, 64, 2, 64]), device=cpu, dtype=torch.uint8, is_shared=False), + state: Tensor(shape=torch.Size([2, 64, 4]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([2, 64]), + device=cpu, + is_shared=False), + reward: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([2, 64]), + device=cpu, + is_shared=False) + >>> # Read data from a stream. Deliver entire trajectories when iterating + >>> dataset = OpenXExperienceReplay("cmu_stretch", + ... num_slices=num_slices, download=True, streaming=False, root=root) + >>> for data in dataset: # data does not have a consistent shape + ... break + >>> # Define batch-size dynamically + >>> data = dataset.sample(128) # delivers 2 sub-trajectories of length 64 """ def __init__( self, dataset_id, - batch_size: int | None, + batch_size: int | None = None, *, + shuffle: bool = True, + num_slices: int | None = None, + slice_len: int | None = None, + pad: float | bool | None = None, streaming: bool = True, root: str | Path | None = None, download: bool = False, @@ -128,6 +264,12 @@ def __init__( self.streaming = streaming self.dataset_id = dataset_id self.split_trajs = split_trajs + self.shuffle = shuffle + self.num_slices = num_slices + self.slice_len = slice_len + self.pad = pad + if (self.num_slices is not None) and (self.slice_len is not None): + raise ValueError("num_slices or slice_len can be not None, but not both.") if split_trajs: raise NotImplementedError if not streaming: @@ -147,7 +289,13 @@ def __init__( raise ValueError( "download and streaming cannot be set to ``True`` concomitantly." ) - storage = _StreamingStorage(dataset_id=dataset_id) + storage = _StreamingStorage( + dataset_id=dataset_id, + shuffle=self.shuffle, + num_slices=self.num_slices, + slice_len=self.slice_len, + pad=self.pad, + ) if sampler is None: sampler = _StreamingSampler() if writer is None: @@ -165,6 +313,16 @@ def __init__( transform=transform, ) + def __iter__(self): + if self._batch_size is None: + # we can still iterate over the dataset + if isinstance(self._storage, _StreamingStorage): + yield from self._storage + else: + raise NotImplementedError("TODO: A slice sampler shoudl be used here") + else: + yield from super().__iter__() + @property def data_path(self): if self.streaming: @@ -232,6 +390,8 @@ def _download_and_preproc(self): class _StreamingStorage(Storage): + SLICE_MISMATCH = "The batch_size {} must be divisible by num_slices {} or slice_len {} if provided." + def __init__( self, dataset_id: str, @@ -240,6 +400,9 @@ def __init__( base_path="data.pickle", shuffle: bool = True, truncate: bool = True, + num_slices=None, + slice_len=None, + pad=None, ): if not _has_datasets: raise ImportError( @@ -250,9 +413,27 @@ def __init__( dataset = datasets.load_dataset(repo, dataset_id, streaming=True, split=split) if shuffle: dataset = dataset.shuffle() - self.dataset = iter(dataset) + self.dataset = dataset + self.dataset_iter = iter(dataset) self.base_path = base_path self.truncate = truncate + self.num_slices = num_slices + self.slice_len = slice_len + self.pad = pad + + def __iter__(self): + episode = 0 + for data in self.dataset: + if self.base_path: + data = data[self.base_path] + data = torch.stack( + [_make_tensordict_image_conv(step) for step in data["steps"]] + ).contiguous() + _format_data(data, episode) + if self.slice_len is not None: + yield _slice_data(data, slice_len=self.slice_len, pad_value=self.pad) + else: + yield data def get(self, index: int) -> Any: if not isinstance(index, range): @@ -261,14 +442,40 @@ def get(self, index: int) -> Any: total = 0 data_list = [] episode = 0 - while total < index.stop: - data = next(self.dataset) + batch_size = index.stop + if self.num_slices is not None: + if batch_size % self.num_slices != 0: + raise ValueError( + self.SLICE_MISMATCH.format( + batch_size, self.num_slices, self.slice_len + ) + ) + num_slices = self.num_slices + slice_len = batch_size // num_slices + else: + if batch_size % self.slice_len != 0: + raise ValueError( + self.SLICE_MISMATCH.format( + batch_size, self.num_slices, self.slice_len + ) + ) + slice_len = self.slice_len + # num_slices = batch_size // slice_len + + while total < batch_size: + try: + data = next(self.dataset_iter) + except StopIteration: + self.dataset_iter = iter(self.dataset) + data = next(self.dataset_iter) + if self.base_path: data = data[self.base_path] data = torch.stack( [_make_tensordict_image_conv(step) for step in data["steps"]] ).contiguous() _format_data(data, episode) + data = _slice_data(data, slice_len=slice_len, pad_value=self.pad) data_list.append(data) total += data.numel() episode += 1 @@ -284,6 +491,40 @@ def __len__(self): ) +def _slice_data(data: TensorDict, slice_len, pad_value): + if data.shape[-1] < slice_len: + if pad_value is None: + raise RuntimeError( + f"The trajectory length ({data.shape[-1]}) is shorter than the slice length ({slice_len}). " + f"Decrease the slice length or provide a padding value." + ) + if pad_value is True: + pad_value = 0 + return pad(data, [0, slice_len - data.shape[-1]], value=pad_value) + + if data.ndim == 1: + random_range = ( + ((data.shape[-1] - slice_len) * torch.rand(())).floor().int().item() + ) + random_range = slice(random_range, random_range + slice_len) + else: + raise NotImplementedError(data) + data = data[..., random_range] + truncated = data.get(("next", "truncated")) + truncated = torch.index_fill( + truncated, + dim=data.ndim, + value=True, + index=torch.tensor(-1, device=truncated.device), + ) + done = data.get(("next", "done")) + data.set( + ("next", "truncated"), truncated + ) + data.set(("next", "done"), truncated | done) + return data + + class _StreamingSampler(Sampler): def __init__(self): ... From 33013b0c724dcb83509b79fae98f06001e588b2b Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 20 Dec 2023 16:02:25 +0000 Subject: [PATCH 04/18] fix --- .../linux_libs/scripts_openx/environment.yml | 22 ++ .../linux_libs/scripts_openx/install.sh | 51 +++ .../linux_libs/scripts_openx/post_process.sh | 6 + .../scripts_openx/run-clang-format.py | 356 ++++++++++++++++++ .../linux_libs/scripts_openx/run_test.sh | 24 ++ .../linux_libs/scripts_openx/setup_env.sh | 50 +++ .github/workflows/test-linux-openx.yml | 42 +++ test/test_collector.py | 3 +- test/test_exploration.py | 3 +- test/test_libs.py | 27 +- test/test_rb.py | 2 +- test/test_transforms.py | 4 +- torchrl/_utils.py | 10 +- torchrl/data/datasets/openx.py | 229 ++++++----- torchrl/data/replay_buffers/replay_buffers.py | 5 +- torchrl/data/replay_buffers/samplers.py | 44 ++- torchrl/data/replay_buffers/utils.py | 5 +- torchrl/envs/common.py | 3 +- torchrl/envs/transforms/transforms.py | 3 +- torchrl/envs/utils.py | 8 +- 20 files changed, 760 insertions(+), 137 deletions(-) create mode 100644 .github/unittest/linux_libs/scripts_openx/environment.yml create mode 100755 .github/unittest/linux_libs/scripts_openx/install.sh create mode 100755 .github/unittest/linux_libs/scripts_openx/post_process.sh create mode 100755 .github/unittest/linux_libs/scripts_openx/run-clang-format.py create mode 100755 .github/unittest/linux_libs/scripts_openx/run_test.sh create mode 100755 .github/unittest/linux_libs/scripts_openx/setup_env.sh create mode 100644 .github/workflows/test-linux-openx.yml diff --git a/.github/unittest/linux_libs/scripts_openx/environment.yml b/.github/unittest/linux_libs/scripts_openx/environment.yml new file mode 100644 index 00000000000..73018195e4a --- /dev/null +++ b/.github/unittest/linux_libs/scripts_openx/environment.yml @@ -0,0 +1,22 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - scipy + - hydra-core + - tqdm + - h5py + - datasets diff --git a/.github/unittest/linux_libs/scripts_openx/install.sh b/.github/unittest/linux_libs/scripts_openx/install.sh new file mode 100755 index 00000000000..2eb52b8f65e --- /dev/null +++ b/.github/unittest/linux_libs/scripts_openx/install.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. +apt-get update && apt-get install -y git wget gcc g++ +#apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_openx/post_process.sh b/.github/unittest/linux_libs/scripts_openx/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_openx/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_openx/run-clang-format.py b/.github/unittest/linux_libs/scripts_openx/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_openx/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_libs/scripts_openx/run_test.sh b/.github/unittest/linux_libs/scripts_openx/run_test.sh new file mode 100755 index 00000000000..00f9f2f4512 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_openx/run_test.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 +ln -s /usr/bin/swig3.0 /usr/bin/swig + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +conda deactivate && conda activate ./env + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestOpenX --error-for-skips --runslow +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_openx/setup_env.sh b/.github/unittest/linux_libs/scripts_openx/setup_env.sh new file mode 100755 index 00000000000..5214617c2ac --- /dev/null +++ b/.github/unittest/linux_libs/scripts_openx/setup_env.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e +set -v + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +apt-get update && apt-get install -y git wget gcc g++ unzip + +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip3 install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/workflows/test-linux-openx.yml b/.github/workflows/test-linux-openx.yml new file mode 100644 index 00000000000..362626e4667 --- /dev/null +++ b/.github/workflows/test-linux-openx.yml @@ -0,0 +1,42 @@ +name: OpenX Tests on Linux + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +concurrency: + # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. + # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} + cancel-in-progress: true + +jobs: + unittests: + strategy: + matrix: + python_version: ["3.9"] + cuda_arch_version: ["12.1"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + docker-image: "nvidia/cudagl:11.4.0-base" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="cu117" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + bash .github/unittest/linux_libs/scripts_openx/setup_env.sh + bash .github/unittest/linux_libs/scripts_openx/install.sh + bash .github/unittest/linux_libs/scripts_openx/run_test.sh + bash .github/unittest/linux_libs/scripts_openx/post_process.sh diff --git a/test/test_collector.py b/test/test_collector.py index 8667ea24790..565bfe1a7fe 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -38,7 +38,7 @@ from tensordict.tensordict import assert_allclose_td, TensorDict from torch import nn -from torchrl._utils import prod, seed_generator +from torchrl._utils import _replace_last, prod, seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( _Interruptor, @@ -60,7 +60,6 @@ from torchrl.envs.transforms import TransformedEnv, VecNorm from torchrl.envs.utils import ( _aggregate_end_of_traj, - _replace_last, check_env_specs, PARTIAL_MISSING_ERR, ) diff --git a/test/test_exploration.py b/test/test_exploration.py index 24bb8c246d0..0d916e5d5e9 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -18,6 +18,7 @@ from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from tensordict.tensordict import TensorDict from torch import nn +from torchrl._utils import _replace_last from torchrl.collectors import SyncDataCollector from torchrl.data import ( @@ -28,7 +29,7 @@ ) from torchrl.envs import SerialEnv from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv -from torchrl.envs.utils import _replace_last, set_exploration_type +from torchrl.envs.utils import set_exploration_type from torchrl.modules import SafeModule, SafeSequential from torchrl.modules.distributions import TanhNormal from torchrl.modules.distributions.continuous import ( diff --git a/test/test_libs.py b/test/test_libs.py index 40d360feb51..d6d8ace1903 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2061,7 +2061,7 @@ def test_load(self): @pytest.mark.slow class TestOpenX: @pytest.mark.parametrize("padding", [None, 0, True, False]) - @pytest.mark.parametrize("download", [False, True]) + @pytest.mark.parametrize("download", [True, False]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize( "batch_size,num_slices,slice_len", @@ -2078,10 +2078,11 @@ def test_openx(self, download, shuffle, padding, batch_size, num_slices, slice_l torch.manual_seed(0) np.random.seed(0) + streaming = not download dataset = OpenXExperienceReplay( "cmu_stretch", download=download, - streaming=not download, + streaming=streaming, batch_size=batch_size, shuffle=shuffle, num_slices=num_slices, @@ -2089,12 +2090,18 @@ def test_openx(self, download, shuffle, padding, batch_size, num_slices, slice_l pad=padding, ) # iterating - if padding is None and ((batch_size is not None and batch_size > 1000) or (slice_len is not None and slice_len > 1000)): - with pytest.raises(RuntimeError, match="The trajectory length (.*) is shorter than the slice length"): - for data in dataset: + if padding is None and ( + (batch_size is not None and batch_size > 1000) + or (slice_len is not None and slice_len > 1000) + ): + with pytest.raises( + RuntimeError, + match="The trajectory length (.*) is shorter than the slice length", + ): + for data in dataset: # noqa: B007 break else: - for data in dataset: + for data in dataset: # noqa: B007 break # check data shape if batch_size is not None: @@ -2104,7 +2111,7 @@ def test_openx(self, download, shuffle, padding, batch_size, num_slices, slice_l if batch_size is not None: if num_slices is not None: assert data.get(("next", "done")).sum(-2) == num_slices - else: + elif streaming: assert ( data.get(("next", "done")).sum(-2) == data.get("episode").unique().numel() @@ -2121,18 +2128,20 @@ def test_openx(self, download, shuffle, padding, batch_size, num_slices, slice_l if padding is None and (batch_size > 1000): with pytest.raises( RuntimeError, - match="The trajectory length (.*) is shorter than the slice length" - ): + match="The trajectory length (.*) is shorter than the slice length", + ): sample = dataset.sample() return else: sample = dataset.sample() assert sample.shape == (batch_size,) + print(sample) if slice_len is not None: assert sample.get(("next", "done")).sum() == int(batch_size // slice_len) elif num_slices is not None: assert sample.get(("next", "done")).sum() == num_slices + @pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found") @pytest.mark.parametrize( "dataset", diff --git a/test/test_rb.py b/test/test_rb.py index 8a5e191e5e4..e9d6eef0077 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1621,7 +1621,7 @@ def test_slice_sampler( num_slices = batch_size // slice_len trajs_unique_id = set() too_short = False - for _ in range(5): + for _ in range(20): index, info = sampler.sample(storage, batch_size=batch_size) if _data_prefix: samples = storage._storage["_data"][index] diff --git a/test/test_transforms.py b/test/test_transforms.py index cff1d33b34a..b3fbae3261b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -47,7 +47,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import _unravel_key_to_tuple from torch import multiprocessing as mp, nn, Tensor -from torchrl._utils import prod +from torchrl._utils import _replace_last, prod from torchrl.data import ( BoundedTensorSpec, CompositeSpec, @@ -113,7 +113,7 @@ from torchrl.envs.transforms.transforms import _has_tv, FORWARD_NOT_IMPLEMENTED from torchrl.envs.transforms.vc1 import _has_vc from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform -from torchrl.envs.utils import _replace_last, check_env_specs, step_mdp +from torchrl.envs.utils import check_env_specs, step_mdp from torchrl.modules import LSTMModule, MLP, ProbabilisticActor, TanhNormal TIMEOUT = 100.0 diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 3488e658ab8..b3d768f2d22 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -24,8 +24,9 @@ import numpy as np import torch from packaging.version import parse -from torch import multiprocessing as mp +from tensordict.utils import NestedKey +from torch import multiprocessing as mp VERBOSE = strtobool(os.environ.get("VERBOSE", "0")) _os_is_windows = sys.platform == "win32" @@ -662,3 +663,10 @@ def format_size(size): ) else: print(indent + os.path.basename(path)) + + +def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: + if isinstance(key, str): + return new_ending + else: + return key[:-1] + (new_ending,) diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 198582ad533..22788f055a7 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -5,7 +5,6 @@ from __future__ import annotations import importlib.util - import io import os import tempfile @@ -16,8 +15,7 @@ import tqdm from tensordict import make_tensordict, pad, TensorDict - -from torchrl.data import ImmutableDatasetWriter, ReplayBuffer, Storage, Writer +from torchrl.data import ImmutableDatasetWriter, Storage, TensorDictReplayBuffer, Writer from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers import Sampler from torchrl.data.replay_buffers.storages import _collate_id, TensorStorage @@ -26,93 +24,38 @@ _has_tv = importlib.util.find_spec("torchvision", None) is not None -class OpenXExperienceReplay(ReplayBuffer): - available_datasets = [ - "fractal20220817_data", - "kuka", - "bridge", - "taco_play", - "jaco_play", - "berkeley_cable_routing", - "roboturk", - "nyu_door_opening_surprising_effectiveness", - "viola", - "berkeley_autolab_ur5", - "toto", - "language_table", - "columbia_cairlab_pusht_real", - "stanford_kuka_multimodal_dataset_converted_externally_to_rlds", - "nyu_rot_dataset_converted_externally_to_rlds", - "stanford_hydra_dataset_converted_externally_to_rlds", - "austin_buds_dataset_converted_externally_to_rlds", - "nyu_franka_play_dataset_converted_externally_to_rlds", - "maniskill_dataset_converted_externally_to_rlds", - "furniture_bench_dataset_converted_externally_to_rlds", - "cmu_franka_exploration_dataset_converted_externally_to_rlds", - "ucsd_kitchen_dataset_converted_externally_to_rlds", - "ucsd_pick_and_place_dataset_converted_externally_to_rlds", - "austin_sailor_dataset_converted_externally_to_rlds", - "austin_sirius_dataset_converted_externally_to_rlds", - "bc_z", - "usc_cloth_sim_converted_externally_to_rlds", - "utokyo_pr2_opening_fridge_converted_externally_to_rlds", - "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", - "utokyo_saytap_converted_externally_to_rlds", - "utokyo_xarm_pick_and_place_converted_externally_to_rlds", - "utokyo_xarm_bimanual_converted_externally_to_rlds", - "robo_net", - "berkeley_mvp_converted_externally_to_rlds", - "berkeley_rpt_converted_externally_to_rlds", - "kaist_nonprehensile_converted_externally_to_rlds", - "stanford_mask_vit_converted_externally_to_rlds", - "tokyo_u_lsmo_converted_externally_to_rlds", - "dlr_sara_pour_converted_externally_to_rlds", - "dlr_sara_grid_clamp_converted_externally_to_rlds", - "dlr_edan_shared_control_converted_externally_to_rlds", - "asu_table_top_converted_externally_to_rlds", - "stanford_robocook_converted_externally_to_rlds", - "eth_agent_affordances", - "imperialcollege_sawyer_wrist_cam", - "iamlab_cmu_pickup_insert_converted_externally_to_rlds", - "uiuc_d3field", - "utaustin_mutex", - "berkeley_fanuc_manipulation", - "cmu_playing_with_food", - "cmu_play_fusion", - "cmu_stretch", - "berkeley_gnm_recon", - "berkeley_gnm_cory_hall", - "berkeley_gnm_sac_son", - ] - +class OpenXExperienceReplay(TensorDictReplayBuffer): """Open X-Embodiment datasets experience replay. - - The Open X-Embodiment Dataset contains 1M+ real robot trajectories - spanning 22 robot embodiments, collected through a collaboration between + + The Open X-Embodiment Dataset contains 1M+ real robot trajectories + spanning 22 robot embodiments, collected through a collaboration between 21 institutions, demonstrating 527 skills (160266 tasks). - - .. note:: - Images ... .. note:: - Text data ... - + Non-tensor data will be written in the tensordict data using the + :class:`~tensordict.tensorclass.NonTensorData` primitive. + For instance, the `language_instruction` field in the data will + be stored in `data.get_non_tensor("language_instruction")` (or equivalently + `data.get("language_instruction").data`). See the documentation of this + class for more information on how to interact with non-tensor data + stored in a :class:`~tensordict.TensorDict`. + Args: - dataset_id (str): The dataset to be downloaded. + dataset_id (str): The dataset to be downloaded. Must be part of OpenXExperienceReplay.available_datasets - batch_size (int): Batch-size used during sampling. + batch_size (int): Batch-size used during sampling. Can be overridden by `data.sample(batch_size)` if necessary. See `num_slices` and `slice_len` keyword arguments for a refined sampling strategy. If the ``batch_size`` is ``None`` (default), iterating over the dataset will deliver trajectories one at a time __whereas__ calling - :meth:`~.sample` will __still__ require a batch-size to be provided. + :meth:`~.sample` will __still__ require a batch-size to be provided. Keyword Args: shuffle (bool, optional): if ``True``, trajectories are delivered in a - random order. If ``False``, the dataset is iterated over + random order. If ``False``, the dataset is iterated over in the pre-defined order. - num_slice (int, optional): the number of slices in a batch. This + num_slice (int, optional): the number of slices in a batch. This corresponds to the number of trajectories present in a batch. Once collected, the batch is presented as a concatenation of sub-trajectories that can be recovered through `batch.reshape(num_slices, -1)`. @@ -120,12 +63,12 @@ class OpenXExperienceReplay(ReplayBuffer): This argument is exclusive with ``slice_len``. If the ``num_slices`` argument equates the ``batch_size``, each sample will belong to a different trajectory. - If neither ``slice_len`` nor ``num_slice`` are provided: + If neither ``slice_len`` nor ``num_slice`` are provided: whenever a trajectory has a length shorter than the - batch-size, a contiguous slice of it of length `batch_size` will be + batch-size, a contiguous slice of it of length `batch_size` will be sampled. If the trajectory length is insufficient, an exception will be raised unless `pad` is not `None`. - slice_len (int, optional): the length of slices in a batch. This + slice_len (int, optional): the length of slices in a batch. This corresponds to the length of trajectories present in a batch. Once collected, the batch is presented as a concatenation of sub-trajectories that can be recovered through `batch.reshape(-1, slice_len)`. @@ -133,27 +76,44 @@ class OpenXExperienceReplay(ReplayBuffer): This argument is exclusive with ``num_slice``. If the ``slice_len`` argument equates ``1``, each sample will belong to a different trajectory. - If neither ``slice_len`` nor ``num_slice`` are provided: + If neither ``slice_len`` nor ``num_slice`` are provided: whenever a trajectory has a length shorter than the - batch-size, a contiguous slice of it of length `batch_size` will be + batch-size, a contiguous slice of it of length `batch_size` will be sampled. If the trajectory length is insufficient, an exception will be raised unless `pad` is not `None`. - .. note:: - The ``slice_len`` (but not ``num_slices``) can be used when + .. note:: + The ``slice_len`` (but not ``num_slices``) can be used when iterating over a dataset without passing a batch-size in the, constructor. In these cases, a random sub-sequence of the trajectory will be chosen. + with_replacement (bool, optional): if ``False``, sampling will be done + without replacement. Defaults to ``True``. pad (bool, float or None): if ``True``, trajectories of insufficient length given the `slice_len` or `num_slices` arguments will be padded with 0s. If another value is provided, it will be used for padding. If - ``False`` or ``None`` (default) any encounter with a trajectory of - insufficient length will raise an exception. + ``False`` or ``None`` (default) any encounter with a trajectory of + insufficient length will raise an exception. root (Path or str, optional): The Minari dataset root directory. The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to ``~/.cache/torchrl/minari`. + streaming (bool, optional): if ``True``, the data won't be downloaded but + read from a stream instead. + + .. note:: The formatting of the data __will change__ when `download=True` + compared to `streaming=True`. If the data is downloaded and + the sampler is left untouched (ie, `num_slices=None`, `slice_len=None` + and `sampler=None`, transitions will be sampled randomly from + the dataset. This isn't possible at a reasonable cost with + `streaming=True`: in this case, trajectories will be sampled + one at a time and delivered as such (with cropping to comply with + the batch-size etc). The behaviour of the two modalities is + much more similar when `num_slices` and `slice_len` are specified, + as in these cases, views of sub-episodes will be returned in both + cases. + download (bool or str, optional): Whether the dataset should be downloaded if not found. Defaults to ``True``. Download can also be passed as "force", in which case the downloaded data will be overwritten. @@ -184,7 +144,7 @@ class OpenXExperienceReplay(ReplayBuffer): >>> from torchrl.data.datasets import OpenXExperienceReplay >>> # Download the data, and sample 128 elements in each batch out of two trajectories >>> num_slices = 2 - >>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=128, + >>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=128, ... num_slices=num_slices, download=True, streaming=False, root=root) >>> for batch in dataset: ... print(data.reshape(num_slices, -1)) @@ -231,7 +191,7 @@ class OpenXExperienceReplay(ReplayBuffer): device=cpu, is_shared=False) >>> # Read data from a stream. Deliver entire trajectories when iterating - >>> dataset = OpenXExperienceReplay("cmu_stretch", + >>> dataset = OpenXExperienceReplay("cmu_stretch", ... num_slices=num_slices, download=True, streaming=False, root=root) >>> for data in dataset: # data does not have a consistent shape ... break @@ -240,6 +200,64 @@ class OpenXExperienceReplay(ReplayBuffer): """ + available_datasets = [ + "fractal20220817_data", + "kuka", + "bridge", + "taco_play", + "jaco_play", + "berkeley_cable_routing", + "roboturk", + "nyu_door_opening_surprising_effectiveness", + "viola", + "berkeley_autolab_ur5", + "toto", + "language_table", + "columbia_cairlab_pusht_real", + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds", + "nyu_rot_dataset_converted_externally_to_rlds", + "stanford_hydra_dataset_converted_externally_to_rlds", + "austin_buds_dataset_converted_externally_to_rlds", + "nyu_franka_play_dataset_converted_externally_to_rlds", + "maniskill_dataset_converted_externally_to_rlds", + "furniture_bench_dataset_converted_externally_to_rlds", + "cmu_franka_exploration_dataset_converted_externally_to_rlds", + "ucsd_kitchen_dataset_converted_externally_to_rlds", + "ucsd_pick_and_place_dataset_converted_externally_to_rlds", + "austin_sailor_dataset_converted_externally_to_rlds", + "austin_sirius_dataset_converted_externally_to_rlds", + "bc_z", + "usc_cloth_sim_converted_externally_to_rlds", + "utokyo_pr2_opening_fridge_converted_externally_to_rlds", + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", + "utokyo_saytap_converted_externally_to_rlds", + "utokyo_xarm_pick_and_place_converted_externally_to_rlds", + "utokyo_xarm_bimanual_converted_externally_to_rlds", + "robo_net", + "berkeley_mvp_converted_externally_to_rlds", + "berkeley_rpt_converted_externally_to_rlds", + "kaist_nonprehensile_converted_externally_to_rlds", + "stanford_mask_vit_converted_externally_to_rlds", + "tokyo_u_lsmo_converted_externally_to_rlds", + "dlr_sara_pour_converted_externally_to_rlds", + "dlr_sara_grid_clamp_converted_externally_to_rlds", + "dlr_edan_shared_control_converted_externally_to_rlds", + "asu_table_top_converted_externally_to_rlds", + "stanford_robocook_converted_externally_to_rlds", + "eth_agent_affordances", + "imperialcollege_sawyer_wrist_cam", + "iamlab_cmu_pickup_insert_converted_externally_to_rlds", + "uiuc_d3field", + "utaustin_mutex", + "berkeley_fanuc_manipulation", + "cmu_playing_with_food", + "cmu_play_fusion", + "cmu_stretch", + "berkeley_gnm_recon", + "berkeley_gnm_cory_hall", + "berkeley_gnm_sac_son", + ] + def __init__( self, dataset_id, @@ -249,6 +267,7 @@ def __init__( num_slices: int | None = None, slice_len: int | None = None, pad: float | bool | None = None, + with_replacement: bool = True, streaming: bool = True, root: str | Path | None = None, download: bool = False, @@ -283,6 +302,29 @@ def __init__( storage = self._download_and_preproc() else: storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) + if num_slices is not None or slice_len is not None: + if sampler is not None: + raise ValueError( + "`num_slices` and `slice_len` are exclusive with the `sampler` argument." + ) + from torchrl.data.replay_buffers.samplers import ( + SliceSampler, + SliceSamplerWithoutReplacement, + ) + + if with_replacement: + sampler = SliceSampler( + num_slices=num_slices, + slice_len=slice_len, + strict_length=pad is not None, + ) + else: + sampler = SliceSamplerWithoutReplacement( + num_slices=num_slices, + slice_len=slice_len, + strict_length=pad is not None, + ) + else: self.root = None if download: @@ -362,12 +404,12 @@ def _download_and_preproc(self): if total_frames == 0: for step in data["data.pickle"]["steps"]: td = _make_tensordict_image_conv(step).zero_() + # format td: requires td to have a non-null batch_size + td = td.expand(2, *td.shape) + _format_data(td, 0) + td = td[0] total_frames += len(data["data.pickle"]["steps"]) - td_data = ( - td.expand(total_frames) - .memmap_like(self.root / self.dataset_id) - .unlock_() - ) + td_data = td.expand(total_frames).memmap_like(self.root / self.dataset_id) pbar = tqdm.tqdm(dataset, desc="preproc", total=total_frames) idx0 = 0 idx1 = 0 @@ -385,7 +427,6 @@ def _download_and_preproc(self): td_data[idx0:idx1] = current_ep idx0 = idx1 pbar.update(current_ep.shape[0]) - print("total episodes", td_data["next", "done"].sum()) return TensorStorage(td_data.lock_()) @@ -518,9 +559,7 @@ def _slice_data(data: TensorDict, slice_len, pad_value): index=torch.tensor(-1, device=truncated.device), ) done = data.get(("next", "done")) - data.set( - ("next", "truncated"), truncated - ) + data.set(("next", "truncated"), truncated) data.set(("next", "done"), truncated | done) return data diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index b225d759664..c2979755746 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -882,7 +882,10 @@ def sample( if is_locked: data.unlock_() for k, v in info.items(): - data.set(k, expand_as_right(_to_torch(v, data.device), data)) + v = _to_torch(v, data.device) + if v.shape[: data.batch_dims] != data.batch_size: + v = expand_as_right(v, data) + data.set(k, v) if is_locked: data.lock_() if return_info: diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index b43718bb61c..9b58f651590 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -30,6 +30,7 @@ except ImportError: warnings.warn(EXTENSION_WARNING) +from torchrl._utils import _replace_last from torchrl.data.replay_buffers.storages import Storage, TensorStorage from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES @@ -732,11 +733,18 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] # pick up as many trajs as we need start_idx, stop_idx, lengths = self._get_stop_and_length(storage) seq_length, num_slices = self._adjusted_batch_size(batch_size) - return self._sample_slices(lengths, start_idx, seq_length, num_slices) + return self._sample_slices(lengths, start_idx, stop_idx, seq_length, num_slices) def _sample_slices( - self, lengths, start_idx, seq_length, num_slices, traj_idx=None + self, lengths, start_idx, stop_idx, seq_length, num_slices, traj_idx=None ) -> Tuple[torch.Tensor, dict]: + if traj_idx is None: + traj_idx = torch.randint( + lengths.shape[0], (num_slices,), device=lengths.device + ) + else: + num_slices = traj_idx.shape[0] + if (lengths < seq_length).any(): if self.strict_length: raise RuntimeError( @@ -745,14 +753,8 @@ def _sample_slices( "in you batch." ) # make seq_length a tensor with values clamped by lengths - seq_length = lengths.clamp_max(seq_length) + seq_length = lengths[traj_idx].clamp_max(seq_length) - if traj_idx is None: - traj_idx = torch.randint( - lengths.shape[0], (num_slices,), device=lengths.device - ) - else: - num_slices = traj_idx.shape[0] relative_starts = ( ( torch.rand(num_slices, device=lengths.device) @@ -765,13 +767,30 @@ def _sample_slices( index = self._tensor_slices_from_startend(seq_length, starts) if self.truncated_key is not None: truncated_key = self.truncated_key + done_key = _replace_last(truncated_key, "done") + terminated_key = _replace_last(truncated_key, "terminated") - truncated = torch.zeros(index.shape, dtype=torch.bool, device=index.device) + truncated = torch.zeros( + (*index.shape, 1), dtype=torch.bool, device=index.device + ) if isinstance(seq_length, int): truncated.view(num_slices, -1)[:, -1] = 1 else: truncated[seq_length.cumsum(0) - 1] = 1 - return index.to(torch.long), {truncated_key: truncated} + traj_terminated = stop_idx[traj_idx] == start_idx[traj_idx] + seq_length - 1 + terminated = torch.zeros_like(truncated) + if terminated.any(): + if isinstance(seq_length, int): + truncated.view(num_slices, -1)[traj_terminated] = 1 + else: + truncated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1 + truncated = truncated & ~terminated + done = terminated | truncated + return index.to(torch.long), { + truncated_key: truncated, + done_key: done, + terminated_key: terminated, + } return index.to(torch.long), {} @property @@ -936,11 +955,10 @@ def _storage_len(self, storage): def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: start_idx, stop_idx, lengths = self._get_stop_and_length(storage) self._storage_len_buffer = len(start_idx) - print("self._storage_len_buffer", self._storage_len_buffer) # first get indices of the trajectories we want to retrieve seq_length, num_slices = self._adjusted_batch_size(batch_size) indices, _ = SamplerWithoutReplacement.sample(self, storage, num_slices) idx, info = self._sample_slices( - lengths, start_idx, seq_length, num_slices, traj_idx=indices + lengths, start_idx, stop_idx, seq_length, num_slices, traj_idx=indices ) return idx, info diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index e480335ec84..cc1773ee3e8 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -27,9 +27,10 @@ def _to_torch( ) -> torch.Tensor: if isinstance(data, np.generic): return torch.tensor(data, device=device) - - if isinstance(data, np.ndarray): + elif isinstance(data, np.ndarray): data = torch.from_numpy(data) + elif not isinstance(data, Tensor): + data = torch.tensor(data, device=device) if pin_memory: data = data.pin_memory() diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 77422a73fdc..eda8c859692 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -16,7 +16,7 @@ from tensordict import unravel_key from tensordict.tensordict import TensorDictBase from tensordict.utils import NestedKey -from torchrl._utils import prod, seed_generator +from torchrl._utils import _replace_last, prod, seed_generator from torchrl.data.tensor_specs import ( CompositeSpec, @@ -26,7 +26,6 @@ ) from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.utils import ( - _replace_last, _repr_by_depth, _terminated_or_truncated, _update_during_reset, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index de8baf2e403..11bae620330 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -23,6 +23,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import expand_as_right, NestedKey from torch import nn, Tensor +from torchrl._utils import _replace_last from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, @@ -43,7 +44,7 @@ _set_missing_tolerance, check_finite, ) -from torchrl.envs.utils import _replace_last, _sort_keys, _update_during_reset, step_mdp +from torchrl.envs.utils import _sort_keys, _update_during_reset, step_mdp from torchrl.objectives.value.functional import reward2go try: diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 9a2a71f24bd..49711bfb19a 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -26,6 +26,7 @@ set_interaction_type as set_exploration_type, ) from tensordict.tensordict import LazyStackedTensorDict, NestedKey +from torchrl._utils import _replace_last __all__ = [ "exploration_mode", @@ -612,13 +613,6 @@ def clear_mpi_env_vars(): os.environ.update(removed_environment) -def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: - if isinstance(key, str): - return new_ending - else: - return key[:-1] + (new_ending,) - - class MarlGroupMapType(Enum): """Marl Group Map Type. From 9c4630e10620a9bade59fba7b8a594624c7a0775 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 20 Dec 2023 16:12:08 +0000 Subject: [PATCH 05/18] lint --- torchrl/data/datasets/openx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 22788f055a7..ef23e1c8ef9 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -617,7 +617,7 @@ def _make_tensordict_image_conv(data): img_bytes = data["observation"]["image"]["bytes"] if not _has_tv: raise ImportError( - f"the `torchvision` library is required to read the image observation." + "the `torchvision` library is required to read the image observation." ) import torchvision.transforms.v2.functional from PIL import Image From 861dc9f1a72ed70360627730a71bce0399c1ea91 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 20 Dec 2023 16:33:42 +0000 Subject: [PATCH 06/18] amend --- torchrl/data/datasets/openx.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index ef23e1c8ef9..ea9fdd81d58 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -12,7 +12,6 @@ from typing import Any, Callable, Tuple import torch -import tqdm from tensordict import make_tensordict, pad, TensorDict from torchrl.data import ImmutableDatasetWriter, Storage, TensorDictReplayBuffer, Writer @@ -399,7 +398,16 @@ def _download_and_preproc(self): ) # iterate over the dataset a first time to count elements total_frames = 0 - pbar = tqdm.tqdm(dataset, desc="counting") + + try: + import tqdm + + _has_tqdm = True + pbar = tqdm.tqdm(dataset, desc="counting") + except ImportError: + _has_tqdm = False + pbar = dataset + for data in pbar: if total_frames == 0: for step in data["data.pickle"]["steps"]: @@ -410,7 +418,10 @@ def _download_and_preproc(self): td = td[0] total_frames += len(data["data.pickle"]["steps"]) td_data = td.expand(total_frames).memmap_like(self.root / self.dataset_id) - pbar = tqdm.tqdm(dataset, desc="preproc", total=total_frames) + if _has_tqdm: + pbar = tqdm.tqdm(dataset, desc="preproc", total=total_frames) + else: + pbar = dataset idx0 = 0 idx1 = 0 episode = 0 @@ -426,7 +437,8 @@ def _download_and_preproc(self): idx1 += len(current_ep) td_data[idx0:idx1] = current_ep idx0 = idx1 - pbar.update(current_ep.shape[0]) + if _has_tqdm: + pbar.update(current_ep.shape[0]) return TensorStorage(td_data.lock_()) From 5c20a160bb66f31895dced19b255156c4135e2b9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 20 Dec 2023 17:02:34 +0000 Subject: [PATCH 07/18] PIL --- .github/unittest/linux_libs/scripts_openx/environment.yml | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_openx/environment.yml b/.github/unittest/linux_libs/scripts_openx/environment.yml index 73018195e4a..b186421c506 100644 --- a/.github/unittest/linux_libs/scripts_openx/environment.yml +++ b/.github/unittest/linux_libs/scripts_openx/environment.yml @@ -20,3 +20,4 @@ dependencies: - tqdm - h5py - datasets + - PIL diff --git a/setup.py b/setup.py index 36d84aa09e3..f0844d29833 100644 --- a/setup.py +++ b/setup.py @@ -218,7 +218,7 @@ def _main(argv): "tqdm", "scikit-learn", "pandas", - "h5py", + "h5py", "PIL", ], "marl": ["vmas>=1.2.10", "pettingzoo>=1.24.1"], } From cdad10c7626148d267be3e835a3139a323ebdb96 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 20 Dec 2023 17:20:43 +0000 Subject: [PATCH 08/18] amend --- .github/unittest/linux_libs/scripts_openx/environment.yml | 2 +- setup.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_openx/environment.yml b/.github/unittest/linux_libs/scripts_openx/environment.yml index b186421c506..30051ac5748 100644 --- a/.github/unittest/linux_libs/scripts_openx/environment.yml +++ b/.github/unittest/linux_libs/scripts_openx/environment.yml @@ -20,4 +20,4 @@ dependencies: - tqdm - h5py - datasets - - PIL + - pil diff --git a/setup.py b/setup.py index f0844d29833..f89d2bfbfa2 100644 --- a/setup.py +++ b/setup.py @@ -218,7 +218,8 @@ def _main(argv): "tqdm", "scikit-learn", "pandas", - "h5py", "PIL", + "h5py", + "pil", ], "marl": ["vmas>=1.2.10", "pettingzoo>=1.24.1"], } From 454078fc2dfb46777ec2823ddbd70a40e8008329 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 20 Dec 2023 17:55:26 +0000 Subject: [PATCH 09/18] amend --- .github/unittest/linux_libs/scripts_openx/environment.yml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_openx/environment.yml b/.github/unittest/linux_libs/scripts_openx/environment.yml index 30051ac5748..175c245a8ed 100644 --- a/.github/unittest/linux_libs/scripts_openx/environment.yml +++ b/.github/unittest/linux_libs/scripts_openx/environment.yml @@ -20,4 +20,4 @@ dependencies: - tqdm - h5py - datasets - - pil + - pillow diff --git a/setup.py b/setup.py index f89d2bfbfa2..3658809c7e2 100644 --- a/setup.py +++ b/setup.py @@ -219,7 +219,7 @@ def _main(argv): "scikit-learn", "pandas", "h5py", - "pil", + "pillow", ], "marl": ["vmas>=1.2.10", "pettingzoo>=1.24.1"], } From 56b01de1d12dfd76e3335d7c37fa4458c7a140e8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 21 Dec 2023 11:25:49 +0000 Subject: [PATCH 10/18] amend --- .github/unittest/linux_libs/scripts_openx/install.sh | 4 ++-- setup.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_openx/install.sh b/.github/unittest/linux_libs/scripts_openx/install.sh index 2eb52b8f65e..1be476425a6 100755 --- a/.github/unittest/linux_libs/scripts_openx/install.sh +++ b/.github/unittest/linux_libs/scripts_openx/install.sh @@ -33,9 +33,9 @@ if [ "${CU_VERSION:-}" == cpu ] ; then # conda install -y pytorch torchvision cpuonly -c pytorch-nightly # use pip to install pytorch as conda can frequently pick older release # conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 fi # install tensordict diff --git a/setup.py b/setup.py index 3658809c7e2..71654270889 100644 --- a/setup.py +++ b/setup.py @@ -216,6 +216,7 @@ def _main(argv): "huggingface_hub", # for roboset "minari", "tqdm", + "torchvision", "scikit-learn", "pandas", "h5py", From 1db9cdadb108be5faf5043d9d2646de8f2b04450 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 25 Dec 2023 16:22:00 +0100 Subject: [PATCH 11/18] amend --- torchrl/data/datasets/__init__.py | 1 + torchrl/data/datasets/openx.py | 19 ++++++++------- torchrl/data/replay_buffers/samplers.py | 32 +++++++++++++++++++++---- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/torchrl/data/datasets/__init__.py b/torchrl/data/datasets/__init__.py index c1429b300fa..3e857c4de6a 100644 --- a/torchrl/data/datasets/__init__.py +++ b/torchrl/data/datasets/__init__.py @@ -3,3 +3,4 @@ from .openml import OpenMLExperienceReplay from .roboset import RobosetExperienceReplay from .vd4rl import VD4RLExperienceReplay +from .openx import OpenXExperienceReplay diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index ea9fdd81d58..2578be52cb4 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -14,9 +14,11 @@ import torch from tensordict import make_tensordict, pad, TensorDict -from torchrl.data import ImmutableDatasetWriter, Storage, TensorDictReplayBuffer, Writer +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer +from torchrl.data.replay_buffers.storages import Storage +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.datasets.utils import _get_root_dir -from torchrl.data.replay_buffers import Sampler +from torchrl.data.replay_buffers.samplers import Sampler, SliceSamplerWithoutReplacement, SliceSampler from torchrl.data.replay_buffers.storages import _collate_id, TensorStorage _has_datasets = importlib.util.find_spec("datasets", None) is not None @@ -146,7 +148,7 @@ class for more information on how to interact with non-tensor data >>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=128, ... num_slices=num_slices, download=True, streaming=False, root=root) >>> for batch in dataset: - ... print(data.reshape(num_slices, -1)) + ... print(batch.reshape(num_slices, -1)) ... break TensorDict( fields={ @@ -257,6 +259,9 @@ class for more information on how to interact with non-tensor data "berkeley_gnm_sac_son", ] + # some very high number that should be above all trajecory lengths in the dataset + _MAX_TRAJ_LEN = 1_000_000 + def __init__( self, dataset_id, @@ -306,10 +311,6 @@ def __init__( raise ValueError( "`num_slices` and `slice_len` are exclusive with the `sampler` argument." ) - from torchrl.data.replay_buffers.samplers import ( - SliceSampler, - SliceSamplerWithoutReplacement, - ) if with_replacement: sampler = SliceSampler( @@ -360,7 +361,9 @@ def __iter__(self): if isinstance(self._storage, _StreamingStorage): yield from self._storage else: - raise NotImplementedError("TODO: A slice sampler shoudl be used here") + sampler = SliceSamplerWithoutReplacement( + num_slices=self.num_slices, strict_length=False, shuffle=self.shuffle) + yield from TensorDictReplayBuffer(storage=self._storage, sampler=sampler, batch_size=self._MAX_TRAJ_LEN) else: yield from super().__iter__() diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 410bf1e6224..32fedd96402 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -121,6 +121,9 @@ class SamplerWithoutReplacement(Sampler): If ``False``, this last sample will be kept and (unlike with torch dataloaders) completed with other samples from a fresh indices permutation. Defaults to ``False``. + shuffle (bool, optional): if ``False``, the items are not randomly + permuted. This enables to iterate over the replay buffer in the + order the data was collected. Defaults to ``True``. *Caution*: If the size of the storage changes in between two calls, the samples will be re-shuffled (as we can't generally keep track of which samples have been sampled before and which haven't). @@ -134,11 +137,12 @@ class SamplerWithoutReplacement(Sampler): """ - def __init__(self, drop_last: bool = False): + def __init__(self, drop_last: bool = False, shuffle: bool = True): self._sample_list = None self.len_storage = 0 self.drop_last = drop_last self._ran_out = False + self.shuffle = shuffle def dumps(self, path): path = Path(path) @@ -163,6 +167,22 @@ def loads(self, path): self.drop_last = metadata["drop_last"] self._ran_out = metadata["_ran_out"] + def _get_sample_list(self, storage: Storage, len_storage: int): + if storage is None: + device = self._sample_list.device + else: + device = storage.device if hasattr(storage, "device") else None + if self.shuffle: + self._sample_list = torch.randperm( + len_storage, + device=device + ) + else: + self._sample_list = torch.arange( + len_storage, + device=device + ) + def _single_sample(self, len_storage, batch_size): index = self._sample_list[:batch_size] self._sample_list = self._sample_list[batch_size:] @@ -173,7 +193,7 @@ def _single_sample(self, len_storage, batch_size): self.drop_last and len(self._sample_list) < batch_size ): self.ran_out = True - self._sample_list = torch.randperm(len_storage) + self._get_sample_list(storage=None, len_storage=len_storage) else: self.ran_out = False return index @@ -188,7 +208,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: if not len_storage: raise RuntimeError("An empty storage was passed") if self.len_storage != len_storage or self._sample_list is None: - self._sample_list = torch.randperm(len_storage) + self._get_sample_list(storage, len_storage) if len_storage < batch_size and self.drop_last: raise ValueError( f"The batch size ({batch_size}) is greater than the storage capacity ({len_storage}). " @@ -861,6 +881,8 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): Be mindful that this can result in effective `batch_size` shorter than the one asked for! Trajectories can be split using :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``. + shuffle (bool, optional): if ``False``, the order of the trajectories + is not shuffled. Defaults to ``True``. .. note:: To recover the trajectory splits in the storage, :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` will first @@ -922,6 +944,7 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): """ + def __init__( self, *, @@ -932,6 +955,7 @@ def __init__( traj_key: NestedKey | None = None, truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, + shuffle: bool = True, ) -> object: SliceSampler.__init__( self, @@ -943,7 +967,7 @@ def __init__( truncated_key=truncated_key, strict_length=strict_length, ) - SamplerWithoutReplacement.__init__(self, drop_last=drop_last) + SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle) def _empty(self): self._cache = {} From c42eb98c40ea20dfc32e1261738c8cae6bf60460 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 5 Jan 2024 14:26:56 +0000 Subject: [PATCH 12/18] amend --- test/test_libs.py | 17 ++++++---- torchrl/data/datasets/__init__.py | 2 +- torchrl/data/datasets/minari_data.py | 4 ++- torchrl/data/datasets/openx.py | 43 +++++++++++++++++++------ torchrl/data/replay_buffers/samplers.py | 11 ++----- 5 files changed, 51 insertions(+), 26 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 8507ea9416c..d2b460a944a 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2094,15 +2094,17 @@ def test_load(self, image_size): @pytest.mark.slow class TestOpenX: - @pytest.mark.parametrize("padding", [None, 0, True, False]) - @pytest.mark.parametrize("download", [True, False]) + @pytest.mark.parametrize( + "download,padding", + [[True, None], [False, None], [False, 0], [False, True], [False, False]], + ) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize( "batch_size,num_slices,slice_len", [ + [3000, 2, None], [32, 32, None], [32, None, 1], - [3000, 2, None], [3000, None, 1500], [None, None, 32], [None, None, 1500], @@ -2130,7 +2132,9 @@ def test_openx(self, download, shuffle, padding, batch_size, num_slices, slice_l ): with pytest.raises( RuntimeError, - match="The trajectory length (.*) is shorter than the slice length", + match="Some stored trajectories have a length shorter than the slice that was asked for" + if not streaming + else "The trajectory length (.*) is shorter than the slice length", ): for data in dataset: # noqa: B007 break @@ -2162,14 +2166,15 @@ def test_openx(self, download, shuffle, padding, batch_size, num_slices, slice_l if padding is None and (batch_size > 1000): with pytest.raises( RuntimeError, - match="The trajectory length (.*) is shorter than the slice length", + match="Some stored trajectories have a length shorter than the slice that was asked for" + if not streaming + else "The trajectory length (.*) is shorter than the slice length", ): sample = dataset.sample() return else: sample = dataset.sample() assert sample.shape == (batch_size,) - print(sample) if slice_len is not None: assert sample.get(("next", "done")).sum() == int(batch_size // slice_len) elif num_slices is not None: diff --git a/torchrl/data/datasets/__init__.py b/torchrl/data/datasets/__init__.py index 3e857c4de6a..1cef4f3ffea 100644 --- a/torchrl/data/datasets/__init__.py +++ b/torchrl/data/datasets/__init__.py @@ -1,6 +1,6 @@ from .d4rl import D4RLExperienceReplay from .minari_data import MinariExperienceReplay from .openml import OpenMLExperienceReplay +from .openx import OpenXExperienceReplay from .roboset import RobosetExperienceReplay from .vd4rl import VD4RLExperienceReplay -from .openx import OpenXExperienceReplay diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index e9fea02ba06..754d5da9865 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -31,6 +31,7 @@ DiscreteTensorSpec, UnboundedContinuousTensorSpec, ) +from torchrl.envs.utils import _classproperty _has_tqdm = importlib.util.find_spec("tqdm", None) is not None @@ -203,6 +204,7 @@ def __init__( transform=transform, ) + @_classproperty def available_datasets(self): import minari @@ -281,7 +283,7 @@ def _download_and_preproc(self): td_data = td_data.memmap_like(self.data_path_root) print("tensordict structure:", td_data) - print(f"Reading data from {max(*episode_dict)} episodes") + print(f"Reading data from {max(*episode_dict) + 1} episodes") index = 0 with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar: # iterate over episodes and populate the tensordict diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 2578be52cb4..1723b2b3f4c 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -14,12 +14,17 @@ import torch from tensordict import make_tensordict, pad, TensorDict -from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer -from torchrl.data.replay_buffers.storages import Storage -from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.collectors.utils import split_trajectories from torchrl.data.datasets.utils import _get_root_dir -from torchrl.data.replay_buffers.samplers import Sampler, SliceSamplerWithoutReplacement, SliceSampler -from torchrl.data.replay_buffers.storages import _collate_id, TensorStorage +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import ( + Sampler, + SliceSampler, + SliceSamplerWithoutReplacement, +) +from torchrl.data.replay_buffers.storages import _collate_id, Storage, TensorStorage +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer +from torchrl.envs.transforms.transforms import Transform _has_datasets = importlib.util.find_spec("datasets", None) is not None _has_tv = importlib.util.find_spec("torchvision", None) is not None @@ -140,6 +145,12 @@ class for more information on how to interact with non-tensor data ``D4RL``, this may not be true. It is up to the user to make accurate choices regarding this usage of ``split_trajs``. Defaults to ``False``. + strict_length (bool, optional): if ``False``, trajectories of length + shorter than `slice_len` (or `batch_size // num_slices`) will be + allowed to appear in the batch. + Be mindful that this can result in effective `batch_size` shorter + than the one asked for! Trajectories can be split using + :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``. Examples: >>> from torchrl.data.datasets import OpenXExperienceReplay @@ -282,6 +293,7 @@ def __init__( prefetch: int | None = None, transform: "torchrl.envs.Transform" | None = None, # noqa-F821 split_trajs: bool = False, + strict_length: bool = True, ): self.download = download self.streaming = streaming @@ -296,6 +308,10 @@ def __init__( if split_trajs: raise NotImplementedError if not streaming: + if pad is not None: + raise RuntimeError( + "the `pad` argument is to be used only with streaming datasets." + ) if root is None: root = _get_root_dir("openx") os.makedirs(root, exist_ok=True) @@ -316,13 +332,13 @@ def __init__( sampler = SliceSampler( num_slices=num_slices, slice_len=slice_len, - strict_length=pad is not None, + strict_length=strict_length, ) else: sampler = SliceSamplerWithoutReplacement( num_slices=num_slices, slice_len=slice_len, - strict_length=pad is not None, + strict_length=strict_length, ) else: @@ -362,8 +378,17 @@ def __iter__(self): yield from self._storage else: sampler = SliceSamplerWithoutReplacement( - num_slices=self.num_slices, strict_length=False, shuffle=self.shuffle) - yield from TensorDictReplayBuffer(storage=self._storage, sampler=sampler, batch_size=self._MAX_TRAJ_LEN) + num_slices=self.num_slices, + slice_len=self.slice_len, + strict_length=False, + shuffle=self.shuffle, + ) + yield from TensorDictReplayBuffer( + storage=self._storage, + sampler=sampler, + batch_size=self._MAX_TRAJ_LEN, + transform=self._transform, + ) else: yield from super().__iter__() diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 32fedd96402..e55aa1e1a5c 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -173,15 +173,9 @@ def _get_sample_list(self, storage: Storage, len_storage: int): else: device = storage.device if hasattr(storage, "device") else None if self.shuffle: - self._sample_list = torch.randperm( - len_storage, - device=device - ) + self._sample_list = torch.randperm(len_storage, device=device) else: - self._sample_list = torch.arange( - len_storage, - device=device - ) + self._sample_list = torch.arange(len_storage, device=device) def _single_sample(self, len_storage, batch_size): index = self._sample_list[:batch_size] @@ -944,7 +938,6 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): """ - def __init__( self, *, From 5ecaf1d1344351eab6bbec69fffc33460a7d4f55 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 5 Jan 2024 17:49:53 +0000 Subject: [PATCH 13/18] install curl --- .github/unittest/linux_libs/scripts_openx/setup_env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_openx/setup_env.sh b/.github/unittest/linux_libs/scripts_openx/setup_env.sh index 5214617c2ac..5b415112814 100755 --- a/.github/unittest/linux_libs/scripts_openx/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_openx/setup_env.sh @@ -10,7 +10,7 @@ set -v this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" # Avoid error: "fatal: unsafe repository" -apt-get update && apt-get install -y git wget gcc g++ unzip +apt-get update && apt-get install -y git wget gcc g++ unzip curl git config --global --add safe.directory '*' root_dir="$(git rev-parse --show-toplevel)" From a874d6e7d1c0c45d20dc76bc080779a0f9cb1e67 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 6 Jan 2024 06:29:05 +0000 Subject: [PATCH 14/18] fixes --- test/test_libs.py | 16 ++++++++++------ torchrl/data/datasets/openx.py | 29 ++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index d2b460a944a..3b0f308d43c 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2130,14 +2130,18 @@ def test_openx(self, download, shuffle, padding, batch_size, num_slices, slice_l (batch_size is not None and batch_size > 1000) or (slice_len is not None and slice_len > 1000) ): - with pytest.raises( + raises_cm = pytest.raises( RuntimeError, - match="Some stored trajectories have a length shorter than the slice that was asked for" - if not streaming - else "The trajectory length (.*) is shorter than the slice length", - ): + match="The trajectory length (.*) is shorter than the slice length|Some stored trajectories have a length shorter than the slice that was asked for" + ) + with raises_cm: for data in dataset: # noqa: B007 break + if batch_size is None and slice_len is not None: + with raises_cm: + dataset.sample(2 * slice_len) + return + else: for data in dataset: # noqa: B007 break @@ -2176,7 +2180,7 @@ def test_openx(self, download, shuffle, padding, batch_size, num_slices, slice_l sample = dataset.sample() assert sample.shape == (batch_size,) if slice_len is not None: - assert sample.get(("next", "done")).sum() == int(batch_size // slice_len) + assert sample.get(("next", "done")).sum() == int(batch_size // slice_len), sample.get(("next", "done")) elif num_slices is not None: assert sample.get(("next", "done")).sum() == num_slices diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 1723b2b3f4c..c1993785377 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -303,6 +303,7 @@ def __init__( self.num_slices = num_slices self.slice_len = slice_len self.pad = pad + self.strict_length = strict_length if (self.num_slices is not None) and (self.slice_len is not None): raise ValueError("num_slices or slice_len can be not None, but not both.") if split_trajs: @@ -376,17 +377,35 @@ def __iter__(self): # we can still iterate over the dataset if isinstance(self._storage, _StreamingStorage): yield from self._storage + elif self.slice_len is not None and self.num_slices is None: + try: + # truncate the trajs with slice_len + self._batch_size = self.slice_len + self.num_slices = 1 + self.slice_len = None + yield from self + finally: + self.slice_len = self._batch_size + self._batch_size = None + self.num_slices = None else: + # if we don't have a batch size but we know how many trajectories + # we want in each batch, we can build that on the fly. + # The only time we can do this is if num_slices is given but not + # slice_len. + num_slices = self.num_slices + if not num_slices: + num_slices = 1 sampler = SliceSamplerWithoutReplacement( - num_slices=self.num_slices, - slice_len=self.slice_len, + num_slices=num_slices, strict_length=False, shuffle=self.shuffle, ) + batch_size = self._MAX_TRAJ_LEN yield from TensorDictReplayBuffer( storage=self._storage, sampler=sampler, - batch_size=self._MAX_TRAJ_LEN, + batch_size=batch_size, transform=self._transform, ) else: @@ -594,7 +613,7 @@ def _slice_data(data: TensorDict, slice_len, pad_value): truncated = data.get(("next", "truncated")) truncated = torch.index_fill( truncated, - dim=data.ndim, + dim=data.ndim-1, value=True, index=torch.tensor(-1, device=truncated.device), ) @@ -637,7 +656,7 @@ def _format_data(data: TensorDict, episode: int): data.rename_key_(key, newkey) data.set( ("next", "truncated"), - data.get(("next", "done")) ^ data.get(("next", "terminated")), + data.get(("next", "done")) & ~data.get(("next", "terminated")), ) for key in ("done", "terminated", "truncated", "reward"): From 97478e4bbc7b9e85d6f3f4bbcb69394f38036735 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 6 Jan 2024 07:02:34 +0000 Subject: [PATCH 15/18] amend shuffle --- test/test_libs.py | 46 ++++++++++++++++++++++++---------- torchrl/data/datasets/openx.py | 33 ++++++++++++++++++------ 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 3b0f308d43c..14c1fcde0d6 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2099,6 +2099,7 @@ class TestOpenX: [[True, None], [False, None], [False, 0], [False, True], [False, False]], ) @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("replacement", [True, False]) @pytest.mark.parametrize( "batch_size,num_slices,slice_len", [ @@ -2110,21 +2111,38 @@ class TestOpenX: [None, None, 1500], ], ) - def test_openx(self, download, shuffle, padding, batch_size, num_slices, slice_len): + def test_openx( + self, download, shuffle, replacement, padding, batch_size, num_slices, slice_len + ): torch.manual_seed(0) np.random.seed(0) streaming = not download - dataset = OpenXExperienceReplay( - "cmu_stretch", - download=download, - streaming=streaming, - batch_size=batch_size, - shuffle=shuffle, - num_slices=num_slices, - slice_len=slice_len, - pad=padding, - ) + cm = ( + pytest.raises(RuntimeError, match="shuffle=False") + if not streaming and not shuffle and replacement + else pytest.raises( + RuntimeError, + match="replacement=True is not available with streamed datasets", + ) + if streaming and replacement + else nullcontext() + ) + dataset = None + with cm: + dataset = OpenXExperienceReplay( + "cmu_stretch", + download=download, + streaming=streaming, + batch_size=batch_size, + shuffle=shuffle, + num_slices=num_slices, + slice_len=slice_len, + pad=padding, + replacement=replacement, + ) + if dataset is None: + return # iterating if padding is None and ( (batch_size is not None and batch_size > 1000) @@ -2132,7 +2150,7 @@ def test_openx(self, download, shuffle, padding, batch_size, num_slices, slice_l ): raises_cm = pytest.raises( RuntimeError, - match="The trajectory length (.*) is shorter than the slice length|Some stored trajectories have a length shorter than the slice that was asked for" + match="The trajectory length (.*) is shorter than the slice length|Some stored trajectories have a length shorter than the slice that was asked for", ) with raises_cm: for data in dataset: # noqa: B007 @@ -2180,7 +2198,9 @@ def test_openx(self, download, shuffle, padding, batch_size, num_slices, slice_l sample = dataset.sample() assert sample.shape == (batch_size,) if slice_len is not None: - assert sample.get(("next", "done")).sum() == int(batch_size // slice_len), sample.get(("next", "done")) + assert sample.get(("next", "done")).sum() == int( + batch_size // slice_len + ), sample.get(("next", "done")) elif num_slices is not None: assert sample.get(("next", "done")).sum() == num_slices diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index c1993785377..7a45720109f 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -59,8 +59,15 @@ class for more information on how to interact with non-tensor data Keyword Args: shuffle (bool, optional): if ``True``, trajectories are delivered in a - random order. If ``False``, the dataset is iterated over - in the pre-defined order. + random order when the dataset is iterated over. + If ``False``, the dataset is iterated over in the pre-defined order. + + .. warning:: + shuffle=False will also impact the sampling. We advice users to + create a copy of the dataset where the ``shuffle`` attribute of the + sampler is set to ``False`` if they wish to enjoy the two different + behaviours (shuffled and not) within the same code base. + num_slice (int, optional): the number of slices in a batch. This corresponds to the number of trajectories present in a batch. Once collected, the batch is presented as a concatenation of @@ -94,8 +101,9 @@ class for more information on how to interact with non-tensor data constructor. In these cases, a random sub-sequence of the trajectory will be chosen. - with_replacement (bool, optional): if ``False``, sampling will be done - without replacement. Defaults to ``True``. + replacement (bool, optional): if ``False``, sampling will be done + without replacement. Defaults to ``True`` for downloaded datasets, + ``False`` for streamed datasets. pad (bool, float or None): if ``True``, trajectories of insufficient length given the `slice_len` or `num_slices` arguments will be padded with 0s. If another value is provided, it will be used for padding. If @@ -282,7 +290,7 @@ def __init__( num_slices: int | None = None, slice_len: int | None = None, pad: float | bool | None = None, - with_replacement: bool = True, + replacement: bool = None, streaming: bool = True, root: str | Path | None = None, download: bool = False, @@ -309,6 +317,7 @@ def __init__( if split_trajs: raise NotImplementedError if not streaming: + replacement = True if replacement is None else False if pad is not None: raise RuntimeError( "the `pad` argument is to be used only with streaming datasets." @@ -329,7 +338,11 @@ def __init__( "`num_slices` and `slice_len` are exclusive with the `sampler` argument." ) - if with_replacement: + if replacement: + if not self.shuffle: + raise RuntimeError( + "shuffle=False can only be used when replacement=False." + ) sampler = SliceSampler( num_slices=num_slices, slice_len=slice_len, @@ -340,9 +353,15 @@ def __init__( num_slices=num_slices, slice_len=slice_len, strict_length=strict_length, + shuffle=self.shuffle, ) else: + if replacement is True: + # replacement can be False or None + raise RuntimeError( + "replacement=True is not available with streamed datasets." + ) self.root = None if download: raise ValueError( @@ -613,7 +632,7 @@ def _slice_data(data: TensorDict, slice_len, pad_value): truncated = data.get(("next", "truncated")) truncated = torch.index_fill( truncated, - dim=data.ndim-1, + dim=data.ndim - 1, value=True, index=torch.tensor(-1, device=truncated.device), ) From e50f1b49ddf7943f2977cae27a2af81f516ff3e7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 6 Jan 2024 07:05:39 +0000 Subject: [PATCH 16/18] amend replacement --- torchrl/data/datasets/openx.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 7a45720109f..5d5f1bf4474 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -317,7 +317,8 @@ def __init__( if split_trajs: raise NotImplementedError if not streaming: - replacement = True if replacement is None else False + if replacement is None: + replacement = True if pad is not None: raise RuntimeError( "the `pad` argument is to be used only with streaming datasets." From a2e3fa7f389b3b69b19f28f94c38c2ea178565e4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 6 Jan 2024 13:03:27 +0000 Subject: [PATCH 17/18] amend --- torchrl/data/datasets/openx.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 5d5f1bf4474..cb7bb71ce42 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -162,31 +162,37 @@ class for more information on how to interact with non-tensor data Examples: >>> from torchrl.data.datasets import OpenXExperienceReplay + >>> import tempfile >>> # Download the data, and sample 128 elements in each batch out of two trajectories >>> num_slices = 2 - >>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=128, - ... num_slices=num_slices, download=True, streaming=False, root=root) - >>> for batch in dataset: - ... print(batch.reshape(num_slices, -1)) - ... break + >>> with tempfile.TemporaryDirectory() as root: + ... dataset = OpenXExperienceReplay("cmu_stretch", batch_size=128, + ... num_slices=num_slices, download=True, streaming=False, + ... root=root, + ... ) + ... for batch in dataset: + ... print(batch.reshape(num_slices, -1)) + ... break TensorDict( fields={ action: Tensor(shape=torch.Size([2, 64, 8]), device=cpu, dtype=torch.float64, is_shared=False), discount: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), episode: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int32, is_shared=False), - is_first: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.bool, is_shared=False), + index: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False), is_init: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.bool, is_shared=False), - is_last: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.bool, is_shared=False), - is_terminal: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.bool, is_shared=False), language_embedding: Tensor(shape=torch.Size([2, 64, 512]), device=cpu, dtype=torch.float64, is_shared=False), - language_instruction: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False), + language_instruction: NonTensorData( + data='lift open green garbage can lid', + batch_size=torch.Size([2, 64]), + device=cpu, + is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: TensorDict( fields={ - image: Tensor(shape=torch.Size([2, 64, 3, 2, 64, 2, 64]), device=cpu, dtype=torch.uint8, is_shared=False), + image: Tensor(shape=torch.Size([2, 64, 3, 128, 128]), device=cpu, dtype=torch.uint8, is_shared=False), state: Tensor(shape=torch.Size([2, 64, 4]), device=cpu, dtype=torch.float64, is_shared=False)}, batch_size=torch.Size([2, 64]), device=cpu, @@ -199,12 +205,11 @@ class for more information on how to interact with non-tensor data is_shared=False), observation: TensorDict( fields={ - image: Tensor(shape=torch.Size([2, 64, 3, 2, 64, 2, 64]), device=cpu, dtype=torch.uint8, is_shared=False), + image: Tensor(shape=torch.Size([2, 64, 3, 128, 128]), device=cpu, dtype=torch.uint8, is_shared=False), state: Tensor(shape=torch.Size([2, 64, 4]), device=cpu, dtype=torch.float64, is_shared=False)}, batch_size=torch.Size([2, 64]), device=cpu, is_shared=False), - reward: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2, 64]), @@ -212,7 +217,7 @@ class for more information on how to interact with non-tensor data is_shared=False) >>> # Read data from a stream. Deliver entire trajectories when iterating >>> dataset = OpenXExperienceReplay("cmu_stretch", - ... num_slices=num_slices, download=True, streaming=False, root=root) + ... num_slices=num_slices, download=False, streaming=True) >>> for data in dataset: # data does not have a consistent shape ... break >>> # Define batch-size dynamically From e39638b2ae98a2b07452939c35d91ba8773058eb Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 6 Jan 2024 13:08:45 +0000 Subject: [PATCH 18/18] doc --- docs/source/reference/data.rst | 1 + torchrl/data/datasets/openx.py | 10 ++++------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 1df0330dc21..586c9531334 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -280,6 +280,7 @@ Here's an example: D4RLExperienceReplay MinariExperienceReplay OpenMLExperienceReplay + OpenXExperienceReplay RobosetExperienceReplay VD4RLExperienceReplay diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index cb7bb71ce42..aa78a92ff16 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -14,7 +14,6 @@ import torch from tensordict import make_tensordict, pad, TensorDict -from torchrl.collectors.utils import split_trajectories from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import ( @@ -24,7 +23,6 @@ ) from torchrl.data.replay_buffers.storages import _collate_id, Storage, TensorStorage from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer -from torchrl.envs.transforms.transforms import Transform _has_datasets = importlib.util.find_spec("datasets", None) is not None _has_tv = importlib.util.find_spec("torchvision", None) is not None @@ -48,14 +46,14 @@ class for more information on how to interact with non-tensor data Args: dataset_id (str): The dataset to be downloaded. - Must be part of OpenXExperienceReplay.available_datasets + Must be part of ``OpenXExperienceReplay.available_datasets``. batch_size (int): Batch-size used during sampling. Can be overridden by `data.sample(batch_size)` if necessary. - See `num_slices` and `slice_len` keyword arguments for a refined + See ``num_slices`` and ``slice_len`` keyword arguments for a refined sampling strategy. If the ``batch_size`` is ``None`` (default), iterating over the - dataset will deliver trajectories one at a time __whereas__ calling - :meth:`~.sample` will __still__ require a batch-size to be provided. + dataset will deliver trajectories one at a time *whereas* calling + :meth:`~.sample` will *still* require a batch-size to be provided. Keyword Args: shuffle (bool, optional): if ``True``, trajectories are delivered in a