From dd9364434421e1e602c3befecf15e0a684f8e138 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 26 Sep 2024 14:53:41 -0700 Subject: [PATCH] Add env wrapper for Unity MLAgents --- .../scripts_unity_mlagents/environment.yml | 21 + .../scripts_unity_mlagents/install.sh | 60 ++ .../scripts_unity_mlagents/post_process.sh | 6 + .../run-clang-format.py | 356 ++++++++ .../scripts_unity_mlagents/run_test.sh | 28 + .../scripts_unity_mlagents/setup_env.sh | 49 + .github/workflows/test-linux-libs.yml | 38 + docs/source/reference/envs.rst | 2 + pytest.ini | 2 + test/conftest.py | 12 + test/test_libs.py | 130 +++ torchrl/envs/__init__.py | 2 + torchrl/envs/libs/__init__.py | 1 + torchrl/envs/libs/unity_mlagents.py | 862 ++++++++++++++++++ 14 files changed, 1569 insertions(+) create mode 100644 .github/unittest/linux_libs/scripts_unity_mlagents/environment.yml create mode 100755 .github/unittest/linux_libs/scripts_unity_mlagents/install.sh create mode 100755 .github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh create mode 100755 .github/unittest/linux_libs/scripts_unity_mlagents/run-clang-format.py create mode 100755 .github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh create mode 100755 .github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh create mode 100644 torchrl/envs/libs/unity_mlagents.py diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/environment.yml b/.github/unittest/linux_libs/scripts_unity_mlagents/environment.yml new file mode 100644 index 00000000000..6dc82afbc25 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/environment.yml @@ -0,0 +1,21 @@ +channels: + - pytorch + - defaults +dependencies: + - python==3.10.12 + - pip + - pip: + - mlagents_envs==1.0.0 + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - scipy + - hydra-core diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/install.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/install.sh new file mode 100755 index 00000000000..95a4a5a0e29 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/install.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi +else + printf "Failed to install pytorch" + exit 1 +fi + +# install tensordict +if [[ "$RELEASE" == 0 ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/run-clang-format.py b/.github/unittest/linux_libs/scripts_unity_mlagents/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh new file mode 100755 index 00000000000..d5bb8695c44 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +apt-get update && apt-get install -y git wget + +export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +conda deactivate && conda activate ./env + +# this workflow only tests the libs +python -c "import mlagents_envs" + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestUnityMLAgents --runslow + +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh new file mode 100755 index 00000000000..e7b08ab02ff --- /dev/null +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e +set -v + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" + +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 5d185fa9df6..bd394f39fa7 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -339,6 +339,44 @@ jobs: bash .github/unittest/linux_libs/scripts_open_spiel/run_test.sh bash .github/unittest/linux_libs/scripts_open_spiel/post_process.sh + unittests-unity_mlagents: + strategy: + matrix: + python_version: ["3.10.12"] + cuda_arch_version: ["12.1"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "11.7" + docker-image: "pytorch/manylinux-cuda124" + timeout: 120 + script: | + if [[ "${{ github.ref }}" =~ release/* ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + + set -euo pipefail + export PYTHON_VERSION="3.10.12" + export CU_VERSION="12.1" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + export BATCHED_PIPE_TIMEOUT=60 + + nvidia-smi + + bash .github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh + bash .github/unittest/linux_libs/scripts_unity_mlagents/install.sh + bash .github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh + bash .github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh + unittests-minari: strategy: matrix: diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 3578cbfd79f..960daf0fb12 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -1120,6 +1120,8 @@ the following function will return ``1`` when queried: RoboHiveEnv SMACv2Env SMACv2Wrapper + UnityMLAgentsEnv + UnityMLAgentsWrapper VmasEnv VmasWrapper gym_backend diff --git a/pytest.ini b/pytest.ini index 36d047d3055..39fe36617a1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,6 +4,8 @@ addopts = -ra # Make tracebacks shorter --tb=native +markers = + unity_editor testpaths = test xfail_strict = True diff --git a/test/conftest.py b/test/conftest.py index ca418d7b6f2..f2648a18041 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -113,6 +113,18 @@ def pytest_addoption(parser): help="Use 'fork' start method for mp dedicated tests only if there is no cuda device available.", ) + parser.addoption( + "--unity_editor", + action="store_true", + default=False, + help="Run tests that require manually pressing play in the Unity editor.", + ) + + +def pytest_runtest_setup(item): + if "unity_editor" in item.keywords and not item.config.getoption("--unity_editor"): + pytest.skip("need --unity_editor option to run this test") + def pytest_configure(config): config.addinivalue_line("markers", "slow: mark test as slow to run") diff --git a/test/test_libs.py b/test/test_libs.py index 363d111db46..a165c6916fb 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -5,6 +5,7 @@ import functools import gc import importlib.util +import urllib.error _has_isaac = importlib.util.find_spec("isaacgym") is not None @@ -18,10 +19,12 @@ import os import time +import urllib from contextlib import nullcontext from pathlib import Path from sys import platform from typing import Optional, Union +from unittest import mock import numpy as np import pytest @@ -36,6 +39,7 @@ PENDULUM_VERSIONED, PONG_VERSIONED, rand_reset, + retry, rollout_consistency_assertion, ) from packaging import version @@ -111,6 +115,11 @@ from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env +from torchrl.envs.libs.unity_mlagents import ( + _has_unity_mlagents, + UnityMLAgentsEnv, + UnityMLAgentsWrapper, +) from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper from torchrl.envs.transforms import ActionMask, TransformedEnv @@ -3930,6 +3939,127 @@ def test_chance_not_implemented(self): OpenSpielEnv("bridge") +# NOTE: Each of the registered envs are around 180 MB, so only test a few. +_mlagents_registered_envs = [ + "3DBall", + "StrikersVsGoalie", +] + + +@pytest.mark.skipif(not _has_unity_mlagents, reason="mlagents_envs not found") +class TestUnityMLAgents: + @mock.patch("mlagents_envs.env_utils.launch_executable") + @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") + def test_env(self, mock_communicator, mock_launcher): + from mlagents_envs.mock_communicator import MockCommunicator + + mock_communicator.return_value = MockCommunicator( + discrete_action=False, visual_inputs=0 + ) + env = UnityMLAgentsEnv(" ") + try: + check_env_specs(env) + finally: + env.close() + + @mock.patch("mlagents_envs.env_utils.launch_executable") + @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") + def test_wrapper(self, mock_communicator, mock_launcher): + from mlagents_envs.environment import UnityEnvironment + from mlagents_envs.mock_communicator import MockCommunicator + + mock_communicator.return_value = MockCommunicator( + discrete_action=False, visual_inputs=0 + ) + env = UnityMLAgentsWrapper(UnityEnvironment(" ")) + try: + check_env_specs(env) + finally: + env.close() + + @mock.patch("mlagents_envs.env_utils.launch_executable") + @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") + def test_rollout(self, mock_communicator, mock_launcher): + from mlagents_envs.environment import UnityEnvironment + from mlagents_envs.mock_communicator import MockCommunicator + + mock_communicator.return_value = MockCommunicator( + discrete_action=False, visual_inputs=0 + ) + env = UnityMLAgentsWrapper(UnityEnvironment(" ")) + try: + env.rollout( + max_steps=500, break_when_any_done=False, break_when_all_done=False + ) + finally: + env.close() + + @pytest.mark.unity_editor + def test_with_editor(self): + print("Please press play in the Unity editor") # noqa: T201 + env = UnityMLAgentsEnv(timeout_wait=30) + try: + env.reset() + check_env_specs(env) + + # Perform a rollout + td = env.reset() + env.rollout( + max_steps=100, break_when_any_done=False, break_when_all_done=False + ) + + # Step manually + tensordicts = [] + td = env.reset() + tensordicts.append(td) + traj_len = 200 + for _ in range(traj_len - 1): + td = env.step(td.update(env.full_action_spec.rand())) + tensordicts.append(td) + + traj = torch.stack(tensordicts) + assert traj.batch_size == torch.Size([traj_len]) + finally: + env.close() + + @retry( + ( + urllib.error.HTTPError, + urllib.error.URLError, + urllib.error.ContentTooShortError, + ), + 5, + ) + @pytest.mark.parametrize("registered_name", _mlagents_registered_envs) + def test_registered_envs(self, registered_name): + env = UnityMLAgentsEnv( + registered_name=registered_name, + no_graphics=True, + ) + try: + check_env_specs(env) + + # Perform a rollout + td = env.reset() + env.rollout( + max_steps=20, break_when_any_done=False, break_when_all_done=False + ) + + # Step manually + tensordicts = [] + td = env.reset() + tensordicts.append(td) + traj_len = 20 + for _ in range(traj_len - 1): + td = env.step(td.update(env.full_action_spec.rand())) + tensordicts.append(td) + + traj = torch.stack(tensordicts) + assert traj.batch_size == torch.Size([traj_len]) + finally: + env.close() + + @pytest.mark.skipif(not _has_meltingpot, reason="Meltingpot not found") class TestMeltingpot: @pytest.mark.parametrize("substrate", MeltingpotWrapper.available_envs) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index d0d92251b69..047550fa9d7 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -36,6 +36,8 @@ set_gym_backend, SMACv2Env, SMACv2Wrapper, + UnityMLAgentsEnv, + UnityMLAgentsWrapper, VmasEnv, VmasWrapper, ) diff --git a/torchrl/envs/libs/__init__.py b/torchrl/envs/libs/__init__.py index 98b416799fa..7ea113ce46d 100644 --- a/torchrl/envs/libs/__init__.py +++ b/torchrl/envs/libs/__init__.py @@ -23,4 +23,5 @@ from .pettingzoo import PettingZooEnv, PettingZooWrapper from .robohive import RoboHiveEnv from .smacv2 import SMACv2Env, SMACv2Wrapper +from .unity_mlagents import UnityMLAgentsEnv, UnityMLAgentsWrapper from .vmas import VmasEnv, VmasWrapper diff --git a/torchrl/envs/libs/unity_mlagents.py b/torchrl/envs/libs/unity_mlagents.py new file mode 100644 index 00000000000..6ed019c2332 --- /dev/null +++ b/torchrl/envs/libs/unity_mlagents.py @@ -0,0 +1,862 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import importlib.util +from typing import Dict, Optional + +import torch +from tensordict import TensorDict, TensorDictBase + +from torchrl.data.tensor_specs import ( + BoundedContinuous, + Categorical, + Composite, + MultiCategorical, + MultiOneHot, + Unbounded, +) +from torchrl.envs.common import _EnvWrapper +from torchrl.envs.utils import _classproperty, check_marl_grouping + +_has_unity_mlagents = importlib.util.find_spec("mlagents_envs") is not None + + +def _get_registered_envs(): + if not _has_unity_mlagents: + raise ImportError( + "mlagents_envs not found. Consider downloading and installing " + f"mlagents from {UnityMLAgentsWrapper.git_url}." + ) + + from mlagents_envs.registry import default_registry + + return list(default_registry.keys()) + + +class UnityMLAgentsWrapper(_EnvWrapper): + """Unity ML-Agents environment wrapper. + + GitHub: https://github.com/Unity-Technologies/ml-agents + + Documentation: https://unity-technologies.github.io/ml-agents/Python-LLAPI/ + + Args: + env (mlagents_envs.environment.UnityEnvironment): the ML-Agents + environment to wrap. + + Keyword Args: + device (torch.device, optional): if provided, the device on which the data + is to be cast. Defaults to ``None``. + batch_size (torch.Size, optional): the batch size of the environment. + Defaults to ``torch.Size([])``. + allow_done_after_reset (bool, optional): if ``True``, it is tolerated + for envs to be ``done`` just after :meth:`~.reset` is called. + Defaults to ``False``. + categorical_actions (bool, optional): if ``True``, categorical specs + will be converted to the TorchRL equivalent + (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding + will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. + + Attributes: + available_envs: list of registered environments available to build + + Examples: + >>> from mlagents_envs.environment import UnityEnvironment + >>> base_env = UnityEnvironment() + >>> from torchrl.envs import UnityMLAgentsWrapper + >>> env = UnityMLAgentsWrapper(base_env) + >>> td = env.reset() + >>> td = env.step(td.update(env.full_action_spec.rand())) + """ + + git_url = "https://github.com/Unity-Technologies/ml-agents" + libname = "mlagents_envs" + _lib = None + + @_classproperty + def lib(cls): + if cls._lib is not None: + return cls._lib + + import mlagents_envs + import mlagents_envs.environment + + cls._lib = mlagents_envs + return mlagents_envs + + def __init__( + self, + env=None, + *, + categorical_actions: bool = False, + **kwargs, + ): + if env is not None: + kwargs["env"] = env + + self.categorical_actions = categorical_actions + super().__init__(**kwargs) + + def _check_kwargs(self, kwargs: Dict): + mlagents_envs = self.lib + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + if not isinstance(env, mlagents_envs.environment.UnityEnvironment): + raise TypeError( + "env is not of type 'mlagents_envs.environment.UnityEnvironment'" + ) + + def _build_env(self, env, requires_grad: bool = False, **kwargs): + self.requires_grad = requires_grad + return env + + def _init_env(self): + self._update_action_mask() + + # Creates a group map where agents are grouped by their group_id. + def _make_group_map(self, env): + group_map = {} + agent_names = [] + agent_name_to_behavior_map = {} + agent_name_to_id_map = {} + + for steps_idx in [0, 1]: + for behavior in env.behavior_specs.keys(): + steps = env.get_steps(behavior)[steps_idx] + is_terminal = steps_idx == 1 + agent_ids = steps.agent_id + group_ids = steps.group_id + + for agent_id, group_id in zip(agent_ids, group_ids): + agent_name = f"agent_{agent_id}" + group_name = f"group_{group_id}" + if group_name not in group_map.keys(): + group_map[group_name] = [] + if agent_name in group_map[group_name]: + # Sometimes in an MLAgents environment, an agent may + # show up in both the decision steps and the terminal + # steps. When that happens, just skip the duplicate. + assert is_terminal + continue + group_map[group_name].append(agent_name) + agent_names.append(agent_name) + agent_name_to_behavior_map[agent_name] = behavior + agent_name_to_id_map[agent_name] = agent_id + + check_marl_grouping(group_map, agent_names) + return group_map, agent_name_to_behavior_map, agent_name_to_id_map + + def _make_specs( + self, env: "mlagents_envs.environment.UnityEnvironment" # noqa: F821 + ) -> None: + # NOTE: We need to reset here because mlagents only initializes the + # agents and behaviors after reset. In order to build specs, we make the + # following assumptions about the mlagents environment: + # * all behaviors are defined on the first step + # * all agents request an action on the first step + # However, mlagents allows you to break these assumptions, so we probably + # will need to detect changes to the behaviors and agents on each step. + env.reset() + ( + self.group_map, + self.agent_name_to_behavior_map, + self.agent_name_to_id_map, + ) = self._make_group_map(env) + + action_spec = {} + observation_spec = {} + reward_spec = {} + done_spec = {} + + for group_name, agents in self.group_map.items(): + group_action_spec = {} + group_observation_spec = {} + group_reward_spec = {} + group_done_spec = {} + for agent_name in agents: + behavior = self.agent_name_to_behavior_map[agent_name] + behavior_spec = env.behavior_specs[behavior] + + # Create action spec + agent_action_spec = Composite() + env_action_spec = behavior_spec.action_spec + discrete_branches = env_action_spec.discrete_branches + continuous_size = env_action_spec.continuous_size + if len(discrete_branches) > 0: + discrete_action_spec_cls = ( + MultiCategorical if self.categorical_actions else MultiOneHot + ) + agent_action_spec["discrete_action"] = discrete_action_spec_cls( + discrete_branches, + dtype=torch.int32, + device=self.device, + ) + if continuous_size > 0: + # In mlagents, continuous actions can take values between -1 + # and 1 by default: + # https://github.com/Unity-Technologies/ml-agents/blob/22a59aad34ef46a5de05469735426feed758f8f5/ml-agents-envs/mlagents_envs/base_env.py#L395 + agent_action_spec["continuous_action"] = BoundedContinuous( + -1, 1, (continuous_size,), self.device, torch.float32 + ) + group_action_spec[agent_name] = agent_action_spec + + # Create observation spec + agent_observation_spec = Composite() + for obs_idx, env_observation_spec in enumerate( + behavior_spec.observation_specs + ): + if len(env_observation_spec.name) == 0: + obs_name = f"observation_{obs_idx}" + else: + obs_name = env_observation_spec.name + agent_observation_spec[obs_name] = Unbounded( + env_observation_spec.shape, + dtype=torch.float32, + device=self.device, + ) + group_observation_spec[agent_name] = agent_observation_spec + + # Create reward spec + agent_reward_spec = Composite() + agent_reward_spec["reward"] = Unbounded( + (1,), + dtype=torch.float32, + device=self.device, + ) + agent_reward_spec["group_reward"] = Unbounded( + (1,), + dtype=torch.float32, + device=self.device, + ) + group_reward_spec[agent_name] = agent_reward_spec + + # Create done spec + agent_done_spec = Composite() + for done_key in ["done", "terminated", "truncated"]: + agent_done_spec[done_key] = Categorical( + 2, (1,), dtype=torch.bool, device=self.device + ) + group_done_spec[agent_name] = agent_done_spec + + action_spec[group_name] = group_action_spec + observation_spec[group_name] = group_observation_spec + reward_spec[group_name] = group_reward_spec + done_spec[group_name] = group_done_spec + + self.action_spec = Composite(action_spec) + self.observation_spec = Composite(observation_spec) + self.reward_spec = Composite(reward_spec) + self.done_spec = Composite(done_spec) + + def _set_seed(self, seed): + if seed is not None: + raise NotImplementedError("This environment has no seed.") + + def _check_agent_exists(self, agent_name, group_name): + if ( + group_name not in self.full_action_spec.keys() + or agent_name not in self.full_action_spec[group_name].keys() + ): + raise RuntimeError( + ( + "Unity environment added a new agent. This is not yet " + "supported in torchrl." + ) + ) + + def _update_action_mask(self): + for behavior, behavior_spec in self._env.behavior_specs.items(): + env_action_spec = behavior_spec.action_spec + discrete_branches = env_action_spec.discrete_branches + + if len(discrete_branches) > 0: + steps = self._env.get_steps(behavior)[0] + env_action_mask = steps.action_mask + if env_action_mask is not None: + combined_action_mask = torch.cat( + [ + torch.tensor(m, device=self.device, dtype=torch.bool) + for m in env_action_mask + ], + dim=-1, + ).logical_not() + + for agent_id, group_id, agent_action_mask in zip( + steps.agent_id, steps.group_id, combined_action_mask + ): + agent_name = f"agent_{agent_id}" + group_name = f"group_{group_id}" + self._check_agent_exists(agent_name, group_name) + self.full_action_spec[ + group_name, agent_name, "discrete_action" + ].update_mask(agent_action_mask) + + def _make_td_out(self, tensordict_in, is_reset=False): + source = {} + for behavior, behavior_spec in self._env.behavior_specs.items(): + for idx, steps in enumerate(self._env.get_steps(behavior)): + is_terminal = idx == 1 + for steps_idx, (agent_id, group_id) in enumerate( + zip(steps.agent_id, steps.group_id) + ): + agent_name = f"agent_{agent_id}" + group_name = f"group_{group_id}" + self._check_agent_exists(agent_name, group_name) + if group_name not in source: + source[group_name] = {} + if agent_name not in source[group_name]: + source[group_name][agent_name] = {} + + # Add observations + for obs_idx, ( + behavior_observation, + env_observation_spec, + ) in enumerate(zip(steps.obs, behavior_spec.observation_specs)): + observation = torch.tensor( + behavior_observation[steps_idx], + device=self.device, + dtype=torch.float32, + ) + if len(env_observation_spec.name) == 0: + obs_name = f"observation_{obs_idx}" + else: + obs_name = env_observation_spec.name + source[group_name][agent_name][obs_name] = observation + + # Add rewards + if not is_reset: + source[group_name][agent_name]["reward"] = torch.tensor( + steps.reward[steps_idx], + device=self.device, + dtype=torch.float32, + ) + source[group_name][agent_name]["group_reward"] = torch.tensor( + steps.group_reward[steps_idx], + device=self.device, + dtype=torch.float32, + ) + + # Add done + done = is_terminal and not is_reset + source[group_name][agent_name]["done"] = torch.tensor( + done, device=self.device, dtype=torch.bool + ) + source[group_name][agent_name]["truncated"] = torch.tensor( + done and steps.interrupted[steps_idx], + device=self.device, + dtype=torch.bool, + ) + source[group_name][agent_name]["terminated"] = torch.tensor( + done and not steps.interrupted[steps_idx], + device=self.device, + dtype=torch.bool, + ) + + if tensordict_in is not None: + # In MLAgents, a given step will only contain information for agents + # which either terminated or requested a decision during the step. + # Some agents may have neither terminated nor requested a decision, + # so we need to fill in their information from the previous step. + for group_name, agents in self.group_map.items(): + for agent_name in agents: + if group_name not in source.keys(): + source[group_name] = {} + if agent_name not in source[group_name].keys(): + agent_dict = {} + agent_behavior = self.agent_name_to_behavior_map[agent_name] + behavior_spec = self._env.behavior_specs[agent_behavior] + td_agent_in = tensordict_in[group_name, agent_name] + + # Add observations + for env_observation_spec in behavior_spec.observation_specs: + if len(env_observation_spec.name) == 0: + obs_name = f"observation_{obs_idx}" + else: + obs_name = env_observation_spec.name + agent_dict[obs_name] = td_agent_in[obs_name] + + # Add rewards + if not is_reset: + # Since the agent didn't request an decision, the + # reward is 0 + agent_dict["reward"] = torch.zeros( + (1,), device=self.device, dtype=torch.float32 + ) + agent_dict["group_reward"] = torch.zeros( + (1,), device=self.device, dtype=torch.float32 + ) + + # Add done + agent_dict["done"] = torch.tensor( + False, device=self.device, dtype=torch.bool + ) + agent_dict["terminated"] = torch.tensor( + False, device=self.device, dtype=torch.bool + ) + agent_dict["truncated"] = torch.tensor( + False, device=self.device, dtype=torch.bool + ) + + source[group_name][agent_name] = agent_dict + + tensordict_out = TensorDict( + source=source, + batch_size=self.batch_size, + device=self.device, + ) + + return tensordict_out + + def _get_action_from_tensor(self, tensor): + if not self.categorical_actions: + action = torch.argmax(tensor, dim=-1) + else: + action = tensor + return action + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + # Apply actions + for behavior, behavior_spec in self._env.behavior_specs.items(): + env_action_spec = behavior_spec.action_spec + steps = self._env.get_steps(behavior)[0] + + for agent_id, group_id in zip(steps.agent_id, steps.group_id): + agent_name = f"agent_{agent_id}" + group_name = f"group_{group_id}" + + self._check_agent_exists(agent_name, group_name) + + agent_action_spec = self.full_action_spec[group_name, agent_name] + action_tuple = self.lib.base_env.ActionTuple() + agent_id = self.agent_name_to_id_map[agent_name] + discrete_branches = env_action_spec.discrete_branches + continuous_size = env_action_spec.continuous_size + + if len(discrete_branches) > 0: + discrete_spec = agent_action_spec["discrete_action"] + discrete_action = tensordict[ + group_name, agent_name, "discrete_action" + ] + if not self.categorical_actions: + discrete_action = discrete_spec.to_categorical(discrete_action) + action_tuple.add_discrete(discrete_action[None, ...].numpy()) + + if continuous_size > 0: + continuous_action = tensordict[ + group_name, agent_name, "continuous_action" + ] + action_tuple.add_continuous(continuous_action[None, ...].numpy()) + + self._env.set_action_for_agent(behavior, agent_id, action_tuple) + + self._env.step() + self._update_action_mask() + return self._make_td_out(tensordict) + + def _to_tensor(self, value): + return torch.tensor(value, device=self.device, dtype=torch.float32) + + def _reset( + self, tensordict: TensorDictBase | None = None, **kwargs + ) -> TensorDictBase: + self._env.reset() + return self._make_td_out(tensordict, is_reset=True) + + def close(self): + self._env.close() + + @_classproperty + def available_envs(cls): + if not _has_unity_mlagents: + return [] + return _get_registered_envs() + + +class UnityMLAgentsEnv(UnityMLAgentsWrapper): + """Unity ML-Agents environment wrapper. + + GitHub: https://github.com/Unity-Technologies/ml-agents + + Documentation: https://unity-technologies.github.io/ml-agents/Python-LLAPI/ + + This class can be provided any of the optional initialization arguments that + :class:`mlagents_envs.environment.UnityEnvironment` class provides. For a + list of these arguments, see: + https://unity-technologies.github.io/ml-agents/Python-LLAPI-Documentation/#__init__ + + If both ``file_name`` and ``registered_name`` are given, an error is raised. + + If neither ``file_name`` nor``registered_name`` are given, the environment + setup waits on a localhost port, and the user must execute a Unity ML-Agents + environment binary for to connect to it. + + Args: + file_name (str, optional): if provided, the path to the Unity + environment binary. Defaults to ``None``. + registered_name (str, optional): if provided, the Unity environment + binary is loaded from the default ML-Agents registry. The list of + registered environments is in :attr:`~.available_envs`. Defaults to + ``None``. + + Keyword Args: + device (torch.device, optional): if provided, the device on which the data + is to be cast. Defaults to ``None``. + batch_size (torch.Size, optional): the batch size of the environment. + Defaults to ``torch.Size([])``. + allow_done_after_reset (bool, optional): if ``True``, it is tolerated + for envs to be ``done`` just after :meth:`~.reset` is called. + Defaults to ``False``. + categorical_actions (bool, optional): if ``True``, categorical specs + will be converted to the TorchRL equivalent + (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding + will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. + + Attributes: + available_envs: list of registered environments available to build + + Examples: + >>> from torchrl.envs import UnityMLAgentsEnv + >>> env = UnityMLAgentsEnv(registered_name='3DBall') + >>> td = env.reset() + >>> td = env.step(td.update(env.full_action_spec.rand())) + >>> td + TensorDict( + fields={ + group_0: TensorDict( + fields={ + agent_0: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_10: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_11: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_1: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_2: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_3: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_4: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_5: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_6: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_7: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_8: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_9: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + group_0: TensorDict( + fields={ + agent_0: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_10: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_11: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_1: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_2: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_3: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_4: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_5: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_6: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_7: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_8: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_9: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + """ + + def __init__( + self, + file_name: Optional[str] = None, + registered_name: Optional[str] = None, + *, + categorical_actions=False, + **kwargs, + ): + kwargs["file_name"] = file_name + kwargs["registered_name"] = registered_name + super().__init__( + categorical_actions=categorical_actions, + **kwargs, + ) + + def _build_env( + self, + file_name: Optional[str], + registered_name: Optional[str], + **kwargs, + ) -> "mlagents_envs.environment.UnityEnvironment": # noqa: F821 + if not _has_unity_mlagents: + raise ImportError( + "mlagents_envs not found, unable to create environment. " + "Consider downloading and installing mlagents from " + f"{self.git_url}" + ) + if file_name is not None and registered_name is not None: + raise ValueError( + "Both `file_name` and `registered_name` were specified, which " + "is not allowed. Specify one of them or neither." + ) + elif registered_name is not None: + from mlagents_envs.registry import default_registry + + env = default_registry[registered_name].make(**kwargs) + else: + env = self.lib.environment.UnityEnvironment(file_name, **kwargs) + requires_grad = kwargs.pop("requires_grad", False) + return super()._build_env( + env, + requires_grad=requires_grad, + ) + + @property + def file_name(self): + return self._constructor_kwargs["file_name"] + + @property + def registered_name(self): + return self._constructor_kwargs["registered_name"] + + def _check_kwargs(self, kwargs: Dict): + pass + + def __repr__(self) -> str: + if self.registered_name is not None: + env_name = self.registered_name + else: + env_name = self.file_name + return f"{self.__class__.__name__}(env={env_name}, batch_size={self.batch_size}, device={self.device})"