diff --git a/.github/unittest/linux_libs/scripts_d4rl/run_test.sh b/.github/unittest/linux_libs/scripts_d4rl/run_test.sh index 3723399a859..2f09be54eac 100755 --- a/.github/unittest/linux_libs/scripts_d4rl/run_test.sh +++ b/.github/unittest/linux_libs/scripts_d4rl/run_test.sh @@ -37,25 +37,6 @@ conda deactivate && conda activate ./env # this workflow only tests the libs python -c "import gym, d4rl" -python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestD4RL --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestD4RL --error-for-skips --runslow coverage combine coverage xml -i - -## check what happens if we update gym -#pip install gym -U -#python -c """ -#from torchrl.data.datasets import D4RLExperienceReplay -#data = D4RLExperienceReplay('halfcheetah-medium-v2', batch_size=10, from_env=False, direct_download=True) -#for batch in data: -# print(batch) -# break -# -#data = D4RLExperienceReplay('halfcheetah-medium-v2', batch_size=10, from_env=False, direct_download=False) -#for batch in data: -# print(batch) -# break -# -#import d4rl -#import gym -#gym.make('halfcheetah-medium-v2') -#""" diff --git a/.github/unittest/linux_libs/scripts_minari/run_test.sh b/.github/unittest/linux_libs/scripts_minari/run_test.sh index 7741a491f5b..0567e2be25d 100755 --- a/.github/unittest/linux_libs/scripts_minari/run_test.sh +++ b/.github/unittest/linux_libs/scripts_minari/run_test.sh @@ -22,6 +22,6 @@ conda deactivate && conda activate ./env # this workflow only tests the libs python -c "import minari" -python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestMinari --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestMinari --error-for-skips --runslow coverage combine coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_roboset/environment.yml b/.github/unittest/linux_libs/scripts_roboset/environment.yml new file mode 100644 index 00000000000..472ea296769 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_roboset/environment.yml @@ -0,0 +1,22 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - scipy + - hydra-core + - huggingface_hub + - tqdm + - h5py diff --git a/.github/unittest/linux_libs/scripts_roboset/install.sh b/.github/unittest/linux_libs/scripts_roboset/install.sh new file mode 100755 index 00000000000..2eb52b8f65e --- /dev/null +++ b/.github/unittest/linux_libs/scripts_roboset/install.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. +apt-get update && apt-get install -y git wget gcc g++ +#apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_roboset/post_process.sh b/.github/unittest/linux_libs/scripts_roboset/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_roboset/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_roboset/run-clang-format.py b/.github/unittest/linux_libs/scripts_roboset/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_roboset/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_libs/scripts_roboset/run_test.sh b/.github/unittest/linux_libs/scripts_roboset/run_test.sh new file mode 100755 index 00000000000..67ae605a43e --- /dev/null +++ b/.github/unittest/linux_libs/scripts_roboset/run_test.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 +ln -s /usr/bin/swig3.0 /usr/bin/swig + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +conda deactivate && conda activate ./env + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestRoboset --error-for-skips --runslow +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_roboset/setup_env.sh b/.github/unittest/linux_libs/scripts_roboset/setup_env.sh new file mode 100755 index 00000000000..5214617c2ac --- /dev/null +++ b/.github/unittest/linux_libs/scripts_roboset/setup_env.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e +set -v + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +apt-get update && apt-get install -y git wget gcc g++ unzip + +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip3 install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/workflows/test-linux-roboset.yml b/.github/workflows/test-linux-roboset.yml new file mode 100644 index 00000000000..e77a69c05b9 --- /dev/null +++ b/.github/workflows/test-linux-roboset.yml @@ -0,0 +1,42 @@ +name: Roboset Tests on Linux + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +concurrency: + # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. + # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} + cancel-in-progress: true + +jobs: + unittests: + strategy: + matrix: + python_version: ["3.9"] + cuda_arch_version: ["12.1"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + docker-image: "nvidia/cudagl:11.4.0-base" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="cu117" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + bash .github/unittest/linux_libs/scripts_roboset/setup_env.sh + bash .github/unittest/linux_libs/scripts_roboset/install.sh + bash .github/unittest/linux_libs/scripts_roboset/run_test.sh + bash .github/unittest/linux_libs/scripts_roboset/post_process.sh diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index b507e716c3d..c413f986499 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -259,6 +259,7 @@ Here's an example: D4RLExperienceReplay MinariExperienceReplay OpenMLExperienceReplay + RobosetExperienceReplay TensorSpec ---------- diff --git a/setup.py b/setup.py index 07880010189..36d84aa09e3 100644 --- a/setup.py +++ b/setup.py @@ -212,6 +212,14 @@ def _main(argv): "checkpointing": [ "torchsnapshot", ], + "offline-data": [ + "huggingface_hub", # for roboset + "minari", + "tqdm", + "scikit-learn", + "pandas", + "h5py", + ], "marl": ["vmas>=1.2.10", "pettingzoo>=1.24.1"], } extra_requires["all"] = set() diff --git a/test/conftest.py b/test/conftest.py index 048b9e6c49e..f392cb7d4f1 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -86,3 +86,23 @@ def set_warnings() -> None: category=DeprecationWarning, message=r"jax.tree_util.register_keypaths is deprecated|jax.ShapedArray is deprecated", ) + + +def pytest_addoption(parser): + parser.addoption( + "--runslow", action="store_true", default=False, help="run slow tests" + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "slow: mark test as slow to run") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runslow"): + # --runslow given in cli: do not skip slow tests + return + skip_slow = pytest.mark.skip(reason="need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) diff --git a/test/test_libs.py b/test/test_libs.py index 7f42d52d63e..996de85a8f7 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -52,6 +52,7 @@ from torchrl.data.datasets.d4rl import D4RLExperienceReplay from torchrl.data.datasets.minari_data import MinariExperienceReplay from torchrl.data.datasets.openml import OpenMLExperienceReplay +from torchrl.data.datasets.roboset import RobosetExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import ( Compose, @@ -1827,6 +1828,7 @@ def test_grouping(self, n_agents, scenario_name="dispersion", n_envs=2): @pytest.mark.skipif(not _has_d4rl, reason="D4RL not found") +@pytest.mark.slow class TestD4RL: @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) @pytest.mark.parametrize("use_truncated_as_done", [True, False]) @@ -2020,6 +2022,7 @@ def _minari_selected_datasets(): @pytest.mark.skipif(not _has_minari, reason="Minari not found") @pytest.mark.parametrize("split", [False, True]) @pytest.mark.parametrize("selected_dataset", _MINARI_DATASETS) +@pytest.mark.slow class TestMinari: def test_load(self, selected_dataset, split): print("dataset", selected_dataset) @@ -2037,6 +2040,23 @@ def test_load(self, selected_dataset, split): break +@pytest.mark.slow +class TestRoboset: + def test_load(self): + selected_dataset = RobosetExperienceReplay.available_datasets[0] + data = RobosetExperienceReplay( + selected_dataset, + batch_size=32, + ) + t0 = time.time() + for i, _ in enumerate(data): + t1 = time.time() + print(f"sampling time {1000 * (t1-t0): 4.4f}ms") + t0 = time.time() + if i == 10: + break + + @pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found") @pytest.mark.parametrize( "dataset", diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 514b884ee7e..90bc9f12398 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -617,3 +617,47 @@ def run(self, *args, **kwargs): warnings.simplefilter("ignore") return mp.Process.run(self, *args, **kwargs) return mp.Process.run(self, *args, **kwargs) + + +def print_directory_tree(path, indent="", display_metadata=True): + """Prints the directory tree starting from the specified path. + + Args: + path (str): The path of the directory to print. + indent (str): The current indentation level for formatting. + display_metadata (bool): if ``True``, metadata of the dir will be + displayed too. + + """ + if display_metadata: + + def get_directory_size(path="."): + total_size = 0 + + for dirpath, _, filenames in os.walk(path): + for filename in filenames: + file_path = os.path.join(dirpath, filename) + total_size += os.path.getsize(file_path) + + return total_size + + def format_size(size): + # Convert size to a human-readable format + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size < 1024.0: + return f"{size:.2f} {unit}" + size /= 1024.0 + + total_size_bytes = get_directory_size(path) + formatted_size = format_size(total_size_bytes) + print(f"Directory size: {formatted_size}") + + if os.path.isdir(path): + print(indent + os.path.basename(path) + "/") + indent += " " + for item in os.listdir(path): + print_directory_tree( + os.path.join(path, item), indent=indent, display_metadata=False + ) + else: + print(indent + os.path.basename(path)) diff --git a/torchrl/data/datasets/__init__.py b/torchrl/data/datasets/__init__.py index 85b8e064917..b4fbf9a54e0 100644 --- a/torchrl/data/datasets/__init__.py +++ b/torchrl/data/datasets/__init__.py @@ -1,3 +1,4 @@ from .d4rl import D4RLExperienceReplay from .minari_data import MinariExperienceReplay from .openml import OpenMLExperienceReplay +from .roboset import RobosetExperienceReplay diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 492ac0fff58..1c89d1a869b 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -57,8 +57,9 @@ class MinariExperienceReplay(TensorDictReplayBuffer): """Minari Experience replay dataset. Args: - dataset_id (str): - batch_size (int): + dataset_id (str): The dataset to be downloaded. Must be part of MinariExperienceReplay.available_datasets + batch_size (int): Batch-size used during sampling. Can be overridden by `data.sample(batch_size)` if + necessary. Keyword Args: root (Path or str, optional): The Minari dataset root directory. @@ -91,6 +92,9 @@ class MinariExperienceReplay(TensorDictReplayBuffer): accurate choices regarding this usage of ``split_trajs``. Defaults to ``False``. + Attributes: + available_datasets: a list of accepted entries to be downloaded. + .. note:: Text data is currenrtly discarded from the wrapped dataset, as there is not PyTorch native way of representing text data. @@ -197,6 +201,7 @@ def __init__( pin_memory=pin_memory, prefetch=prefetch, batch_size=batch_size, + transform=transform, ) def available_datasets(self): @@ -238,30 +243,32 @@ def _download_and_preproc(self): h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") # populate the tensordict episode_dict = {} - for episode_key, episode in h5_data.items(): + for i, (episode_key, episode) in enumerate(h5_data.items()): episode_num = int(episode_key[len("episode_") :]) episode_len = episode["actions"].shape[0] episode_dict[episode_num] = (episode_key, episode_len) # Get the total number of steps for the dataset total_steps += episode_len - for key, val in episode.items(): - match = _NAME_MATCH[key] - if key in ("observations", "state", "infos"): - if ( - not val.shape - ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: - if val.is_empty(): - continue - val = _patch_info(val) - td_data.set(("next", match), torch.zeros_like(val[0])) - td_data.set(match, torch.zeros_like(val[0])) - if key not in ("terminations", "truncations", "rewards"): - td_data.set(match, torch.zeros_like(val[0])) - else: - td_data.set( - ("next", match), - torch.zeros_like(val[0].unsqueeze(-1)), - ) + if i == 0: + td_data.set("episode", 0) + for key, val in episode.items(): + match = _NAME_MATCH[key] + if key in ("observations", "state", "infos"): + if ( + not val.shape + ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: + if val.is_empty(): + continue + val = _patch_info(val) + td_data.set(("next", match), torch.zeros_like(val[0])) + td_data.set(match, torch.zeros_like(val[0])) + if key not in ("terminations", "truncations", "rewards"): + td_data.set(match, torch.zeros_like(val[0])) + else: + td_data.set( + ("next", match), + torch.zeros_like(val[0].unsqueeze(-1)), + ) # give it the proper size td_data["next", "done"] = ( @@ -284,6 +291,7 @@ def _download_and_preproc(self): episode = h5_data.get(episode_key) idx = slice(index, (index + steps)) data_view = td_data[idx] + data_view.fill_("episode", episode_num) for key, val in episode.items(): match = _NAME_MATCH[key] if key in ( diff --git a/torchrl/data/datasets/roboset.py b/torchrl/data/datasets/roboset.py new file mode 100644 index 00000000000..216b408aed3 --- /dev/null +++ b/torchrl/data/datasets/roboset.py @@ -0,0 +1,344 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import importlib.util +import os.path +import shutil +import tempfile + +from contextlib import nullcontext +from pathlib import Path +from typing import Callable + +import torch + +from tensordict import PersistentTensorDict, TensorDict +from torchrl._utils import KeyDependentDefaultDict, print_directory_tree +from torchrl.data.datasets.utils import _get_root_dir +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import Sampler +from torchrl.data.replay_buffers.storages import TensorStorage +from torchrl.data.replay_buffers.writers import Writer + +_has_tqdm = importlib.util.find_spec("tqdm", None) is not None +_has_h5py = importlib.util.find_spec("h5py", None) is not None +_has_hf_hub = importlib.util.find_spec("huggingface_hub", None) is not None + +_NAME_MATCH = KeyDependentDefaultDict(lambda key: key) +_NAME_MATCH["observations"] = "observation" +_NAME_MATCH["rewards"] = "reward" +_NAME_MATCH["actions"] = "action" +_NAME_MATCH["env_infos"] = "info" + + +class RobosetExperienceReplay(TensorDictReplayBuffer): + """Roboset experience replay dataset. + + This class downloads the H5 data from roboset and processes it in a mmap + format, which makes indexing (and therefore sampling) faster. + + Learn more about roboset here: https://sites.google.com/view/robohive/roboset + + Args: + dataset_id (str): the dataset to be downloaded. Must be part of RobosetExperienceReplay.available_datasets. + batch_size (int): Batch-size used during sampling. Can be overridden by `data.sample(batch_size)` if + necessary. + + Keyword Args: + root (Path or str, optional): The Roboset dataset root directory. + The actual dataset memory-mapped files will be saved under + `/`. If none is provided, it defaults to + ``~/.cache/torchrl/minari`. + download (bool or str, optional): Whether the dataset should be downloaded if + not found. Defaults to ``True``. Download can also be passed as "force", + in which case the downloaded data will be overwritten. + sampler (Sampler, optional): the sampler to be used. If none is provided + a default RandomSampler() will be used. + writer (Writer, optional): the writer to be used. If none is provided + a default RoundRobinWriter() will be used. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched + loading from a map-style dataset. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. + split_trajs (bool, optional): if ``True``, the trajectories will be split + along the first dimension and padded to have a matching shape. + To split the trajectories, the ``"done"`` signal will be used, which + is recovered via ``done = truncated | terminated``. In other words, + it is assumed that any ``truncated`` or ``terminated`` signal is + equivalent to the end of a trajectory. For some datasets from + ``D4RL``, this may not be true. It is up to the user to make + accurate choices regarding this usage of ``split_trajs``. + Defaults to ``False``. + + Attributes: + available_datasets: a list of accepted entries to be downloaded. + + Examples: + >>> import torch + >>> torch.manual_seed(0) + >>> from torchrl.envs.transforms import ExcludeTransform + >>> from torchrl.data.datasets import RobosetExperienceReplay + >>> d = RobosetExperienceReplay("FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32, + ... transform=ExcludeTransform("info", ("next", "info"))) # excluding info dict for conciseness + >>> for batch in d: + ... break + >>> # data is organised by seed and episode, but stored contiguously + >>> print(batch["seed"], batch["episode"]) + tensor([2, 1, 0, 0, 1, 1, 0, 0, 1, 1, 2, 2, 2, 2, 2, 1, 1, 2, 0, 2, 0, 2, 2, 1, + 0, 2, 0, 0, 1, 1, 2, 1]) tensor([17, 20, 18, 9, 6, 1, 12, 6, 2, 6, 8, 15, 8, 21, 17, 3, 9, 20, + 23, 12, 3, 16, 19, 16, 16, 4, 4, 12, 1, 2, 15, 24]) + >>> print(batch) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([32, 9]), device=cpu, dtype=torch.float64, is_shared=False), + done: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False), + episode: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False), + index: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([32, 75]), device=cpu, dtype=torch.float64, is_shared=False), + reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float64, is_shared=False), + terminated: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([32, 75]), device=cpu, dtype=torch.float64, is_shared=False), + seed: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False), + time: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False) + + """ + + available_datasets = [ + "DAPG(expert)/door_v2d-v1", + "DAPG(expert)/relocate_v2d-v1", + "DAPG(expert)/hammer_v2d-v1", + "DAPG(expert)/pen_v2d-v1", + "DAPG(human)/door_v2d-v1", + "DAPG(human)/relocate_v2d-v1", + "DAPG(human)/hammer_v2d-v1", + "DAPG(human)/pen_v2d-v1", + "FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", + "FK1-v4(expert)/FK1_Knob2OffRandom_v2d-v4", + "FK1-v4(expert)/FK1_LdoorOpenRandom_v2d-v4", + "FK1-v4(expert)/FK1_SdoorOpenRandom_v2d-v4", + "FK1-v4(expert)/FK1_Knob1OnRandom_v2d-v4", + "FK1-v4(human)/human_demos_by_playdata", + "FK1-v4(human)/human_demos_by_task/human_demo_singleTask_Fixed-v4", + "FK1-v4(human)/human_demos_by_task/FK1_SdoorOpenRandom_v2d-v4", + "FK1-v4(human)/human_demos_by_task/FK1_LdoorOpenRandom_v2d-v4", + "FK1-v4(human)/human_demos_by_task/FK1_Knob2OffRandom_v2d-v4", + "FK1-v4(human)/human_demos_by_task/FK1_Knob1OnRandom_v2d-v4", + "FK1-v4(human)/human_demos_by_task/FK1_MicroOpenRandom_v2d-v4", + ] + + def __init__( + self, + dataset_id, + batch_size: int, + *, + root: str | Path | None = None, + download: bool = True, + sampler: Sampler | None = None, + writer: Writer | None = None, + collate_fn: Callable | None = None, + pin_memory: bool = False, + prefetch: int | None = None, + transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + split_trajs: bool = False, + **env_kwargs, + ): + if not _has_h5py or not _has_hf_hub: + raise ImportError( + "h5py and huggingface_hub are required for Roboset datasets." + ) + if dataset_id not in self.available_datasets: + raise ValueError( + f"The dataset_id {dataset_id} isn't part of the accepted datasets. " + f"To check which dataset can be downloaded, call `{type(self)}.available_datasets`." + ) + self.dataset_id = dataset_id + if root is None: + root = _get_root_dir("roboset") + os.makedirs(root, exist_ok=True) + self.root = root + self.split_trajs = split_trajs + self.download = download + if self.download == "force" or (self.download and not self._is_downloaded()): + if self.download == "force": + try: + shutil.rmtree(self.data_path_root) + if self.data_path != self.data_path_root: + shutil.rmtree(self.data_path) + except FileNotFoundError: + pass + storage = self._download_and_preproc() + elif self.split_trajs and not os.path.exists(self.data_path): + storage = self._make_split() + else: + storage = self._load() + storage = TensorStorage(storage) + super().__init__( + storage=storage, + sampler=sampler, + writer=writer, + collate_fn=collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + transform=transform, + batch_size=batch_size, + ) + + def _download_from_huggingface(self, tempdir): + try: + from huggingface_hub import hf_hub_download, HfApi + except ImportError: + raise ImportError( + f"huggingface_hub is required for downloading {type(self)}'s datasets." + ) + dataset = HfApi().dataset_info("jdvakil/RoboSet_Sim") + h5_files = [] + datapath = Path(tempdir) / "data" + for sibling in dataset.siblings: + if sibling.rfilename.startswith( + self.dataset_id + ) and sibling.rfilename.endswith(".h5"): + path = Path(sibling.rfilename) + local_path = hf_hub_download( + "jdvakil/RoboSet_Sim", + subfolder=str(path.parent), + filename=str(path.parts[-1]), + repo_type="dataset", + cache_dir=str(datapath), + ) + h5_files.append(local_path) + + return sorted(h5_files) + + def _download_and_preproc(self): + + with tempfile.TemporaryDirectory() as tempdir: + h5_data_files = self._download_from_huggingface(tempdir) + return self._preproc_h5(h5_data_files) + + def _preproc_h5(self, h5_data_files): + td_data = TensorDict({}, []) + total_steps = 0 + print( + f"first read through data files {h5_data_files} to create data structure..." + ) + episode_dict = {} + h5_datas = [] + for seed, h5_data_name in enumerate(h5_data_files): + print("\nReading", h5_data_name) + h5_data = PersistentTensorDict.from_h5(h5_data_name) + h5_datas.append(h5_data) + for i, (episode_key, episode) in enumerate(h5_data.items()): + episode_num = int(episode_key[len("Trial") :]) + episode_len = episode["actions"].shape[0] + episode_dict[(seed, episode_num)] = (episode_key, episode_len) + # Get the total number of steps for the dataset + total_steps += episode_len + print("total_steps", total_steps, end="\t") + if i == 0 and seed == 0: + td_data.set("episode", 0) + td_data.set("seed", 0) + for key, val in episode.items(): + match = _NAME_MATCH[key] + if key in ("observations", "env_infos", "done"): + td_data.set(("next", match), torch.zeros_like(val[0])) + td_data.set(match, torch.zeros_like(val[0])) + elif key not in ("rewards",): + td_data.set(match, torch.zeros_like(val[0])) + else: + td_data.set( + ("next", match), + torch.zeros_like(val[0].unsqueeze(-1)), + ) + + # give it the proper size + td_data["next", "terminated"] = td_data["next", "done"] + td_data["next", "truncated"] = td_data["next", "done"] + + td_data = td_data.expand(total_steps) + # save to designated location + print(f"creating tensordict data in {self.data_path_root}: ", end="\t") + td_data = td_data.memmap_like(self.data_path_root) + # print("tensordict structure:", td_data) + print("Local dataset structure:", print_directory_tree(self.data_path_root)) + + print(f"Reading data from {len(episode_dict)} episodes") + index = 0 + if _has_tqdm: + from tqdm import tqdm + else: + tqdm = None + with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar: + # iterate over episodes and populate the tensordict + for seed, episode_num in sorted(episode_dict, key=lambda key: key[1]): + h5_data = h5_datas[seed] + episode_key, steps = episode_dict[(seed, episode_num)] + episode = h5_data.get(episode_key) + idx = slice(index, (index + steps)) + data_view = td_data[idx] + data_view.fill_("episode", episode_num) + data_view.fill_("seed", seed) + for key, val in episode.items(): + match = _NAME_MATCH[key] + if steps != val.shape[0]: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." + ) + if key in ( + "observations", + "env_infos", + ): + data_view["next", match][:-1].copy_(val[1:]) + data_view[match].copy_(val) + elif key not in ("rewards",): + data_view[match].copy_(val) + else: + data_view[("next", match)].copy_(val.unsqueeze(-1)) + data_view["next", "terminated"].copy_(data_view["next", "done"]) + if pbar is not None: + pbar.update(steps) + pbar.set_description( + f"index={index} - episode num {episode_num} - seed {seed}" + ) + index += steps + return td_data + + def _make_split(self): + from torchrl.collectors.utils import split_trajectories + + td_data = TensorDict.load_memmap(self.data_path_root) + td_data = split_trajectories(td_data).memmap_(self.data_path) + return td_data + + def _load(self): + return TensorDict.load_memmap(self.data_path) + + @property + def data_path(self): + if self.split_trajs: + return Path(self.root) / (self.dataset_id + "_split") + return self.data_path_root + + @property + def data_path_root(self): + return Path(self.root) / self.dataset_id + + def _is_downloaded(self): + return os.path.exists(self.data_path_root)