From 3aa7d2ea5ed86da26851083e8816fa6f8c8b973b Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 9 Jan 2024 09:06:45 +0000 Subject: [PATCH 01/17] init --- torchrl/data/datasets/atari_dqn.py | 53 ++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 torchrl/data/datasets/atari_dqn.py diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py new file mode 100644 index 00000000000..ed588d89471 --- /dev/null +++ b/torchrl/data/datasets/atari_dqn.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import subprocess +import tempfile + +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +import os +import gzip +import tqdm +import numpy as np +import torch +import io +from pathlib import Path + +class AtariDQNExperienceReplay(TensorDictReplayBuffer): + available_datasets = ["Pong/1",] + def __init__(self, dataset_id): + self.dataset_id = dataset_id + storage = LazyMemmapStorage(1_000_000) + super().__init__(storage=storage) + + + def _download_dataset(self): + # with tempfile.TemporaryDirectory() as tempdir: + tempdir = "/Users/vmoens/Downloads/Pong/1" + # command = f"gsutil -m cp -R gs://atari-replay-datasets/dqn/{self.dataset_id} {tempdir}" + # subprocess.run(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + gz_files = [] + for root, dirs, files in os.walk(tempdir): + for file in files: + if file.endswith('.gz'): + gz_files.append(os.path.join(root, file)) + + def _count_files(pattern): + return sum(pattern in filename for filename in gz_files) + + pbar = tqdm.tqdm(gz_files) + for file in pbar: + name = str(Path(file).parts[-1]).split(".")[0] + # with open(file, "r") as fopen: + if "obs" in file: + print(name, file) + print("count", _count_files(name)) + with gzip.GzipFile(file) as f: + file_content = f.read() + t = torch.as_tensor(np.load(io.BytesIO(file_content))) + print(t.shape, t.dtype) + break + +AtariDQNExperienceReplay(AtariDQNExperienceReplay.available_datasets[0])._download_dataset() From 32e2523ccfe66f0d5121a8a5fd8b082a734c4ea8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 9 Jan 2024 10:55:34 +0000 Subject: [PATCH 02/17] amend --- torchrl/data/datasets/atari_dqn.py | 65 +++++++++++++++++++----------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index ed588d89471..3603975dd15 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -10,44 +10,63 @@ import gzip import tqdm import numpy as np +from tensordict import TensorDict import torch import io from pathlib import Path +from collections import defaultdict + +tempdir = "/Users/vmoens/Downloads/Pong/1" + class AtariDQNExperienceReplay(TensorDictReplayBuffer): - available_datasets = ["Pong/1",] + available_datasets = ["Pong/1", ] + def __init__(self, dataset_id): self.dataset_id = dataset_id storage = LazyMemmapStorage(1_000_000) super().__init__(storage=storage) + # def _download_dataset(self): + # # with tempfile.TemporaryDirectory() as tempdir: + # # command = f"gsutil -m cp -R gs://atari-replay-datasets/dqn/{self.dataset_id} {tempdir}" + # # subprocess.run(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - def _download_dataset(self): - # with tempfile.TemporaryDirectory() as tempdir: - tempdir = "/Users/vmoens/Downloads/Pong/1" - # command = f"gsutil -m cp -R gs://atari-replay-datasets/dqn/{self.dataset_id} {tempdir}" - # subprocess.run(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + def _get_episode(self, episode, path): + gz_files = self._list_episodes(path) + files = gz_files[episode] + td = {} + for file in files: + name = str(Path(file).parts[-1]).split(".")[0] + with gzip.GzipFile(file) as f: + file_content = f.read() + t = torch.as_tensor(np.load(io.BytesIO(file_content))) + td[self._process_name(name)] = t + td = TensorDict.from_dict(td) + return td + + @staticmethod + def _process_name(name): + if "store" in name: + return ("data", name.split("_")[1]) + if name.endswith("_ckpt"): + return name[:-5] + def _list_episodes(self, path): gz_files = [] - for root, dirs, files in os.walk(tempdir): + for root, dirs, files in os.walk(path): for file in files: if file.endswith('.gz'): gz_files.append(os.path.join(root, file)) + episodes = defaultdict(list) + for file in gz_files: + filename = Path(file).parts[-1] + name, episode, extension = str(filename).split(".") + episode = int(episode) + episodes[episode].append(file) + return episodes - def _count_files(pattern): - return sum(pattern in filename for filename in gz_files) - pbar = tqdm.tqdm(gz_files) - for file in pbar: - name = str(Path(file).parts[-1]).split(".")[0] - # with open(file, "r") as fopen: - if "obs" in file: - print(name, file) - print("count", _count_files(name)) - with gzip.GzipFile(file) as f: - file_content = f.read() - t = torch.as_tensor(np.load(io.BytesIO(file_content))) - print(t.shape, t.dtype) - break - -AtariDQNExperienceReplay(AtariDQNExperienceReplay.available_datasets[0])._download_dataset() +AtariDQNExperienceReplay( + AtariDQNExperienceReplay.available_datasets[0] + )._get_episode(0, tempdir) From 3cff3f3570667c818459a02b010965874a0e0b24 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 9 Jan 2024 13:20:04 +0000 Subject: [PATCH 03/17] amend --- torchrl/data/datasets/atari_dqn.py | 71 +++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index 3603975dd15..45fb71a7956 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -5,14 +5,17 @@ import subprocess import tempfile -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer, Storage import os import gzip import tqdm +import time +from concurrent.futures import ThreadPoolExecutor import numpy as np -from tensordict import TensorDict +from tensordict import TensorDict, NonTensorData import torch import io +import mmap from pathlib import Path from collections import defaultdict @@ -24,27 +27,64 @@ class AtariDQNExperienceReplay(TensorDictReplayBuffer): def __init__(self, dataset_id): self.dataset_id = dataset_id - storage = LazyMemmapStorage(1_000_000) - super().__init__(storage=storage) + storage = _AtariStorage(tempdir) + super().__init__(storage=storage, collate_fn=lambda x: x) - # def _download_dataset(self): - # # with tempfile.TemporaryDirectory() as tempdir: - # # command = f"gsutil -m cp -R gs://atari-replay-datasets/dqn/{self.dataset_id} {tempdir}" - # # subprocess.run(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - def _get_episode(self, episode, path): - gz_files = self._list_episodes(path) +class _AtariStorage(Storage): + def __init__(self, path): + self.path = path + self.gz_files = self._list_episodes(self.path) + + def __len__(self): + return len(self.gz_files) + + def _get_episode(self, episode): + gz_files = self.gz_files files = gz_files[episode] td = {} for file in files: name = str(Path(file).parts[-1]).split(".")[0] - with gzip.GzipFile(file) as f: + with gzip.GzipFile(file, mode="rb") as f: + t0 = time.time() file_content = f.read() - t = torch.as_tensor(np.load(io.BytesIO(file_content))) + t1 = time.time() + file_content = io.BytesIO(file_content) + t2 = time.time() + file_content = np.load(file_content) + t3 = time.time() + print(t1 - t0, t2 - t1, t3 - t2) + t = torch.as_tensor(file_content) td[self._process_name(name)] = t td = TensorDict.from_dict(td) + td = td["data"].set( + "metadata", + NonTensorData( + td.exclude("data").to_dict(), + batch_size=td["data"].batch_size + ) + ) return td + def get(self, index): + if isinstance(index, int): + return self._get_episode(index) + if isinstance(index, tuple): + if len(index) == 1: + return self.get(index[0]) + return self.get(index[0])[..., index[1:]] + if isinstance(index, torch.Tensor): + if index.ndim == 0: + return self[int(index)] + if index.ndim > 1: + raise RuntimeError("Only 1d tensors are accepted") + # with ThreadPoolExecutor(16) as pool: + results = map(self.__getitem__, index.tolist()) + return torch.stack(list(results)) + if isinstance(index, (range, list)): + return self[torch.tensor(index)] + return self[torch.arange(len(self))[index]] + @staticmethod def _process_name(name): if "store" in name: @@ -66,7 +106,6 @@ def _list_episodes(self, path): episodes[episode].append(file) return episodes - -AtariDQNExperienceReplay( - AtariDQNExperienceReplay.available_datasets[0] - )._get_episode(0, tempdir) +t0 = time.time() +AtariDQNExperienceReplay(AtariDQNExperienceReplay.available_datasets[0])[:3] +time.time()-t0 From 5405888fdebd0c10572f97dc52902c32371bb921 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 18 Jan 2024 20:48:28 +0000 Subject: [PATCH 04/17] amend --- torchrl/data/datasets/atari_dqn.py | 252 ++++++++++++++++++++++------- 1 file changed, 191 insertions(+), 61 deletions(-) diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index 45fb71a7956..bb15477a4bf 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -2,69 +2,218 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import subprocess -import tempfile +import gzip +import io +import pathlib +import shutil -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer, Storage +import mmap import os -import gzip -import tqdm +import subprocess +import tempfile import time +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + import numpy as np -from tensordict import TensorDict, NonTensorData import torch -import io -import mmap -from pathlib import Path -from collections import defaultdict +import tqdm +from tensordict import NonTensorData, TensorDict, MemoryMappedTensor + +from torchrl.data import LazyMemmapStorage, Storage, TensorDictReplayBuffer +from torchrl.envs.utils import _classproperty tempdir = "/Users/vmoens/Downloads/Pong/1" class AtariDQNExperienceReplay(TensorDictReplayBuffer): - available_datasets = ["Pong/1", ] + @_classproperty + def available_datasets(cls): + games = [ + "AirRaid", + "Alien", + "Amidar", + "Assault", + "Asterix", + "Asteroids", + "Atlantis", + "BankHeist", + "BattleZone", + "BeamRider", + "Berzerk", + "Bowling", + "Boxing", + "Breakout", + "Carnival", + "Centipede", + "ChopperCommand", + "CrazyClimber", + "DemonAttack", + "DoubleDunk", + "ElevatorAction", + "Enduro", + "FishingDerby", + "Freeway", + "Frostbite", + "Gopher", + "Gravitar", + "Hero", + "IceHockey", + "Jamesbond", + "JourneyEscape", + "Kangaroo", + "Krull", + "KungFuMaster", + "MontezumaRevenge", + "MsPacman", + "NameThisGame", + "Phoenix", + "Pitfall", + "Pong", + "Pooyan", + "PrivateEye", + "Qbert", + "Riverraid", + "RoadRunner", + "Robotank", + "Seaquest", + "Skiing", + "Solaris", + "SpaceInvaders", + ] + return ["/".join((game, str(loop))) for game in games for loop in range(1, 6)] def __init__(self, dataset_id): self.dataset_id = dataset_id - storage = _AtariStorage(tempdir) + from torchrl.data.datasets.utils import _get_root_dir + self.root = Path(_get_root_dir("atari")) + self._download_and_preproc() + storage = _AtariStorage(self._root) super().__init__(storage=storage, collate_fn=lambda x: x) + @property + def root(self): + return self._root + @root.setter + def root(self, value): + self._root = Path(value) + @property + def dataset_path(self): + return self._root / self.dataset_id + def _download_and_preproc(self): + if os.path.exists(self.dataset_path): + # TODO: better check + return + with tempfile.TemporaryDirectory() as tempdir: + command = f"gsutil -m cp -R gs://atari-replay-datasets/dqn/{self.dataset_id} {tempdir}" + subprocess.run(command, shell=True) #, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + gz_files = self._list_episodes(tempdir) + for episode in gz_files: + try: + path = self._root / str(episode) + self._preproc_episode(path, gz_files, episode) + except Exception: + shutil.rmtree(path) + raise -class _AtariStorage(Storage): - def __init__(self, path): - self.path = path - self.gz_files = self._list_episodes(self.path) - - def __len__(self): - return len(self.gz_files) - - def _get_episode(self, episode): - gz_files = self.gz_files + def _preproc_episode(self, path, gz_files, episode): + print("preproc", episode) files = gz_files[episode] - td = {} + td = TensorDict({}, []) + path = Path(path) for file in files: name = str(Path(file).parts[-1]).split(".")[0] with gzip.GzipFile(file, mode="rb") as f: - t0 = time.time() file_content = f.read() - t1 = time.time() file_content = io.BytesIO(file_content) - t2 = time.time() file_content = np.load(file_content) - t3 = time.time() - print(t1 - t0, t2 - t1, t3 - t2) t = torch.as_tensor(file_content) - td[self._process_name(name)] = t - td = TensorDict.from_dict(td) - td = td["data"].set( - "metadata", - NonTensorData( - td.exclude("data").to_dict(), - batch_size=td["data"].batch_size - ) - ) - return td + # Create the memmap file + key = self._process_name(name) + if key == ("data", "observation"): + shape = t.shape + shape = [shape[0] + 1] + list(shape[1:]) + filename = path / "data" / "observation.memmap" + os.makedirs(filename.parent, exist_ok=True) + mmap = MemoryMappedTensor.empty(shape, dtype=t.dtype, filename=filename) + print('copying') + mmap[:-1].copy_(t) + td[key] = mmap + # td["data", "next", key[1:]] = mmap[1:] + else: + if key in (("data", "reward"), ("data", "done"), ("data", "terminated")): + filename = path / "data" / "next" / (key[-1] + ".memmap") + os.makedirs(filename.parent, exist_ok=True) + mmap = MemoryMappedTensor.from_tensor(t, filename=filename) + td["data", "next", key[1:]] = mmap + else: + filename = path + for i, _key in enumerate(key): + if i == len(key) - 1: + _key = _key + ".memmap" + filename = filename / _key + os.makedirs(filename.parent, exist_ok=True) + mmap = MemoryMappedTensor.from_tensor(t, filename=filename) + td[key] = mmap + td.set_non_tensor("info", {"episode": episode, "path": path}) + td.memmap_(path, copy_existing=False) + + @staticmethod + def _process_name(name): + if name.endswith("_ckpt"): + name = name[:-5] + if "store" in name: + key = ("data", name.split("_")[1]) + else: + key = (name,) + if key[-1] == "terminal": + key = (*key[:-1], "terminated") + return key + + def _list_episodes(self, download_path): + path = download_path + gz_files = [] + for root, dirs, files in os.walk(path): + for file in files: + if file.endswith(".gz"): + gz_files.append(os.path.join(root, file)) + episodes = defaultdict(list) + for file in gz_files: + filename = Path(file).parts[-1] + name, episode, extension = str(filename).split(".") + episode = int(episode) + episodes[episode].append(file) + return dict(sorted(episodes.items(), key=lambda x: x[0])) + + +class _AtariStorage(Storage): + def __init__(self, path): + self.path = Path(path) + + def __len__(self): + return len(self.gz_files) + + def _get_episode(self, episode: int): + path = self.path / str(episode) + if os.path.exists(path): + return self._load_episode(path) + else: + raise RuntimeError + + def _load_episode(self, path): + return self._proc_td(TensorDict.load_memmap(path)) + + def _proc_td(self, td): + with td.unlock_(): + td["data", "next", "observation"] = td["data", "observation"][1:] + td["data", "observation"] = td["data", "observation"][:-1] + non_tensor = td.exclude("data").to_dict() + td = td["data"] + td.auto_batch_size_(1) + td.set_non_tensor("metadata", non_tensor) + return td + def get(self, index): if isinstance(index, int): @@ -85,27 +234,8 @@ def get(self, index): return self[torch.tensor(index)] return self[torch.arange(len(self))[index]] - @staticmethod - def _process_name(name): - if "store" in name: - return ("data", name.split("_")[1]) - if name.endswith("_ckpt"): - return name[:-5] - - def _list_episodes(self, path): - gz_files = [] - for root, dirs, files in os.walk(path): - for file in files: - if file.endswith('.gz'): - gz_files.append(os.path.join(root, file)) - episodes = defaultdict(list) - for file in gz_files: - filename = Path(file).parts[-1] - name, episode, extension = str(filename).split(".") - episode = int(episode) - episodes[episode].append(file) - return episodes - +import logging +logging.getLogger().setLevel(logging.INFO) t0 = time.time() -AtariDQNExperienceReplay(AtariDQNExperienceReplay.available_datasets[0])[:3] -time.time()-t0 +print(AtariDQNExperienceReplay("Pong/5")[:3]) +time.time() - t0 From 8b8867a8093cf49efe52857ea692318d0c768398 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 08:35:54 +0000 Subject: [PATCH 05/17] amend --- torchrl/data/datasets/atari_dqn.py | 198 +++++++++++++++++++++-------- 1 file changed, 145 insertions(+), 53 deletions(-) diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index bb15477a4bf..22ec8f4f2e6 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import gzip import io import pathlib @@ -24,9 +26,6 @@ from torchrl.data import LazyMemmapStorage, Storage, TensorDictReplayBuffer from torchrl.envs.utils import _classproperty -tempdir = "/Users/vmoens/Downloads/Pong/1" - - class AtariDQNExperienceReplay(TensorDictReplayBuffer): @_classproperty def available_datasets(cls): @@ -84,12 +83,15 @@ def available_datasets(cls): ] return ["/".join((game, str(loop))) for game in games for loop in range(1, 6)] + tmpdir = "/Users/vmoens/.cache/atari_root" + max_ep = 3 + def __init__(self, dataset_id): self.dataset_id = dataset_id from torchrl.data.datasets.utils import _get_root_dir self.root = Path(_get_root_dir("atari")) self._download_and_preproc() - storage = _AtariStorage(self._root) + storage = _AtariStorage(self.dataset_path) super().__init__(storage=storage, collate_fn=lambda x: x) @property @@ -106,19 +108,53 @@ def _download_and_preproc(self): # TODO: better check return with tempfile.TemporaryDirectory() as tempdir: - command = f"gsutil -m cp -R gs://atari-replay-datasets/dqn/{self.dataset_id} {tempdir}" - subprocess.run(command, shell=True) #, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - gz_files = self._list_episodes(tempdir) - for episode in gz_files: - try: - path = self._root / str(episode) - self._preproc_episode(path, gz_files, episode) - except Exception: - shutil.rmtree(path) - raise - - def _preproc_episode(self, path, gz_files, episode): - print("preproc", episode) + if self.tmpdir is not None: + tempdir = self.tmpdir + try: + shutil.rmtree(tempdir) + os.makedirs(tempdir, exist_ok=True) + except: + os.makedirs(tempdir, exist_ok=True) + if not os.listdir(tempdir): + os.makedirs(tempdir, exist_ok=True) + # get the list of episodes + command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/Pong/1/replay_logs" + output = subprocess.run( + command, + shell=True, capture_output=True + ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + files = [file.decode("utf-8").replace('$', '\$') for file in output.stdout.splitlines() if + file.endswith(b'.gz')] + self.remote_gz_files = self._list_episodes(None, files) + for episode, episode_files in self.remote_gz_files.items(): + self._download_and_proc_episode(episode, episode_files, tempdir, self.dataset_path) + + @classmethod + def _download_and_proc_episode(cls, episode, episode_files, tempdir, dataset_path): + if episode >= 3: + return + tempdir = Path(tempdir) + os.makedirs(tempdir/str(episode)) + files_str = ' '.join(episode_files) # .decode("utf-8") + print("downloading", files_str) + command = f"gsutil -m cp {files_str} {tempdir}/{episode}" + subprocess.run( + command, + shell=True + ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + local_gz_files = cls._list_episodes(tempdir/str(episode)) + # we iterate over the dict but this one has length 1 + for episode in local_gz_files: + path = dataset_path / str(episode) + try: + cls._preproc_episode(path, local_gz_files, episode) + except Exception: + shutil.rmtree(path) + raise + shutil.rmtree(tempdir / str(episode)) + + @classmethod + def _preproc_episode(cls, path, gz_files, episode): files = gz_files[episode] td = TensorDict({}, []) path = Path(path) @@ -130,14 +166,13 @@ def _preproc_episode(self, path, gz_files, episode): file_content = np.load(file_content) t = torch.as_tensor(file_content) # Create the memmap file - key = self._process_name(name) + key = cls._process_name(name) if key == ("data", "observation"): shape = t.shape shape = [shape[0] + 1] + list(shape[1:]) filename = path / "data" / "observation.memmap" os.makedirs(filename.parent, exist_ok=True) mmap = MemoryMappedTensor.empty(shape, dtype=t.dtype, filename=filename) - print('copying') mmap[:-1].copy_(t) td[key] = mmap # td["data", "next", key[1:]] = mmap[1:] @@ -171,13 +206,15 @@ def _process_name(name): key = (*key[:-1], "terminated") return key - def _list_episodes(self, download_path): + @classmethod + def _list_episodes(cls, download_path, gz_files=None): path = download_path - gz_files = [] - for root, dirs, files in os.walk(path): - for file in files: - if file.endswith(".gz"): - gz_files.append(os.path.join(root, file)) + if gz_files is None: + gz_files = [] + for root, dirs, files in os.walk(path): + for file in files: + if file.endswith(".gz"): + gz_files.append(os.path.join(root, file)) episodes = defaultdict(list) for file in gz_files: filename = Path(file).parts[-1] @@ -191,29 +228,62 @@ class _AtariStorage(Storage): def __init__(self, path): self.path = Path(path) + def get_folders(path): + return [name for name in os.listdir(path) if + os.path.isdir(os.path.join(path, name))] + + # Usage + self.episodes = [] + folders = get_folders(path) + for folder in folders: + self.episodes.append(int(Path(folder).parts[-1])) + self.frames_per_ep = 1000000 + self._episode_tds = [] + for episode in self.episodes: + path = self.path / str(episode) + self._episode_tds.append(self._load_episode(path)) + def __len__(self): - return len(self.gz_files) + return len(self.episodes) * self.frames_per_ep - def _get_episode(self, episode: int): - path = self.path / str(episode) - if os.path.exists(path): - return self._load_episode(path) + def _get_episode(self, item: int | torch.Tensor): + # print('get episode', item) + episode = item // self.frames_per_ep + item = item % self.frames_per_ep + if isinstance(item, int): + unique_episodes = (episode,) + episode_inverse = None else: - raise RuntimeError + unique_episodes, episode_inverse = torch.unique(episode, return_inverse=True) + # print('unique_episodes, episode_inverse', unique_episodes, episode_inverse) + out = [] + for i, episode in enumerate(unique_episodes): + episode = int(episode) + _item = item[episode_inverse == i] if episode_inverse is not None else item + # print('_item', _item) + path = self.path / str(episode) + if os.path.exists(path): + out.append( self._proc_td(self._episode_tds[episode], _item)) + else: + raise RuntimeError + # print('out', out) + return torch.cat(out, 0) def _load_episode(self, path): - return self._proc_td(TensorDict.load_memmap(path)) - - def _proc_td(self, td): - with td.unlock_(): - td["data", "next", "observation"] = td["data", "observation"][1:] - td["data", "observation"] = td["data", "observation"][:-1] - non_tensor = td.exclude("data").to_dict() - td = td["data"] - td.auto_batch_size_(1) - td.set_non_tensor("metadata", non_tensor) - return td + return TensorDict.load_memmap(path) + def _proc_td(self, td, index): + obs_ = td["data", "observation"][index + 1] + done = td["data", "next", "terminated"][index].bool() + if done.ndim and done.any(): + obs_ = torch.masked_fill(obs_, done, 0) + td_idx = td.empty() + td_idx["next", "observation"] = obs_ + non_tensor = td.exclude("data").to_dict() + td_idx.update(td["data"].apply(lambda x: x[index])) + td_idx.auto_batch_size_(1) + td_idx.set_non_tensor("metadata", non_tensor) + return td_idx def get(self, index): if isinstance(index, int): @@ -221,21 +291,43 @@ def get(self, index): if isinstance(index, tuple): if len(index) == 1: return self.get(index[0]) - return self.get(index[0])[..., index[1:]] + return self.get(index[0])[(Ellipsis, *index[1:])] if isinstance(index, torch.Tensor): - if index.ndim == 0: - return self[int(index)] - if index.ndim > 1: + if index.ndim <= 1: + return self._get_episode(index) + else: raise RuntimeError("Only 1d tensors are accepted") # with ThreadPoolExecutor(16) as pool: - results = map(self.__getitem__, index.tolist()) - return torch.stack(list(results)) + # results = map(self.__getitem__, index.tolist()) + # return torch.stack(list(results)) if isinstance(index, (range, list)): return self[torch.tensor(index)] + if isinstance(index, slice): + start = index.start if index.start is not None else 0 + stop = index.stop if index.stop is not None else len(self) + step = index.step if index.step is not None else 1 + return self.get(torch.arange(start, stop, step)) return self[torch.arange(len(self))[index]] -import logging -logging.getLogger().setLevel(logging.INFO) -t0 = time.time() -print(AtariDQNExperienceReplay("Pong/5")[:3]) -time.time() - t0 +if __name__ == '__main__': + # command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/Pong/1/replay_logs" + # output = subprocess.run( + # command, + # shell=True, capture_output=True + # ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + # print(output.stdout.splitlines(), type(output.stdout)) + # files = [file for file in output.stdout.splitlines() if file.endswith(b'.gz') and int(file.split(b'.')[-2]) <= 3] + # files_str = b' '.join(files) + # command = f"gsutil -m cp -R {files_str} {tempdir}" + # subprocess.run( + # command, + # shell=True + # ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + import logging + logging.getLogger().setLevel(logging.INFO) + dataset = AtariDQNExperienceReplay("Pong/5") + # t0 = time.time() + for _ in range(200): + dataset[slice(0, 3000000, 50000)] + # print(time.time() - t0) From 2f72e556050e63f674add30da4b507e3f5067d38 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 11:31:10 +0000 Subject: [PATCH 06/17] amend --- torchrl/data/datasets/__init__.py | 1 + torchrl/data/datasets/atari_dqn.py | 198 +++++++++++++++++++-------- torchrl/data/datasets/d4rl.py | 4 +- torchrl/data/datasets/minari_data.py | 8 +- torchrl/data/datasets/openml.py | 4 +- torchrl/data/datasets/openx.py | 12 +- torchrl/data/datasets/roboset.py | 8 +- torchrl/data/datasets/vd4rl.py | 4 +- 8 files changed, 161 insertions(+), 78 deletions(-) diff --git a/torchrl/data/datasets/__init__.py b/torchrl/data/datasets/__init__.py index 1cef4f3ffea..b1a071e18b5 100644 --- a/torchrl/data/datasets/__init__.py +++ b/torchrl/data/datasets/__init__.py @@ -4,3 +4,4 @@ from .openx import OpenXExperienceReplay from .roboset import RobosetExperienceReplay from .vd4rl import VD4RLExperienceReplay +from .atari_dqn import AtariDQNExperienceReplay diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index 22ec8f4f2e6..1d99b6fb051 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -4,14 +4,17 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import functools import gzip import io +import json import pathlib import shutil import mmap import os import subprocess +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter import tempfile import time from collections import defaultdict @@ -22,11 +25,68 @@ import torch import tqdm from tensordict import NonTensorData, TensorDict, MemoryMappedTensor +from tensordict.utils import expand_right from torchrl.data import LazyMemmapStorage, Storage, TensorDictReplayBuffer from torchrl.envs.utils import _classproperty +from torch import multiprocessing as mp class AtariDQNExperienceReplay(TensorDictReplayBuffer): + """Atari DQN Experience replay class. + + The Atari DQN dataset (https://offline-rl.github.io/) is a collection of 5 training + iterations of DQN over each of the Arari 2600 games for a total of 200 million frames. + The sub-sampling rate (frame-skip) is equal to 4, meaning that each game dataset + has 50 million steps in total. + + The data format follows the TED convention. Since the dataset is quite heavy, + the data formatting is done on-line, at sampling time. + + To make training more modular, we split the dataset in each of the Atari games + and separate each training round. Consequently, each dataset is presented as + a Storage of length 50x10^6 elements. Under the hood, this dataset is split + in 50 memory-mapped tensordicts of length 1 million each. + + Args: + dataset_id (str): The dataset to be downloaded. + Must be part of ``AtariDQNExperienceReplay.available_datasets``. + batch_size (int): Batch-size used during sampling. + Can be overridden by `data.sample(batch_size)` if necessary. + + Keyword Args: + root (Path or str, optional): The AtariDQN dataset root directory. + The actual dataset memory-mapped files will be saved under + `/`. If none is provided, it defaults to + ``~/.cache/torchrl/atari`. + download (bool or str, optional): Whether the dataset should be downloaded if + not found. Defaults to ``True``. Download can also be passed as "force", + in which case the downloaded data will be overwritten. + sampler (Sampler, optional): the sampler to be used. If none is provided + a default RandomSampler() will be used. + writer (Writer, optional): the writer to be used. If none is provided + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched + loading from a map-style dataset. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. + + Examples: + >>> from torchrl.data.datasets import AtariDQNExperienceReplay + >>> from torchrl.data.replay_buffers import SliceSampler + >>> sampler = SliceSampler() + >>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128, sampler=sampler) + >>> for data in dataset: + ... print(data) + ... break + + As always, datasets should be composed using :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble` + + """ @_classproperty def available_datasets(cls): games = [ @@ -84,29 +144,54 @@ def available_datasets(cls): return ["/".join((game, str(loop))) for game in games for loop in range(1, 6)] tmpdir = "/Users/vmoens/.cache/atari_root" - max_ep = 3 - def __init__(self, dataset_id): + # use _max_episodes for debugging, avoids downloading the entire dataset + _max_episodes = 4 + + def __init__(self, dataset_id:str, batch_size:int|None=None, *, root: str | Path | None = None, download: bool|str=True, sampler=None, writer=None, transform: "Transform" | None=None, num_procs: int=0, **kwargs): + if dataset_id not in self.available_datasets: + raise ValueError("The dataseet_id is not part of the available datasets. The dataset should be named / " + "where is one of the Atari 2600 games and the run is a number betweeen 1 and 5. " + "The full list of accepted dataset_ids is available under AtariDQNExperienceReplay.available_datasets.") self.dataset_id = dataset_id from torchrl.data.datasets.utils import _get_root_dir - self.root = Path(_get_root_dir("atari")) - self._download_and_preproc() + if root is None: + root = _get_root_dir("atari") + self.root = root + self.num_procs = num_procs + if download == "force" or (download and not self._is_downloaded): + try: + self._download_and_preproc() + except Exception: + # remove temporary data + if os.path.exists(self.dataset_path): + shutil.rmtree(self.dataset_path) + raise storage = _AtariStorage(self.dataset_path) - super().__init__(storage=storage, collate_fn=lambda x: x) + if writer is None: + writer = ImmutableDatasetWriter() + super().__init__(storage=storage, batch_size=batch_size, writer=writer, sampler=sampler, collate_fn=lambda x: x, transform=transform, **kwargs) @property - def root(self): + def root(self)->Path: return self._root @root.setter def root(self, value): self._root = Path(value) @property - def dataset_path(self): + def dataset_path(self) -> Path: return self._root / self.dataset_id + @property + def _is_downloaded(self): + if os.path.exists(self.dataset_path / "processed.json"): + with open(self.dataset_path / "processed.json", "r") as jsonfile: + return json.load(jsonfile).get("processed", False) + return False + def _download_and_preproc(self): + logging.info(f"Downloading and preprocessing dataset {self.dataset_id} with {self.num_procs} processes. This may take a while...") if os.path.exists(self.dataset_path): - # TODO: better check - return + shutil.rmtree(self.dataset_path) with tempfile.TemporaryDirectory() as tempdir: if self.tmpdir is not None: tempdir = self.tmpdir @@ -126,12 +211,21 @@ def _download_and_preproc(self): files = [file.decode("utf-8").replace('$', '\$') for file in output.stdout.splitlines() if file.endswith(b'.gz')] self.remote_gz_files = self._list_episodes(None, files) - for episode, episode_files in self.remote_gz_files.items(): - self._download_and_proc_episode(episode, episode_files, tempdir, self.dataset_path) + total_episodes = list(self.remote_gz_files)[-1] + if self.num_procs == 0: + for episode, episode_files in self.remote_gz_files.items(): + self._download_and_proc_episode(episode, episode_files, tempdir=tempdir, dataset_path=self.dataset_path, total_episodes=total_episodes) + else: + func = functools.partial(self._download_and_proc_episode, tempdir=tempdir, dataset_path=self.dataset_path, total_episodes=total_episodes) + args = [(episode, episode_files) for (episode, episode_files) in self.remote_gz_files.items()] + with mp.Pool(self.num_procs) as pool: + pool.starmap(func, args) + with open(self.dataset_path / "processed.json", "w") as file: + json.dump({"processed": True}, file) @classmethod - def _download_and_proc_episode(cls, episode, episode_files, tempdir, dataset_path): - if episode >= 3: + def _download_and_proc_episode(cls, episode, episode_files, *, tempdir, dataset_path, total_episodes): + if cls._max_episodes is not None and episode >= cls._max_episodes: return tempdir = Path(tempdir) os.makedirs(tempdir/str(episode)) @@ -152,6 +246,7 @@ def _download_and_proc_episode(cls, episode, episode_files, tempdir, dataset_pat shutil.rmtree(path) raise shutil.rmtree(tempdir / str(episode)) + print(f'Concluded episode {episode} out of {total_episodes}') @classmethod def _preproc_episode(cls, path, gz_files, episode): @@ -237,64 +332,73 @@ def get_folders(path): folders = get_folders(path) for folder in folders: self.episodes.append(int(Path(folder).parts[-1])) - self.frames_per_ep = 1000000 self._episode_tds = [] + frames_per_ep = {} for episode in self.episodes: path = self.path / str(episode) self._episode_tds.append(self._load_episode(path)) + # take away 1 because we padded with 1 empty val + frames_per_ep[episode] = self._episode_tds[-1].get(("data", "observation")).shape[0] - 1 + + frames_per_ep = torch.tensor([[episode, length] for (episode, length) in frames_per_ep.items()]) + frames_per_ep[:, 1] = frames_per_ep[:, 1].cumsum(0) + self.frames_per_ep = torch.cat([torch.tensor([[-1, 0]]), frames_per_ep], 0) def __len__(self): - return len(self.episodes) * self.frames_per_ep + return self.frames_per_ep[-1, 1].item() - def _get_episode(self, item: int | torch.Tensor): - # print('get episode', item) - episode = item // self.frames_per_ep - item = item % self.frames_per_ep + def _read_from_episodes(self, item: int | torch.Tensor): + # We need to allocate each item to its storage. + # We don't assume each storage has the same size (too expensive to test) + # so we keep a map of each storage cumulative length and retrieve the + # storages one after the other. + episode = (item < self.frames_per_ep[1:, 1].unsqueeze(1)) & (item >= self.frames_per_ep[:-1, 1].unsqueeze(1)) + episode = episode.squeeze().nonzero()[:, 0] + episode = self.frames_per_ep[episode+1, 0] + item = item - self.frames_per_ep[episode, 1] if isinstance(item, int): unique_episodes = (episode,) episode_inverse = None else: unique_episodes, episode_inverse = torch.unique(episode, return_inverse=True) - # print('unique_episodes, episode_inverse', unique_episodes, episode_inverse) + unique_episodes = unique_episodes.tolist() out = [] for i, episode in enumerate(unique_episodes): - episode = int(episode) _item = item[episode_inverse == i] if episode_inverse is not None else item - # print('_item', _item) - path = self.path / str(episode) - if os.path.exists(path): - out.append( self._proc_td(self._episode_tds[episode], _item)) - else: - raise RuntimeError - # print('out', out) + out.append( self._proc_td(self._episode_tds[episode], _item)) return torch.cat(out, 0) def _load_episode(self, path): return TensorDict.load_memmap(path) def _proc_td(self, td, index): - obs_ = td["data", "observation"][index + 1] - done = td["data", "next", "terminated"][index].bool() + td_data = td.get("data") + obs_ = td_data.get(("observation"))[index + 1] + done = td_data.get(("next", "terminated"))[index].squeeze(-1).bool() if done.ndim and done.any(): - obs_ = torch.masked_fill(obs_, done, 0) + obs_ = torch.index_fill(obs_, 0, done.nonzero().squeeze(), 0) + # obs_ = torch.masked_fill(obs_, done, 0) + # obs_ = torch.masked_fill(obs_, expand_right(done, obs_.shape), 0) + # obs_ = torch.where(~expand_right(done, obs_.shape), obs_, 0) td_idx = td.empty() - td_idx["next", "observation"] = obs_ + td_idx.set(("next", "observation"), obs_) non_tensor = td.exclude("data").to_dict() - td_idx.update(td["data"].apply(lambda x: x[index])) - td_idx.auto_batch_size_(1) + td_idx.update(td_data.apply(lambda x: x[index])) + if isinstance(index, torch.Tensor): + td_idx.batch_size = [len(index)] td_idx.set_non_tensor("metadata", non_tensor) return td_idx def get(self, index): if isinstance(index, int): - return self._get_episode(index) + return self._read_from_episodes(index) if isinstance(index, tuple): if len(index) == 1: return self.get(index[0]) return self.get(index[0])[(Ellipsis, *index[1:])] if isinstance(index, torch.Tensor): if index.ndim <= 1: - return self._get_episode(index) + return self._read_from_episodes(index) else: raise RuntimeError("Only 1d tensors are accepted") # with ThreadPoolExecutor(16) as pool: @@ -310,24 +414,8 @@ def get(self, index): return self[torch.arange(len(self))[index]] if __name__ == '__main__': - # command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/Pong/1/replay_logs" - # output = subprocess.run( - # command, - # shell=True, capture_output=True - # ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - # print(output.stdout.splitlines(), type(output.stdout)) - # files = [file for file in output.stdout.splitlines() if file.endswith(b'.gz') and int(file.split(b'.')[-2]) <= 3] - # files_str = b' '.join(files) - # command = f"gsutil -m cp -R {files_str} {tempdir}" - # subprocess.run( - # command, - # shell=True - # ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - import logging logging.getLogger().setLevel(logging.INFO) - dataset = AtariDQNExperienceReplay("Pong/5") - # t0 = time.time() - for _ in range(200): - dataset[slice(0, 3000000, 50000)] - # print(time.time() - t0) + dataset = AtariDQNExperienceReplay("Pong/5", num_procs=4) + for _ in range(100): + out = dataset[slice(0, 3000000, 10000)] \ No newline at end of file diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 38fce4a6b7c..08a8d7c5edd 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -50,7 +50,7 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -59,7 +59,7 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. split_trajs (bool, optional): if ``True``, the trajectories will be split along the first dimension and padded to have a matching shape. To split the trajectories, the ``"done"`` signal will be used, which diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 754d5da9865..efa225d9bc6 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -73,7 +73,7 @@ class MinariExperienceReplay(TensorDictReplayBuffer): sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -82,15 +82,13 @@ class MinariExperienceReplay(TensorDictReplayBuffer): prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. split_trajs (bool, optional): if ``True``, the trajectories will be split along the first dimension and padded to have a matching shape. To split the trajectories, the ``"done"`` signal will be used, which is recovered via ``done = truncated | terminated``. In other words, it is assumed that any ``truncated`` or ``terminated`` signal is - equivalent to the end of a trajectory. For some datasets from - ``D4RL``, this may not be true. It is up to the user to make - accurate choices regarding this usage of ``split_trajs``. + equivalent to the end of a trajectory. Defaults to ``False``. Attributes: diff --git a/torchrl/data/datasets/openml.py b/torchrl/data/datasets/openml.py index 76ccb66f601..ec1538b46a2 100644 --- a/torchrl/data/datasets/openml.py +++ b/torchrl/data/datasets/openml.py @@ -38,7 +38,7 @@ class OpenMLExperienceReplay(TensorDictReplayBuffer): sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -47,7 +47,7 @@ class OpenMLExperienceReplay(TensorDictReplayBuffer): prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. """ diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index aa78a92ff16..28a5b33226d 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -107,10 +107,10 @@ class for more information on how to interact with non-tensor data 0s. If another value is provided, it will be used for padding. If ``False`` or ``None`` (default) any encounter with a trajectory of insufficient length will raise an exception. - root (Path or str, optional): The Minari dataset root directory. + root (Path or str, optional): The OpenX dataset root directory. The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to - ``~/.cache/torchrl/minari`. + ``~/.cache/torchrl/openx`. streaming (bool, optional): if ``True``, the data won't be downloaded but read from a stream instead. @@ -132,7 +132,7 @@ class for more information on how to interact with non-tensor data sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -141,15 +141,13 @@ class for more information on how to interact with non-tensor data prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. split_trajs (bool, optional): if ``True``, the trajectories will be split along the first dimension and padded to have a matching shape. To split the trajectories, the ``"done"`` signal will be used, which is recovered via ``done = truncated | terminated``. In other words, it is assumed that any ``truncated`` or ``terminated`` signal is - equivalent to the end of a trajectory. For some datasets from - ``D4RL``, this may not be true. It is up to the user to make - accurate choices regarding this usage of ``split_trajs``. + equivalent to the end of a trajectory. Defaults to ``False``. strict_length (bool, optional): if ``False``, trajectories of length shorter than `slice_len` (or `batch_size // num_slices`) will be diff --git a/torchrl/data/datasets/roboset.py b/torchrl/data/datasets/roboset.py index 6e9a9bb23f7..5ea864470f7 100644 --- a/torchrl/data/datasets/roboset.py +++ b/torchrl/data/datasets/roboset.py @@ -58,7 +58,7 @@ class RobosetExperienceReplay(TensorDictReplayBuffer): sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -67,15 +67,13 @@ class RobosetExperienceReplay(TensorDictReplayBuffer): prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. split_trajs (bool, optional): if ``True``, the trajectories will be split along the first dimension and padded to have a matching shape. To split the trajectories, the ``"done"`` signal will be used, which is recovered via ``done = truncated | terminated``. In other words, it is assumed that any ``truncated`` or ``terminated`` signal is - equivalent to the end of a trajectory. For some datasets from - ``D4RL``, this may not be true. It is up to the user to make - accurate choices regarding this usage of ``split_trajs``. + equivalent to the end of a trajectory. Defaults to ``False``. Attributes: diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index 815a00ca687..bb1a15ddba3 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -66,7 +66,7 @@ class VD4RLExperienceReplay(TensorDictReplayBuffer): sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -75,7 +75,7 @@ class VD4RLExperienceReplay(TensorDictReplayBuffer): prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. split_trajs (bool, optional): if ``True``, the trajectories will be split along the first dimension and padded to have a matching shape. To split the trajectories, the ``"done"`` signal will be used, which From c6787e576125bca47214451512b308fdaf0637ff Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 14:16:13 +0000 Subject: [PATCH 07/17] amend --- torchrl/data/datasets/__init__.py | 2 +- torchrl/data/datasets/atari_dqn.py | 297 +++++++++++++----- torchrl/data/replay_buffers/replay_buffers.py | 45 ++- torchrl/data/replay_buffers/samplers.py | 123 ++++++-- 4 files changed, 349 insertions(+), 118 deletions(-) diff --git a/torchrl/data/datasets/__init__.py b/torchrl/data/datasets/__init__.py index b06ffe1c0a1..092b80083a1 100644 --- a/torchrl/data/datasets/__init__.py +++ b/torchrl/data/datasets/__init__.py @@ -1,3 +1,4 @@ +from .atari_dqn import AtariDQNExperienceReplay from .d4rl import D4RLExperienceReplay from .gen_dgrl import GenDGRLExperienceReplay from .minari_data import MinariExperienceReplay @@ -5,4 +6,3 @@ from .openx import OpenXExperienceReplay from .roboset import RobosetExperienceReplay from .vd4rl import VD4RLExperienceReplay -from .atari_dqn import AtariDQNExperienceReplay diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index 1d99b6fb051..b6c189f862b 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -8,26 +8,23 @@ import gzip import io import json -import pathlib import shutil -import mmap import os import subprocess from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter import tempfile -import time from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor from pathlib import Path import numpy as np import torch -import tqdm -from tensordict import NonTensorData, TensorDict, MemoryMappedTensor -from tensordict.utils import expand_right +from tensordict import TensorDict, MemoryMappedTensor -from torchrl.data import LazyMemmapStorage, Storage, TensorDictReplayBuffer +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.storages import Storage +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement, \ + SliceSampler, SliceSamplerWithoutReplacement from torchrl.envs.utils import _classproperty from torch import multiprocessing as mp @@ -74,17 +71,133 @@ class AtariDQNExperienceReplay(TensorDictReplayBuffer): using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. + num_slices (int, optional): the number of slices to be sampled. The batch-size + must be greater or equal to the ``num_slices`` argument. Exclusive + with ``slice_len``. Defaults to ``None`` (no slice sampling). + The ``sampler`` arg will override this value. + slice_len (int, optional): the length of the slices to be sampled. The batch-size + must be greater or equal to the ``slice_len`` argument and divisible + by it. Exclusive with ``num_slices``. Defaults to ``None`` (no slice sampling). + The ``sampler`` arg will override this value. + strict_length (bool, optional): if ``False``, trajectories of length + shorter than `slice_len` (or `batch_size // num_slices`) will be + allowed to appear in the batch. + Be mindful that this can result in effective `batch_size` shorter + than the one asked for! Trajectories can be split using + :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``. + The ``sampler`` arg will override this value. + replacement (bool, optional): if ``False``, sampling will occur without replacement. + The ``sampler`` arg will override this value. + + Attributes: + available_datasets: list of available datasets, formatted as `/`. Example: + `"Pong/5"`, `"Krull/2"`, ... + dataset_id (str): the name of the dataset. + episodes (torch.Tensor): a 1d tensor indicating to what run each of the + 1M frames belongs. To be used with :class:`~torchrl.data.replay_buffers.SliceSampler` + to cheaply sample slices of episodes. Examples: >>> from torchrl.data.datasets import AtariDQNExperienceReplay - >>> from torchrl.data.replay_buffers import SliceSampler - >>> sampler = SliceSampler() - >>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128, sampler=sampler) + >>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128) >>> for data in dataset: ... print(data) ... break + TensorDict( + fields={ + action: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), + index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False), + metadata: NonTensorData( + data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'info': {'episode': 0, 'path': PosixPath('/Users/vmoens/.cache/torchrl/atari/Pong/5/0')}}, + batch_size=torch.Size([128]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), + observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), + truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([128]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), + truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([128]), + device=None, + is_shared=False) + + .. warning:: + Atari-DQN does not provide the next observation after a termination signal. + In other words, there is no way to obtain the ``("next", "observation")`` state + when ``("next", "done")`` is ``True``. This value is filled with 0s but should + not be used in practice. If TorchRL's value estimators (:class:`~torchrl.objectives.values.ValueEstimator`) + are used, this should not be an issue. + + .. note:: + Because the construction of the sampler for episode sampling is slightly + convoluted, we made it easy for users to pass the arguments of the + :class:`~torchrl.data.replay_buffers.SliceSampler` directly to the + ``AtariDQNExperienceReplay`` dataset: any of the ``num_slices`` or + ``slice_len`` arguments will make the sampler an instance of + :class:`~torchrl.data.replay_buffers.SliceSampler`. The ``strict_length`` + can also be passed. - As always, datasets should be composed using :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble` + >>> from torchrl.data.datasets import AtariDQNExperienceReplay + >>> from torchrl.data.replay_buffers import SliceSampler + >>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128, slice_len=64) + >>> for data in dataset: + ... print(data) + ... print(data.get("index")) # indices are in 4 groups of consecutive values + ... break + TensorDict( + fields={ + action: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), + index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False), + metadata: NonTensorData( + data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'info': {'episode': 0, 'path': PosixPath('/Users/vmoens/.cache/torchrl/atari/Pong/5/0')}}, + batch_size=torch.Size([128]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([128]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), + truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([128]), + device=None, + is_shared=False) + tensor([2657628, 2657629, 2657630, 2657631, 2657632, 2657633, 2657634, 2657635, + 2657636, 2657637, 2657638, 2657639, 2657640, 2657641, 2657642, 2657643, + 2657644, 2657645, 2657646, 2657647, 2657648, 2657649, 2657650, 2657651, + 2657652, 2657653, 2657654, 2657655, 2657656, 2657657, 2657658, 2657659, + 2657660, 2657661, 2657662, 2657663, 2657664, 2657665, 2657666, 2657667, + 2657668, 2657669, 2657670, 2657671, 2657672, 2657673, 2657674, 2657675, + 2657676, 2657677, 2657678, 2657679, 2657680, 2657681, 2657682, 2657683, + 2657684, 2657685, 2657686, 2657687, 2657688, 2657689, 2657690, 2657691, + 1995687, 1995688, 1995689, 1995690, 1995691, 1995692, 1995693, 1995694, + 1995695, 1995696, 1995697, 1995698, 1995699, 1995700, 1995701, 1995702, + 1995703, 1995704, 1995705, 1995706, 1995707, 1995708, 1995709, 1995710, + 1995711, 1995712, 1995713, 1995714, 1995715, 1995716, 1995717, 1995718, + 1995719, 1995720, 1995721, 1995722, 1995723, 1995724, 1995725, 1995726, + 1995727, 1995728, 1995729, 1995730, 1995731, 1995732, 1995733, 1995734, + 1995735, 1995736, 1995737, 1995738, 1995739, 1995740, 1995741, 1995742, + 1995743, 1995744, 1995745, 1995746, 1995747, 1995748, 1995749, 1995750]) + + .. note:: + As always, datasets should be composed using :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble` """ @_classproperty @@ -145,10 +258,10 @@ def available_datasets(cls): tmpdir = "/Users/vmoens/.cache/atari_root" - # use _max_episodes for debugging, avoids downloading the entire dataset - _max_episodes = 4 + # use _max_runs for debugging, avoids downloading the entire dataset + _max_runs = None - def __init__(self, dataset_id:str, batch_size:int|None=None, *, root: str | Path | None = None, download: bool|str=True, sampler=None, writer=None, transform: "Transform" | None=None, num_procs: int=0, **kwargs): + def __init__(self, dataset_id:str, batch_size:int|None=None, *, root: str | Path | None = None, download: bool|str=True, sampler=None, writer=None, transform: "Transform" | None=None, num_procs: int=0, num_slices: int|None=None, slice_len: int|None=None, strict_len: bool=True, replacement: bool=True, **kwargs): if dataset_id not in self.available_datasets: raise ValueError("The dataseet_id is not part of the available datasets. The dataset should be named / " "where is one of the Atari 2600 games and the run is a number betweeen 1 and 5. " @@ -170,8 +283,21 @@ def __init__(self, dataset_id:str, batch_size:int|None=None, *, root: str | Path storage = _AtariStorage(self.dataset_path) if writer is None: writer = ImmutableDatasetWriter() + if sampler is None: + if num_slices is not None or slice_len is not None: + if not replacement: + sampler = SliceSamplerWithoutReplacement(num_slices=num_slices, slice_len=slice_len, trajectories=storage.episodes) + else: + sampler = SliceSampler(num_slices=num_slices, slice_len=slice_len, trajectories=storage.episodes, cache_values=True) + elif not replacement: + sampler = SamplerWithoutReplacement() + super().__init__(storage=storage, batch_size=batch_size, writer=writer, sampler=sampler, collate_fn=lambda x: x, transform=transform, **kwargs) + @property + def episodes(self): + return self._storage.episodes + @property def root(self)->Path: return self._root @@ -185,7 +311,7 @@ def dataset_path(self) -> Path: def _is_downloaded(self): if os.path.exists(self.dataset_path / "processed.json"): with open(self.dataset_path / "processed.json", "r") as jsonfile: - return json.load(jsonfile).get("processed", False) + return json.load(jsonfile).get("processed", False) == self._max_runs return False def _download_and_preproc(self): @@ -202,7 +328,7 @@ def _download_and_preproc(self): os.makedirs(tempdir, exist_ok=True) if not os.listdir(tempdir): os.makedirs(tempdir, exist_ok=True) - # get the list of episodes + # get the list of runs command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/Pong/1/replay_logs" output = subprocess.run( command, @@ -210,47 +336,49 @@ def _download_and_preproc(self): ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) files = [file.decode("utf-8").replace('$', '\$') for file in output.stdout.splitlines() if file.endswith(b'.gz')] - self.remote_gz_files = self._list_episodes(None, files) - total_episodes = list(self.remote_gz_files)[-1] + self.remote_gz_files = self._list_runs(None, files) + total_runs = list(self.remote_gz_files)[-1] if self.num_procs == 0: - for episode, episode_files in self.remote_gz_files.items(): - self._download_and_proc_episode(episode, episode_files, tempdir=tempdir, dataset_path=self.dataset_path, total_episodes=total_episodes) + for run, run_files in self.remote_gz_files.items(): + self._download_and_proc_split(run, run_files, tempdir=tempdir, dataset_path=self.dataset_path, total_episodes=total_runs) else: - func = functools.partial(self._download_and_proc_episode, tempdir=tempdir, dataset_path=self.dataset_path, total_episodes=total_episodes) - args = [(episode, episode_files) for (episode, episode_files) in self.remote_gz_files.items()] + func = functools.partial(self._download_and_proc_split, tempdir=tempdir, dataset_path=self.dataset_path, total_episodes=total_runs) + args = [(run, run_files) for (run, run_files) in self.remote_gz_files.items()] with mp.Pool(self.num_procs) as pool: pool.starmap(func, args) with open(self.dataset_path / "processed.json", "w") as file: - json.dump({"processed": True}, file) + # we save self._max_runs such that changing the number of runs to process + # forces the data to be re-downloaded + json.dump({"processed": self._max_runs}, file) @classmethod - def _download_and_proc_episode(cls, episode, episode_files, *, tempdir, dataset_path, total_episodes): - if cls._max_episodes is not None and episode >= cls._max_episodes: + def _download_and_proc_split(cls, run, run_files, *, tempdir, dataset_path, total_episodes): + if cls._max_runs is not None and run >= cls._max_runs: return tempdir = Path(tempdir) - os.makedirs(tempdir/str(episode)) - files_str = ' '.join(episode_files) # .decode("utf-8") - print("downloading", files_str) - command = f"gsutil -m cp {files_str} {tempdir}/{episode}" + os.makedirs(tempdir / str(run)) + files_str = ' '.join(run_files) # .decode("utf-8") + logging.info("downloading", files_str) + command = f"gsutil -m cp {files_str} {tempdir}/{run}" subprocess.run( command, shell=True ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - local_gz_files = cls._list_episodes(tempdir/str(episode)) + local_gz_files = cls._list_runs(tempdir / str(run)) # we iterate over the dict but this one has length 1 - for episode in local_gz_files: - path = dataset_path / str(episode) + for run in local_gz_files: + path = dataset_path / str(run) try: - cls._preproc_episode(path, local_gz_files, episode) + cls._preproc_run(path, local_gz_files, run) except Exception: shutil.rmtree(path) raise - shutil.rmtree(tempdir / str(episode)) - print(f'Concluded episode {episode} out of {total_episodes}') + shutil.rmtree(tempdir / str(run)) + logging.info(f'Concluded run {run} out of {total_episodes}') @classmethod - def _preproc_episode(cls, path, gz_files, episode): - files = gz_files[episode] + def _preproc_run(cls, path, gz_files, run): + files = gz_files[run] td = TensorDict({}, []) path = Path(path) for file in files: @@ -286,7 +414,7 @@ def _preproc_episode(cls, path, gz_files, episode): os.makedirs(filename.parent, exist_ok=True) mmap = MemoryMappedTensor.from_tensor(t, filename=filename) td[key] = mmap - td.set_non_tensor("info", {"episode": episode, "path": path}) + td.set_non_tensor("dataset_id", "/".join(path.parts[-3:-1])) td.memmap_(path, copy_existing=False) @staticmethod @@ -302,7 +430,7 @@ def _process_name(name): return key @classmethod - def _list_episodes(cls, download_path, gz_files=None): + def _list_runs(cls, download_path, gz_files=None): path = download_path if gz_files is None: gz_files = [] @@ -310,13 +438,13 @@ def _list_episodes(cls, download_path, gz_files=None): for file in files: if file.endswith(".gz"): gz_files.append(os.path.join(root, file)) - episodes = defaultdict(list) + runs = defaultdict(list) for file in gz_files: filename = Path(file).parts[-1] name, episode, extension = str(filename).split(".") episode = int(episode) - episodes[episode].append(file) - return dict(sorted(episodes.items(), key=lambda x: x[0])) + runs[episode].append(file) + return dict(sorted(runs.items(), key=lambda x: x[0])) class _AtariStorage(Storage): @@ -328,47 +456,54 @@ def get_folders(path): os.path.isdir(os.path.join(path, name))] # Usage - self.episodes = [] + self.splits = [] folders = get_folders(path) for folder in folders: - self.episodes.append(int(Path(folder).parts[-1])) - self._episode_tds = [] - frames_per_ep = {} - for episode in self.episodes: - path = self.path / str(episode) - self._episode_tds.append(self._load_episode(path)) + self.splits.append(int(Path(folder).parts[-1])) + self.splits = sorted(self.splits) + self._split_tds = [] + frames_per_split = {} + for split in self.splits: + path = self.path / str(split) + self._split_tds.append(self._load_split(path)) # take away 1 because we padded with 1 empty val - frames_per_ep[episode] = self._episode_tds[-1].get(("data", "observation")).shape[0] - 1 + frames_per_split[split] = self._split_tds[-1].get(("data", "observation")).shape[0] - 1 - frames_per_ep = torch.tensor([[episode, length] for (episode, length) in frames_per_ep.items()]) - frames_per_ep[:, 1] = frames_per_ep[:, 1].cumsum(0) - self.frames_per_ep = torch.cat([torch.tensor([[-1, 0]]), frames_per_ep], 0) + frames_per_split = torch.tensor([[split, length] for (split, length) in frames_per_split.items()]) + frames_per_split[:, 1] = frames_per_split[:, 1].cumsum(0) + self.frames_per_split = torch.cat([torch.tensor([[-1, 0]]), frames_per_split], 0) + + # retrieve episodes + self.episodes = torch.cumsum(torch.cat([td.get(("data", "next", "terminated")) for td in self._split_tds], 0), 0) def __len__(self): - return self.frames_per_ep[-1, 1].item() + return self.frames_per_split[-1, 1].item() - def _read_from_episodes(self, item: int | torch.Tensor): + def _read_from_splits(self, item: int | torch.Tensor): # We need to allocate each item to its storage. # We don't assume each storage has the same size (too expensive to test) # so we keep a map of each storage cumulative length and retrieve the # storages one after the other. - episode = (item < self.frames_per_ep[1:, 1].unsqueeze(1)) & (item >= self.frames_per_ep[:-1, 1].unsqueeze(1)) - episode = episode.squeeze().nonzero()[:, 0] - episode = self.frames_per_ep[episode+1, 0] - item = item - self.frames_per_ep[episode, 1] + split = (item < self.frames_per_split[1:, 1].unsqueeze(1)) & (item >= self.frames_per_split[:-1, 1].unsqueeze(1)) + split_tmp, idx = split.squeeze().nonzero().unbind(-1) + split = torch.zeros_like(split_tmp) + split[idx] = split_tmp + split = self.frames_per_split[split + 1, 0] + item = item - self.frames_per_split[split, 1] + assert (item>=0).all() if isinstance(item, int): - unique_episodes = (episode,) - episode_inverse = None + unique_splits = (split,) + split_inverse = None else: - unique_episodes, episode_inverse = torch.unique(episode, return_inverse=True) - unique_episodes = unique_episodes.tolist() + unique_splits, split_inverse = torch.unique(split, return_inverse=True) + unique_splits = unique_splits.tolist() out = [] - for i, episode in enumerate(unique_episodes): - _item = item[episode_inverse == i] if episode_inverse is not None else item - out.append( self._proc_td(self._episode_tds[episode], _item)) + for i, split in enumerate(unique_splits): + _item = item[split_inverse == i] if split_inverse is not None else item + out.append( self._proc_td(self._split_tds[split], _item)) return torch.cat(out, 0) - def _load_episode(self, path): + def _load_split(self, path): return TensorDict.load_memmap(path) def _proc_td(self, td, index): @@ -377,9 +512,6 @@ def _proc_td(self, td, index): done = td_data.get(("next", "terminated"))[index].squeeze(-1).bool() if done.ndim and done.any(): obs_ = torch.index_fill(obs_, 0, done.nonzero().squeeze(), 0) - # obs_ = torch.masked_fill(obs_, done, 0) - # obs_ = torch.masked_fill(obs_, expand_right(done, obs_.shape), 0) - # obs_ = torch.where(~expand_right(done, obs_.shape), obs_, 0) td_idx = td.empty() td_idx.set(("next", "observation"), obs_) non_tensor = td.exclude("data").to_dict() @@ -387,18 +519,27 @@ def _proc_td(self, td, index): if isinstance(index, torch.Tensor): td_idx.batch_size = [len(index)] td_idx.set_non_tensor("metadata", non_tensor) + + terminated = td_idx.get(("next", "terminated")) + zterminated = torch.zeros_like(terminated) + td_idx.set(("next", "done"), terminated.clone()) + td_idx.set(("next", "truncated"), zterminated) + td_idx.set("terminated", zterminated) + td_idx.set("done", zterminated) + td_idx.set("truncated", zterminated) + return td_idx def get(self, index): if isinstance(index, int): - return self._read_from_episodes(index) + return self._read_from_splits(index) if isinstance(index, tuple): if len(index) == 1: return self.get(index[0]) return self.get(index[0])[(Ellipsis, *index[1:])] if isinstance(index, torch.Tensor): if index.ndim <= 1: - return self._read_from_episodes(index) + return self._read_from_splits(index) else: raise RuntimeError("Only 1d tensors are accepted") # with ThreadPoolExecutor(16) as pool: @@ -416,6 +557,10 @@ def get(self, index): if __name__ == '__main__': import logging logging.getLogger().setLevel(logging.INFO) - dataset = AtariDQNExperienceReplay("Pong/5", num_procs=4) - for _ in range(100): - out = dataset[slice(0, 3000000, 10000)] \ No newline at end of file + AtariDQNExperienceReplay._max_runs = 3 + dataset = AtariDQNExperienceReplay("Pong/5", num_procs=4, num_slices=4, batch_size=128, replacement=False) + torch.manual_seed(0) + for i, data in enumerate(dataset): + print(data) + if i == 10: + break diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index de9b13b8129..b17fcf9d2f5 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -157,13 +157,7 @@ def __init__( self._writer = writer if writer is not None else RoundRobinWriter() self._writer.register_storage(self._storage) - self._collate_fn = ( - collate_fn - if collate_fn is not None - else _get_default_collate( - self._storage, _is_tensordict=isinstance(self, TensorDictReplayBuffer) - ) - ) + self._get_collate_fn(collate_fn) self._pin_memory = pin_memory self._prefetch = bool(prefetch) @@ -201,6 +195,43 @@ def __init__( ) self._batch_size = batch_size + def _get_collate_fn(self, collate_fn): + self._collate_fn = ( + collate_fn + if collate_fn is not None + else _get_default_collate( + self._storage, _is_tensordict=isinstance(self, TensorDictReplayBuffer) + ) + ) + + def set_storage(self, storage: Storage, collate_fn: Callable | None = None): + """Sets a new storage in the replay buffer and returns the previous storage. + + Args: + collate_fn (callable, optional): if provided, the collate_fn is set to this + value. Otherwise it is reset to a default value. + + """ + + prev_storage = self._storage + self._storage = storage + self._get_collate_fn(collate_fn) + + return prev_storage + + def set_writer(self, writer: Writer): + """Sets a new writer in the replay buffer and returns the previous writer.""" + prev_writer = self._writer + self._writer = writer + self._writer.register_storage(self._storage) + return prev_writer + + def set_sampler(self, sampler: Sampler): + """Sets a new sampler in the replay buffer and returns the previous sampler.""" + prev_sampler = self._sampler + self._sampler = sampler + return prev_sampler + def __len__(self) -> int: with self._replay_lock: return len(self._storage) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 05baa2eaee1..be4ee9c6406 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -524,6 +524,14 @@ class SliceSampler(Sampler): trajectory (or episode). Defaults to ``("next", "done")``. traj_key (NestedKey, optional): the key indicating the trajectories. Defaults to ``"episode"`` (commonly used across datasets in TorchRL). + ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. + trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. cache_values (bool, optional): to be used with static datasets. Will cache the start and end signal of the trajectory. truncated_key (NestedKey, optional): If not ``None``, this argument @@ -612,19 +620,12 @@ def __init__( slice_len: int = None, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, + ends: torch.Tensor | None = None, + trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, ) -> object: - if end_key is None: - end_key = ("next", "done") - if traj_key is None: - traj_key = "episode" - if not ((num_slices is None) ^ (slice_len is None)): - raise TypeError( - "Either num_slices or slice_len must be not None, and not both. " - f"Got num_slices={num_slices} and slice_len={slice_len}." - ) self.num_slices = num_slices self.slice_len = slice_len self.end_key = end_key @@ -635,6 +636,47 @@ def __init__( self._uses_data_prefix = False self.strict_length = strict_length self._cache = {} + if trajectories is not None: + if traj_key is not None or end_key: + raise RuntimeError( + "`trajectories` and `end_key` or `traj_key` are exclusive arguments." + ) + if ends is not None: + raise RuntimeError("trajectories and ends are exclusive arguments.") + if not cache_values: + raise RuntimeError( + "To be used, trajectories requires `cache_values` to be set to `True`." + ) + vals = self._find_start_stop_traj(trajectory=trajectories) + self._cache["stop-and-length"] = vals + + elif ends is not None: + if traj_key is not None or end_key: + raise RuntimeError( + "`ends` and `end_key` or `traj_key` are exclusive arguments." + ) + if trajectories is not None: + raise RuntimeError("trajectories and ends are exclusive arguments.") + if not cache_values: + raise RuntimeError( + "To be used, ends requires `cache_values` to be set to `True`." + ) + vals = self._find_start_stop_traj(end=ends) + self._cache["stop-and-length"] = vals + + else: + if end_key is None: + end_key = ("next", "done") + if traj_key is None: + traj_key = "run" + self.end_key = end_key + self.traj_key = traj_key + + if not ((num_slices is None) ^ (slice_len is None)): + raise TypeError( + "Either num_slices or slice_len must be not None, and not both. " + f"Got num_slices={num_slices} and slice_len={slice_len}." + ) @staticmethod def _find_start_stop_traj(*, trajectory=None, end=None): @@ -696,16 +738,19 @@ def _get_stop_and_length(self, storage, fallback=True): # In the future, this may be deprecated, and we don't want to mess # with the keys provided by the user so we fall back on a proxy to # the traj key. - try: - trajectory = storage._storage.get(self._used_traj_key) - except KeyError: - trajectory = storage._storage.get(("_data", self.traj_key)) - # cache that value for future use - self._used_traj_key = ("_data", self.traj_key) - self._uses_data_prefix = ( - isinstance(self._used_traj_key, tuple) - and self._used_traj_key[0] == "_data" - ) + if isinstance(storage, TensorStorage): + try: + trajectory = storage._storage.get(self._used_traj_key) + except KeyError: + trajectory = storage._storage.get(("_data", self.traj_key)) + # cache that value for future use + self._used_traj_key = ("_data", self.traj_key) + self._uses_data_prefix = ( + isinstance(self._used_traj_key, tuple) + and self._used_traj_key[0] == "_data" + ) + else: + trajectory = storage[:].get(self.traj_key) vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)]) if self.cache_values: self._cache["stop-and-length"] = vals @@ -722,16 +767,19 @@ def _get_stop_and_length(self, storage, fallback=True): # In the future, this may be deprecated, and we don't want to mess # with the keys provided by the user so we fall back on a proxy to # the traj key. - try: - done = storage._storage.get(self._used_end_key) - except KeyError: - done = storage._storage.get(("_data", self.end_key)) - # cache that value for future use - self._used_end_key = ("_data", self.end_key) - self._uses_data_prefix = ( - isinstance(self._used_end_key, tuple) - and self._used_end_key[0] == "_data" - ) + if isinstance(storage, TensorStorage): + try: + done = storage._storage.get(self._used_end_key) + except KeyError: + done = storage._storage.get(("_data", self.end_key)) + # cache that value for future use + self._used_end_key = ("_data", self.end_key) + self._uses_data_prefix = ( + isinstance(self._used_end_key, tuple) + and self._used_end_key[0] == "_data" + ) + else: + done = storage[:].get(self.end_key) vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)] if self.cache_values: self._cache["stop-and-length"] = vals @@ -760,11 +808,6 @@ def _adjusted_batch_size(self, batch_size): return seq_length, num_slices def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: - if not isinstance(storage, TensorStorage): - raise RuntimeError( - f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead." - ) - # pick up as many trajs as we need start_idx, stop_idx, lengths = self._get_stop_and_length(storage) seq_length, num_slices = self._adjusted_batch_size(batch_size) @@ -889,6 +932,14 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): trajectory (or episode). Defaults to ``("next", "done")``. traj_key (NestedKey, optional): the key indicating the trajectories. Defaults to ``"episode"`` (commonly used across datasets in TorchRL). + ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. + trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. truncated_key (NestedKey, optional): If not ``None``, this argument indicates where a truncated signal should be written in the output data. This is used to indicate to value estimators where the provided @@ -973,6 +1024,8 @@ def __init__( drop_last: bool = False, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, + ends: torch.Tensor | None = None, + trajectories: torch.Tensor | None = None, truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, shuffle: bool = True, @@ -986,6 +1039,8 @@ def __init__( cache_values=True, truncated_key=truncated_key, strict_length=strict_length, + ends=ends, + trajectories=trajectories, ) SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle) From 943d85667c31cdc52870b78d3e26b7db64f880f6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 14:19:36 +0000 Subject: [PATCH 08/17] doc --- docs/source/reference/data.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 90dbe4f3d4e..2def1b4bfa8 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -280,7 +280,7 @@ Here's an example: :toctree: generated/ :template: rl_template.rst - + AtariDQNExperienceReplay D4RLExperienceReplay GenDGRLExperienceReplay MinariExperienceReplay From 23b36e5938a49c6824bc400c147ec65c43de5658 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 14:54:57 +0000 Subject: [PATCH 09/17] amend --- torchrl/data/datasets/atari_dqn.py | 334 ++++++++++++++++++++++++----- 1 file changed, 282 insertions(+), 52 deletions(-) diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index b6c189f862b..b0fd3f83357 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -8,25 +8,30 @@ import gzip import io import json -import shutil +import logging import os +import shutil import subprocess -from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter import tempfile from collections import defaultdict from pathlib import Path import numpy as np import torch -from tensordict import TensorDict, MemoryMappedTensor +from tensordict import MemoryMappedTensor, TensorDict +from torch import multiprocessing as mp from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import ( + SamplerWithoutReplacement, + SliceSampler, + SliceSamplerWithoutReplacement, +) from torchrl.data.replay_buffers.storages import Storage -from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement, \ - SliceSampler, SliceSamplerWithoutReplacement +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter from torchrl.envs.utils import _classproperty -from torch import multiprocessing as mp + class AtariDQNExperienceReplay(TensorDictReplayBuffer): """Atari DQN Experience replay class. @@ -55,6 +60,9 @@ class AtariDQNExperienceReplay(TensorDictReplayBuffer): The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to ``~/.cache/torchrl/atari`. + num_procs (int, optional): number of processes to launch for preprocessing. + Has no effect whenever the data is already downloaded. Defaults to 0 + (no multiprocessing used). download (bool or str, optional): Whether the dataset should be downloaded if not found. Defaults to ``True``. Download can also be passed as "force", in which case the downloaded data will be overwritten. @@ -109,7 +117,7 @@ class AtariDQNExperienceReplay(TensorDictReplayBuffer): done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False), metadata: NonTensorData( - data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'info': {'episode': 0, 'path': PosixPath('/Users/vmoens/.cache/torchrl/atari/Pong/5/0')}}, + data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}}, batch_size=torch.Size([128]), device=None, is_shared=False), @@ -159,7 +167,7 @@ class AtariDQNExperienceReplay(TensorDictReplayBuffer): done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False), metadata: NonTensorData( - data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'info': {'episode': 0, 'path': PosixPath('/Users/vmoens/.cache/torchrl/atari/Pong/5/0')}}, + data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}}, batch_size=torch.Size([128]), device=None, is_shared=False), @@ -197,9 +205,131 @@ class AtariDQNExperienceReplay(TensorDictReplayBuffer): 1995743, 1995744, 1995745, 1995746, 1995747, 1995748, 1995749, 1995750]) .. note:: - As always, datasets should be composed using :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble` + As always, datasets should be composed using :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble`: + + >>> from torchrl.data.datasets import AtariDQNExperienceReplay + >>> from torchrl.data.replay_buffers import ReplayBufferEnsemble + >>> # we change this parameter for quick experimentation, in practice it should be left untouched + >>> AtariDQNExperienceReplay._max_runs = 2 + >>> dataset_asterix = AtariDQNExperienceReplay("Asterix/5", batch_size=128, slice_len=64, num_procs=4) + >>> dataset_pong = AtariDQNExperienceReplay("Pong/5", batch_size=128, slice_len=64, num_procs=4) + >>> dataset = ReplayBufferEnsemble(dataset_pong, dataset_asterix, batch_size=128, sample_from_all=True) + >>> sample = dataset.sample() + >>> print("first sample, Asterix", sample[0]) + first sample, Asterix TensorDict( + fields={ + action: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False), + index: TensorDict( + fields={ + buffer_ids: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False), + index: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([64]), + device=None, + is_shared=False), + metadata: NonTensorData( + data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}, + batch_size=torch.Size([64]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([64]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + terminated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False), + truncated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([64]), + device=None, + is_shared=False) + >>> print("second sample, Pong", sample[1]) + second sample, Pong TensorDict( + fields={ + action: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False), + index: TensorDict( + fields={ + buffer_ids: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False), + index: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([64]), + device=None, + is_shared=False), + metadata: NonTensorData( + data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Asterix/5'}, + batch_size=torch.Size([64]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([64]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + terminated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False), + truncated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([64]), + device=None, + is_shared=False) + >>> print("Aggregate (metadata hidden)", sample) + Aggregate (metadata hidden) LazyStackedTensorDict( + fields={ + action: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False), + index: LazyStackedTensorDict( + fields={ + buffer_ids: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False), + index: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([2, 64]), + device=None, + is_shared=False, + stack_dim=0), + metadata: LazyStackedTensorDict( + fields={ + }, + exclusive_fields={ + }, + batch_size=torch.Size([2, 64]), + device=None, + is_shared=False, + stack_dim=0), + next: LazyStackedTensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([2, 64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([2, 64]), + device=None, + is_shared=False, + stack_dim=0), + observation: Tensor(shape=torch.Size([2, 64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + terminated: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False), + truncated: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([2, 64]), + device=None, + is_shared=False, + stack_dim=0) """ + @_classproperty def available_datasets(cls): games = [ @@ -261,13 +391,32 @@ def available_datasets(cls): # use _max_runs for debugging, avoids downloading the entire dataset _max_runs = None - def __init__(self, dataset_id:str, batch_size:int|None=None, *, root: str | Path | None = None, download: bool|str=True, sampler=None, writer=None, transform: "Transform" | None=None, num_procs: int=0, num_slices: int|None=None, slice_len: int|None=None, strict_len: bool=True, replacement: bool=True, **kwargs): + def __init__( + self, + dataset_id: str, + batch_size: int | None = None, + *, + root: str | Path | None = None, + download: bool | str = True, + sampler=None, + writer=None, + transform: "Transform" | None = None, + num_procs: int = 0, + num_slices: int | None = None, + slice_len: int | None = None, + strict_len: bool = True, + replacement: bool = True, + **kwargs, + ): if dataset_id not in self.available_datasets: - raise ValueError("The dataseet_id is not part of the available datasets. The dataset should be named / " - "where is one of the Atari 2600 games and the run is a number betweeen 1 and 5. " - "The full list of accepted dataset_ids is available under AtariDQNExperienceReplay.available_datasets.") + raise ValueError( + "The dataseet_id is not part of the available datasets. The dataset should be named / " + "where is one of the Atari 2600 games and the run is a number betweeen 1 and 5. " + "The full list of accepted dataset_ids is available under AtariDQNExperienceReplay.available_datasets." + ) self.dataset_id = dataset_id from torchrl.data.datasets.utils import _get_root_dir + if root is None: root = _get_root_dir("atari") self.root = root @@ -286,27 +435,47 @@ def __init__(self, dataset_id:str, batch_size:int|None=None, *, root: str | Path if sampler is None: if num_slices is not None or slice_len is not None: if not replacement: - sampler = SliceSamplerWithoutReplacement(num_slices=num_slices, slice_len=slice_len, trajectories=storage.episodes) + sampler = SliceSamplerWithoutReplacement( + num_slices=num_slices, + slice_len=slice_len, + trajectories=storage.episodes, + ) else: - sampler = SliceSampler(num_slices=num_slices, slice_len=slice_len, trajectories=storage.episodes, cache_values=True) + sampler = SliceSampler( + num_slices=num_slices, + slice_len=slice_len, + trajectories=storage.episodes, + cache_values=True, + ) elif not replacement: sampler = SamplerWithoutReplacement() - super().__init__(storage=storage, batch_size=batch_size, writer=writer, sampler=sampler, collate_fn=lambda x: x, transform=transform, **kwargs) + super().__init__( + storage=storage, + batch_size=batch_size, + writer=writer, + sampler=sampler, + collate_fn=lambda x: x, + transform=transform, + **kwargs, + ) @property def episodes(self): return self._storage.episodes @property - def root(self)->Path: + def root(self) -> Path: return self._root + @root.setter def root(self, value): self._root = Path(value) + @property def dataset_path(self) -> Path: return self._root / self.dataset_id + @property def _is_downloaded(self): if os.path.exists(self.dataset_path / "processed.json"): @@ -315,7 +484,9 @@ def _is_downloaded(self): return False def _download_and_preproc(self): - logging.info(f"Downloading and preprocessing dataset {self.dataset_id} with {self.num_procs} processes. This may take a while...") + logging.info( + f"Downloading and preprocessing dataset {self.dataset_id} with {self.num_procs} processes. This may take a while..." + ) if os.path.exists(self.dataset_path): shutil.rmtree(self.dataset_path) with tempfile.TemporaryDirectory() as tempdir: @@ -329,21 +500,39 @@ def _download_and_preproc(self): if not os.listdir(tempdir): os.makedirs(tempdir, exist_ok=True) # get the list of runs - command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/Pong/1/replay_logs" + command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/{self.dataset_id}/replay_logs" output = subprocess.run( - command, - shell=True, capture_output=True + command, shell=True, capture_output=True ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - files = [file.decode("utf-8").replace('$', '\$') for file in output.stdout.splitlines() if - file.endswith(b'.gz')] + files = [ + file.decode("utf-8").replace("$", "\$") + for file in output.stdout.splitlines() + if file.endswith(b".gz") + ] self.remote_gz_files = self._list_runs(None, files) total_runs = list(self.remote_gz_files)[-1] if self.num_procs == 0: for run, run_files in self.remote_gz_files.items(): - self._download_and_proc_split(run, run_files, tempdir=tempdir, dataset_path=self.dataset_path, total_episodes=total_runs) + self._download_and_proc_split( + run, + run_files, + tempdir=tempdir, + dataset_path=self.dataset_path, + total_episodes=total_runs, + max_runs=self._max_runs, + ) else: - func = functools.partial(self._download_and_proc_split, tempdir=tempdir, dataset_path=self.dataset_path, total_episodes=total_runs) - args = [(run, run_files) for (run, run_files) in self.remote_gz_files.items()] + func = functools.partial( + self._download_and_proc_split, + tempdir=tempdir, + dataset_path=self.dataset_path, + total_episodes=total_runs, + max_runs=self._max_runs, + ) + args = [ + (run, run_files) + for (run, run_files) in self.remote_gz_files.items() + ] with mp.Pool(self.num_procs) as pool: pool.starmap(func, args) with open(self.dataset_path / "processed.json", "w") as file: @@ -352,17 +541,18 @@ def _download_and_preproc(self): json.dump({"processed": self._max_runs}, file) @classmethod - def _download_and_proc_split(cls, run, run_files, *, tempdir, dataset_path, total_episodes): - if cls._max_runs is not None and run >= cls._max_runs: + def _download_and_proc_split( + cls, run, run_files, *, tempdir, dataset_path, total_episodes, max_runs + ): + if (max_runs is not None) and (run >= max_runs): return tempdir = Path(tempdir) os.makedirs(tempdir / str(run)) - files_str = ' '.join(run_files) # .decode("utf-8") + files_str = " ".join(run_files) # .decode("utf-8") logging.info("downloading", files_str) command = f"gsutil -m cp {files_str} {tempdir}/{run}" subprocess.run( - command, - shell=True + command, shell=True ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) local_gz_files = cls._list_runs(tempdir / str(run)) # we iterate over the dict but this one has length 1 @@ -374,7 +564,7 @@ def _download_and_proc_split(cls, run, run_files, *, tempdir, dataset_path, tota shutil.rmtree(path) raise shutil.rmtree(tempdir / str(run)) - logging.info(f'Concluded run {run} out of {total_episodes}') + logging.info(f"Concluded run {run} out of {total_episodes}") @classmethod def _preproc_run(cls, path, gz_files, run): @@ -400,7 +590,11 @@ def _preproc_run(cls, path, gz_files, run): td[key] = mmap # td["data", "next", key[1:]] = mmap[1:] else: - if key in (("data", "reward"), ("data", "done"), ("data", "terminated")): + if key in ( + ("data", "reward"), + ("data", "done"), + ("data", "terminated"), + ): filename = path / "data" / "next" / (key[-1] + ".memmap") os.makedirs(filename.parent, exist_ok=True) mmap = MemoryMappedTensor.from_tensor(t, filename=filename) @@ -452,8 +646,11 @@ def __init__(self, path): self.path = Path(path) def get_folders(path): - return [name for name in os.listdir(path) if - os.path.isdir(os.path.join(path, name))] + return [ + name + for name in os.listdir(path) + if os.path.isdir(os.path.join(path, name)) + ] # Usage self.splits = [] @@ -467,14 +664,25 @@ def get_folders(path): path = self.path / str(split) self._split_tds.append(self._load_split(path)) # take away 1 because we padded with 1 empty val - frames_per_split[split] = self._split_tds[-1].get(("data", "observation")).shape[0] - 1 + frames_per_split[split] = ( + self._split_tds[-1].get(("data", "observation")).shape[0] - 1 + ) - frames_per_split = torch.tensor([[split, length] for (split, length) in frames_per_split.items()]) + frames_per_split = torch.tensor( + [[split, length] for (split, length) in frames_per_split.items()] + ) frames_per_split[:, 1] = frames_per_split[:, 1].cumsum(0) - self.frames_per_split = torch.cat([torch.tensor([[-1, 0]]), frames_per_split], 0) + self.frames_per_split = torch.cat( + [torch.tensor([[-1, 0]]), frames_per_split], 0 + ) # retrieve episodes - self.episodes = torch.cumsum(torch.cat([td.get(("data", "next", "terminated")) for td in self._split_tds], 0), 0) + self.episodes = torch.cumsum( + torch.cat( + [td.get(("data", "next", "terminated")) for td in self._split_tds], 0 + ), + 0, + ) def __len__(self): return self.frames_per_split[-1, 1].item() @@ -484,13 +692,15 @@ def _read_from_splits(self, item: int | torch.Tensor): # We don't assume each storage has the same size (too expensive to test) # so we keep a map of each storage cumulative length and retrieve the # storages one after the other. - split = (item < self.frames_per_split[1:, 1].unsqueeze(1)) & (item >= self.frames_per_split[:-1, 1].unsqueeze(1)) + split = (item < self.frames_per_split[1:, 1].unsqueeze(1)) & ( + item >= self.frames_per_split[:-1, 1].unsqueeze(1) + ) split_tmp, idx = split.squeeze().nonzero().unbind(-1) split = torch.zeros_like(split_tmp) split[idx] = split_tmp split = self.frames_per_split[split + 1, 0] item = item - self.frames_per_split[split, 1] - assert (item>=0).all() + assert (item >= 0).all() if isinstance(item, int): unique_splits = (split,) split_inverse = None @@ -500,7 +710,7 @@ def _read_from_splits(self, item: int | torch.Tensor): out = [] for i, split in enumerate(unique_splits): _item = item[split_inverse == i] if split_inverse is not None else item - out.append( self._proc_td(self._split_tds[split], _item)) + out.append(self._proc_td(self._split_tds[split], _item)) return torch.cat(out, 0) def _load_split(self, path): @@ -554,13 +764,33 @@ def get(self, index): return self.get(torch.arange(start, stop, step)) return self[torch.arange(len(self))[index]] -if __name__ == '__main__': - import logging - logging.getLogger().setLevel(logging.INFO) - AtariDQNExperienceReplay._max_runs = 3 - dataset = AtariDQNExperienceReplay("Pong/5", num_procs=4, num_slices=4, batch_size=128, replacement=False) - torch.manual_seed(0) - for i, data in enumerate(dataset): - print(data) - if i == 10: - break + +if __name__ == "__main__": + # import logging + # logging.getLogger().setLevel(logging.INFO) + # AtariDQNExperienceReplay._max_runs = 3 + # dataset = AtariDQNExperienceReplay("Pong/5", num_procs=4, num_slices=4, batch_size=128, replacement=False) + # torch.manual_seed(0) + # for i, data in enumerate(dataset): + # print(data) + # if i == 10: + # break + + from torchrl.data.datasets import AtariDQNExperienceReplay + from torchrl.data.replay_buffers import ReplayBufferEnsemble + + # we change this parameter for quick experimentation, in practice it should be left untouched + AtariDQNExperienceReplay._max_runs = 2 + dataset_asterix = AtariDQNExperienceReplay( + "Asterix/5", batch_size=128, slice_len=64, num_procs=4 + ) + dataset_pong = AtariDQNExperienceReplay( + "Pong/5", batch_size=128, slice_len=64, num_procs=4 + ) + dataset = ReplayBufferEnsemble( + dataset_pong, dataset_asterix, batch_size=128, sample_from_all=True + ) + sample = dataset.sample() + print("first sample, Asterix", sample[0]) + print("second sample, Pong", sample[1]) + print("Aggregate (metadata hidden)", sample) From fa8ed5a4bd7b469f361d2f1de03791f2836ce9c9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 14:58:46 +0000 Subject: [PATCH 10/17] amend --- torchrl/data/datasets/atari_dqn.py | 42 ++----------------- torchrl/data/replay_buffers/replay_buffers.py | 2 +- 2 files changed, 4 insertions(+), 40 deletions(-) diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index b0fd3f83357..cde6db8fe80 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -400,7 +400,7 @@ def __init__( download: bool | str = True, sampler=None, writer=None, - transform: "Transform" | None = None, + transform: "Transform" | None = None, # noqa: F821 num_procs: int = 0, num_slices: int | None = None, slice_len: int | None = None, @@ -492,11 +492,6 @@ def _download_and_preproc(self): with tempfile.TemporaryDirectory() as tempdir: if self.tmpdir is not None: tempdir = self.tmpdir - try: - shutil.rmtree(tempdir) - os.makedirs(tempdir, exist_ok=True) - except: - os.makedirs(tempdir, exist_ok=True) if not os.listdir(tempdir): os.makedirs(tempdir, exist_ok=True) # get the list of runs @@ -505,7 +500,7 @@ def _download_and_preproc(self): command, shell=True, capture_output=True ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) files = [ - file.decode("utf-8").replace("$", "\$") + file.decode("utf-8").replace("$", "\$") # noqa: W605 for file in output.stdout.splitlines() if file.endswith(b".gz") ] @@ -628,7 +623,7 @@ def _list_runs(cls, download_path, gz_files=None): path = download_path if gz_files is None: gz_files = [] - for root, dirs, files in os.walk(path): + for root, _, files in os.walk(path): for file in files: if file.endswith(".gz"): gz_files.append(os.path.join(root, file)) @@ -763,34 +758,3 @@ def get(self, index): step = index.step if index.step is not None else 1 return self.get(torch.arange(start, stop, step)) return self[torch.arange(len(self))[index]] - - -if __name__ == "__main__": - # import logging - # logging.getLogger().setLevel(logging.INFO) - # AtariDQNExperienceReplay._max_runs = 3 - # dataset = AtariDQNExperienceReplay("Pong/5", num_procs=4, num_slices=4, batch_size=128, replacement=False) - # torch.manual_seed(0) - # for i, data in enumerate(dataset): - # print(data) - # if i == 10: - # break - - from torchrl.data.datasets import AtariDQNExperienceReplay - from torchrl.data.replay_buffers import ReplayBufferEnsemble - - # we change this parameter for quick experimentation, in practice it should be left untouched - AtariDQNExperienceReplay._max_runs = 2 - dataset_asterix = AtariDQNExperienceReplay( - "Asterix/5", batch_size=128, slice_len=64, num_procs=4 - ) - dataset_pong = AtariDQNExperienceReplay( - "Pong/5", batch_size=128, slice_len=64, num_procs=4 - ) - dataset = ReplayBufferEnsemble( - dataset_pong, dataset_asterix, batch_size=128, sample_from_all=True - ) - sample = dataset.sample() - print("first sample, Asterix", sample[0]) - print("second sample, Pong", sample[1]) - print("Aggregate (metadata hidden)", sample) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index b17fcf9d2f5..1e1ce31bf96 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -208,11 +208,11 @@ def set_storage(self, storage: Storage, collate_fn: Callable | None = None): """Sets a new storage in the replay buffer and returns the previous storage. Args: + storage (Storage): the new storage for the buffer. collate_fn (callable, optional): if provided, the collate_fn is set to this value. Otherwise it is reset to a default value. """ - prev_storage = self._storage self._storage = storage self._get_collate_fn(collate_fn) From 583a04da026751d605c52a95da9cf0f7d0b2fac3 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 15:32:34 +0000 Subject: [PATCH 11/17] amend --- .../scripts_ataridqn/environment.yml | 24 ++ .../linux_libs/scripts_ataridqn/install.sh | 51 +++ .../scripts_ataridqn/post_process.sh | 6 + .../scripts_ataridqn/run-clang-format.py | 356 ++++++++++++++++++ .../linux_libs/scripts_ataridqn/run_test.sh | 24 ++ .../linux_libs/scripts_ataridqn/setup_env.sh | 50 +++ test/test_libs.py | 32 +- 7 files changed, 542 insertions(+), 1 deletion(-) create mode 100644 .github/unittest/linux_libs/scripts_ataridqn/environment.yml create mode 100755 .github/unittest/linux_libs/scripts_ataridqn/install.sh create mode 100755 .github/unittest/linux_libs/scripts_ataridqn/post_process.sh create mode 100755 .github/unittest/linux_libs/scripts_ataridqn/run-clang-format.py create mode 100755 .github/unittest/linux_libs/scripts_ataridqn/run_test.sh create mode 100755 .github/unittest/linux_libs/scripts_ataridqn/setup_env.sh diff --git a/.github/unittest/linux_libs/scripts_ataridqn/environment.yml b/.github/unittest/linux_libs/scripts_ataridqn/environment.yml new file mode 100644 index 00000000000..092251f6361 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_ataridqn/environment.yml @@ -0,0 +1,24 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - gsutil + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - scipy + - hydra-core + - tqdm + - h5py + - datasets + - pillow diff --git a/.github/unittest/linux_libs/scripts_ataridqn/install.sh b/.github/unittest/linux_libs/scripts_ataridqn/install.sh new file mode 100755 index 00000000000..1be476425a6 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_ataridqn/install.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. +apt-get update && apt-get install -y git wget gcc g++ +#apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu +else + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_ataridqn/post_process.sh b/.github/unittest/linux_libs/scripts_ataridqn/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_ataridqn/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_ataridqn/run-clang-format.py b/.github/unittest/linux_libs/scripts_ataridqn/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_ataridqn/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_libs/scripts_ataridqn/run_test.sh b/.github/unittest/linux_libs/scripts_ataridqn/run_test.sh new file mode 100755 index 00000000000..ee7bf9b46b1 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_ataridqn/run_test.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 +ln -s /usr/bin/swig3.0 /usr/bin/swig + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +conda deactivate && conda activate ./env + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestAtariDQN --error-for-skips --runslow +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_ataridqn/setup_env.sh b/.github/unittest/linux_libs/scripts_ataridqn/setup_env.sh new file mode 100755 index 00000000000..5b415112814 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_ataridqn/setup_env.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e +set -v + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +apt-get update && apt-get install -y git wget gcc g++ unzip curl + +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip3 install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/test/test_libs.py b/test/test_libs.py index e034cea84c7..35b787c2739 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -58,7 +58,7 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, - MultiDiscreteTensorSpec, + MultiDiscreteTensorSpec,ReplayBufferEnsemble, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, @@ -70,6 +70,7 @@ from torchrl.data.datasets.openx import OpenXExperienceReplay from torchrl.data.datasets.roboset import RobosetExperienceReplay from torchrl.data.datasets.vd4rl import VD4RLExperienceReplay +from torchrl.data.datasets.atari_dqn import AtariDQNExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import ( CatTensors, @@ -2489,6 +2490,35 @@ def test_load(self, image_size): break +@pytest.mark.slow +class TestAtariDQN: + @pytest.fixture(scope="class") + def limit_max_runs(self): + prev_val = AtariDQNExperienceReplay._max_runs + AtariDQNExperienceReplay._max_runs = 3 + yield + AtariDQNExperienceReplay._max_runs = prev_val + + @pytest.mark.parametrize("dataset", ["Asterix/1", "Pong/4"]) + @pytest.mark.parametrize("num_slices,slice_len",[[None, None], [None, 8], [2, None]]) + def test_single_dataset(self, dataset, slice_len, num_slices, limit_max_runs): + dataset = AtariDQNExperienceReplay(dataset, slice_len=slice_len, num_slices=num_slices) + sample = dataset.sample(64) + for key in (("next", "observation"), ("next", "truncated"), ("next", "terminated"), ("next", "done"), ("next", "reward"), "observation", "action", "done", "truncated", "terminated"): + assert key in sample.keys(True) + assert sample.shape == (64,) + + @pytest.mark.parametrize("num_slices,slice_len",[[None, None], [None, 8], [2, None]]) + def test_double_dataset(self, slice_len, num_slices, limit_max_runs): + dataset_pong = AtariDQNExperienceReplay("Pong/4", slice_len=slice_len, num_slices=num_slices) + dataset_asterix = AtariDQNExperienceReplay("Asterix/1", slice_len=slice_len, num_slices=num_slices) + dataset = ReplayBufferEnsemble(dataset_pong, dataset_asterix, sample_from_all=True, batch_size=128) + sample = dataset.sample() + assert sample.shape == (2, 64) + assert sample[0]["metadata"]["dataset_id"] == "Pong/4" + assert sample[0]["metadata"]["dataset_id"] == "Asterix/1" + + @pytest.mark.slow class TestOpenX: @pytest.mark.parametrize( From 5bb4995948b6b0b4388196d086673dad2d1c70ca Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 15:41:53 +0000 Subject: [PATCH 12/17] amend --- .github/workflows/test-linux-libs.yml | 26 +++++++++++++++++ test/test_libs.py | 42 +++++++++++++++++++++------ 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 3b090582e4f..abf78e5e19c 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -16,6 +16,32 @@ concurrency: cancel-in-progress: true jobs: + + unittests-atari-dqn: + strategy: + matrix: + python_version: ["3.9"] + cuda_arch_version: ["12.1"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + docker-image: "nvidia/cudagl:11.4.0-base" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="cu117" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + bash .github/unittest/linux_libs/scripts_ataridqn/setup_env.sh + bash .github/unittest/linux_libs/scripts_ataridqn/install.sh + bash .github/unittest/linux_libs/scripts_ataridqn/run_test.sh + bash .github/unittest/linux_libs/scripts_ataridqn/post_process.sh + unittests-brax: strategy: matrix: diff --git a/test/test_libs.py b/test/test_libs.py index 35b787c2739..d9b3a200517 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -58,19 +58,20 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, - MultiDiscreteTensorSpec,ReplayBufferEnsemble, + MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, + ReplayBufferEnsemble, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, ) +from torchrl.data.datasets.atari_dqn import AtariDQNExperienceReplay from torchrl.data.datasets.d4rl import D4RLExperienceReplay from torchrl.data.datasets.minari_data import MinariExperienceReplay from torchrl.data.datasets.openml import OpenMLExperienceReplay from torchrl.data.datasets.openx import OpenXExperienceReplay from torchrl.data.datasets.roboset import RobosetExperienceReplay from torchrl.data.datasets.vd4rl import VD4RLExperienceReplay -from torchrl.data.datasets.atari_dqn import AtariDQNExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import ( CatTensors, @@ -2500,19 +2501,42 @@ def limit_max_runs(self): AtariDQNExperienceReplay._max_runs = prev_val @pytest.mark.parametrize("dataset", ["Asterix/1", "Pong/4"]) - @pytest.mark.parametrize("num_slices,slice_len",[[None, None], [None, 8], [2, None]]) + @pytest.mark.parametrize( + "num_slices,slice_len", [[None, None], [None, 8], [2, None]] + ) def test_single_dataset(self, dataset, slice_len, num_slices, limit_max_runs): - dataset = AtariDQNExperienceReplay(dataset, slice_len=slice_len, num_slices=num_slices) + dataset = AtariDQNExperienceReplay( + dataset, slice_len=slice_len, num_slices=num_slices + ) sample = dataset.sample(64) - for key in (("next", "observation"), ("next", "truncated"), ("next", "terminated"), ("next", "done"), ("next", "reward"), "observation", "action", "done", "truncated", "terminated"): + for key in ( + ("next", "observation"), + ("next", "truncated"), + ("next", "terminated"), + ("next", "done"), + ("next", "reward"), + "observation", + "action", + "done", + "truncated", + "terminated", + ): assert key in sample.keys(True) assert sample.shape == (64,) - @pytest.mark.parametrize("num_slices,slice_len",[[None, None], [None, 8], [2, None]]) + @pytest.mark.parametrize( + "num_slices,slice_len", [[None, None], [None, 8], [2, None]] + ) def test_double_dataset(self, slice_len, num_slices, limit_max_runs): - dataset_pong = AtariDQNExperienceReplay("Pong/4", slice_len=slice_len, num_slices=num_slices) - dataset_asterix = AtariDQNExperienceReplay("Asterix/1", slice_len=slice_len, num_slices=num_slices) - dataset = ReplayBufferEnsemble(dataset_pong, dataset_asterix, sample_from_all=True, batch_size=128) + dataset_pong = AtariDQNExperienceReplay( + "Pong/4", slice_len=slice_len, num_slices=num_slices + ) + dataset_asterix = AtariDQNExperienceReplay( + "Asterix/1", slice_len=slice_len, num_slices=num_slices + ) + dataset = ReplayBufferEnsemble( + dataset_pong, dataset_asterix, sample_from_all=True, batch_size=128 + ) sample = dataset.sample() assert sample.shape == (2, 64) assert sample[0]["metadata"]["dataset_id"] == "Pong/4" From f71ccbb5151ac3aad049a3b806b6eca3e5a3a167 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 16:01:43 +0000 Subject: [PATCH 13/17] amend --- .github/unittest/linux_libs/scripts_ataridqn/environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/unittest/linux_libs/scripts_ataridqn/environment.yml b/.github/unittest/linux_libs/scripts_ataridqn/environment.yml index 092251f6361..b88860dddde 100644 --- a/.github/unittest/linux_libs/scripts_ataridqn/environment.yml +++ b/.github/unittest/linux_libs/scripts_ataridqn/environment.yml @@ -1,6 +1,7 @@ channels: - pytorch - defaults + - conda-forge dependencies: - pip - gsutil From 5961bd9c58d072ae6e02c77b2ebb6adcef9b3b39 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 16:26:52 +0000 Subject: [PATCH 14/17] amend --- torchrl/data/datasets/atari_dqn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index cde6db8fe80..2e93d801783 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -386,8 +386,9 @@ def available_datasets(cls): ] return ["/".join((game, str(loop))) for game in games for loop in range(1, 6)] - tmpdir = "/Users/vmoens/.cache/atari_root" + # If we want to keep track of the original atari files + tmpdir = None # use _max_runs for debugging, avoids downloading the entire dataset _max_runs = None From 6790d363234d97954ff70d6e452fb8bdccaa84e6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 16:31:47 +0000 Subject: [PATCH 15/17] lint --- torchrl/data/datasets/atari_dqn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index 2e93d801783..93950532026 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -386,7 +386,6 @@ def available_datasets(cls): ] return ["/".join((game, str(loop))) for game in games for loop in range(1, 6)] - # If we want to keep track of the original atari files tmpdir = None # use _max_runs for debugging, avoids downloading the entire dataset From 787f76ed4bbcc97262dbeaece551a6e74b480747 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 17:15:26 +0000 Subject: [PATCH 16/17] fixes --- test/test_libs.py | 5 +++-- test/test_rb.py | 3 ++- torchrl/data/replay_buffers/samplers.py | 14 ++++++++++++-- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index d9b3a200517..a6f7c91a659 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2523,6 +2523,7 @@ def test_single_dataset(self, dataset, slice_len, num_slices, limit_max_runs): ): assert key in sample.keys(True) assert sample.shape == (64,) + assert sample.get_non_tensor("metadata")["dataset_id"] == dataset @pytest.mark.parametrize( "num_slices,slice_len", [[None, None], [None, 8], [2, None]] @@ -2539,8 +2540,8 @@ def test_double_dataset(self, slice_len, num_slices, limit_max_runs): ) sample = dataset.sample() assert sample.shape == (2, 64) - assert sample[0]["metadata"]["dataset_id"] == "Pong/4" - assert sample[0]["metadata"]["dataset_id"] == "Asterix/1" + assert sample[0].get_non_tensor("metadata")["dataset_id"] == "Pong/4" + assert sample[0].get_non_tensor("metadata")["dataset_id"] == "Asterix/1" @pytest.mark.slow diff --git a/test/test_rb.py b/test/test_rb.py index cf9deabb956..5d184c365e2 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1808,7 +1808,8 @@ def test_slice_sampler_errors(self): storage.set(range(100), data) sampler = SliceSampler(num_slices=num_slices) with pytest.raises( - RuntimeError, match="can only sample from TensorStorage subclasses" + RuntimeError, + match="Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories.", ): index, _ = sampler.sample(storage, batch_size=batch_size) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index be4ee9c6406..3460f6ed51c 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -750,7 +750,12 @@ def _get_stop_and_length(self, storage, fallback=True): and self._used_traj_key[0] == "_data" ) else: - trajectory = storage[:].get(self.traj_key) + try: + trajectory = storage[:].get(self.traj_key) + except Exception: + raise RuntimeError( + "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." + ) vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)]) if self.cache_values: self._cache["stop-and-length"] = vals @@ -779,7 +784,12 @@ def _get_stop_and_length(self, storage, fallback=True): and self._used_end_key[0] == "_data" ) else: - done = storage[:].get(self.end_key) + try: + done = storage[:].get(self.end_key) + except Exception: + raise RuntimeError( + "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." + ) vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)] if self.cache_values: self._cache["stop-and-length"] = vals From 0e0982de8ebed5c028733ea3616f0bcf90ca6e2a Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 17:32:32 +0000 Subject: [PATCH 17/17] fixes --- test/test_libs.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index a6f7c91a659..13891331b05 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2500,13 +2500,13 @@ def limit_max_runs(self): yield AtariDQNExperienceReplay._max_runs = prev_val - @pytest.mark.parametrize("dataset", ["Asterix/1", "Pong/4"]) + @pytest.mark.parametrize("dataset_id", ["Asterix/1", "Pong/4"]) @pytest.mark.parametrize( "num_slices,slice_len", [[None, None], [None, 8], [2, None]] ) - def test_single_dataset(self, dataset, slice_len, num_slices, limit_max_runs): + def test_single_dataset(self, dataset_id, slice_len, num_slices, limit_max_runs): dataset = AtariDQNExperienceReplay( - dataset, slice_len=slice_len, num_slices=num_slices + dataset_id, slice_len=slice_len, num_slices=num_slices ) sample = dataset.sample(64) for key in ( @@ -2523,7 +2523,7 @@ def test_single_dataset(self, dataset, slice_len, num_slices, limit_max_runs): ): assert key in sample.keys(True) assert sample.shape == (64,) - assert sample.get_non_tensor("metadata")["dataset_id"] == dataset + assert sample.get_non_tensor("metadata")["dataset_id"] == dataset_id @pytest.mark.parametrize( "num_slices,slice_len", [[None, None], [None, 8], [2, None]] @@ -2541,7 +2541,7 @@ def test_double_dataset(self, slice_len, num_slices, limit_max_runs): sample = dataset.sample() assert sample.shape == (2, 64) assert sample[0].get_non_tensor("metadata")["dataset_id"] == "Pong/4" - assert sample[0].get_non_tensor("metadata")["dataset_id"] == "Asterix/1" + assert sample[1].get_non_tensor("metadata")["dataset_id"] == "Asterix/1" @pytest.mark.slow