From eb5dd98d5bd9e30eed0a27ca5173fbc029cd3970 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 20 Dec 2023 21:19:55 +0000 Subject: [PATCH 1/8] init --- torchrl/data/datasets/vd4rl.py | 119 +++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 torchrl/data/datasets/vd4rl.py diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py new file mode 100644 index 00000000000..380901daa1b --- /dev/null +++ b/torchrl/data/datasets/vd4rl.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import pathlib +from torchrl._utils import KeyDependentDefaultDict +from tensordict import TensorDict, PersistentTensorDict +import tempfile +from torchrl._utils import print_directory_tree +from collections import defaultdict +import numpy as np +import json +import datasets +from huggingface_hub import hf_hub_download, HfApi +from pathlib import Path +from torchrl.data import TensorDictReplayBuffer + +THIS_DIR = pathlib.Path(__file__).parent + +class VD4RLExperienceReplay(TensorDictReplayBuffer): + def __init__(self): + ... + + @classmethod + def _parse_datasets(cls): + dataset = HfApi().dataset_info("conglu/vd4rl") + sibs = defaultdict(list) + for sib in dataset.siblings: + if sib.rfilename.endswith("npz") or sib.rfilename.endswith("hdf5"): + path = Path(sib.rfilename) + sibs[path.parent].append(path) + return sibs + + @classmethod + def _download_and_preproc(cls, dataset_id): + path = None + files = [] + with tempfile.TemporaryDirectory() as datapath: + sibs = cls._parse_datasets() + # files = [] + total_steps = 0 + for path in sibs: + if dataset_id not in str(path): + continue + for file in sibs[path]: + # print(path, file) + local_path = hf_hub_download( + "conglu/vd4rl", + subfolder=str(path), + filename=str(file.parts[-1]), + repo_type="dataset", + cache_dir=str(datapath), + ) + files.append(local_path) + # print_directory_tree(datapath) + if local_path.endswith("hdf5"): + td = PersistentTensorDict.from_h5(local_path) + else: + td = _from_npz(local_path) + if total_steps == 0: + td = td.to_tensordict() + cls._process_data(td) + td_save = td[0] + total_steps += td.shape[0] + td_save = td_save.expand(total_steps).memmap_like(path) + print(td_save) + idx0 = 0 + idx1 = 0 + while len(files): + local_path = files.pop(0) + if local_path.endswith("hdf5"): + td = PersistentTensorDict.from_h5(local_path) + else: + td = _from_npz(local_path) + td = td.to_tensordict() + cls._process_data(td) + idx1 += td.shape[0] + td_save[idx0:idx1] = td + idx0 = idx1 + return td_save + + @classmethod + def _process_data(cls, td: TensorDict): + print(td) + for name, val in list(td.items()): + if name != _NAME_MATCH[name]: + td.rename_key_(name, _NAME_MATCH[name]) + observation = td.get("observation") + td.get_sub_tensordict(slice(0, -1)).set(("next", "observation"), observation[1:]) + print(td) + + @property + def available_datasets(self): + return self.available_datasets + @classmethod + def _available_datasets(cls): + # try to gather paths from hf + try: + sibs = cls._parse_datasets() + return [str(path)[6:] for path in sibs] + except Exception: + # return the default datasets + with open(THIS_DIR / "vd4rl.json", "r") as file: + return json.load(file) + +def _from_npz(npz_path): + npz = np.load(npz_path) + npz_dict = { + file: npz[file] for file in npz.files + } + return TensorDict.from_dict(npz_dict) + +_NAME_MATCH = KeyDependentDefaultDict(lambda x: x) +_NAME_MATCH.update({ + "is_first": "is_init", + "is_last": ("next", "done"), + "is_terminal": ("next", "terminated"), + "reward": ("next", "reward"), +}) From 52346a2eb7039742d2b2fda8e93e6fff986d5297 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 21 Dec 2023 11:14:22 +0000 Subject: [PATCH 2/8] amend --- .../linux_libs/scripts_vd4rl/environment.yml | 22 + .../linux_libs/scripts_vd4rl/install.sh | 51 +++ .../linux_libs/scripts_vd4rl/post_process.sh | 6 + .../scripts_vd4rl/run-clang-format.py | 356 +++++++++++++++ .../linux_libs/scripts_vd4rl/run_test.sh | 24 + .../linux_libs/scripts_vd4rl/setup_env.sh | 50 +++ .github/workflows/test-linux-vd4rl.yml | 42 ++ test/test_libs.py | 34 ++ torchrl/data/datasets/__init__.py | 1 + torchrl/data/datasets/roboset.py | 2 +- torchrl/data/datasets/vd4rl.py | 411 +++++++++++++++--- torchrl/data/datasets/vd4rl_datasets.json | 1 + torchrl/data/rlhf/dataset.py | 7 +- torchrl/envs/libs/envpool.py | 2 +- torchrl/envs/libs/openml.py | 6 +- torchrl/envs/libs/pettingzoo.py | 2 +- torchrl/envs/libs/robohive.py | 2 +- torchrl/envs/libs/smacv2.py | 2 +- torchrl/envs/libs/vmas.py | 2 +- torchrl/envs/transforms/transforms.py | 18 +- torchrl/envs/utils.py | 12 +- torchrl/modules/models/multiagent.py | 4 +- torchrl/modules/tensordict_module/rnn.py | 2 +- torchrl/objectives/cql.py | 2 +- torchrl/objectives/deprecated.py | 2 +- torchrl/objectives/redq.py | 2 +- torchrl/objectives/sac.py | 2 +- torchrl/objectives/td3.py | 2 +- torchrl/trainers/helpers/replay_buffer.py | 5 +- torchrl/trainers/helpers/trainers.py | 2 +- torchrl/trainers/trainers.py | 5 +- 31 files changed, 987 insertions(+), 94 deletions(-) create mode 100644 .github/unittest/linux_libs/scripts_vd4rl/environment.yml create mode 100755 .github/unittest/linux_libs/scripts_vd4rl/install.sh create mode 100755 .github/unittest/linux_libs/scripts_vd4rl/post_process.sh create mode 100755 .github/unittest/linux_libs/scripts_vd4rl/run-clang-format.py create mode 100755 .github/unittest/linux_libs/scripts_vd4rl/run_test.sh create mode 100755 .github/unittest/linux_libs/scripts_vd4rl/setup_env.sh create mode 100644 .github/workflows/test-linux-vd4rl.yml create mode 100644 torchrl/data/datasets/vd4rl_datasets.json diff --git a/.github/unittest/linux_libs/scripts_vd4rl/environment.yml b/.github/unittest/linux_libs/scripts_vd4rl/environment.yml new file mode 100644 index 00000000000..472ea296769 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_vd4rl/environment.yml @@ -0,0 +1,22 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - scipy + - hydra-core + - huggingface_hub + - tqdm + - h5py diff --git a/.github/unittest/linux_libs/scripts_vd4rl/install.sh b/.github/unittest/linux_libs/scripts_vd4rl/install.sh new file mode 100755 index 00000000000..2eb52b8f65e --- /dev/null +++ b/.github/unittest/linux_libs/scripts_vd4rl/install.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. +apt-get update && apt-get install -y git wget gcc g++ +#apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_vd4rl/post_process.sh b/.github/unittest/linux_libs/scripts_vd4rl/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_vd4rl/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_vd4rl/run-clang-format.py b/.github/unittest/linux_libs/scripts_vd4rl/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_vd4rl/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_libs/scripts_vd4rl/run_test.sh b/.github/unittest/linux_libs/scripts_vd4rl/run_test.sh new file mode 100755 index 00000000000..e0323047a16 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_vd4rl/run_test.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 +ln -s /usr/bin/swig3.0 /usr/bin/swig + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +conda deactivate && conda activate ./env + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestVD4RL --error-for-skips --runslow +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_vd4rl/setup_env.sh b/.github/unittest/linux_libs/scripts_vd4rl/setup_env.sh new file mode 100755 index 00000000000..5214617c2ac --- /dev/null +++ b/.github/unittest/linux_libs/scripts_vd4rl/setup_env.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e +set -v + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +apt-get update && apt-get install -y git wget gcc g++ unzip + +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip3 install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/workflows/test-linux-vd4rl.yml b/.github/workflows/test-linux-vd4rl.yml new file mode 100644 index 00000000000..7b8383b2c68 --- /dev/null +++ b/.github/workflows/test-linux-vd4rl.yml @@ -0,0 +1,42 @@ +name: V-D4RL Tests on Linux + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +concurrency: + # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. + # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} + cancel-in-progress: true + +jobs: + unittests: + strategy: + matrix: + python_version: ["3.9"] + cuda_arch_version: ["12.1"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + docker-image: "nvidia/cudagl:11.4.0-base" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="cu117" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + bash .github/unittest/linux_libs/scripts_vd4rl/setup_env.sh + bash .github/unittest/linux_libs/scripts_vd4rl/install.sh + bash .github/unittest/linux_libs/scripts_vd4rl/run_test.sh + bash .github/unittest/linux_libs/scripts_vd4rl/post_process.sh diff --git a/test/test_libs.py b/test/test_libs.py index 996de85a8f7..2a6aa5ef053 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -53,6 +53,7 @@ from torchrl.data.datasets.minari_data import MinariExperienceReplay from torchrl.data.datasets.openml import OpenMLExperienceReplay from torchrl.data.datasets.roboset import RobosetExperienceReplay +from torchrl.data.datasets.vd4rl import VD4RLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import ( Compose, @@ -2057,6 +2058,39 @@ def test_load(self): break +@pytest.mark.slow +class TestVD4RL: + @pytest.mark.parametrize("image_size", [None, (37, 33)]) + def test_load(self, image_size): + torch.manual_seed(0) + datasets = VD4RLExperienceReplay.available_datasets + for idx in torch.randperm(len(datasets)).tolist()[:4]: + selected_dataset = datasets[idx] + data = VD4RLExperienceReplay( + selected_dataset, + batch_size=32, + image_size=image_size, + ) + t0 = time.time() + for i, batch in enumerate(data): + if image_size: + assert batch.get("pixels").shape == (32, 3, *image_size) + assert batch.get(("next", "pixels")).shape == (32, 3, *image_size) + else: + assert batch.get("pixels").shape[:2] == (32, 3) + assert batch.get(("next", "pixels")).shape[:2] == (32, 3) + + assert batch.get("pixels").dtype is torch.float32 + assert batch.get(("next", "pixels")).dtype is torch.float32 + assert (batch.get("pixels") != 0).any() + assert (batch.get(("next", "pixels")) != 0).any() + t1 = time.time() + print(f"sampling time {1000 * (t1-t0): 4.4f}ms") + t0 = time.time() + if i == 10: + break + + @pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found") @pytest.mark.parametrize( "dataset", diff --git a/torchrl/data/datasets/__init__.py b/torchrl/data/datasets/__init__.py index b4fbf9a54e0..c1429b300fa 100644 --- a/torchrl/data/datasets/__init__.py +++ b/torchrl/data/datasets/__init__.py @@ -2,3 +2,4 @@ from .minari_data import MinariExperienceReplay from .openml import OpenMLExperienceReplay from .roboset import RobosetExperienceReplay +from .vd4rl import VD4RLExperienceReplay diff --git a/torchrl/data/datasets/roboset.py b/torchrl/data/datasets/roboset.py index 216b408aed3..6e9a9bb23f7 100644 --- a/torchrl/data/datasets/roboset.py +++ b/torchrl/data/datasets/roboset.py @@ -51,7 +51,7 @@ class RobosetExperienceReplay(TensorDictReplayBuffer): root (Path or str, optional): The Roboset dataset root directory. The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to - ``~/.cache/torchrl/minari`. + ``~/.cache/torchrl/roboset`. download (bool or str, optional): Whether the dataset should be downloaded if not found. Defaults to ``True``. Download can also be passed as "force", in which case the downloaded data will be overwritten. diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index 380901daa1b..2fe1471ed5c 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -2,96 +2,351 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import importlib +import json +import os import pathlib -from torchrl._utils import KeyDependentDefaultDict -from tensordict import TensorDict, PersistentTensorDict +import shutil import tempfile -from torchrl._utils import print_directory_tree from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Callable, List + import numpy as np -import json -import datasets + +import torch +import tqdm from huggingface_hub import hf_hub_download, HfApi -from pathlib import Path -from torchrl.data import TensorDictReplayBuffer +from tensordict import PersistentTensorDict, TensorDict + +from torchrl._utils import KeyDependentDefaultDict +from torchrl.data.datasets.utils import _get_root_dir +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import Sampler +from torchrl.data.replay_buffers.storages import TensorStorage +from torchrl.data.replay_buffers.writers import Writer + +from torchrl.envs.transforms import Compose, Resize, ToTensorImage +from torchrl.envs.utils import _classproperty + +_has_tqdm = importlib.util.find_spec("tqdm", None) is not None +_has_h5py = importlib.util.find_spec("h5py", None) is not None +_has_hf_hub = importlib.util.find_spec("huggingface_hub", None) is not None THIS_DIR = pathlib.Path(__file__).parent + class VD4RLExperienceReplay(TensorDictReplayBuffer): - def __init__(self): - ... + """V-D4RL experience replay dataset. + + This class downloads the H5/npz data from V-D4RL and processes it in a mmap + format, which makes indexing (and therefore sampling) faster. + + Learn more about V-D4RL here: https://arxiv.org/abs/2206.04779 + + The `"pixels"` entry is located at the root of the data, and all the data + that is not reward, done-state, action or pixels is moved under a `"state"` + node. + + Args: + dataset_id (str): the dataset to be downloaded. Must be part of + VD4RLExperienceReplay.available_datasets. + batch_size (int): Batch-size used during sampling. Can be overridden by + `data.sample(batch_size)` if necessary. + + Keyword Args: + root (Path or str, optional): The V-D4RL dataset root directory. + The actual dataset memory-mapped files will be saved under + `/`. If none is provided, it defaults to + ``~/.cache/torchrl/vd4rl`. + download (bool or str, optional): Whether the dataset should be downloaded if + not found. Defaults to ``True``. Download can also be passed as "force", + in which case the downloaded data will be overwritten. + sampler (Sampler, optional): the sampler to be used. If none is provided + a default RandomSampler() will be used. + writer (Writer, optional): the writer to be used. If none is provided + a default RoundRobinWriter() will be used. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched + loading from a map-style dataset. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. + split_trajs (bool, optional): if ``True``, the trajectories will be split + along the first dimension and padded to have a matching shape. + To split the trajectories, the ``"done"`` signal will be used, which + is recovered via ``done = truncated | terminated``. In other words, + it is assumed that any ``truncated`` or ``terminated`` signal is + equivalent to the end of a trajectory. For some datasets from + ``D4RL``, this may not be true. It is up to the user to make + accurate choices regarding this usage of ``split_trajs``. + Defaults to ``False``. + totensor (bool, optional): if ``True``, a :class:`~torchrl.envs.transforms.ToTensorImage` + transform will be included in the transform list (if not automatically + detected). Defaults to ``True``. + image_size (int, list of ints or None): if not ``None``, this argument + will be used to create a :class:`~torchrl.envs.transforms.Resize` + transform that will be appended to the transform list. Supports + `int` types (square resizing) or a list/tuple of `int` (rectangular + resizing). Defaults to ``None`` (no resizing). + + Attributes: + available_datasets: a list of accepted entries to be downloaded. These + names correspond to the directory path in the huggingface dataset + repository. If possible, the list will be dynamically retrieved from + huggingface. If no internet connection is available, it a cached + version will be used. + + .. note:: Since not all experience replay have start and stop signals, we + do not mark the episodes in the retrieved dataset. + + Examples: + >>> import torch + >>> torch.manual_seed(0) + >>> from torchrl.data.datasets import VD4RLExperienceReplay + >>> d = VD4RLExperienceReplay("main/walker_walk/random/64px", batch_size=32, + ... image_size=50) + >>> for batch in d: + ... break + >>> print(batch) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([32, 6]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), + index: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False), + is_init: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: TensorDict( + fields={ + height: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.float32, is_shared=False), + orientations: Tensor(shape=torch.Size([32, 14]), device=cpu, dtype=torch.float32, is_shared=False), + velocity: Tensor(shape=torch.Size([32, 9]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False), + pixels: Tensor(shape=torch.Size([32, 3, 50, 50]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False), + observation: TensorDict( + fields={ + height: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.float32, is_shared=False), + orientations: Tensor(shape=torch.Size([32, 14]), device=cpu, dtype=torch.float32, is_shared=False), + velocity: Tensor(shape=torch.Size([32, 9]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False), + pixels: Tensor(shape=torch.Size([32, 3, 50, 50]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False) + + """ + + def __init__( + self, + dataset_id, + batch_size: int, + *, + root: str | Path | None = None, + download: bool = True, + sampler: Sampler | None = None, + writer: Writer | None = None, + collate_fn: Callable | None = None, + pin_memory: bool = False, + prefetch: int | None = None, + transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + split_trajs: bool = False, + totensor: bool = True, + image_size: int | List[int] | None = None, + **env_kwargs, + ): + if not _has_h5py or not _has_hf_hub: + raise ImportError( + "h5py and huggingface_hub are required for V-D4RL datasets." + ) + if dataset_id not in self.available_datasets: + raise ValueError( + f"The dataset_id {dataset_id} isn't part of the accepted datasets. " + f"To check which dataset can be downloaded, call `{type(self)}.available_datasets`." + ) + self.dataset_id = dataset_id + if root is None: + root = _get_root_dir("vd4rl") + os.makedirs(root, exist_ok=True) + self.root = root + self.split_trajs = split_trajs + self.download = download + if self.download == "force" or (self.download and not self._is_downloaded()): + if self.download == "force": + try: + shutil.rmtree(self.data_path_root) + if self.data_path != self.data_path_root: + shutil.rmtree(self.data_path) + except FileNotFoundError: + pass + storage = self._download_and_preproc(dataset_id, data_path=self.data_path) + elif self.split_trajs and not os.path.exists(self.data_path): + storage = self._make_split() + else: + storage = self._load() + if totensor and transform is None: + transform = ToTensorImage(in_keys=["pixels", ("next", "pixels")]) + elif totensor and ( + not isinstance(transform, Compose) + or not any(isinstance(t, ToTensorImage) for t in transform) + ): + transform = Compose( + transform, ToTensorImage(in_keys=["pixels", ("next", "pixels")]) + ) + if image_size is not None: + transform = Compose( + transform, Resize(image_size, in_keys=["pixels", ("next", "pixels")]) + ) + storage = TensorStorage(storage) + super().__init__( + storage=storage, + sampler=sampler, + writer=writer, + collate_fn=collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + transform=transform, + batch_size=batch_size, + ) @classmethod def _parse_datasets(cls): dataset = HfApi().dataset_info("conglu/vd4rl") sibs = defaultdict(list) for sib in dataset.siblings: - if sib.rfilename.endswith("npz") or sib.rfilename.endswith("hdf5"): - path = Path(sib.rfilename) - sibs[path.parent].append(path) + if sib.rfilename.endswith("npz") or sib.rfilename.endswith("hdf5"): + path = Path(sib.rfilename) + sibs[path.parent].append(path) return sibs @classmethod - def _download_and_preproc(cls, dataset_id): - path = None + def _download_and_preproc(cls, dataset_id, data_path): files = [] - with tempfile.TemporaryDirectory() as datapath: + tds = [] + with tempfile.TemporaryDirectory() as tmpdir: sibs = cls._parse_datasets() # files = [] total_steps = 0 + + paths_to_proc = [] + files_to_proc = [] + for path in sibs: if dataset_id not in str(path): continue for file in sibs[path]: - # print(path, file) - local_path = hf_hub_download( + paths_to_proc.append(str(path)) + files_to_proc.append(str(file.parts[-1])) + + with ThreadPoolExecutor(32) as executor: + files = executor.map( + lambda path_file: hf_hub_download( "conglu/vd4rl", - subfolder=str(path), - filename=str(file.parts[-1]), + subfolder=path_file[0], + filename=path_file[1], repo_type="dataset", - cache_dir=str(datapath), - ) - files.append(local_path) - # print_directory_tree(datapath) - if local_path.endswith("hdf5"): - td = PersistentTensorDict.from_h5(local_path) - else: - td = _from_npz(local_path) - if total_steps == 0: - td = td.to_tensordict() - cls._process_data(td) - td_save = td[0] - total_steps += td.shape[0] - td_save = td_save.expand(total_steps).memmap_like(path) - print(td_save) - idx0 = 0 - idx1 = 0 - while len(files): - local_path = files.pop(0) + cache_dir=str(tmpdir), + ), + zip(paths_to_proc, files_to_proc), + ) + files = list(files) + print("Downloaded, processing files") + if _has_tqdm: + pbar = tqdm.tqdm(files) + else: + pbar = files + for local_path in pbar: + if _has_tqdm: + pbar.set_description(f"file={local_path}") + # we memmap temporarily the files for faster access later if local_path.endswith("hdf5"): - td = PersistentTensorDict.from_h5(local_path) + td = ( + PersistentTensorDict.from_h5(local_path) + .to_tensordict() + .memmap(num_threads=32) + ) else: - td = _from_npz(local_path) - td = td.to_tensordict() - cls._process_data(td) - idx1 += td.shape[0] - td_save[idx0:idx1] = td - idx0 = idx1 - return td_save + td = _from_npz(local_path).memmap(num_threads=32) + td.unlock_() + if total_steps == 0: + tdc = cls._process_data(td.clone()) + td_save = tdc[0] + tds.append(td) + total_steps += td.shape[0] + + # From this point, the local paths are non needed anymore + td_save = td_save.expand(total_steps).memmap_like(data_path, num_threads=32) + print("Saved tensordict:", td_save) + idx0 = 0 + idx1 = 0 + while len(files): + _ = files.pop(0) + td = tds.pop(0) + td = cls._process_data(td) + idx1 += td.shape[0] + td_save[idx0:idx1] = td + idx0 = idx1 + return td_save @classmethod def _process_data(cls, td: TensorDict): - print(td) - for name, val in list(td.items()): - if name != _NAME_MATCH[name]: + for name in list(td.keys()): + # move remaining data + if name not in _NAME_MATCH: + td.rename_key_(name, ("state", name)) + elif name != _NAME_MATCH[name]: td.rename_key_(name, _NAME_MATCH[name]) - observation = td.get("observation") - td.get_sub_tensordict(slice(0, -1)).set(("next", "observation"), observation[1:]) - print(td) + if ("next", "reward") in td.keys(True): + td.set(("next", "reward"), td.get(("next", "reward")).unsqueeze(-1)) + if ("next", "done") in td.keys(True) and ("next", "terminated") in td.keys( + True + ): + # first unsqueeze + td.set(("next", "done"), td.get(("next", "done")).unsqueeze(-1)) + td.set(("next", "terminated"), td.get(("next", "terminated")).unsqueeze(-1)) + # create root vals + td.set("done", torch.zeros_like(td.get(("next", "done")))) + td.set("terminated", torch.zeros_like(td.get(("next", "terminated")))) + # Add truncated + td.set( + ("next", "truncated"), + td.get(("next", "done")) & ~td.get(("next", "terminated")), + ) + + td.set("truncated", torch.zeros_like(td.get(("next", "truncated")))) + + pixels = td.get("pixels") + subtd = td.get_sub_tensordict(slice(0, -1)) + subtd.set(("next", "pixels"), pixels[1:], inplace=True) + state = td.get("state", None) + if state is not None: + subtd.set(("next", "state"), state[1:], inplace=True) + + return td + + @_classproperty + def available_datasets(cls): + return cls._available_datasets() - @property - def available_datasets(self): - return self.available_datasets @classmethod def _available_datasets(cls): # try to gather paths from hf @@ -103,17 +358,45 @@ def _available_datasets(cls): with open(THIS_DIR / "vd4rl.json", "r") as file: return json.load(file) + def _make_split(self): + from torchrl.collectors.utils import split_trajectories + + td_data = TensorDict.load_memmap(self.data_path_root) + td_data = split_trajectories(td_data).memmap_(self.data_path) + return td_data + + def _load(self): + return TensorDict.load_memmap(self.data_path) + + @property + def data_path(self): + if self.split_trajs: + return Path(self.root) / (self.dataset_id + "_split") + return self.data_path_root + + @property + def data_path_root(self): + return Path(self.root) / self.dataset_id + + def _is_downloaded(self): + return os.path.exists(self.data_path_root) + + def _from_npz(npz_path): npz = np.load(npz_path) - npz_dict = { - file: npz[file] for file in npz.files - } + npz_dict = {file: npz[file] for file in npz.files} return TensorDict.from_dict(npz_dict) + _NAME_MATCH = KeyDependentDefaultDict(lambda x: x) -_NAME_MATCH.update({ - "is_first": "is_init", - "is_last": ("next", "done"), - "is_terminal": ("next", "terminated"), - "reward": ("next", "reward"), -}) +_NAME_MATCH.update( + { + "is_first": "is_init", + "is_last": ("next", "done"), + "is_terminal": ("next", "terminated"), + "reward": ("next", "reward"), + "image": "pixels", + "observation": "pixels", + "action": "action", + } +) diff --git a/torchrl/data/datasets/vd4rl_datasets.json b/torchrl/data/datasets/vd4rl_datasets.json new file mode 100644 index 00000000000..eb45e317c3b --- /dev/null +++ b/torchrl/data/datasets/vd4rl_datasets.json @@ -0,0 +1 @@ +["distracting/walker_walk_random/64px/easy", "distracting/walker_walk_random/64px/hard", "distracting/walker_walk_random/64px/medium", "main/cheetah_run/expert/64px", "main/cheetah_run/medium/64px", "main/cheetah_run/medium_expert/64px", "main/cheetah_run/medium_replay/64px", "main/cheetah_run/random/64px", "main/humanoid_walk/expert/64px", "main/humanoid_walk/medium/64px", "main/humanoid_walk/medium_expert/64px", "main/humanoid_walk/medium_replay/64px", "main/humanoid_walk/random/64px", "main/walker_walk/expert/64px", "main/walker_walk/medium/64px", "main/walker_walk/medium_expert/64px", "main/walker_walk/medium_replay/64px", "main/walker_walk/random/64px", "multitask/cheetah_run_random/64px", "multitask/walker_walk_random/64px"] diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index 09086bfad65..3d8f7fa6de1 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -15,8 +15,11 @@ from tensordict import TensorDict, TensorDictBase from tensordict.tensordict import NestedKey -from torchrl.data import TensorDictReplayBuffer, TensorStorage -from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.data.replay_buffers import ( + SamplerWithoutReplacement, + TensorDictReplayBuffer, + TensorStorage, +) _has_transformers = importlib.util.find_spec("transformers") is not None _has_datasets = importlib.util.find_spec("datasets") is not None diff --git a/torchrl/envs/libs/envpool.py b/torchrl/envs/libs/envpool.py index 9774e0627e0..d78eb3ca1b5 100644 --- a/torchrl/envs/libs/envpool.py +++ b/torchrl/envs/libs/envpool.py @@ -13,7 +13,7 @@ import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data import ( +from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, TensorSpec, diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index 25b5dfbad30..6d42a101896 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -5,15 +5,15 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase +from torchrl.data.datasets.openml import OpenMLExperienceReplay +from torchrl.data.replay_buffers import SamplerWithoutReplacement -from torchrl.data import ( +from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, ) -from torchrl.data.datasets.openml import OpenMLExperienceReplay -from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs.common import EnvBase from torchrl.envs.transforms import Compose, DoubleToFloat, RenameTransform diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index 5d17c246d42..17c6697cb24 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -11,7 +11,7 @@ import torch from tensordict.tensordict import TensorDictBase -from torchrl.data import ( +from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, OneHotDiscreteTensorSpec, diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 7ce0938facb..62f96a4318f 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -14,7 +14,7 @@ from tensordict import TensorDict from tensordict.tensordict import make_tensordict from torchrl._utils import implement_for -from torchrl.data import UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec from torchrl.envs.libs.gym import _AsyncMeta, _gym_to_torchrl_spec_transform, GymEnv from torchrl.envs.utils import _classproperty, make_composite_from_td diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index 906b9c456bd..447ee7b4b69 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -10,7 +10,7 @@ import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data import ( +from torchrl.data.tensor_specs import ( BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index dbe097aa312..2d0b8448a6b 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -11,7 +11,7 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase -from torchrl.data import ( +from torchrl.data.tensor_specs import ( BoundedTensorSpec, CompositeSpec, DEVICE_TYPING, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index de8baf2e403..5b6e8c377f8 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1658,19 +1658,31 @@ class Resize(ObservationTransform): """Resizes a pixel observation. Args: - w (int): resulting width - h (int): resulting height + w (int): resulting width. + h (int, optional): resulting height. If not provided, the value of `w` + is taken. interpolation (str): interpolation method + + Examples: + >>> from torchrl.envs import GymEnv + >>> t = Resize(64, 84) + >>> base_env = GymEnv("HalfCheetah-v4", from_pixels=True) + >>> env = TransformedEnv(base_env, Compose(ToTensorImage(), t)) """ def __init__( self, w: int, - h: int, + h: int | None = None, interpolation: str = "bilinear", in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, ): + # we also allow lists or tuples + if isinstance(w, (list, tuple)): + w, h = w + if h is None: + h = w if not _has_tv: raise ImportError( "Torchvision not found. The Resize transform relies on " diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 9a2a71f24bd..2659a8a34b4 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -27,6 +27,13 @@ ) from tensordict.tensordict import LazyStackedTensorDict, NestedKey +from torchrl.data.tensor_specs import ( + CompositeSpec, + TensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.data.utils import check_no_exclusive_keys + __all__ = [ "exploration_mode", "exploration_type", @@ -41,9 +48,6 @@ ] -from torchrl.data import CompositeSpec, TensorSpec -from torchrl.data.utils import check_no_exclusive_keys - ACTION_MASK_ERROR = RuntimeError( "An out-of-bounds actions has been provided to an env with an 'action_mask' output." " If you are using a custom policy, make sure to take the action mask into account when computing the output." @@ -566,8 +570,6 @@ def make_composite_from_td(data): shape=torch.Size([1]), space=ContinuousBox(low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) >>> assert (spec.zero() == data.zero_()).all() """ - from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec - # custom funtion to convert a tensordict in a similar spec structure # of unbounded values. composite = CompositeSpec( diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index 8ebd97fdaa8..f6b80ead12c 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -10,9 +10,9 @@ import torch from torch import nn -from ...data import DEVICE_TYPING +from torchrl.data.utils import DEVICE_TYPING -from .models import ConvNet, MLP +from torchrl.modules.models import ConvNet, MLP class MultiAgentMLP(nn.Module): diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index e76ee043c4e..75c6110c413 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -17,7 +17,7 @@ from torch import nn, Tensor from torch.nn.modules.rnn import RNNCellBase -from torchrl.data import UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec from torchrl.objectives.value.functional import ( _inv_pad_sequence, _split_and_pad_sequence, diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 431f8503b3d..800629baf70 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -17,7 +17,7 @@ from tensordict.utils import NestedKey, unravel_key from torch import Tensor -from torchrl.data import CompositeSpec +from torchrl.data.tensor_specs import CompositeSpec from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 947a7574967..ea329a2b726 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -17,7 +17,7 @@ from tensordict.utils import NestedKey from torch import Tensor -from torchrl.data import CompositeSpec +from torchrl.data.tensor_specs import CompositeSpec from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators from torchrl.objectives.common import LossModule diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 347becc24ae..30918bae131 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -15,7 +15,7 @@ from tensordict.utils import NestedKey from torch import Tensor -from torchrl.data import CompositeSpec +from torchrl.data.tensor_specs import CompositeSpec from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 0ec675596cd..ff973e52f6e 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -16,7 +16,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import Tensor -from torchrl.data import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import CompositeSpec, TensorSpec from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 86fd17a8392..a54fea1fb1e 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -11,7 +11,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey -from torchrl.data import BoundedTensorSpec, CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule diff --git a/torchrl/trainers/helpers/replay_buffer.py b/torchrl/trainers/helpers/replay_buffer.py index 229a22cbe8e..de5102e0613 100644 --- a/torchrl/trainers/helpers/replay_buffer.py +++ b/torchrl/trainers/helpers/replay_buffer.py @@ -7,7 +7,10 @@ import torch -from torchrl.data import ReplayBuffer, TensorDictReplayBuffer +from torchrl.data.replay_buffers.replay_buffers import ( + ReplayBuffer, + TensorDictReplayBuffer, +) from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.data.utils import DEVICE_TYPING diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index f93b1b7a8e4..a2764df2912 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -14,7 +14,7 @@ from torchrl._utils import VERBOSE from torchrl.collectors.collectors import DataCollectorBase -from torchrl.data import ReplayBuffer +from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer from torchrl.envs.common import EnvBase from torchrl.envs.utils import ExplorationType from torchrl.modules import reset_noise diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 669a16ca4cd..fead31f742d 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -23,7 +23,10 @@ from torchrl._utils import _CKPT_BACKEND, KeyDependentDefaultDict, VERBOSE from torchrl.collectors.collectors import DataCollectorBase from torchrl.collectors.utils import split_trajectories -from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer +from torchrl.data.replay_buffers import ( + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.utils import ExplorationType, set_exploration_type From 179ef8040507c1f9836fa40efffa814f34ad732c Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 21 Dec 2023 11:27:34 +0000 Subject: [PATCH 3/8] amend --- torchrl/data/datasets/vd4rl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index 2fe1471ed5c..f5689f2d92e 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -18,7 +18,6 @@ import numpy as np import torch -import tqdm from huggingface_hub import hf_hub_download, HfApi from tensordict import PersistentTensorDict, TensorDict @@ -271,6 +270,8 @@ def _download_and_preproc(cls, dataset_id, data_path): files = list(files) print("Downloaded, processing files") if _has_tqdm: + import tqdm + pbar = tqdm.tqdm(files) else: pbar = files From 73d34751c405348481d223cfac9533d2164f297b Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 21 Dec 2023 11:36:21 +0000 Subject: [PATCH 4/8] amend --- torchrl/data/datasets/vd4rl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index f5689f2d92e..c25d11077cb 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -18,7 +18,6 @@ import numpy as np import torch -from huggingface_hub import hf_hub_download, HfApi from tensordict import PersistentTensorDict, TensorDict from torchrl._utils import KeyDependentDefaultDict @@ -229,6 +228,8 @@ def __init__( @classmethod def _parse_datasets(cls): + from huggingface_hub import HfApi + dataset = HfApi().dataset_info("conglu/vd4rl") sibs = defaultdict(list) for sib in dataset.siblings: @@ -239,6 +240,8 @@ def _parse_datasets(cls): @classmethod def _download_and_preproc(cls, dataset_id, data_path): + from huggingface_hub import hf_hub_download + files = [] tds = [] with tempfile.TemporaryDirectory() as tmpdir: From 11fc232f13229b16fc31db76596b56cef51c3f28 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 21 Dec 2023 11:51:10 +0000 Subject: [PATCH 5/8] amend --- .github/unittest/linux_libs/scripts_vd4rl/install.sh | 4 ++-- torchrl/data/datasets/vd4rl.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_vd4rl/install.sh b/.github/unittest/linux_libs/scripts_vd4rl/install.sh index 2eb52b8f65e..1be476425a6 100755 --- a/.github/unittest/linux_libs/scripts_vd4rl/install.sh +++ b/.github/unittest/linux_libs/scripts_vd4rl/install.sh @@ -33,9 +33,9 @@ if [ "${CU_VERSION:-}" == cpu ] ; then # conda install -y pytorch torchvision cpuonly -c pytorch-nightly # use pip to install pytorch as conda can frequently pick older release # conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 fi # install tensordict diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index c25d11077cb..932bdc5a1d1 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -401,6 +401,7 @@ def _from_npz(npz_path): "reward": ("next", "reward"), "image": "pixels", "observation": "pixels", + "discount": "discount", "action": "action", } ) From a2edd94ebb8f12ea33b8d72fd94478fa7c2bb709 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 21 Dec 2023 12:07:55 +0000 Subject: [PATCH 6/8] amend --- torchrl/data/datasets/vd4rl.py | 5 +++- torchrl/envs/transforms/transforms.py | 36 ++++++++++++++++++--------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index 932bdc5a1d1..9f4bfa36dee 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -208,7 +208,10 @@ def __init__( or not any(isinstance(t, ToTensorImage) for t in transform) ): transform = Compose( - transform, ToTensorImage(in_keys=["pixels", ("next", "pixels")]) + transform, + ToTensorImage( + in_keys=["pixels", ("next", "pixels")], shape_tolerant=True + ), ) if image_size is not None: transform = Compose( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 5b6e8c377f8..09fed393685 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1151,6 +1151,13 @@ class ToTensorImage(ObservationTransform): dtype (torch.dtype, optional): dtype to use for the resulting observations. + Keyword arguments: + in_keys (list of NestedKeys): keys to process. + out_keys (list of NestedKeys): keys to write. + shape_tolerant (bool, optional): if ``True``, the shape of the input + images will be check. If the last channel is not `3`, the permuation + will be ignored. Defaults to ``False``. + Examples: >>> transform = ToTensorImage(in_keys=["pixels"]) >>> ri = torch.randint(0, 255, (1 , 1, 10, 11, 3), dtype=torch.uint8) @@ -1168,8 +1175,10 @@ def __init__( from_int: Optional[bool] = None, unsqueeze: bool = False, dtype: Optional[torch.device] = None, + *, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, + shape_tolerant: bool = False, ): if in_keys is None: in_keys = IMAGE_KEYS # default @@ -1179,6 +1188,7 @@ def __init__( self.from_int = from_int self.unsqueeze = unsqueeze self.dtype = dtype if dtype is not None else torch.get_default_dtype() + self.shape_tolerant = shape_tolerant def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase @@ -1188,9 +1198,10 @@ def _reset( return tensordict_reset def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor: - observation = observation.permute( - *list(range(observation.ndimension() - 3)), -1, -3, -2 - ) + if not self.shape_tolerant or observation.shape[-1] == 3: + observation = observation.permute( + *list(range(observation.ndimension() - 3)), -1, -3, -2 + ) if self.from_int or ( self.from_int is None and not torch.is_floating_point(observation) ): @@ -1204,15 +1215,16 @@ def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor: def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: observation_spec = self._pixel_observation(observation_spec) unsqueeze_dim = [1] if self._should_unsqueeze(observation_spec) else [] - observation_spec.shape = torch.Size( - [ - *unsqueeze_dim, - *observation_spec.shape[:-3], - observation_spec.shape[-1], - observation_spec.shape[-3], - observation_spec.shape[-2], - ] - ) + if not self.shape_tolerant or observation_spec.shape[-1] == 3: + observation_spec.shape = torch.Size( + [ + *unsqueeze_dim, + *observation_spec.shape[:-3], + observation_spec.shape[-1], + observation_spec.shape[-3], + observation_spec.shape[-2], + ] + ) observation_spec.dtype = self.dtype return observation_spec From 9fd57aeadd0d63f52de8eb733e78b452a310eb43 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 21 Dec 2023 12:34:44 +0000 Subject: [PATCH 7/8] amend --- torchrl/data/datasets/vd4rl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index 9f4bfa36dee..815a00ca687 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -202,7 +202,9 @@ def __init__( else: storage = self._load() if totensor and transform is None: - transform = ToTensorImage(in_keys=["pixels", ("next", "pixels")]) + transform = ToTensorImage( + in_keys=["pixels", ("next", "pixels")], shape_tolerant=True + ) elif totensor and ( not isinstance(transform, Compose) or not any(isinstance(t, ToTensorImage) for t in transform) From 63ea34219841d9b5a691adad623fc392a7497c5b Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 21 Dec 2023 13:14:07 +0000 Subject: [PATCH 8/8] amend --- docs/source/reference/data.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index ea9f448b02a..9bd780acb48 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -280,6 +280,7 @@ Here's an example: MinariExperienceReplay OpenMLExperienceReplay RobosetExperienceReplay + VD4RLExperienceReplay TensorSpec ----------